cubecl_core/frontend/element/
float.rs

1use half::{bf16, f16};
2
3use crate::{
4    ir::{Elem, FloatKind},
5    prelude::*,
6};
7
8use super::Numeric;
9
10mod relaxed;
11mod tensor_float;
12mod typemap;
13
14pub use relaxed::*;
15pub use tensor_float::*;
16pub use typemap::*;
17
18/// Floating point numbers. Used as input in float kernels
19pub trait Float:
20    Numeric
21    + Exp
22    + Log
23    + Log1p
24    + Cos
25    + Sin
26    + Tanh
27    + Powf
28    + Sqrt
29    + Round
30    + Floor
31    + Ceil
32    + Erf
33    + Recip
34    + Magnitude
35    + Normalize
36    + Dot
37    + Into<Self::ExpandType>
38    + core::ops::Neg<Output = Self>
39    + core::ops::Add<Output = Self>
40    + core::ops::Sub<Output = Self>
41    + core::ops::Mul<Output = Self>
42    + core::ops::Div<Output = Self>
43    + std::ops::AddAssign
44    + std::ops::SubAssign
45    + std::ops::MulAssign
46    + std::ops::DivAssign
47    + std::cmp::PartialOrd
48    + std::cmp::PartialEq
49{
50    const DIGITS: u32;
51    const EPSILON: Self;
52    const INFINITY: Self;
53    const MANTISSA_DIGITS: u32;
54    const MAX_10_EXP: i32;
55    const MAX_EXP: i32;
56    const MIN_10_EXP: i32;
57    const MIN_EXP: i32;
58    const MIN_POSITIVE: Self;
59    const NAN: Self;
60    const NEG_INFINITY: Self;
61    const RADIX: u32;
62
63    fn new(val: f32) -> Self;
64    fn __expand_new(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType {
65        __expand_new(context, val)
66    }
67}
68
69macro_rules! impl_float {
70    (half $primitive:ident, $kind:ident) => {
71        impl_float!($primitive, $kind, |val| $primitive::from_f32(val));
72    };
73    ($primitive:ident, $kind:ident) => {
74        impl_float!($primitive, $kind, |val| val as $primitive);
75    };
76    ($primitive:ident, $kind:ident, $new:expr) => {
77        impl CubeType for $primitive {
78            type ExpandType = ExpandElementTyped<$primitive>;
79        }
80
81        impl CubePrimitive for $primitive {
82            /// Return the element type to use on GPU
83            fn as_elem_native() -> Option<Elem> {
84                Some(Elem::Float(FloatKind::$kind))
85            }
86        }
87
88        impl IntoRuntime for $primitive {
89            fn __expand_runtime_method(
90                self,
91                context: &mut CubeContext,
92            ) -> ExpandElementTyped<Self> {
93                let expand: ExpandElementTyped<Self> = self.into();
94                Init::init(expand, context)
95            }
96        }
97
98        impl Numeric for $primitive {
99            fn min_value() -> Self {
100                <Self as num_traits::Float>::min_value()
101            }
102            fn max_value() -> Self {
103                <Self as num_traits::Float>::max_value()
104            }
105        }
106
107        impl ExpandElementBaseInit for $primitive {
108            fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
109                init_expand_element(context, elem)
110            }
111        }
112
113        impl Float for $primitive {
114            const DIGITS: u32 = $primitive::DIGITS;
115            const EPSILON: Self = $primitive::EPSILON;
116            const INFINITY: Self = $primitive::INFINITY;
117            const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
118            const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
119            const MAX_EXP: i32 = $primitive::MAX_EXP;
120            const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
121            const MIN_EXP: i32 = $primitive::MIN_EXP;
122            const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
123            const NAN: Self = $primitive::NAN;
124            const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
125            const RADIX: u32 = $primitive::RADIX;
126
127            fn new(val: f32) -> Self {
128                $new(val)
129            }
130        }
131
132        impl LaunchArgExpand for $primitive {
133            type CompilationArg = ();
134
135            fn expand(
136                _: &Self::CompilationArg,
137                builder: &mut KernelBuilder,
138            ) -> ExpandElementTyped<Self> {
139                builder.scalar($primitive::as_elem(&builder.context)).into()
140            }
141        }
142    };
143}
144
145impl_float!(half f16, F16);
146impl_float!(half bf16, BF16);
147impl_float!(f32, F32);
148impl_float!(f64, F64);
149
150impl ScalarArgSettings for f16 {
151    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
152        settings.register_f16(*self);
153    }
154}
155
156impl ScalarArgSettings for bf16 {
157    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
158        settings.register_bf16(*self);
159    }
160}
161
162impl ScalarArgSettings for f32 {
163    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
164        settings.register_f32(*self);
165    }
166}
167
168impl ScalarArgSettings for f64 {
169    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
170        settings.register_f64(*self);
171    }
172}