cubecl_core/frontend/element/
float.rs1use cubecl_ir::{ConstantScalarValue, Scope, StorageType};
2use half::{bf16, f16};
3
4use crate::{
5 ir::{ElemType, ExpandElement, FloatKind},
6 prelude::*,
7};
8
9use super::Numeric;
10
11mod fp4;
12mod fp6;
13mod fp8;
14mod relaxed;
15mod tensor_float;
16mod typemap;
17
18pub use typemap::*;
19
20pub trait Float:
22 Numeric
23 + Exp
24 + Log
25 + Log1p
26 + Cos
27 + Sin
28 + Tanh
29 + Powf
30 + Powi<i32>
31 + Sqrt
32 + InverseSqrt
33 + Round
34 + Floor
35 + Ceil
36 + Trunc
37 + Erf
38 + Recip
39 + Magnitude
40 + Normalize
41 + Dot
42 + IsNan
43 + IsInf
44 + Into<Self::ExpandType>
45 + core::ops::Neg<Output = Self>
46 + core::ops::Add<Output = Self>
47 + core::ops::Sub<Output = Self>
48 + core::ops::Mul<Output = Self>
49 + core::ops::Div<Output = Self>
50 + std::ops::AddAssign
51 + std::ops::SubAssign
52 + std::ops::MulAssign
53 + std::ops::DivAssign
54 + std::cmp::PartialOrd
55 + std::cmp::PartialEq
56{
57 const DIGITS: u32;
58 const EPSILON: Self;
59 const INFINITY: Self;
60 const MANTISSA_DIGITS: u32;
61 const MAX_10_EXP: i32;
62 const MAX_EXP: i32;
63 const MIN_10_EXP: i32;
64 const MIN_EXP: i32;
65 const MIN_POSITIVE: Self;
66 const NAN: Self;
67 const NEG_INFINITY: Self;
68 const RADIX: u32;
69
70 fn new(val: f32) -> Self;
71 fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
72 __expand_new(scope, val)
73 }
74}
75
76macro_rules! impl_float {
77 (half $primitive:ident, $kind:ident) => {
78 impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
79 };
80 ($primitive:ident, $kind:ident) => {
81 impl_float!($primitive, $kind, |val| val as $primitive);
82 };
83 ($primitive:ident, $kind:ident, $new:expr) => {
84 impl CubeType for $primitive {
85 type ExpandType = ExpandElementTyped<$primitive>;
86 }
87
88 impl CubePrimitive for $primitive {
89 fn as_type_native() -> Option<StorageType> {
91 Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)))
92 }
93
94 fn from_const_value(value: ConstantScalarValue) -> Self {
95 let ConstantScalarValue::Float(value, _) = value else {
96 unreachable!()
97 };
98 $new(value)
99 }
100 }
101
102 impl IntoRuntime for $primitive {
103 fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
104 let elem: ExpandElementTyped<Self> = self.into();
105 into_runtime_expand_element(scope, elem).into()
106 }
107 }
108
109 impl Numeric for $primitive {
110 fn min_value() -> Self {
111 <Self as num_traits::Float>::min_value()
112 }
113 fn max_value() -> Self {
114 <Self as num_traits::Float>::max_value()
115 }
116 }
117
118 impl ExpandElementIntoMut for $primitive {
119 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
120 into_mut_expand_element(scope, elem)
121 }
122 }
123
124 impl IntoMut for $primitive {
125 fn into_mut(self, _scope: &mut Scope) -> Self {
126 self
127 }
128 }
129
130 impl Float for $primitive {
131 const DIGITS: u32 = $primitive::DIGITS;
132 const EPSILON: Self = $primitive::EPSILON;
133 const INFINITY: Self = $primitive::INFINITY;
134 const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
135 const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
136 const MAX_EXP: i32 = $primitive::MAX_EXP;
137 const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
138 const MIN_EXP: i32 = $primitive::MIN_EXP;
139 const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
140 const NAN: Self = $primitive::NAN;
141 const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
142 const RADIX: u32 = $primitive::RADIX;
143
144 fn new(val: f32) -> Self {
145 $new(val as f64)
146 }
147 }
148 };
149}
150
151impl_float!(half f16, F16);
152impl_float!(half bf16, BF16);
153impl_float!(f32, F32);
154impl_float!(f64, F64);
155
156impl ScalarArgSettings for f16 {
157 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
158 settings.register_f16(*self);
159 }
160}
161
162impl ScalarArgSettings for bf16 {
163 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
164 settings.register_bf16(*self);
165 }
166}
167
168impl ScalarArgSettings for f32 {
169 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
170 settings.register_f32(*self);
171 }
172}
173
174impl ScalarArgSettings for f64 {
175 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
176 settings.register_f64(*self);
177 }
178}