cubecl_core/frontend/element/
float.rs

1use cubecl_ir::Scope;
2use half::{bf16, f16};
3
4use crate::{
5    ir::{Elem, ExpandElement, FloatKind},
6    prelude::*,
7};
8
9use super::Numeric;
10
11mod fp4;
12mod fp6;
13mod fp8;
14mod relaxed;
15mod tensor_float;
16mod typemap;
17
18pub use typemap::*;
19
20/// Floating point numbers. Used as input in float kernels
21pub trait Float:
22    Numeric
23    + Exp
24    + Log
25    + Log1p
26    + Cos
27    + Sin
28    + Tanh
29    + Powf
30    + Sqrt
31    + Round
32    + Floor
33    + Ceil
34    + Erf
35    + Recip
36    + Magnitude
37    + Normalize
38    + Dot
39    + Into<Self::ExpandType>
40    + core::ops::Neg<Output = Self>
41    + core::ops::Add<Output = Self>
42    + core::ops::Sub<Output = Self>
43    + core::ops::Mul<Output = Self>
44    + core::ops::Div<Output = Self>
45    + std::ops::AddAssign
46    + std::ops::SubAssign
47    + std::ops::MulAssign
48    + std::ops::DivAssign
49    + std::cmp::PartialOrd
50    + std::cmp::PartialEq
51{
52    const DIGITS: u32;
53    const EPSILON: Self;
54    const INFINITY: Self;
55    const MANTISSA_DIGITS: u32;
56    const MAX_10_EXP: i32;
57    const MAX_EXP: i32;
58    const MIN_10_EXP: i32;
59    const MIN_EXP: i32;
60    const MIN_POSITIVE: Self;
61    const NAN: Self;
62    const NEG_INFINITY: Self;
63    const RADIX: u32;
64
65    fn new(val: f32) -> Self;
66    fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
67        __expand_new(scope, val)
68    }
69}
70
71macro_rules! impl_float {
72    (half $primitive:ident, $kind:ident) => {
73        impl_float!($primitive, $kind, |val| $primitive::from_f32(val));
74    };
75    ($primitive:ident, $kind:ident) => {
76        impl_float!($primitive, $kind, |val| val as $primitive);
77    };
78    ($primitive:ident, $kind:ident, $new:expr) => {
79        impl CubeType for $primitive {
80            type ExpandType = ExpandElementTyped<$primitive>;
81        }
82
83        impl CubePrimitive for $primitive {
84            /// Return the element type to use on GPU
85            fn as_elem_native() -> Option<Elem> {
86                Some(Elem::Float(FloatKind::$kind))
87            }
88        }
89
90        impl IntoRuntime for $primitive {
91            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
92                let elem: ExpandElementTyped<Self> = self.into();
93                into_runtime_expand_element(scope, elem).into()
94            }
95        }
96
97        impl Numeric for $primitive {
98            fn min_value() -> Self {
99                <Self as num_traits::Float>::min_value()
100            }
101            fn max_value() -> Self {
102                <Self as num_traits::Float>::max_value()
103            }
104        }
105
106        impl ExpandElementIntoMut for $primitive {
107            fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
108                into_mut_expand_element(scope, elem)
109            }
110        }
111
112        impl IntoMut for $primitive {
113            fn into_mut(self, _scope: &mut Scope) -> Self {
114                self
115            }
116        }
117
118        impl Float for $primitive {
119            const DIGITS: u32 = $primitive::DIGITS;
120            const EPSILON: Self = $primitive::EPSILON;
121            const INFINITY: Self = $primitive::INFINITY;
122            const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
123            const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
124            const MAX_EXP: i32 = $primitive::MAX_EXP;
125            const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
126            const MIN_EXP: i32 = $primitive::MIN_EXP;
127            const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
128            const NAN: Self = $primitive::NAN;
129            const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
130            const RADIX: u32 = $primitive::RADIX;
131
132            fn new(val: f32) -> Self {
133                $new(val)
134            }
135        }
136
137        impl LaunchArgExpand for $primitive {
138            type CompilationArg = ();
139
140            fn expand(
141                _: &Self::CompilationArg,
142                builder: &mut KernelBuilder,
143            ) -> ExpandElementTyped<Self> {
144                builder.scalar($primitive::as_elem(&builder.scope)).into()
145            }
146        }
147    };
148}
149
150impl_float!(half f16, F16);
151impl_float!(half bf16, BF16);
152impl_float!(f32, F32);
153impl_float!(f64, F64);
154
155impl ScalarArgSettings for f16 {
156    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
157        settings.register_f16(*self);
158    }
159}
160
161impl ScalarArgSettings for bf16 {
162    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
163        settings.register_bf16(*self);
164    }
165}
166
167impl ScalarArgSettings for f32 {
168    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
169        settings.register_f32(*self);
170    }
171}
172
173impl ScalarArgSettings for f64 {
174    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
175        settings.register_f64(*self);
176    }
177}