Skip to main content

cubecl_core/frontend/element/
float.rs

1use cubecl_ir::{ConstantValue, Scope, StorageType, Type};
2use half::{bf16, f16};
3
4use crate::{
5    self as cubecl,
6    ir::{ElemType, FloatKind},
7    prelude::*,
8};
9
10use super::Numeric;
11
12mod fp4;
13mod fp6;
14mod fp8;
15mod relaxed;
16mod tensor_float;
17
18/// Floating point numbers. Used as input in float kernels
19pub trait Float:
20    Numeric
21    + FloatOps
22    + Exp
23    + Log
24    + Log1p
25    + Cos
26    + Sin
27    + Tan
28    + Tanh
29    + Sinh
30    + Cosh
31    + ArcCos
32    + ArcSin
33    + ArcTan
34    + ArcSinh
35    + ArcCosh
36    + ArcTanh
37    + Degrees
38    + Radians
39    + ArcTan2
40    + Powf
41    + Powi<i32>
42    + Hypot
43    + Rhypot
44    + Sqrt
45    + InverseSqrt
46    + Round
47    + Floor
48    + Ceil
49    + Trunc
50    + Erf
51    + Recip
52    + Magnitude
53    + Normalize
54    + Dot
55    + IsNan
56    + IsInf
57    + Into<Self::ExpandType>
58    + core::ops::Neg<Output = Self>
59    + core::cmp::PartialOrd
60    + core::cmp::PartialEq
61{
62    const DIGITS: u32;
63    const EPSILON: Self;
64    const INFINITY: Self;
65    const MANTISSA_DIGITS: u32;
66    const MAX_10_EXP: i32;
67    const MAX_EXP: i32;
68    const MIN_10_EXP: i32;
69    const MIN_EXP: i32;
70    const MIN_POSITIVE: Self;
71    const NAN: Self;
72    const NEG_INFINITY: Self;
73    const RADIX: u32;
74
75    fn new(val: f32) -> Self;
76    fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
77        __expand_new(scope, val)
78    }
79}
80
81#[cube]
82pub trait FloatOps: CubePrimitive + PartialOrd + Sized {
83    fn min(self, other: Self) -> Self {
84        cubecl::prelude::min(self, other)
85    }
86
87    fn max(self, other: Self) -> Self {
88        cubecl::prelude::max(self, other)
89    }
90
91    fn clamp(self, min: Self, max: Self) -> Self {
92        clamp(self, min, max)
93    }
94}
95
96impl<T: Float> FloatOps for T {}
97impl<T: FloatOps + CubePrimitive> FloatOpsExpand for NativeExpand<T> {
98    fn __expand_min_method(self, scope: &mut Scope, other: Self) -> Self {
99        min::expand(scope, self, other)
100    }
101
102    fn __expand_max_method(self, scope: &mut Scope, other: Self) -> Self {
103        max::expand(scope, self, other)
104    }
105
106    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
107        clamp::expand(scope, self, min, max)
108    }
109}
110
111macro_rules! impl_float {
112    (half $primitive:ident, $kind:ident) => {
113        impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
114    };
115    ($primitive:ident, $kind:ident) => {
116        impl_float!($primitive, $kind, |val| val as $primitive);
117    };
118    ($primitive:ident, $kind:ident, $new:expr) => {
119        impl CubeType for $primitive {
120            type ExpandType = NativeExpand<$primitive>;
121        }
122
123        impl Scalar for $primitive {}
124        impl CubePrimitive for $primitive {
125            type Scalar = Self;
126            type Size = Const<1>;
127            type WithScalar<S: Scalar> = S;
128
129            /// Return the element type to use on GPU
130            fn as_type_native() -> Option<Type> {
131                Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)).into())
132            }
133
134            fn from_const_value(value: ConstantValue) -> Self {
135                let ConstantValue::Float(value) = value else {
136                    unreachable!()
137                };
138                $new(value)
139            }
140        }
141
142        impl IntoRuntime for $primitive {
143            fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
144                self.into()
145            }
146        }
147
148        impl Numeric for $primitive {
149            fn min_value() -> Self {
150                <Self as num_traits::Float>::min_value()
151            }
152            fn max_value() -> Self {
153                <Self as num_traits::Float>::max_value()
154            }
155        }
156
157        impl NativeAssign for $primitive {}
158
159        impl IntoMut for $primitive {
160            fn into_mut(self, _scope: &mut Scope) -> Self {
161                self
162            }
163        }
164
165        impl Float for $primitive {
166            const DIGITS: u32 = $primitive::DIGITS;
167            const EPSILON: Self = $primitive::EPSILON;
168            const INFINITY: Self = $primitive::INFINITY;
169            const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
170            const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
171            const MAX_EXP: i32 = $primitive::MAX_EXP;
172            const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
173            const MIN_EXP: i32 = $primitive::MIN_EXP;
174            const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
175            const NAN: Self = $primitive::NAN;
176            const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
177            const RADIX: u32 = $primitive::RADIX;
178
179            fn new(val: f32) -> Self {
180                $new(val as f64)
181            }
182        }
183    };
184}
185
186impl_float!(half f16, F16);
187impl_float!(half bf16, BF16);
188impl_float!(f32, F32);
189impl_float!(f64, F64);