cubecl_core/frontend/element/float/
relaxed.rs

1use core::f32;
2use std::{
3    cmp::Ordering,
4    ops::{Div, DivAssign, Mul, MulAssign, Rem, RemAssign},
5};
6
7use bytemuck::{Pod, Zeroable};
8use derive_more::derive::{
9    Add, AddAssign, Display, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11use num_traits::{Num, NumCast, One, ToPrimitive, Zero};
12use serde::Serialize;
13
14use crate::{
15    ir::{Elem, FloatKind},
16    prelude::Numeric,
17};
18
19use super::{
20    init_expand_element, CubeContext, CubePrimitive, CubeType, ExpandElement,
21    ExpandElementBaseInit, ExpandElementTyped, Float, Init, IntoRuntime, KernelBuilder,
22    KernelLauncher, LaunchArgExpand, Runtime, ScalarArgSettings,
23};
24
25/// A floating point type with relaxed precision, minimum [`f16`], max [`f32`].
26///
27#[allow(non_camel_case_types)]
28#[repr(transparent)]
29#[derive(
30    Clone,
31    Copy,
32    Default,
33    Serialize,
34    Zeroable,
35    Pod,
36    PartialEq,
37    PartialOrd,
38    Neg,
39    Add,
40    Sub,
41    Mul,
42    Div,
43    Rem,
44    AddAssign,
45    SubAssign,
46    MulAssign,
47    DivAssign,
48    RemAssign,
49    Debug,
50    Display,
51)]
52pub struct flex32(f32);
53
54impl flex32 {
55    pub const MIN_POSITIVE: Self = Self(half::f16::MIN_POSITIVE.to_f32_const());
56
57    pub const fn from_f32(val: f32) -> Self {
58        flex32(val)
59    }
60
61    pub const fn from_f64(val: f64) -> Self {
62        flex32(val as f32)
63    }
64
65    pub const fn to_f32(self) -> f32 {
66        self.0
67    }
68
69    pub const fn to_f64(self) -> f64 {
70        self.0 as f64
71    }
72
73    pub fn total_cmp(&self, other: &flex32) -> Ordering {
74        self.0.total_cmp(&other.0)
75    }
76
77    pub fn is_nan(&self) -> bool {
78        self.0.is_nan()
79    }
80}
81
82impl Mul for flex32 {
83    type Output = flex32;
84
85    fn mul(self, rhs: Self) -> Self::Output {
86        flex32(self.0 * rhs.0)
87    }
88}
89
90impl Div for flex32 {
91    type Output = flex32;
92
93    fn div(self, rhs: Self) -> Self::Output {
94        flex32(self.0 / rhs.0)
95    }
96}
97
98impl Rem for flex32 {
99    type Output = flex32;
100
101    fn rem(self, rhs: Self) -> Self::Output {
102        flex32(self.0 % rhs.0)
103    }
104}
105
106impl MulAssign for flex32 {
107    fn mul_assign(&mut self, rhs: Self) {
108        self.0 *= rhs.0;
109    }
110}
111
112impl DivAssign for flex32 {
113    fn div_assign(&mut self, rhs: Self) {
114        self.0 /= rhs.0;
115    }
116}
117
118impl RemAssign for flex32 {
119    fn rem_assign(&mut self, rhs: Self) {
120        self.0 %= rhs.0;
121    }
122}
123
124impl From<f32> for flex32 {
125    fn from(value: f32) -> Self {
126        Self::from_f32(value)
127    }
128}
129
130impl From<flex32> for f32 {
131    fn from(val: flex32) -> Self {
132        val.to_f32()
133    }
134}
135
136impl ToPrimitive for flex32 {
137    fn to_i64(&self) -> Option<i64> {
138        Some((*self).to_f32() as i64)
139    }
140
141    fn to_u64(&self) -> Option<u64> {
142        Some((*self).to_f32() as u64)
143    }
144
145    fn to_f32(&self) -> Option<f32> {
146        Some((*self).to_f32())
147    }
148
149    fn to_f64(&self) -> Option<f64> {
150        Some((*self).to_f32() as f64)
151    }
152}
153
154impl NumCast for flex32 {
155    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
156        Some(flex32::from_f32(n.to_f32()?))
157    }
158}
159
160impl CubeType for flex32 {
161    type ExpandType = ExpandElementTyped<flex32>;
162}
163
164impl CubePrimitive for flex32 {
165    /// Return the element type to use on GPU
166    fn as_elem_native() -> Option<Elem> {
167        Some(Elem::Float(FloatKind::Flex32))
168    }
169}
170
171impl IntoRuntime for flex32 {
172    fn __expand_runtime_method(self, context: &mut CubeContext) -> ExpandElementTyped<Self> {
173        let expand: ExpandElementTyped<Self> = self.into();
174        Init::init(expand, context)
175    }
176}
177
178impl Numeric for flex32 {
179    fn min_value() -> Self {
180        <Self as num_traits::Float>::min_value()
181    }
182    fn max_value() -> Self {
183        <Self as num_traits::Float>::max_value()
184    }
185}
186
187impl ExpandElementBaseInit for flex32 {
188    fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
189        init_expand_element(context, elem)
190    }
191}
192
193impl Float for flex32 {
194    const DIGITS: u32 = 32;
195
196    const EPSILON: Self = flex32::from_f32(half::f16::EPSILON.to_f32_const());
197
198    const INFINITY: Self = flex32::from_f32(f32::INFINITY);
199
200    const MANTISSA_DIGITS: u32 = f32::MANTISSA_DIGITS;
201
202    /// Maximum possible [`tf32`] power of 10 exponent
203    const MAX_10_EXP: i32 = f32::MAX_10_EXP;
204    /// Maximum possible [`tf32`] power of 2 exponent
205    const MAX_EXP: i32 = f32::MAX_EXP;
206
207    /// Minimum possible normal [`tf32`] power of 10 exponent
208    const MIN_10_EXP: i32 = f32::MIN_10_EXP;
209    /// One greater than the minimum possible normal [`v`] power of 2 exponent
210    const MIN_EXP: i32 = f32::MIN_EXP;
211
212    const MIN_POSITIVE: Self = flex32(f32::MIN_POSITIVE);
213
214    const NAN: Self = flex32::from_f32(f32::NAN);
215
216    const NEG_INFINITY: Self = flex32::from_f32(f32::NEG_INFINITY);
217
218    const RADIX: u32 = 2;
219
220    fn new(val: f32) -> Self {
221        flex32::from_f32(val)
222    }
223}
224
225impl LaunchArgExpand for flex32 {
226    type CompilationArg = ();
227
228    fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
229        builder.scalar(flex32::as_elem(&builder.context)).into()
230    }
231}
232
233impl ScalarArgSettings for flex32 {
234    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
235        settings.register_f32(self.0);
236    }
237}
238
239impl num_traits::Float for flex32 {
240    fn nan() -> Self {
241        flex32(f32::nan())
242    }
243
244    fn infinity() -> Self {
245        flex32(f32::infinity())
246    }
247
248    fn neg_infinity() -> Self {
249        flex32(f32::neg_infinity())
250    }
251
252    fn neg_zero() -> Self {
253        flex32(f32::neg_zero())
254    }
255
256    fn min_value() -> Self {
257        flex32(<f32 as num_traits::Float>::min_value())
258    }
259
260    fn min_positive_value() -> Self {
261        flex32(f32::min_positive_value())
262    }
263
264    fn max_value() -> Self {
265        flex32(<f32 as num_traits::Float>::max_value())
266    }
267
268    fn is_nan(self) -> bool {
269        self.0.is_nan()
270    }
271
272    fn is_infinite(self) -> bool {
273        self.0.is_infinite()
274    }
275
276    fn is_finite(self) -> bool {
277        self.0.is_finite()
278    }
279
280    fn is_normal(self) -> bool {
281        self.0.is_normal()
282    }
283
284    fn classify(self) -> std::num::FpCategory {
285        self.0.classify()
286    }
287
288    fn floor(self) -> Self {
289        flex32(self.0.floor())
290    }
291
292    fn ceil(self) -> Self {
293        flex32(self.0.ceil())
294    }
295
296    fn round(self) -> Self {
297        flex32(self.0.round())
298    }
299
300    fn trunc(self) -> Self {
301        flex32(self.0.trunc())
302    }
303
304    fn fract(self) -> Self {
305        flex32(self.0.fract())
306    }
307
308    fn abs(self) -> Self {
309        flex32(self.0.abs())
310    }
311
312    fn signum(self) -> Self {
313        flex32(self.0.signum())
314    }
315
316    fn is_sign_positive(self) -> bool {
317        self.0.is_sign_positive()
318    }
319
320    fn is_sign_negative(self) -> bool {
321        self.0.is_sign_negative()
322    }
323
324    fn mul_add(self, a: Self, b: Self) -> Self {
325        flex32(self.0.mul_add(a.0, b.0))
326    }
327
328    fn recip(self) -> Self {
329        flex32(self.0.recip())
330    }
331
332    fn powi(self, n: i32) -> Self {
333        flex32(self.0.powi(n))
334    }
335
336    fn powf(self, n: Self) -> Self {
337        flex32(self.0.powf(n.0))
338    }
339
340    fn sqrt(self) -> Self {
341        flex32(self.0.sqrt())
342    }
343
344    fn exp(self) -> Self {
345        flex32(self.0.exp())
346    }
347
348    fn exp2(self) -> Self {
349        flex32(self.0.exp2())
350    }
351
352    fn ln(self) -> Self {
353        flex32(self.0.ln())
354    }
355
356    fn log(self, base: Self) -> Self {
357        flex32(self.0.log(base.0))
358    }
359
360    fn log2(self) -> Self {
361        flex32(self.0.log2())
362    }
363
364    fn log10(self) -> Self {
365        flex32(self.0.log10())
366    }
367
368    fn max(self, other: Self) -> Self {
369        flex32(self.0.max(other.0))
370    }
371
372    fn min(self, other: Self) -> Self {
373        flex32(self.0.min(other.0))
374    }
375
376    fn abs_sub(self, other: Self) -> Self {
377        flex32((self.0 - other.0).abs())
378    }
379
380    fn cbrt(self) -> Self {
381        flex32(self.0.cbrt())
382    }
383
384    fn hypot(self, other: Self) -> Self {
385        flex32(self.0.hypot(other.0))
386    }
387
388    fn sin(self) -> Self {
389        flex32(self.0.sin())
390    }
391
392    fn cos(self) -> Self {
393        flex32(self.0.cos())
394    }
395
396    fn tan(self) -> Self {
397        flex32(self.0.tan())
398    }
399
400    fn asin(self) -> Self {
401        flex32(self.0.asin())
402    }
403
404    fn acos(self) -> Self {
405        flex32(self.0.acos())
406    }
407
408    fn atan(self) -> Self {
409        flex32(self.0.atan())
410    }
411
412    fn atan2(self, other: Self) -> Self {
413        flex32(self.0.atan2(other.0))
414    }
415
416    fn sin_cos(self) -> (Self, Self) {
417        let (a, b) = self.0.sin_cos();
418        (flex32(a), flex32(b))
419    }
420
421    fn exp_m1(self) -> Self {
422        flex32(self.0.exp_m1())
423    }
424
425    fn ln_1p(self) -> Self {
426        flex32(self.0.ln_1p())
427    }
428
429    fn sinh(self) -> Self {
430        flex32(self.0.sinh())
431    }
432
433    fn cosh(self) -> Self {
434        flex32(self.0.cosh())
435    }
436
437    fn tanh(self) -> Self {
438        flex32(self.0.tanh())
439    }
440
441    fn asinh(self) -> Self {
442        flex32(self.0.asinh())
443    }
444
445    fn acosh(self) -> Self {
446        flex32(self.0.acosh())
447    }
448
449    fn atanh(self) -> Self {
450        flex32(self.0.atanh())
451    }
452
453    fn integer_decode(self) -> (u64, i16, i8) {
454        self.0.integer_decode()
455    }
456}
457
458impl Num for flex32 {
459    type FromStrRadixErr = <f32 as Num>::FromStrRadixErr;
460
461    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
462        Ok(flex32(f32::from_str_radix(str, radix)?))
463    }
464}
465
466impl One for flex32 {
467    fn one() -> Self {
468        flex32(1.0)
469    }
470}
471
472impl Zero for flex32 {
473    fn zero() -> Self {
474        flex32(0.0)
475    }
476
477    fn is_zero(&self) -> bool {
478        self.0 == 0.0
479    }
480}