cubecl_core/frontend/element/float/
fp8.rs

1use cubecl_common::{e4m3, e5m2, ue8m0};
2use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType};
3
4use crate::prelude::{
5    CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, Numeric,
6    into_mut_expand_element, into_runtime_expand_element,
7};
8
9impl CubeType for e4m3 {
10    type ExpandType = ExpandElementTyped<e4m3>;
11}
12
13impl CubePrimitive for e4m3 {
14    /// Return the element type to use on GPU
15    fn as_type_native() -> Option<StorageType> {
16        Some(ElemType::Float(FloatKind::E4M3).into())
17    }
18
19    fn from_const_value(value: ConstantScalarValue) -> Self {
20        let ConstantScalarValue::Float(value, _) = value else {
21            unreachable!()
22        };
23        e4m3::from_f64(value)
24    }
25}
26
27impl IntoRuntime for e4m3 {
28    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
29        let elem: ExpandElementTyped<Self> = self.into();
30        into_runtime_expand_element(scope, elem).into()
31    }
32}
33
34impl Numeric for e4m3 {
35    fn min_value() -> Self {
36        Self::from_f64(Self::MIN)
37    }
38    fn max_value() -> Self {
39        Self::from_f64(Self::MAX)
40    }
41}
42
43impl ExpandElementIntoMut for e4m3 {
44    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
45        into_mut_expand_element(scope, elem)
46    }
47}
48
49impl CubeType for e5m2 {
50    type ExpandType = ExpandElementTyped<e5m2>;
51}
52
53impl CubePrimitive for e5m2 {
54    /// Return the element type to use on GPU
55    fn as_type_native() -> Option<StorageType> {
56        Some(ElemType::Float(FloatKind::E5M2).into())
57    }
58
59    fn from_const_value(value: ConstantScalarValue) -> Self {
60        let ConstantScalarValue::Float(value, _) = value else {
61            unreachable!()
62        };
63        e5m2::from_f64(value)
64    }
65}
66
67impl IntoRuntime for e5m2 {
68    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
69        let elem: ExpandElementTyped<Self> = self.into();
70        into_runtime_expand_element(scope, elem).into()
71    }
72}
73
74impl Numeric for e5m2 {
75    fn min_value() -> Self {
76        Self::from_f64(Self::MIN)
77    }
78    fn max_value() -> Self {
79        Self::from_f64(Self::MAX)
80    }
81}
82
83impl ExpandElementIntoMut for e5m2 {
84    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
85        into_mut_expand_element(scope, elem)
86    }
87}
88
89impl CubeType for ue8m0 {
90    type ExpandType = ExpandElementTyped<ue8m0>;
91}
92
93impl CubePrimitive for ue8m0 {
94    /// Return the element type to use on GPU
95    fn as_type_native() -> Option<StorageType> {
96        Some(ElemType::Float(FloatKind::UE8M0).into())
97    }
98
99    fn from_const_value(value: ConstantScalarValue) -> Self {
100        let ConstantScalarValue::Float(value, _) = value else {
101            unreachable!()
102        };
103        ue8m0::from_f64(value)
104    }
105}
106
107impl IntoRuntime for ue8m0 {
108    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
109        let elem: ExpandElementTyped<Self> = self.into();
110        into_runtime_expand_element(scope, elem).into()
111    }
112}
113
114impl Numeric for ue8m0 {
115    fn min_value() -> Self {
116        Self::from_f64(Self::MIN)
117    }
118    fn max_value() -> Self {
119        Self::from_f64(Self::MAX)
120    }
121}
122
123impl ExpandElementIntoMut for ue8m0 {
124    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
125        into_mut_expand_element(scope, elem)
126    }
127}