cubecl_core/frontend/element/
float.rs1use cubecl_ir::{ConstantScalarValue, Scope, StorageType};
2use half::{bf16, f16};
3
4use crate::{
5 ir::{ElemType, ExpandElement, FloatKind},
6 prelude::*,
7};
8
9use super::Numeric;
10
11mod fp4;
12mod fp6;
13mod fp8;
14mod relaxed;
15mod tensor_float;
16mod typemap;
17
18pub use typemap::*;
19
20pub trait Float:
22 Numeric
23 + Exp
24 + Log
25 + Log1p
26 + Cos
27 + Sin
28 + Tan
29 + Tanh
30 + Sinh
31 + Cosh
32 + ArcCos
33 + ArcSin
34 + ArcTan
35 + ArcSinh
36 + ArcCosh
37 + ArcTanh
38 + Degrees
39 + Radians
40 + ArcTan2
41 + Powf
42 + Powi<i32>
43 + Sqrt
44 + InverseSqrt
45 + Round
46 + Floor
47 + Ceil
48 + Trunc
49 + Erf
50 + Recip
51 + Magnitude
52 + Normalize
53 + Dot
54 + IsNan
55 + IsInf
56 + Into<Self::ExpandType>
57 + core::ops::Neg<Output = Self>
58 + core::ops::Add<Output = Self>
59 + core::ops::Sub<Output = Self>
60 + core::ops::Mul<Output = Self>
61 + core::ops::Div<Output = Self>
62 + std::ops::AddAssign
63 + std::ops::SubAssign
64 + std::ops::MulAssign
65 + std::ops::DivAssign
66 + std::cmp::PartialOrd
67 + std::cmp::PartialEq
68{
69 const DIGITS: u32;
70 const EPSILON: Self;
71 const INFINITY: Self;
72 const MANTISSA_DIGITS: u32;
73 const MAX_10_EXP: i32;
74 const MAX_EXP: i32;
75 const MIN_10_EXP: i32;
76 const MIN_EXP: i32;
77 const MIN_POSITIVE: Self;
78 const NAN: Self;
79 const NEG_INFINITY: Self;
80 const RADIX: u32;
81
82 fn new(val: f32) -> Self;
83 fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
84 __expand_new(scope, val)
85 }
86}
87
88macro_rules! impl_float {
89 (half $primitive:ident, $kind:ident) => {
90 impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
91 };
92 ($primitive:ident, $kind:ident) => {
93 impl_float!($primitive, $kind, |val| val as $primitive);
94 };
95 ($primitive:ident, $kind:ident, $new:expr) => {
96 impl CubeType for $primitive {
97 type ExpandType = ExpandElementTyped<$primitive>;
98 }
99
100 impl CubePrimitive for $primitive {
101 fn as_type_native() -> Option<StorageType> {
103 Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)))
104 }
105
106 fn from_const_value(value: ConstantScalarValue) -> Self {
107 let ConstantScalarValue::Float(value, _) = value else {
108 unreachable!()
109 };
110 $new(value)
111 }
112 }
113
114 impl IntoRuntime for $primitive {
115 fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
116 let elem: ExpandElementTyped<Self> = self.into();
117 into_runtime_expand_element(scope, elem).into()
118 }
119 }
120
121 impl Numeric for $primitive {
122 fn min_value() -> Self {
123 <Self as num_traits::Float>::min_value()
124 }
125 fn max_value() -> Self {
126 <Self as num_traits::Float>::max_value()
127 }
128 }
129
130 impl ExpandElementIntoMut for $primitive {
131 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
132 into_mut_expand_element(scope, elem)
133 }
134 }
135
136 impl IntoMut for $primitive {
137 fn into_mut(self, _scope: &mut Scope) -> Self {
138 self
139 }
140 }
141
142 impl Float for $primitive {
143 const DIGITS: u32 = $primitive::DIGITS;
144 const EPSILON: Self = $primitive::EPSILON;
145 const INFINITY: Self = $primitive::INFINITY;
146 const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
147 const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
148 const MAX_EXP: i32 = $primitive::MAX_EXP;
149 const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
150 const MIN_EXP: i32 = $primitive::MIN_EXP;
151 const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
152 const NAN: Self = $primitive::NAN;
153 const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
154 const RADIX: u32 = $primitive::RADIX;
155
156 fn new(val: f32) -> Self {
157 $new(val as f64)
158 }
159 }
160 };
161}
162
163impl_float!(half f16, F16);
164impl_float!(half bf16, BF16);
165impl_float!(f32, F32);
166impl_float!(f64, F64);
167
168impl ScalarArgSettings for f16 {
169 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
170 settings.register_f16(*self);
171 }
172}
173
174impl ScalarArgSettings for bf16 {
175 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
176 settings.register_bf16(*self);
177 }
178}
179
180impl ScalarArgSettings for f32 {
181 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
182 settings.register_f32(*self);
183 }
184}
185
186impl ScalarArgSettings for f64 {
187 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
188 settings.register_f64(*self);
189 }
190}