cubecl_core/frontend/element/
float.rs1use cubecl_ir::{ConstantValue, Scope, StorageType, Type};
2use half::{bf16, f16};
3
4use crate::{
5 self as cubecl,
6 ir::{ElemType, FloatKind},
7 prelude::*,
8};
9
10use super::Numeric;
11
12mod fp4;
13mod fp6;
14mod fp8;
15mod relaxed;
16mod tensor_float;
17
18pub trait Float:
20 Numeric
21 + FloatOps
22 + Exp
23 + Log
24 + Log1p
25 + Cos
26 + Sin
27 + Tan
28 + Tanh
29 + Sinh
30 + Cosh
31 + ArcCos
32 + ArcSin
33 + ArcTan
34 + ArcSinh
35 + ArcCosh
36 + ArcTanh
37 + Degrees
38 + Radians
39 + ArcTan2
40 + Powf
41 + Powi<i32>
42 + Hypot
43 + Rhypot
44 + Sqrt
45 + InverseSqrt
46 + Round
47 + Floor
48 + Ceil
49 + Trunc
50 + Erf
51 + Recip
52 + Magnitude
53 + Normalize
54 + Dot
55 + IsNan
56 + IsInf
57 + Into<Self::ExpandType>
58 + core::ops::Neg<Output = Self>
59 + core::cmp::PartialOrd
60 + core::cmp::PartialEq
61{
62 const DIGITS: u32;
63 const EPSILON: Self;
64 const INFINITY: Self;
65 const MANTISSA_DIGITS: u32;
66 const MAX_10_EXP: i32;
67 const MAX_EXP: i32;
68 const MIN_10_EXP: i32;
69 const MIN_EXP: i32;
70 const MIN_POSITIVE: Self;
71 const NAN: Self;
72 const NEG_INFINITY: Self;
73 const RADIX: u32;
74
75 fn new(val: f32) -> Self;
76 fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
77 __expand_new(scope, val)
78 }
79}
80
81#[cube]
82pub trait FloatOps: CubePrimitive + PartialOrd + Sized {
83 fn min(self, other: Self) -> Self {
84 cubecl::prelude::min(self, other)
85 }
86
87 fn max(self, other: Self) -> Self {
88 cubecl::prelude::max(self, other)
89 }
90
91 fn clamp(self, min: Self, max: Self) -> Self {
92 clamp(self, min, max)
93 }
94}
95
96impl<T: Float> FloatOps for T {}
97impl<T: FloatOps + CubePrimitive> FloatOpsExpand for NativeExpand<T> {
98 fn __expand_min_method(self, scope: &mut Scope, other: Self) -> Self {
99 min::expand(scope, self, other)
100 }
101
102 fn __expand_max_method(self, scope: &mut Scope, other: Self) -> Self {
103 max::expand(scope, self, other)
104 }
105
106 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
107 clamp::expand(scope, self, min, max)
108 }
109}
110
111macro_rules! impl_float {
112 (half $primitive:ident, $kind:ident) => {
113 impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
114 };
115 ($primitive:ident, $kind:ident) => {
116 impl_float!($primitive, $kind, |val| val as $primitive);
117 };
118 ($primitive:ident, $kind:ident, $new:expr) => {
119 impl CubeType for $primitive {
120 type ExpandType = NativeExpand<$primitive>;
121 }
122
123 impl Scalar for $primitive {}
124 impl CubePrimitive for $primitive {
125 type Scalar = Self;
126 type Size = Const<1>;
127 type WithScalar<S: Scalar> = S;
128
129 fn as_type_native() -> Option<Type> {
131 Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)).into())
132 }
133
134 fn from_const_value(value: ConstantValue) -> Self {
135 let ConstantValue::Float(value) = value else {
136 unreachable!()
137 };
138 $new(value)
139 }
140 }
141
142 impl IntoRuntime for $primitive {
143 fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
144 self.into()
145 }
146 }
147
148 impl Numeric for $primitive {
149 fn min_value() -> Self {
150 <Self as num_traits::Float>::min_value()
151 }
152 fn max_value() -> Self {
153 <Self as num_traits::Float>::max_value()
154 }
155 }
156
157 impl NativeAssign for $primitive {}
158
159 impl IntoMut for $primitive {
160 fn into_mut(self, _scope: &mut Scope) -> Self {
161 self
162 }
163 }
164
165 impl Float for $primitive {
166 const DIGITS: u32 = $primitive::DIGITS;
167 const EPSILON: Self = $primitive::EPSILON;
168 const INFINITY: Self = $primitive::INFINITY;
169 const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
170 const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
171 const MAX_EXP: i32 = $primitive::MAX_EXP;
172 const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
173 const MIN_EXP: i32 = $primitive::MIN_EXP;
174 const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
175 const NAN: Self = $primitive::NAN;
176 const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
177 const RADIX: u32 = $primitive::RADIX;
178
179 fn new(val: f32) -> Self {
180 $new(val as f64)
181 }
182 }
183 };
184}
185
186impl_float!(half f16, F16);
187impl_float!(half bf16, BF16);
188impl_float!(f32, F32);
189impl_float!(f64, F64);