1use burn_backend::{Element, bf16, f16};
2use cubecl::{
3 CubeElement as CubeElem, flex32,
4 prelude::{Float, Int, Numeric},
5};
6use cubek::{
7 matmul::definition::{MatmulPrecision, MatrixPrecision},
8 reduce::ReducePrecision,
9};
10
11pub trait CubeElement: Element + CubeElem + PartialEq + Numeric {}
13
14pub trait MatmulElement:
16 CubeElement + MatmulPrecision<Acc: MatrixPrecision<Global: CubeElement>>
17{
18}
19
20pub trait FloatElement: MatmulElement + Float {}
22
23pub trait IntElement:
25 MatmulElement + Int + ReducePrecision<EI: CubeElement, EA: CubeElement>
26{
27}
28
29pub trait BoolElement: CubeElement + Int {
31 fn true_val() -> Self {
33 Self::from_int(1)
34 }
35
36 fn false_val() -> Self {
38 Self::from_int(0)
39 }
40
41 fn new_bool(val: bool) -> Self {
43 match val {
44 true => Self::true_val(),
45 false => Self::false_val(),
46 }
47 }
48}
49
50impl CubeElement for u64 {}
51impl CubeElement for u32 {}
52impl CubeElement for u16 {}
53impl CubeElement for u8 {}
54impl CubeElement for i64 {}
55impl CubeElement for i32 {}
56impl CubeElement for i16 {}
57impl CubeElement for i8 {}
58impl CubeElement for f64 {}
59impl CubeElement for f32 {}
60impl CubeElement for flex32 {}
61impl CubeElement for f16 {}
62impl CubeElement for bf16 {}
63
64impl FloatElement for f64 {}
65impl FloatElement for f32 {}
66impl FloatElement for flex32 {}
67impl FloatElement for bf16 {}
68impl FloatElement for f16 {}
69impl IntElement for i64 {}
70impl IntElement for i32 {}
71impl IntElement for i16 {}
72impl IntElement for i8 {}
73impl IntElement for u64 {}
74impl IntElement for u32 {}
75impl IntElement for u16 {}
76impl IntElement for u8 {}
77
78impl BoolElement for u8 {}
79impl BoolElement for u32 {}
80
81impl MatmulElement for f64 {}
82impl MatmulElement for f32 {}
83impl MatmulElement for flex32 {}
84impl MatmulElement for bf16 {}
85impl MatmulElement for f16 {}
86
87impl MatmulElement for i64 {}
88impl MatmulElement for i32 {}
89impl MatmulElement for i16 {}
90impl MatmulElement for i8 {}
91impl MatmulElement for u64 {}
92impl MatmulElement for u32 {}
93impl MatmulElement for u16 {}
94impl MatmulElement for u8 {}