1use cubecl::{
2 CubeElement as CubeElem, flex32,
3 matmul::components::{MatmulPrecision, MatrixPrecision},
4 prelude::{Float, Int, Numeric},
5 reduce::ReducePrecision,
6};
7
8pub trait CubeElement: burn_tensor::Element + CubeElem + PartialEq + Numeric {}
10
11pub trait MatmulElement:
13 CubeElement + MatmulPrecision<Acc: MatrixPrecision<Global: CubeElement>>
14{
15}
16
17pub trait FloatElement: MatmulElement + Float {}
19
20pub trait IntElement:
22 MatmulElement + Int + ReducePrecision<EI: CubeElement, EA: CubeElement>
23{
24}
25
26pub trait BoolElement: CubeElement + Int {
28 fn true_val() -> Self {
30 Self::from_int(1)
31 }
32
33 fn false_val() -> Self {
35 Self::from_int(0)
36 }
37
38 fn new_bool(val: bool) -> Self {
40 match val {
41 true => Self::true_val(),
42 false => Self::false_val(),
43 }
44 }
45}
46
47impl CubeElement for u64 {}
48impl CubeElement for u32 {}
49impl CubeElement for u16 {}
50impl CubeElement for u8 {}
51impl CubeElement for i64 {}
52impl CubeElement for i32 {}
53impl CubeElement for i16 {}
54impl CubeElement for i8 {}
55impl CubeElement for f64 {}
56impl CubeElement for f32 {}
57impl CubeElement for flex32 {}
58impl CubeElement for half::f16 {}
59impl CubeElement for half::bf16 {}
60
61impl FloatElement for f64 {}
62impl FloatElement for f32 {}
63impl FloatElement for flex32 {}
64impl FloatElement for half::bf16 {}
65impl FloatElement for half::f16 {}
66impl IntElement for i64 {}
67impl IntElement for i32 {}
68impl IntElement for i16 {}
69impl IntElement for i8 {}
70impl IntElement for u64 {}
71impl IntElement for u32 {}
72impl IntElement for u16 {}
73impl IntElement for u8 {}
74
75impl BoolElement for u8 {}
76impl BoolElement for u32 {}
77
78impl MatmulElement for f64 {}
79impl MatmulElement for f32 {}
80impl MatmulElement for flex32 {}
81impl MatmulElement for half::bf16 {}
82impl MatmulElement for half::f16 {}
83
84impl MatmulElement for i64 {}
85impl MatmulElement for i32 {}
86impl MatmulElement for i16 {}
87impl MatmulElement for i8 {}
88impl MatmulElement for u64 {}
89impl MatmulElement for u32 {}
90impl MatmulElement for u16 {}
91impl MatmulElement for u8 {}