cubecl_core/frontend/element/
float.rs

1use cubecl_ir::{ConstantScalarValue, Scope, StorageType};
2use half::{bf16, f16};
3
4use crate::{
5    ir::{ElemType, ExpandElement, FloatKind},
6    prelude::*,
7};
8
9use super::Numeric;
10
11mod fp4;
12mod fp6;
13mod fp8;
14mod relaxed;
15mod tensor_float;
16mod typemap;
17
18pub use typemap::*;
19
20/// Floating point numbers. Used as input in float kernels
21pub trait Float:
22    Numeric
23    + Exp
24    + Log
25    + Log1p
26    + Cos
27    + Sin
28    + Tan
29    + Tanh
30    + Sinh
31    + Cosh
32    + ArcCos
33    + ArcSin
34    + ArcTan
35    + ArcSinh
36    + ArcCosh
37    + ArcTanh
38    + Degrees
39    + Radians
40    + ArcTan2
41    + Powf
42    + Powi<i32>
43    + Sqrt
44    + InverseSqrt
45    + Round
46    + Floor
47    + Ceil
48    + Trunc
49    + Erf
50    + Recip
51    + Magnitude
52    + Normalize
53    + Dot
54    + IsNan
55    + IsInf
56    + Into<Self::ExpandType>
57    + core::ops::Neg<Output = Self>
58    + core::ops::Add<Output = Self>
59    + core::ops::Sub<Output = Self>
60    + core::ops::Mul<Output = Self>
61    + core::ops::Div<Output = Self>
62    + std::ops::AddAssign
63    + std::ops::SubAssign
64    + std::ops::MulAssign
65    + std::ops::DivAssign
66    + std::cmp::PartialOrd
67    + std::cmp::PartialEq
68{
69    const DIGITS: u32;
70    const EPSILON: Self;
71    const INFINITY: Self;
72    const MANTISSA_DIGITS: u32;
73    const MAX_10_EXP: i32;
74    const MAX_EXP: i32;
75    const MIN_10_EXP: i32;
76    const MIN_EXP: i32;
77    const MIN_POSITIVE: Self;
78    const NAN: Self;
79    const NEG_INFINITY: Self;
80    const RADIX: u32;
81
82    fn new(val: f32) -> Self;
83    fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
84        __expand_new(scope, val)
85    }
86}
87
88macro_rules! impl_float {
89    (half $primitive:ident, $kind:ident) => {
90        impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
91    };
92    ($primitive:ident, $kind:ident) => {
93        impl_float!($primitive, $kind, |val| val as $primitive);
94    };
95    ($primitive:ident, $kind:ident, $new:expr) => {
96        impl CubeType for $primitive {
97            type ExpandType = ExpandElementTyped<$primitive>;
98        }
99
100        impl CubePrimitive for $primitive {
101            /// Return the element type to use on GPU
102            fn as_type_native() -> Option<StorageType> {
103                Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)))
104            }
105
106            fn from_const_value(value: ConstantScalarValue) -> Self {
107                let ConstantScalarValue::Float(value, _) = value else {
108                    unreachable!()
109                };
110                $new(value)
111            }
112        }
113
114        impl IntoRuntime for $primitive {
115            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
116                let elem: ExpandElementTyped<Self> = self.into();
117                into_runtime_expand_element(scope, elem).into()
118            }
119        }
120
121        impl Numeric for $primitive {
122            fn min_value() -> Self {
123                <Self as num_traits::Float>::min_value()
124            }
125            fn max_value() -> Self {
126                <Self as num_traits::Float>::max_value()
127            }
128        }
129
130        impl ExpandElementIntoMut for $primitive {
131            fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
132                into_mut_expand_element(scope, elem)
133            }
134        }
135
136        impl IntoMut for $primitive {
137            fn into_mut(self, _scope: &mut Scope) -> Self {
138                self
139            }
140        }
141
142        impl Float for $primitive {
143            const DIGITS: u32 = $primitive::DIGITS;
144            const EPSILON: Self = $primitive::EPSILON;
145            const INFINITY: Self = $primitive::INFINITY;
146            const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
147            const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
148            const MAX_EXP: i32 = $primitive::MAX_EXP;
149            const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
150            const MIN_EXP: i32 = $primitive::MIN_EXP;
151            const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
152            const NAN: Self = $primitive::NAN;
153            const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
154            const RADIX: u32 = $primitive::RADIX;
155
156            fn new(val: f32) -> Self {
157                $new(val as f64)
158            }
159        }
160    };
161}
162
163impl_float!(half f16, F16);
164impl_float!(half bf16, BF16);
165impl_float!(f32, F32);
166impl_float!(f64, F64);
167
168impl ScalarArgSettings for f16 {
169    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
170        settings.register_f16(*self);
171    }
172}
173
174impl ScalarArgSettings for bf16 {
175    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
176        settings.register_bf16(*self);
177    }
178}
179
180impl ScalarArgSettings for f32 {
181    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
182        settings.register_f32(*self);
183    }
184}
185
186impl ScalarArgSettings for f64 {
187    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
188        settings.register_f64(*self);
189    }
190}