cubecl_reduce/
precision.rs

1use cubecl_core::{flex32, prelude::Numeric};
2
3/// Precision used for the reduction.
4pub trait ReducePrecision {
5    /// Precision used for the input tensor.
6    type EI: Numeric;
7    /// Precision used for the accumulation.
8    type EA: Numeric;
9}
10
11impl<EI: Numeric, EA: Numeric> ReducePrecision for (EI, EA) {
12    type EI = EI;
13    type EA = EA;
14}
15
16// The below implementations are suggestion for reduction that can accumulate precision errors like
17// summations.
18
19impl ReducePrecision for f64 {
20    type EI = f64;
21    type EA = f64;
22}
23
24impl ReducePrecision for f32 {
25    type EI = f32;
26    type EA = f32;
27}
28
29impl ReducePrecision for flex32 {
30    type EI = f32;
31    type EA = f32;
32}
33
34impl ReducePrecision for half::f16 {
35    type EI = half::f16;
36    type EA = f32;
37}
38
39impl ReducePrecision for half::bf16 {
40    type EI = half::bf16;
41    type EA = f32;
42}
43
44impl ReducePrecision for i64 {
45    type EI = i64;
46    type EA = i64;
47}
48
49impl ReducePrecision for i32 {
50    type EI = i32;
51    type EA = i32;
52}
53
54impl ReducePrecision for i16 {
55    type EI = i16;
56    type EA = i32;
57}
58
59impl ReducePrecision for i8 {
60    type EI = i8;
61    type EA = i32;
62}
63
64impl ReducePrecision for u64 {
65    type EI = u64;
66    type EA = u64;
67}
68
69impl ReducePrecision for u32 {
70    type EI = u32;
71    type EA = u32;
72}
73
74impl ReducePrecision for u16 {
75    type EI = u16;
76    type EA = u32;
77}
78
79impl ReducePrecision for u8 {
80    type EI = u8;
81    type EA = u32;
82}