burn_cubecl/
element.rs

1use cubecl::{
2    CubeElement as CubeElem, flex32,
3    matmul::components::{MatmulPrecision, MatrixPrecision},
4    prelude::{Float, Int, Numeric},
5    reduce::ReducePrecision,
6};
7
8/// The base element trait for the jit backend.
9pub trait CubeElement: burn_tensor::Element + CubeElem + PartialEq + Numeric {}
10
11/// Element that can be used for matrix multiplication. Includes ints and floats.
12pub trait MatmulElement:
13    CubeElement + MatmulPrecision<Acc: MatrixPrecision<Global: CubeElement>>
14{
15}
16
17/// The float element type for the jit backend.
18pub trait FloatElement: MatmulElement + Float {}
19
20/// The int element type for the jit backend.
21pub trait IntElement:
22    MatmulElement + Int + ReducePrecision<EI: CubeElement, EA: CubeElement>
23{
24}
25
26/// The element type for booleans for the jit backend.
27pub trait BoolElement: CubeElement + Int {
28    /// The true value for the boolean element.
29    fn true_val() -> Self {
30        Self::from_int(1)
31    }
32
33    /// The false value for the boolean element.
34    fn false_val() -> Self {
35        Self::from_int(0)
36    }
37
38    /// New bool element from Rust bool.
39    fn new_bool(val: bool) -> Self {
40        match val {
41            true => Self::true_val(),
42            false => Self::false_val(),
43        }
44    }
45}
46
47impl CubeElement for u64 {}
48impl CubeElement for u32 {}
49impl CubeElement for u16 {}
50impl CubeElement for u8 {}
51impl CubeElement for i64 {}
52impl CubeElement for i32 {}
53impl CubeElement for i16 {}
54impl CubeElement for i8 {}
55impl CubeElement for f64 {}
56impl CubeElement for f32 {}
57impl CubeElement for flex32 {}
58impl CubeElement for half::f16 {}
59impl CubeElement for half::bf16 {}
60
61impl FloatElement for f64 {}
62impl FloatElement for f32 {}
63impl FloatElement for flex32 {}
64impl FloatElement for half::bf16 {}
65impl FloatElement for half::f16 {}
66impl IntElement for i64 {}
67impl IntElement for i32 {}
68impl IntElement for i16 {}
69impl IntElement for i8 {}
70impl IntElement for u64 {}
71impl IntElement for u32 {}
72impl IntElement for u16 {}
73impl IntElement for u8 {}
74
75impl BoolElement for u8 {}
76impl BoolElement for u32 {}
77
78impl MatmulElement for f64 {}
79impl MatmulElement for f32 {}
80impl MatmulElement for flex32 {}
81impl MatmulElement for half::bf16 {}
82impl MatmulElement for half::f16 {}
83
84impl MatmulElement for i64 {}
85impl MatmulElement for i32 {}
86impl MatmulElement for i16 {}
87impl MatmulElement for i8 {}
88impl MatmulElement for u64 {}
89impl MatmulElement for u32 {}
90impl MatmulElement for u16 {}
91impl MatmulElement for u8 {}