cubecl_core/frontend/element/
float.rs1use cubecl_ir::{ConstantValue, Scope, StorageType};
2use half::{bf16, f16};
3
4use crate::{
5 self as cubecl,
6 ir::{ElemType, ExpandElement, FloatKind},
7 prelude::*,
8};
9
10use super::Numeric;
11
12mod fp4;
13mod fp6;
14mod fp8;
15mod relaxed;
16mod tensor_float;
17mod typemap;
18
19pub use typemap::*;
20
21pub trait Float:
23 Numeric
24 + FloatOps
25 + Exp
26 + Log
27 + Log1p
28 + Cos
29 + Sin
30 + Tan
31 + Tanh
32 + Sinh
33 + Cosh
34 + ArcCos
35 + ArcSin
36 + ArcTan
37 + ArcSinh
38 + ArcCosh
39 + ArcTanh
40 + Degrees
41 + Radians
42 + ArcTan2
43 + Powf
44 + Powi<i32>
45 + Hypot
46 + Rhypot
47 + Sqrt
48 + InverseSqrt
49 + Round
50 + Floor
51 + Ceil
52 + Trunc
53 + Erf
54 + Recip
55 + Magnitude
56 + Normalize
57 + Dot
58 + IsNan
59 + IsInf
60 + Into<Self::ExpandType>
61 + core::ops::Neg<Output = Self>
62 + std::cmp::PartialOrd
63 + std::cmp::PartialEq
64{
65 const DIGITS: u32;
66 const EPSILON: Self;
67 const INFINITY: Self;
68 const MANTISSA_DIGITS: u32;
69 const MAX_10_EXP: i32;
70 const MAX_EXP: i32;
71 const MIN_10_EXP: i32;
72 const MIN_EXP: i32;
73 const MIN_POSITIVE: Self;
74 const NAN: Self;
75 const NEG_INFINITY: Self;
76 const RADIX: u32;
77
78 fn new(val: f32) -> Self;
79 fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
80 __expand_new(scope, val)
81 }
82}
83
84#[cube]
85pub trait FloatOps: CubePrimitive + PartialOrd + Sized {
86 fn min(self, other: Self) -> Self {
87 cubecl::prelude::min(self, other)
88 }
89
90 fn max(self, other: Self) -> Self {
91 cubecl::prelude::max(self, other)
92 }
93
94 fn clamp(self, min: Self, max: Self) -> Self {
95 clamp(self, min, max)
96 }
97}
98
99impl<T: Float> FloatOps for T {}
100impl<T: FloatOps + CubePrimitive> FloatOpsExpand for ExpandElementTyped<T> {
101 fn __expand_min_method(self, scope: &mut Scope, other: Self) -> Self {
102 min::expand(scope, self, other)
103 }
104
105 fn __expand_max_method(self, scope: &mut Scope, other: Self) -> Self {
106 max::expand(scope, self, other)
107 }
108
109 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
110 clamp::expand(scope, self, min, max)
111 }
112}
113
114macro_rules! impl_float {
115 (half $primitive:ident, $kind:ident) => {
116 impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
117 };
118 ($primitive:ident, $kind:ident) => {
119 impl_float!($primitive, $kind, |val| val as $primitive);
120 };
121 ($primitive:ident, $kind:ident, $new:expr) => {
122 impl CubeType for $primitive {
123 type ExpandType = ExpandElementTyped<$primitive>;
124 }
125
126 impl CubePrimitive for $primitive {
127 fn as_type_native() -> Option<StorageType> {
129 Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)))
130 }
131
132 fn from_const_value(value: ConstantValue) -> Self {
133 let ConstantValue::Float(value) = value else {
134 unreachable!()
135 };
136 $new(value)
137 }
138 }
139
140 impl IntoRuntime for $primitive {
141 fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
142 let elem: ExpandElementTyped<Self> = self.into();
143 into_runtime_expand_element(scope, elem).into()
144 }
145 }
146
147 impl Numeric for $primitive {
148 fn min_value() -> Self {
149 <Self as num_traits::Float>::min_value()
150 }
151 fn max_value() -> Self {
152 <Self as num_traits::Float>::max_value()
153 }
154 }
155
156 impl ExpandElementIntoMut for $primitive {
157 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
158 into_mut_expand_element(scope, elem)
159 }
160 }
161
162 impl IntoMut for $primitive {
163 fn into_mut(self, _scope: &mut Scope) -> Self {
164 self
165 }
166 }
167
168 impl Float for $primitive {
169 const DIGITS: u32 = $primitive::DIGITS;
170 const EPSILON: Self = $primitive::EPSILON;
171 const INFINITY: Self = $primitive::INFINITY;
172 const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
173 const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
174 const MAX_EXP: i32 = $primitive::MAX_EXP;
175 const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
176 const MIN_EXP: i32 = $primitive::MIN_EXP;
177 const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
178 const NAN: Self = $primitive::NAN;
179 const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
180 const RADIX: u32 = $primitive::RADIX;
181
182 fn new(val: f32) -> Self {
183 $new(val as f64)
184 }
185 }
186 };
187}
188
189impl_float!(half f16, F16);
190impl_float!(half bf16, BF16);
191impl_float!(f32, F32);
192impl_float!(f64, F64);