Skip to main content

cubek_reduce/components/
precision.rs

1use cubecl::{
2    define_size, flex32,
3    prelude::{Numeric, Size},
4};
5
6define_size!(pub S);
7
8/// Precision used for the reduction.
9pub trait ReducePrecision: 'static {
10    /// Precision used for the input tensor.
11    type EI: Numeric;
12    /// Size used for the input tensor.
13    type SI: Size;
14    /// Precision used for the accumulation.
15    type EA: Numeric;
16}
17
18impl<EI: Numeric, SI: Size, EA: Numeric> ReducePrecision for (EI, SI, EA) {
19    type EI = EI;
20    type SI = SI;
21    type EA = EA;
22}
23
24// The below implementations are suggestion for reduction that can accumulate precision errors like
25// summations.
26
27impl ReducePrecision for f64 {
28    type EI = f64;
29    type EA = f64;
30    type SI = S;
31}
32
33impl ReducePrecision for f32 {
34    type EI = f32;
35    type EA = f32;
36    type SI = S;
37}
38
39impl ReducePrecision for flex32 {
40    type EI = f32;
41    type EA = f32;
42    type SI = S;
43}
44
45impl ReducePrecision for half::f16 {
46    type EI = half::f16;
47    type EA = f32;
48    type SI = S;
49}
50
51impl ReducePrecision for half::bf16 {
52    type EI = half::bf16;
53    type EA = f32;
54    type SI = S;
55}
56
57impl ReducePrecision for i64 {
58    type EI = i64;
59    type EA = i64;
60    type SI = S;
61}
62
63impl ReducePrecision for i32 {
64    type EI = i32;
65    type EA = i32;
66    type SI = S;
67}
68
69impl ReducePrecision for i16 {
70    type EI = i16;
71    type EA = i32;
72    type SI = S;
73}
74
75impl ReducePrecision for i8 {
76    type EI = i8;
77    type EA = i32;
78    type SI = S;
79}
80
81impl ReducePrecision for u64 {
82    type EI = u64;
83    type EA = u64;
84    type SI = S;
85}
86
87impl ReducePrecision for u32 {
88    type EI = u32;
89    type EA = u32;
90    type SI = S;
91}
92
93impl ReducePrecision for u16 {
94    type EI = u16;
95    type EA = u32;
96    type SI = S;
97}
98
99impl ReducePrecision for u8 {
100    type EI = u8;
101    type EA = u32;
102    type SI = S;
103}