cubecl_core/frontend/element/float/
relaxed.rs

1use cubecl_common::flex32;
2use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType};
3
4use crate::prelude::{Numeric, into_runtime_expand_element};
5
6use super::{
7    CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, Float, IntoRuntime,
8    KernelLauncher, Runtime, ScalarArgSettings, into_mut_expand_element,
9};
10
11impl CubeType for flex32 {
12    type ExpandType = ExpandElementTyped<flex32>;
13}
14
15impl CubePrimitive for flex32 {
16    /// Return the element type to use on GPU
17    fn as_type_native() -> Option<StorageType> {
18        Some(ElemType::Float(FloatKind::Flex32).into())
19    }
20
21    fn from_const_value(value: ConstantScalarValue) -> Self {
22        let ConstantScalarValue::Float(value, _) = value else {
23            unreachable!()
24        };
25        flex32::from_f64(value)
26    }
27}
28
29impl IntoRuntime for flex32 {
30    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
31        let elem: ExpandElementTyped<Self> = self.into();
32        into_runtime_expand_element(scope, elem).into()
33    }
34}
35
36impl Numeric for flex32 {
37    fn min_value() -> Self {
38        <Self as num_traits::Float>::min_value()
39    }
40    fn max_value() -> Self {
41        <Self as num_traits::Float>::max_value()
42    }
43}
44
45impl ExpandElementIntoMut for flex32 {
46    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
47        into_mut_expand_element(scope, elem)
48    }
49}
50
51impl Float for flex32 {
52    const DIGITS: u32 = 32;
53
54    const EPSILON: Self = flex32::from_f32(half::f16::EPSILON.to_f32_const());
55
56    const INFINITY: Self = flex32::from_f32(f32::INFINITY);
57
58    const MANTISSA_DIGITS: u32 = f32::MANTISSA_DIGITS;
59
60    /// Maximum possible [`flex32`](crate::frontend::flex32) power of 10 exponent
61    const MAX_10_EXP: i32 = f32::MAX_10_EXP;
62    /// Maximum possible [`flex32`](crate::frontend::flex32) power of 2 exponent
63    const MAX_EXP: i32 = f32::MAX_EXP;
64
65    /// Minimum possible normal [`flex32`](crate::frontend::flex32) power of 10 exponent
66    const MIN_10_EXP: i32 = f32::MIN_10_EXP;
67    /// One greater than the minimum possible normal [`flex32`](crate::frontend::flex32) power of 2 exponent
68    const MIN_EXP: i32 = f32::MIN_EXP;
69
70    const MIN_POSITIVE: Self = flex32::from_f32(f32::MIN_POSITIVE);
71
72    const NAN: Self = flex32::from_f32(f32::NAN);
73
74    const NEG_INFINITY: Self = flex32::from_f32(f32::NEG_INFINITY);
75
76    const RADIX: u32 = 2;
77
78    fn new(val: f32) -> Self {
79        flex32::from_f32(val)
80    }
81}
82
83impl ScalarArgSettings for flex32 {
84    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
85        settings.register_f32(self.to_f32());
86    }
87}