cubecl_common/float/
relaxed.rs

1use core::f32;
2use core::{
3    cmp::Ordering,
4    ops::{Div, DivAssign, Mul, MulAssign, Rem, RemAssign},
5};
6
7use bytemuck::{Pod, Zeroable};
8use derive_more::derive::{
9    Add, AddAssign, Display, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11use num_traits::{Num, NumCast, One, ToPrimitive, Zero};
12
13/// A floating point type with relaxed precision, minimum [`f16`], max [`f32`].
14///
15#[allow(non_camel_case_types)]
16#[repr(transparent)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(
19    Clone,
20    Copy,
21    Default,
22    Zeroable,
23    Pod,
24    PartialEq,
25    PartialOrd,
26    Neg,
27    Add,
28    Sub,
29    Mul,
30    Div,
31    Rem,
32    AddAssign,
33    SubAssign,
34    MulAssign,
35    DivAssign,
36    RemAssign,
37    Debug,
38    Display,
39)]
40pub struct flex32(f32);
41
42impl flex32 {
43    /// Minimum positive flex32 value
44    pub const MIN_POSITIVE: Self = Self(half::f16::MIN_POSITIVE.to_f32_const());
45
46    /// Create a `flex32` from [`prim@f32`]
47    pub const fn from_f32(val: f32) -> Self {
48        flex32(val)
49    }
50
51    /// Create a `flex32` from [`prim@f64`]
52    pub const fn from_f64(val: f64) -> Self {
53        flex32(val as f32)
54    }
55
56    /// Turn a `flex32` into [`prim@f32`]
57    pub const fn to_f32(self) -> f32 {
58        self.0
59    }
60
61    /// Turn a `flex32` into [`prim@f64`]
62    pub const fn to_f64(self) -> f64 {
63        self.0 as f64
64    }
65
66    /// Compare two flex32 numbers
67    pub fn total_cmp(&self, other: &flex32) -> Ordering {
68        self.0.total_cmp(&other.0)
69    }
70
71    /// Whether this flex32 represents `NaN`
72    pub fn is_nan(&self) -> bool {
73        self.0.is_nan()
74    }
75}
76
77impl Mul for flex32 {
78    type Output = flex32;
79
80    fn mul(self, rhs: Self) -> Self::Output {
81        flex32(self.0 * rhs.0)
82    }
83}
84
85impl Div for flex32 {
86    type Output = flex32;
87
88    fn div(self, rhs: Self) -> Self::Output {
89        flex32(self.0 / rhs.0)
90    }
91}
92
93impl Rem for flex32 {
94    type Output = flex32;
95
96    fn rem(self, rhs: Self) -> Self::Output {
97        flex32(self.0 % rhs.0)
98    }
99}
100
101impl MulAssign for flex32 {
102    fn mul_assign(&mut self, rhs: Self) {
103        self.0 *= rhs.0;
104    }
105}
106
107impl DivAssign for flex32 {
108    fn div_assign(&mut self, rhs: Self) {
109        self.0 /= rhs.0;
110    }
111}
112
113impl RemAssign for flex32 {
114    fn rem_assign(&mut self, rhs: Self) {
115        self.0 %= rhs.0;
116    }
117}
118
119impl From<f32> for flex32 {
120    fn from(value: f32) -> Self {
121        Self::from_f32(value)
122    }
123}
124
125impl From<flex32> for f32 {
126    fn from(val: flex32) -> Self {
127        val.to_f32()
128    }
129}
130
131impl ToPrimitive for flex32 {
132    fn to_i64(&self) -> Option<i64> {
133        Some((*self).to_f32() as i64)
134    }
135
136    fn to_u64(&self) -> Option<u64> {
137        Some((*self).to_f32() as u64)
138    }
139
140    fn to_f32(&self) -> Option<f32> {
141        Some((*self).to_f32())
142    }
143
144    fn to_f64(&self) -> Option<f64> {
145        Some((*self).to_f32() as f64)
146    }
147}
148
149impl NumCast for flex32 {
150    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
151        Some(flex32::from_f32(n.to_f32()?))
152    }
153}
154
155impl num_traits::Float for flex32 {
156    fn nan() -> Self {
157        flex32(f32::nan())
158    }
159
160    fn infinity() -> Self {
161        flex32(f32::infinity())
162    }
163
164    fn neg_infinity() -> Self {
165        flex32(f32::neg_infinity())
166    }
167
168    fn neg_zero() -> Self {
169        flex32(f32::neg_zero())
170    }
171
172    fn min_value() -> Self {
173        flex32(<f32 as num_traits::Float>::min_value())
174    }
175
176    fn min_positive_value() -> Self {
177        flex32(f32::min_positive_value())
178    }
179
180    fn max_value() -> Self {
181        flex32(<f32 as num_traits::Float>::max_value())
182    }
183
184    fn is_nan(self) -> bool {
185        self.0.is_nan()
186    }
187
188    fn is_infinite(self) -> bool {
189        self.0.is_infinite()
190    }
191
192    fn is_finite(self) -> bool {
193        self.0.is_finite()
194    }
195
196    fn is_normal(self) -> bool {
197        self.0.is_normal()
198    }
199
200    fn classify(self) -> core::num::FpCategory {
201        self.0.classify()
202    }
203
204    fn floor(self) -> Self {
205        flex32(self.0.floor())
206    }
207
208    fn ceil(self) -> Self {
209        flex32(self.0.ceil())
210    }
211
212    fn round(self) -> Self {
213        flex32(self.0.round())
214    }
215
216    fn trunc(self) -> Self {
217        flex32(self.0.trunc())
218    }
219
220    fn fract(self) -> Self {
221        flex32(self.0.fract())
222    }
223
224    fn abs(self) -> Self {
225        flex32(self.0.abs())
226    }
227
228    fn signum(self) -> Self {
229        flex32(self.0.signum())
230    }
231
232    fn is_sign_positive(self) -> bool {
233        self.0.is_sign_positive()
234    }
235
236    fn is_sign_negative(self) -> bool {
237        self.0.is_sign_negative()
238    }
239
240    fn mul_add(self, a: Self, b: Self) -> Self {
241        flex32(self.0.mul_add(a.0, b.0))
242    }
243
244    fn recip(self) -> Self {
245        flex32(self.0.recip())
246    }
247
248    fn powi(self, n: i32) -> Self {
249        flex32(self.0.powi(n))
250    }
251
252    fn powf(self, n: Self) -> Self {
253        flex32(self.0.powf(n.0))
254    }
255
256    fn sqrt(self) -> Self {
257        flex32(self.0.sqrt())
258    }
259
260    fn exp(self) -> Self {
261        flex32(self.0.exp())
262    }
263
264    fn exp2(self) -> Self {
265        flex32(self.0.exp2())
266    }
267
268    fn ln(self) -> Self {
269        flex32(self.0.ln())
270    }
271
272    fn log(self, base: Self) -> Self {
273        flex32(self.0.log(base.0))
274    }
275
276    fn log2(self) -> Self {
277        flex32(self.0.log2())
278    }
279
280    fn log10(self) -> Self {
281        flex32(self.0.log10())
282    }
283
284    fn max(self, other: Self) -> Self {
285        flex32(self.0.max(other.0))
286    }
287
288    fn min(self, other: Self) -> Self {
289        flex32(self.0.min(other.0))
290    }
291
292    fn abs_sub(self, other: Self) -> Self {
293        flex32((self.0 - other.0).abs())
294    }
295
296    fn cbrt(self) -> Self {
297        flex32(self.0.cbrt())
298    }
299
300    fn hypot(self, other: Self) -> Self {
301        flex32(self.0.hypot(other.0))
302    }
303
304    fn sin(self) -> Self {
305        flex32(self.0.sin())
306    }
307
308    fn cos(self) -> Self {
309        flex32(self.0.cos())
310    }
311
312    fn tan(self) -> Self {
313        flex32(self.0.tan())
314    }
315
316    fn asin(self) -> Self {
317        flex32(self.0.asin())
318    }
319
320    fn acos(self) -> Self {
321        flex32(self.0.acos())
322    }
323
324    fn atan(self) -> Self {
325        flex32(self.0.atan())
326    }
327
328    fn atan2(self, other: Self) -> Self {
329        flex32(self.0.atan2(other.0))
330    }
331
332    fn sin_cos(self) -> (Self, Self) {
333        let (a, b) = self.0.sin_cos();
334        (flex32(a), flex32(b))
335    }
336
337    fn exp_m1(self) -> Self {
338        flex32(self.0.exp_m1())
339    }
340
341    fn ln_1p(self) -> Self {
342        flex32(self.0.ln_1p())
343    }
344
345    fn sinh(self) -> Self {
346        flex32(self.0.sinh())
347    }
348
349    fn cosh(self) -> Self {
350        flex32(self.0.cosh())
351    }
352
353    fn tanh(self) -> Self {
354        flex32(self.0.tanh())
355    }
356
357    fn asinh(self) -> Self {
358        flex32(self.0.asinh())
359    }
360
361    fn acosh(self) -> Self {
362        flex32(self.0.acosh())
363    }
364
365    fn atanh(self) -> Self {
366        flex32(self.0.atanh())
367    }
368
369    fn integer_decode(self) -> (u64, i16, i8) {
370        self.0.integer_decode()
371    }
372}
373
374impl Num for flex32 {
375    type FromStrRadixErr = <f32 as Num>::FromStrRadixErr;
376
377    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
378        Ok(flex32(f32::from_str_radix(str, radix)?))
379    }
380}
381
382impl One for flex32 {
383    fn one() -> Self {
384        flex32(1.0)
385    }
386}
387
388impl Zero for flex32 {
389    fn zero() -> Self {
390        flex32(0.0)
391    }
392
393    fn is_zero(&self) -> bool {
394        self.0 == 0.0
395    }
396}