burn_jit/
element.rs

1use cubecl::{
2    flex32,
3    prelude::{Float, Int, Numeric},
4    CubeElement,
5};
6
7/// The base element trait for the jit backend.
8pub trait JitElement: burn_tensor::Element + CubeElement + PartialEq + Numeric {}
9
10/// The float element type for the jit backend.
11pub trait FloatElement: JitElement + Float {}
12
13/// The int element type for the jit backend.
14pub trait IntElement: JitElement + Int {}
15
16/// The element type for booleans for the jit backend.
17pub trait BoolElement: JitElement + Int {
18    /// The true value for the boolean element.
19    fn true_val() -> Self {
20        Self::from_int(1)
21    }
22
23    /// The false value for the boolean element.
24    fn false_val() -> Self {
25        Self::from_int(0)
26    }
27
28    /// New bool element from Rust bool.
29    fn new_bool(val: bool) -> Self {
30        match val {
31            true => Self::true_val(),
32            false => Self::false_val(),
33        }
34    }
35}
36
37impl JitElement for u64 {}
38impl JitElement for u32 {}
39impl JitElement for u16 {}
40impl JitElement for u8 {}
41impl JitElement for i64 {}
42impl JitElement for i32 {}
43impl JitElement for i16 {}
44impl JitElement for i8 {}
45impl JitElement for f64 {}
46impl JitElement for f32 {}
47impl JitElement for flex32 {}
48impl JitElement for half::f16 {}
49impl JitElement for half::bf16 {}
50
51impl FloatElement for f64 {}
52impl FloatElement for f32 {}
53impl FloatElement for flex32 {}
54impl FloatElement for half::bf16 {}
55impl FloatElement for half::f16 {}
56impl IntElement for i64 {}
57impl IntElement for i32 {}
58impl IntElement for i16 {}
59impl IntElement for i8 {}
60
61impl BoolElement for u8 {}
62impl BoolElement for u32 {}