cubecl_core/frontend/element/
float.rs

1use cubecl_ir::{ConstantScalarValue, Scope, StorageType};
2use half::{bf16, f16};
3
4use crate::{
5    ir::{ElemType, ExpandElement, FloatKind},
6    prelude::*,
7};
8
9use super::Numeric;
10
11mod fp4;
12mod fp6;
13mod fp8;
14mod relaxed;
15mod tensor_float;
16mod typemap;
17
18pub use typemap::*;
19
20/// Floating point numbers. Used as input in float kernels
21pub trait Float:
22    Numeric
23    + Exp
24    + Log
25    + Log1p
26    + Cos
27    + Sin
28    + Tanh
29    + Powf
30    + Powi<i32>
31    + Sqrt
32    + InverseSqrt
33    + Round
34    + Floor
35    + Ceil
36    + Trunc
37    + Erf
38    + Recip
39    + Magnitude
40    + Normalize
41    + Dot
42    + IsNan
43    + IsInf
44    + Into<Self::ExpandType>
45    + core::ops::Neg<Output = Self>
46    + core::ops::Add<Output = Self>
47    + core::ops::Sub<Output = Self>
48    + core::ops::Mul<Output = Self>
49    + core::ops::Div<Output = Self>
50    + std::ops::AddAssign
51    + std::ops::SubAssign
52    + std::ops::MulAssign
53    + std::ops::DivAssign
54    + std::cmp::PartialOrd
55    + std::cmp::PartialEq
56{
57    const DIGITS: u32;
58    const EPSILON: Self;
59    const INFINITY: Self;
60    const MANTISSA_DIGITS: u32;
61    const MAX_10_EXP: i32;
62    const MAX_EXP: i32;
63    const MIN_10_EXP: i32;
64    const MIN_EXP: i32;
65    const MIN_POSITIVE: Self;
66    const NAN: Self;
67    const NEG_INFINITY: Self;
68    const RADIX: u32;
69
70    fn new(val: f32) -> Self;
71    fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
72        __expand_new(scope, val)
73    }
74}
75
76macro_rules! impl_float {
77    (half $primitive:ident, $kind:ident) => {
78        impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
79    };
80    ($primitive:ident, $kind:ident) => {
81        impl_float!($primitive, $kind, |val| val as $primitive);
82    };
83    ($primitive:ident, $kind:ident, $new:expr) => {
84        impl CubeType for $primitive {
85            type ExpandType = ExpandElementTyped<$primitive>;
86        }
87
88        impl CubePrimitive for $primitive {
89            /// Return the element type to use on GPU
90            fn as_type_native() -> Option<StorageType> {
91                Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)))
92            }
93
94            fn from_const_value(value: ConstantScalarValue) -> Self {
95                let ConstantScalarValue::Float(value, _) = value else {
96                    unreachable!()
97                };
98                $new(value)
99            }
100        }
101
102        impl IntoRuntime for $primitive {
103            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
104                let elem: ExpandElementTyped<Self> = self.into();
105                into_runtime_expand_element(scope, elem).into()
106            }
107        }
108
109        impl Numeric for $primitive {
110            fn min_value() -> Self {
111                <Self as num_traits::Float>::min_value()
112            }
113            fn max_value() -> Self {
114                <Self as num_traits::Float>::max_value()
115            }
116        }
117
118        impl ExpandElementIntoMut for $primitive {
119            fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
120                into_mut_expand_element(scope, elem)
121            }
122        }
123
124        impl IntoMut for $primitive {
125            fn into_mut(self, _scope: &mut Scope) -> Self {
126                self
127            }
128        }
129
130        impl Float for $primitive {
131            const DIGITS: u32 = $primitive::DIGITS;
132            const EPSILON: Self = $primitive::EPSILON;
133            const INFINITY: Self = $primitive::INFINITY;
134            const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
135            const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
136            const MAX_EXP: i32 = $primitive::MAX_EXP;
137            const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
138            const MIN_EXP: i32 = $primitive::MIN_EXP;
139            const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
140            const NAN: Self = $primitive::NAN;
141            const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
142            const RADIX: u32 = $primitive::RADIX;
143
144            fn new(val: f32) -> Self {
145                $new(val as f64)
146            }
147        }
148    };
149}
150
151impl_float!(half f16, F16);
152impl_float!(half bf16, BF16);
153impl_float!(f32, F32);
154impl_float!(f64, F64);
155
156impl ScalarArgSettings for f16 {
157    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
158        settings.register_f16(*self);
159    }
160}
161
162impl ScalarArgSettings for bf16 {
163    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
164        settings.register_bf16(*self);
165    }
166}
167
168impl ScalarArgSettings for f32 {
169    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
170        settings.register_f32(*self);
171    }
172}
173
174impl ScalarArgSettings for f64 {
175    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
176        settings.register_f64(*self);
177    }
178}