cubecl_core/frontend/element/
float.rs

1use cubecl_ir::{ConstantValue, Scope, StorageType};
2use half::{bf16, f16};
3
4use crate::{
5    self as cubecl,
6    ir::{ElemType, ExpandElement, FloatKind},
7    prelude::*,
8};
9
10use super::Numeric;
11
12mod fp4;
13mod fp6;
14mod fp8;
15mod relaxed;
16mod tensor_float;
17mod typemap;
18
19pub use typemap::*;
20
21/// Floating point numbers. Used as input in float kernels
22pub trait Float:
23    Numeric
24    + FloatOps
25    + Exp
26    + Log
27    + Log1p
28    + Cos
29    + Sin
30    + Tan
31    + Tanh
32    + Sinh
33    + Cosh
34    + ArcCos
35    + ArcSin
36    + ArcTan
37    + ArcSinh
38    + ArcCosh
39    + ArcTanh
40    + Degrees
41    + Radians
42    + ArcTan2
43    + Powf
44    + Powi<i32>
45    + Hypot
46    + Rhypot
47    + Sqrt
48    + InverseSqrt
49    + Round
50    + Floor
51    + Ceil
52    + Trunc
53    + Erf
54    + Recip
55    + Magnitude
56    + Normalize
57    + Dot
58    + IsNan
59    + IsInf
60    + Into<Self::ExpandType>
61    + core::ops::Neg<Output = Self>
62    + std::cmp::PartialOrd
63    + std::cmp::PartialEq
64{
65    const DIGITS: u32;
66    const EPSILON: Self;
67    const INFINITY: Self;
68    const MANTISSA_DIGITS: u32;
69    const MAX_10_EXP: i32;
70    const MAX_EXP: i32;
71    const MIN_10_EXP: i32;
72    const MIN_EXP: i32;
73    const MIN_POSITIVE: Self;
74    const NAN: Self;
75    const NEG_INFINITY: Self;
76    const RADIX: u32;
77
78    fn new(val: f32) -> Self;
79    fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
80        __expand_new(scope, val)
81    }
82}
83
84#[cube]
85pub trait FloatOps: CubePrimitive + PartialOrd + Sized {
86    fn min(self, other: Self) -> Self {
87        cubecl::prelude::min(self, other)
88    }
89
90    fn max(self, other: Self) -> Self {
91        cubecl::prelude::max(self, other)
92    }
93
94    fn clamp(self, min: Self, max: Self) -> Self {
95        clamp(self, min, max)
96    }
97}
98
99impl<T: Float> FloatOps for T {}
100impl<T: FloatOps + CubePrimitive> FloatOpsExpand for ExpandElementTyped<T> {
101    fn __expand_min_method(self, scope: &mut Scope, other: Self) -> Self {
102        min::expand(scope, self, other)
103    }
104
105    fn __expand_max_method(self, scope: &mut Scope, other: Self) -> Self {
106        max::expand(scope, self, other)
107    }
108
109    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
110        clamp::expand(scope, self, min, max)
111    }
112}
113
114macro_rules! impl_float {
115    (half $primitive:ident, $kind:ident) => {
116        impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
117    };
118    ($primitive:ident, $kind:ident) => {
119        impl_float!($primitive, $kind, |val| val as $primitive);
120    };
121    ($primitive:ident, $kind:ident, $new:expr) => {
122        impl CubeType for $primitive {
123            type ExpandType = ExpandElementTyped<$primitive>;
124        }
125
126        impl CubePrimitive for $primitive {
127            /// Return the element type to use on GPU
128            fn as_type_native() -> Option<StorageType> {
129                Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)))
130            }
131
132            fn from_const_value(value: ConstantValue) -> Self {
133                let ConstantValue::Float(value) = value else {
134                    unreachable!()
135                };
136                $new(value)
137            }
138        }
139
140        impl IntoRuntime for $primitive {
141            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
142                let elem: ExpandElementTyped<Self> = self.into();
143                into_runtime_expand_element(scope, elem).into()
144            }
145        }
146
147        impl Numeric for $primitive {
148            fn min_value() -> Self {
149                <Self as num_traits::Float>::min_value()
150            }
151            fn max_value() -> Self {
152                <Self as num_traits::Float>::max_value()
153            }
154        }
155
156        impl ExpandElementIntoMut for $primitive {
157            fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
158                into_mut_expand_element(scope, elem)
159            }
160        }
161
162        impl IntoMut for $primitive {
163            fn into_mut(self, _scope: &mut Scope) -> Self {
164                self
165            }
166        }
167
168        impl Float for $primitive {
169            const DIGITS: u32 = $primitive::DIGITS;
170            const EPSILON: Self = $primitive::EPSILON;
171            const INFINITY: Self = $primitive::INFINITY;
172            const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
173            const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
174            const MAX_EXP: i32 = $primitive::MAX_EXP;
175            const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
176            const MIN_EXP: i32 = $primitive::MIN_EXP;
177            const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
178            const NAN: Self = $primitive::NAN;
179            const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
180            const RADIX: u32 = $primitive::RADIX;
181
182            fn new(val: f32) -> Self {
183                $new(val as f64)
184            }
185        }
186    };
187}
188
189impl_float!(half f16, F16);
190impl_float!(half bf16, BF16);
191impl_float!(f32, F32);
192impl_float!(f64, F64);