1use cubecl::{
2 flex32,
3 prelude::{Float, Int, Numeric},
4 CubeElement,
5};
6
7pub trait JitElement: burn_tensor::Element + CubeElement + PartialEq + Numeric {}
9
10pub trait FloatElement: JitElement + Float {}
12
13pub trait IntElement: JitElement + Int {}
15
16pub trait BoolElement: JitElement + Int {
18 fn true_val() -> Self {
20 Self::from_int(1)
21 }
22
23 fn false_val() -> Self {
25 Self::from_int(0)
26 }
27
28 fn new_bool(val: bool) -> Self {
30 match val {
31 true => Self::true_val(),
32 false => Self::false_val(),
33 }
34 }
35}
36
37impl JitElement for u64 {}
38impl JitElement for u32 {}
39impl JitElement for u16 {}
40impl JitElement for u8 {}
41impl JitElement for i64 {}
42impl JitElement for i32 {}
43impl JitElement for i16 {}
44impl JitElement for i8 {}
45impl JitElement for f64 {}
46impl JitElement for f32 {}
47impl JitElement for flex32 {}
48impl JitElement for half::f16 {}
49impl JitElement for half::bf16 {}
50
51impl FloatElement for f64 {}
52impl FloatElement for f32 {}
53impl FloatElement for flex32 {}
54impl FloatElement for half::bf16 {}
55impl FloatElement for half::f16 {}
56impl IntElement for i64 {}
57impl IntElement for i32 {}
58impl IntElement for i16 {}
59impl IntElement for i8 {}
60
61impl BoolElement for u8 {}
62impl BoolElement for u32 {}