cubecl_core/frontend/element/float/
relaxed.rs1use core::f32;
2use std::{
3 cmp::Ordering,
4 ops::{Div, DivAssign, Mul, MulAssign, Rem, RemAssign},
5};
6
7use bytemuck::{Pod, Zeroable};
8use derive_more::derive::{
9 Add, AddAssign, Display, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
10};
11use num_traits::{Num, NumCast, One, ToPrimitive, Zero};
12use serde::Serialize;
13
14use crate::{
15 ir::{Elem, FloatKind},
16 prelude::Numeric,
17};
18
19use super::{
20 init_expand_element, CubeContext, CubePrimitive, CubeType, ExpandElement,
21 ExpandElementBaseInit, ExpandElementTyped, Float, Init, IntoRuntime, KernelBuilder,
22 KernelLauncher, LaunchArgExpand, Runtime, ScalarArgSettings,
23};
24
25#[allow(non_camel_case_types)]
28#[repr(transparent)]
29#[derive(
30 Clone,
31 Copy,
32 Default,
33 Serialize,
34 Zeroable,
35 Pod,
36 PartialEq,
37 PartialOrd,
38 Neg,
39 Add,
40 Sub,
41 Mul,
42 Div,
43 Rem,
44 AddAssign,
45 SubAssign,
46 MulAssign,
47 DivAssign,
48 RemAssign,
49 Debug,
50 Display,
51)]
52pub struct flex32(f32);
53
54impl flex32 {
55 pub const MIN_POSITIVE: Self = Self(half::f16::MIN_POSITIVE.to_f32_const());
56
57 pub const fn from_f32(val: f32) -> Self {
58 flex32(val)
59 }
60
61 pub const fn from_f64(val: f64) -> Self {
62 flex32(val as f32)
63 }
64
65 pub const fn to_f32(self) -> f32 {
66 self.0
67 }
68
69 pub const fn to_f64(self) -> f64 {
70 self.0 as f64
71 }
72
73 pub fn total_cmp(&self, other: &flex32) -> Ordering {
74 self.0.total_cmp(&other.0)
75 }
76
77 pub fn is_nan(&self) -> bool {
78 self.0.is_nan()
79 }
80}
81
82impl Mul for flex32 {
83 type Output = flex32;
84
85 fn mul(self, rhs: Self) -> Self::Output {
86 flex32(self.0 * rhs.0)
87 }
88}
89
90impl Div for flex32 {
91 type Output = flex32;
92
93 fn div(self, rhs: Self) -> Self::Output {
94 flex32(self.0 / rhs.0)
95 }
96}
97
98impl Rem for flex32 {
99 type Output = flex32;
100
101 fn rem(self, rhs: Self) -> Self::Output {
102 flex32(self.0 % rhs.0)
103 }
104}
105
106impl MulAssign for flex32 {
107 fn mul_assign(&mut self, rhs: Self) {
108 self.0 *= rhs.0;
109 }
110}
111
112impl DivAssign for flex32 {
113 fn div_assign(&mut self, rhs: Self) {
114 self.0 /= rhs.0;
115 }
116}
117
118impl RemAssign for flex32 {
119 fn rem_assign(&mut self, rhs: Self) {
120 self.0 %= rhs.0;
121 }
122}
123
124impl From<f32> for flex32 {
125 fn from(value: f32) -> Self {
126 Self::from_f32(value)
127 }
128}
129
130impl From<flex32> for f32 {
131 fn from(val: flex32) -> Self {
132 val.to_f32()
133 }
134}
135
136impl ToPrimitive for flex32 {
137 fn to_i64(&self) -> Option<i64> {
138 Some((*self).to_f32() as i64)
139 }
140
141 fn to_u64(&self) -> Option<u64> {
142 Some((*self).to_f32() as u64)
143 }
144
145 fn to_f32(&self) -> Option<f32> {
146 Some((*self).to_f32())
147 }
148
149 fn to_f64(&self) -> Option<f64> {
150 Some((*self).to_f32() as f64)
151 }
152}
153
154impl NumCast for flex32 {
155 fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
156 Some(flex32::from_f32(n.to_f32()?))
157 }
158}
159
160impl CubeType for flex32 {
161 type ExpandType = ExpandElementTyped<flex32>;
162}
163
164impl CubePrimitive for flex32 {
165 fn as_elem_native() -> Option<Elem> {
167 Some(Elem::Float(FloatKind::Flex32))
168 }
169}
170
171impl IntoRuntime for flex32 {
172 fn __expand_runtime_method(self, context: &mut CubeContext) -> ExpandElementTyped<Self> {
173 let expand: ExpandElementTyped<Self> = self.into();
174 Init::init(expand, context)
175 }
176}
177
178impl Numeric for flex32 {
179 fn min_value() -> Self {
180 <Self as num_traits::Float>::min_value()
181 }
182 fn max_value() -> Self {
183 <Self as num_traits::Float>::max_value()
184 }
185}
186
187impl ExpandElementBaseInit for flex32 {
188 fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
189 init_expand_element(context, elem)
190 }
191}
192
193impl Float for flex32 {
194 const DIGITS: u32 = 32;
195
196 const EPSILON: Self = flex32::from_f32(half::f16::EPSILON.to_f32_const());
197
198 const INFINITY: Self = flex32::from_f32(f32::INFINITY);
199
200 const MANTISSA_DIGITS: u32 = f32::MANTISSA_DIGITS;
201
202 const MAX_10_EXP: i32 = f32::MAX_10_EXP;
204 const MAX_EXP: i32 = f32::MAX_EXP;
206
207 const MIN_10_EXP: i32 = f32::MIN_10_EXP;
209 const MIN_EXP: i32 = f32::MIN_EXP;
211
212 const MIN_POSITIVE: Self = flex32(f32::MIN_POSITIVE);
213
214 const NAN: Self = flex32::from_f32(f32::NAN);
215
216 const NEG_INFINITY: Self = flex32::from_f32(f32::NEG_INFINITY);
217
218 const RADIX: u32 = 2;
219
220 fn new(val: f32) -> Self {
221 flex32::from_f32(val)
222 }
223}
224
225impl LaunchArgExpand for flex32 {
226 type CompilationArg = ();
227
228 fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
229 builder.scalar(flex32::as_elem(&builder.context)).into()
230 }
231}
232
233impl ScalarArgSettings for flex32 {
234 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
235 settings.register_f32(self.0);
236 }
237}
238
239impl num_traits::Float for flex32 {
240 fn nan() -> Self {
241 flex32(f32::nan())
242 }
243
244 fn infinity() -> Self {
245 flex32(f32::infinity())
246 }
247
248 fn neg_infinity() -> Self {
249 flex32(f32::neg_infinity())
250 }
251
252 fn neg_zero() -> Self {
253 flex32(f32::neg_zero())
254 }
255
256 fn min_value() -> Self {
257 flex32(<f32 as num_traits::Float>::min_value())
258 }
259
260 fn min_positive_value() -> Self {
261 flex32(f32::min_positive_value())
262 }
263
264 fn max_value() -> Self {
265 flex32(<f32 as num_traits::Float>::max_value())
266 }
267
268 fn is_nan(self) -> bool {
269 self.0.is_nan()
270 }
271
272 fn is_infinite(self) -> bool {
273 self.0.is_infinite()
274 }
275
276 fn is_finite(self) -> bool {
277 self.0.is_finite()
278 }
279
280 fn is_normal(self) -> bool {
281 self.0.is_normal()
282 }
283
284 fn classify(self) -> std::num::FpCategory {
285 self.0.classify()
286 }
287
288 fn floor(self) -> Self {
289 flex32(self.0.floor())
290 }
291
292 fn ceil(self) -> Self {
293 flex32(self.0.ceil())
294 }
295
296 fn round(self) -> Self {
297 flex32(self.0.round())
298 }
299
300 fn trunc(self) -> Self {
301 flex32(self.0.trunc())
302 }
303
304 fn fract(self) -> Self {
305 flex32(self.0.fract())
306 }
307
308 fn abs(self) -> Self {
309 flex32(self.0.abs())
310 }
311
312 fn signum(self) -> Self {
313 flex32(self.0.signum())
314 }
315
316 fn is_sign_positive(self) -> bool {
317 self.0.is_sign_positive()
318 }
319
320 fn is_sign_negative(self) -> bool {
321 self.0.is_sign_negative()
322 }
323
324 fn mul_add(self, a: Self, b: Self) -> Self {
325 flex32(self.0.mul_add(a.0, b.0))
326 }
327
328 fn recip(self) -> Self {
329 flex32(self.0.recip())
330 }
331
332 fn powi(self, n: i32) -> Self {
333 flex32(self.0.powi(n))
334 }
335
336 fn powf(self, n: Self) -> Self {
337 flex32(self.0.powf(n.0))
338 }
339
340 fn sqrt(self) -> Self {
341 flex32(self.0.sqrt())
342 }
343
344 fn exp(self) -> Self {
345 flex32(self.0.exp())
346 }
347
348 fn exp2(self) -> Self {
349 flex32(self.0.exp2())
350 }
351
352 fn ln(self) -> Self {
353 flex32(self.0.ln())
354 }
355
356 fn log(self, base: Self) -> Self {
357 flex32(self.0.log(base.0))
358 }
359
360 fn log2(self) -> Self {
361 flex32(self.0.log2())
362 }
363
364 fn log10(self) -> Self {
365 flex32(self.0.log10())
366 }
367
368 fn max(self, other: Self) -> Self {
369 flex32(self.0.max(other.0))
370 }
371
372 fn min(self, other: Self) -> Self {
373 flex32(self.0.min(other.0))
374 }
375
376 fn abs_sub(self, other: Self) -> Self {
377 flex32((self.0 - other.0).abs())
378 }
379
380 fn cbrt(self) -> Self {
381 flex32(self.0.cbrt())
382 }
383
384 fn hypot(self, other: Self) -> Self {
385 flex32(self.0.hypot(other.0))
386 }
387
388 fn sin(self) -> Self {
389 flex32(self.0.sin())
390 }
391
392 fn cos(self) -> Self {
393 flex32(self.0.cos())
394 }
395
396 fn tan(self) -> Self {
397 flex32(self.0.tan())
398 }
399
400 fn asin(self) -> Self {
401 flex32(self.0.asin())
402 }
403
404 fn acos(self) -> Self {
405 flex32(self.0.acos())
406 }
407
408 fn atan(self) -> Self {
409 flex32(self.0.atan())
410 }
411
412 fn atan2(self, other: Self) -> Self {
413 flex32(self.0.atan2(other.0))
414 }
415
416 fn sin_cos(self) -> (Self, Self) {
417 let (a, b) = self.0.sin_cos();
418 (flex32(a), flex32(b))
419 }
420
421 fn exp_m1(self) -> Self {
422 flex32(self.0.exp_m1())
423 }
424
425 fn ln_1p(self) -> Self {
426 flex32(self.0.ln_1p())
427 }
428
429 fn sinh(self) -> Self {
430 flex32(self.0.sinh())
431 }
432
433 fn cosh(self) -> Self {
434 flex32(self.0.cosh())
435 }
436
437 fn tanh(self) -> Self {
438 flex32(self.0.tanh())
439 }
440
441 fn asinh(self) -> Self {
442 flex32(self.0.asinh())
443 }
444
445 fn acosh(self) -> Self {
446 flex32(self.0.acosh())
447 }
448
449 fn atanh(self) -> Self {
450 flex32(self.0.atanh())
451 }
452
453 fn integer_decode(self) -> (u64, i16, i8) {
454 self.0.integer_decode()
455 }
456}
457
458impl Num for flex32 {
459 type FromStrRadixErr = <f32 as Num>::FromStrRadixErr;
460
461 fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
462 Ok(flex32(f32::from_str_radix(str, radix)?))
463 }
464}
465
466impl One for flex32 {
467 fn one() -> Self {
468 flex32(1.0)
469 }
470}
471
472impl Zero for flex32 {
473 fn zero() -> Self {
474 flex32(0.0)
475 }
476
477 fn is_zero(&self) -> bool {
478 self.0 == 0.0
479 }
480}