cubecl_common/float/
relaxed.rs1use core::f32;
2use core::{
3 cmp::Ordering,
4 ops::{Div, DivAssign, Mul, MulAssign, Rem, RemAssign},
5};
6
7use bytemuck::{Pod, Zeroable};
8use derive_more::derive::{
9 Add, AddAssign, Display, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11use num_traits::{Num, NumCast, One, ToPrimitive, Zero};
12
13#[allow(non_camel_case_types)]
16#[repr(transparent)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(
19 Clone,
20 Copy,
21 Default,
22 Zeroable,
23 Pod,
24 PartialEq,
25 PartialOrd,
26 Neg,
27 Add,
28 Sub,
29 Mul,
30 Div,
31 Rem,
32 AddAssign,
33 SubAssign,
34 MulAssign,
35 DivAssign,
36 RemAssign,
37 Debug,
38 Display,
39)]
40pub struct flex32(f32);
41
42impl flex32 {
43 pub const MIN_POSITIVE: Self = Self(half::f16::MIN_POSITIVE.to_f32_const());
45
46 pub const fn from_f32(val: f32) -> Self {
48 flex32(val)
49 }
50
51 pub const fn from_f64(val: f64) -> Self {
53 flex32(val as f32)
54 }
55
56 pub const fn to_f32(self) -> f32 {
58 self.0
59 }
60
61 pub const fn to_f64(self) -> f64 {
63 self.0 as f64
64 }
65
66 pub fn total_cmp(&self, other: &flex32) -> Ordering {
68 self.0.total_cmp(&other.0)
69 }
70
71 pub fn is_nan(&self) -> bool {
73 self.0.is_nan()
74 }
75}
76
77impl Mul for flex32 {
78 type Output = flex32;
79
80 fn mul(self, rhs: Self) -> Self::Output {
81 flex32(self.0 * rhs.0)
82 }
83}
84
85impl Div for flex32 {
86 type Output = flex32;
87
88 fn div(self, rhs: Self) -> Self::Output {
89 flex32(self.0 / rhs.0)
90 }
91}
92
93impl Rem for flex32 {
94 type Output = flex32;
95
96 fn rem(self, rhs: Self) -> Self::Output {
97 flex32(self.0 % rhs.0)
98 }
99}
100
101impl MulAssign for flex32 {
102 fn mul_assign(&mut self, rhs: Self) {
103 self.0 *= rhs.0;
104 }
105}
106
107impl DivAssign for flex32 {
108 fn div_assign(&mut self, rhs: Self) {
109 self.0 /= rhs.0;
110 }
111}
112
113impl RemAssign for flex32 {
114 fn rem_assign(&mut self, rhs: Self) {
115 self.0 %= rhs.0;
116 }
117}
118
119impl From<f32> for flex32 {
120 fn from(value: f32) -> Self {
121 Self::from_f32(value)
122 }
123}
124
125impl From<flex32> for f32 {
126 fn from(val: flex32) -> Self {
127 val.to_f32()
128 }
129}
130
131impl ToPrimitive for flex32 {
132 fn to_i64(&self) -> Option<i64> {
133 Some((*self).to_f32() as i64)
134 }
135
136 fn to_u64(&self) -> Option<u64> {
137 Some((*self).to_f32() as u64)
138 }
139
140 fn to_f32(&self) -> Option<f32> {
141 Some((*self).to_f32())
142 }
143
144 fn to_f64(&self) -> Option<f64> {
145 Some((*self).to_f32() as f64)
146 }
147}
148
149impl NumCast for flex32 {
150 fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
151 Some(flex32::from_f32(n.to_f32()?))
152 }
153}
154
155impl num_traits::Float for flex32 {
156 fn nan() -> Self {
157 flex32(f32::nan())
158 }
159
160 fn infinity() -> Self {
161 flex32(f32::infinity())
162 }
163
164 fn neg_infinity() -> Self {
165 flex32(f32::neg_infinity())
166 }
167
168 fn neg_zero() -> Self {
169 flex32(f32::neg_zero())
170 }
171
172 fn min_value() -> Self {
173 flex32(<f32 as num_traits::Float>::min_value())
174 }
175
176 fn min_positive_value() -> Self {
177 flex32(f32::min_positive_value())
178 }
179
180 fn max_value() -> Self {
181 flex32(<f32 as num_traits::Float>::max_value())
182 }
183
184 fn is_nan(self) -> bool {
185 self.0.is_nan()
186 }
187
188 fn is_infinite(self) -> bool {
189 self.0.is_infinite()
190 }
191
192 fn is_finite(self) -> bool {
193 self.0.is_finite()
194 }
195
196 fn is_normal(self) -> bool {
197 self.0.is_normal()
198 }
199
200 fn classify(self) -> core::num::FpCategory {
201 self.0.classify()
202 }
203
204 fn floor(self) -> Self {
205 flex32(self.0.floor())
206 }
207
208 fn ceil(self) -> Self {
209 flex32(self.0.ceil())
210 }
211
212 fn round(self) -> Self {
213 flex32(self.0.round())
214 }
215
216 fn trunc(self) -> Self {
217 flex32(self.0.trunc())
218 }
219
220 fn fract(self) -> Self {
221 flex32(self.0.fract())
222 }
223
224 fn abs(self) -> Self {
225 flex32(self.0.abs())
226 }
227
228 fn signum(self) -> Self {
229 flex32(self.0.signum())
230 }
231
232 fn is_sign_positive(self) -> bool {
233 self.0.is_sign_positive()
234 }
235
236 fn is_sign_negative(self) -> bool {
237 self.0.is_sign_negative()
238 }
239
240 fn mul_add(self, a: Self, b: Self) -> Self {
241 flex32(self.0.mul_add(a.0, b.0))
242 }
243
244 fn recip(self) -> Self {
245 flex32(self.0.recip())
246 }
247
248 fn powi(self, n: i32) -> Self {
249 flex32(self.0.powi(n))
250 }
251
252 fn powf(self, n: Self) -> Self {
253 flex32(self.0.powf(n.0))
254 }
255
256 fn sqrt(self) -> Self {
257 flex32(self.0.sqrt())
258 }
259
260 fn exp(self) -> Self {
261 flex32(self.0.exp())
262 }
263
264 fn exp2(self) -> Self {
265 flex32(self.0.exp2())
266 }
267
268 fn ln(self) -> Self {
269 flex32(self.0.ln())
270 }
271
272 fn log(self, base: Self) -> Self {
273 flex32(self.0.log(base.0))
274 }
275
276 fn log2(self) -> Self {
277 flex32(self.0.log2())
278 }
279
280 fn log10(self) -> Self {
281 flex32(self.0.log10())
282 }
283
284 fn max(self, other: Self) -> Self {
285 flex32(self.0.max(other.0))
286 }
287
288 fn min(self, other: Self) -> Self {
289 flex32(self.0.min(other.0))
290 }
291
292 fn abs_sub(self, other: Self) -> Self {
293 flex32((self.0 - other.0).abs())
294 }
295
296 fn cbrt(self) -> Self {
297 flex32(self.0.cbrt())
298 }
299
300 fn hypot(self, other: Self) -> Self {
301 flex32(self.0.hypot(other.0))
302 }
303
304 fn sin(self) -> Self {
305 flex32(self.0.sin())
306 }
307
308 fn cos(self) -> Self {
309 flex32(self.0.cos())
310 }
311
312 fn tan(self) -> Self {
313 flex32(self.0.tan())
314 }
315
316 fn asin(self) -> Self {
317 flex32(self.0.asin())
318 }
319
320 fn acos(self) -> Self {
321 flex32(self.0.acos())
322 }
323
324 fn atan(self) -> Self {
325 flex32(self.0.atan())
326 }
327
328 fn atan2(self, other: Self) -> Self {
329 flex32(self.0.atan2(other.0))
330 }
331
332 fn sin_cos(self) -> (Self, Self) {
333 let (a, b) = self.0.sin_cos();
334 (flex32(a), flex32(b))
335 }
336
337 fn exp_m1(self) -> Self {
338 flex32(self.0.exp_m1())
339 }
340
341 fn ln_1p(self) -> Self {
342 flex32(self.0.ln_1p())
343 }
344
345 fn sinh(self) -> Self {
346 flex32(self.0.sinh())
347 }
348
349 fn cosh(self) -> Self {
350 flex32(self.0.cosh())
351 }
352
353 fn tanh(self) -> Self {
354 flex32(self.0.tanh())
355 }
356
357 fn asinh(self) -> Self {
358 flex32(self.0.asinh())
359 }
360
361 fn acosh(self) -> Self {
362 flex32(self.0.acosh())
363 }
364
365 fn atanh(self) -> Self {
366 flex32(self.0.atanh())
367 }
368
369 fn integer_decode(self) -> (u64, i16, i8) {
370 self.0.integer_decode()
371 }
372}
373
374impl Num for flex32 {
375 type FromStrRadixErr = <f32 as Num>::FromStrRadixErr;
376
377 fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
378 Ok(flex32(f32::from_str_radix(str, radix)?))
379 }
380}
381
382impl One for flex32 {
383 fn one() -> Self {
384 flex32(1.0)
385 }
386}
387
388impl Zero for flex32 {
389 fn zero() -> Self {
390 flex32(0.0)
391 }
392
393 fn is_zero(&self) -> bool {
394 self.0 == 0.0
395 }
396}