1use cubecl::{
2 CubeElement as CubeElem, flex32,
3 linalg::matmul::components::MatmulPrecision,
4 prelude::{Float, Int, Numeric},
5};
6
7pub trait CubeElement: burn_tensor::Element + CubeElem + PartialEq + Numeric {}
9
10pub trait FloatElement:
12 CubeElement
13 + Float
14 + MatmulPrecision<EI: CubeElement, EO: CubeElement, EA: CubeElement, ES: CubeElement>
15{
16}
17
18pub trait IntElement: CubeElement + Int {}
20
21pub trait BoolElement: CubeElement + Int {
23 fn true_val() -> Self {
25 Self::from_int(1)
26 }
27
28 fn false_val() -> Self {
30 Self::from_int(0)
31 }
32
33 fn new_bool(val: bool) -> Self {
35 match val {
36 true => Self::true_val(),
37 false => Self::false_val(),
38 }
39 }
40}
41
42impl CubeElement for u64 {}
43impl CubeElement for u32 {}
44impl CubeElement for u16 {}
45impl CubeElement for u8 {}
46impl CubeElement for i64 {}
47impl CubeElement for i32 {}
48impl CubeElement for i16 {}
49impl CubeElement for i8 {}
50impl CubeElement for f64 {}
51impl CubeElement for f32 {}
52impl CubeElement for flex32 {}
53impl CubeElement for half::f16 {}
54impl CubeElement for half::bf16 {}
55
56impl FloatElement for f64 {}
57impl FloatElement for f32 {}
58impl FloatElement for flex32 {}
59impl FloatElement for half::bf16 {}
60impl FloatElement for half::f16 {}
61impl IntElement for i64 {}
62impl IntElement for i32 {}
63impl IntElement for i16 {}
64impl IntElement for i8 {}
65impl IntElement for u32 {}
66
67impl BoolElement for u8 {}
68impl BoolElement for u32 {}