cubecl_core/frontend/element/
float.rs

1use cubecl_ir::Scope;
2use half::{bf16, f16};
3
4use crate::{
5    ir::{Elem, ExpandElement, FloatKind},
6    prelude::*,
7};
8
9use super::Numeric;
10
11mod relaxed;
12mod tensor_float;
13mod typemap;
14
15pub use typemap::*;
16
17/// Floating point numbers. Used as input in float kernels
18pub trait Float:
19    Numeric
20    + Exp
21    + Log
22    + Log1p
23    + Cos
24    + Sin
25    + Tanh
26    + Powf
27    + Sqrt
28    + Round
29    + Floor
30    + Ceil
31    + Erf
32    + Recip
33    + Magnitude
34    + Normalize
35    + Dot
36    + Into<Self::ExpandType>
37    + core::ops::Neg<Output = Self>
38    + core::ops::Add<Output = Self>
39    + core::ops::Sub<Output = Self>
40    + core::ops::Mul<Output = Self>
41    + core::ops::Div<Output = Self>
42    + std::ops::AddAssign
43    + std::ops::SubAssign
44    + std::ops::MulAssign
45    + std::ops::DivAssign
46    + std::cmp::PartialOrd
47    + std::cmp::PartialEq
48{
49    const DIGITS: u32;
50    const EPSILON: Self;
51    const INFINITY: Self;
52    const MANTISSA_DIGITS: u32;
53    const MAX_10_EXP: i32;
54    const MAX_EXP: i32;
55    const MIN_10_EXP: i32;
56    const MIN_EXP: i32;
57    const MIN_POSITIVE: Self;
58    const NAN: Self;
59    const NEG_INFINITY: Self;
60    const RADIX: u32;
61
62    fn new(val: f32) -> Self;
63    fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
64        __expand_new(scope, val)
65    }
66}
67
68macro_rules! impl_float {
69    (half $primitive:ident, $kind:ident) => {
70        impl_float!($primitive, $kind, |val| $primitive::from_f32(val));
71    };
72    ($primitive:ident, $kind:ident) => {
73        impl_float!($primitive, $kind, |val| val as $primitive);
74    };
75    ($primitive:ident, $kind:ident, $new:expr) => {
76        impl CubeType for $primitive {
77            type ExpandType = ExpandElementTyped<$primitive>;
78        }
79
80        impl CubePrimitive for $primitive {
81            /// Return the element type to use on GPU
82            fn as_elem_native() -> Option<Elem> {
83                Some(Elem::Float(FloatKind::$kind))
84            }
85        }
86
87        impl IntoRuntime for $primitive {
88            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
89                let expand: ExpandElementTyped<Self> = self.into();
90                Init::init(expand, scope)
91            }
92        }
93
94        impl Numeric for $primitive {
95            fn min_value() -> Self {
96                <Self as num_traits::Float>::min_value()
97            }
98            fn max_value() -> Self {
99                <Self as num_traits::Float>::max_value()
100            }
101        }
102
103        impl ExpandElementBaseInit for $primitive {
104            fn init_elem(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
105                init_expand_element(scope, elem)
106            }
107        }
108
109        impl Float for $primitive {
110            const DIGITS: u32 = $primitive::DIGITS;
111            const EPSILON: Self = $primitive::EPSILON;
112            const INFINITY: Self = $primitive::INFINITY;
113            const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
114            const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
115            const MAX_EXP: i32 = $primitive::MAX_EXP;
116            const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
117            const MIN_EXP: i32 = $primitive::MIN_EXP;
118            const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
119            const NAN: Self = $primitive::NAN;
120            const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
121            const RADIX: u32 = $primitive::RADIX;
122
123            fn new(val: f32) -> Self {
124                $new(val)
125            }
126        }
127
128        impl LaunchArgExpand for $primitive {
129            type CompilationArg = ();
130
131            fn expand(
132                _: &Self::CompilationArg,
133                builder: &mut KernelBuilder,
134            ) -> ExpandElementTyped<Self> {
135                builder.scalar($primitive::as_elem(&builder.context)).into()
136            }
137        }
138    };
139}
140
141impl_float!(half f16, F16);
142impl_float!(half bf16, BF16);
143impl_float!(f32, F32);
144impl_float!(f64, F64);
145
146impl ScalarArgSettings for f16 {
147    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
148        settings.register_f16(*self);
149    }
150}
151
152impl ScalarArgSettings for bf16 {
153    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
154        settings.register_bf16(*self);
155    }
156}
157
158impl ScalarArgSettings for f32 {
159    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
160        settings.register_f32(*self);
161    }
162}
163
164impl ScalarArgSettings for f64 {
165    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
166        settings.register_f64(*self);
167    }
168}