burn_cubecl/
element.rs

1use burn_backend::{Element, bf16, f16};
2use cubecl::{
3    CubeElement as CubeElem, flex32,
4    prelude::{Float, Int, Numeric},
5};
6use cubek::{
7    matmul::definition::{MatmulPrecision, MatrixPrecision},
8    reduce::ReducePrecision,
9};
10
11/// The base element trait for the jit backend.
12pub trait CubeElement: Element + CubeElem + PartialEq + Numeric {}
13
14/// Element that can be used for matrix multiplication. Includes ints and floats.
15pub trait MatmulElement:
16    CubeElement + MatmulPrecision<Acc: MatrixPrecision<Global: CubeElement>>
17{
18}
19
20/// The float element type for the jit backend.
21pub trait FloatElement: MatmulElement + Float {}
22
23/// The int element type for the jit backend.
24pub trait IntElement:
25    MatmulElement + Int + ReducePrecision<EI: CubeElement, EA: CubeElement>
26{
27}
28
29/// The element type for booleans for the jit backend.
30pub trait BoolElement: CubeElement + Int {
31    /// The true value for the boolean element.
32    fn true_val() -> Self {
33        Self::from_int(1)
34    }
35
36    /// The false value for the boolean element.
37    fn false_val() -> Self {
38        Self::from_int(0)
39    }
40
41    /// New bool element from Rust bool.
42    fn new_bool(val: bool) -> Self {
43        match val {
44            true => Self::true_val(),
45            false => Self::false_val(),
46        }
47    }
48}
49
50impl CubeElement for u64 {}
51impl CubeElement for u32 {}
52impl CubeElement for u16 {}
53impl CubeElement for u8 {}
54impl CubeElement for i64 {}
55impl CubeElement for i32 {}
56impl CubeElement for i16 {}
57impl CubeElement for i8 {}
58impl CubeElement for f64 {}
59impl CubeElement for f32 {}
60impl CubeElement for flex32 {}
61impl CubeElement for f16 {}
62impl CubeElement for bf16 {}
63
64impl FloatElement for f64 {}
65impl FloatElement for f32 {}
66impl FloatElement for flex32 {}
67impl FloatElement for bf16 {}
68impl FloatElement for f16 {}
69impl IntElement for i64 {}
70impl IntElement for i32 {}
71impl IntElement for i16 {}
72impl IntElement for i8 {}
73impl IntElement for u64 {}
74impl IntElement for u32 {}
75impl IntElement for u16 {}
76impl IntElement for u8 {}
77
78impl BoolElement for u8 {}
79impl BoolElement for u32 {}
80
81impl MatmulElement for f64 {}
82impl MatmulElement for f32 {}
83impl MatmulElement for flex32 {}
84impl MatmulElement for bf16 {}
85impl MatmulElement for f16 {}
86
87impl MatmulElement for i64 {}
88impl MatmulElement for i32 {}
89impl MatmulElement for i16 {}
90impl MatmulElement for i8 {}
91impl MatmulElement for u64 {}
92impl MatmulElement for u32 {}
93impl MatmulElement for u16 {}
94impl MatmulElement for u8 {}