cubecl_core/frontend/element/float/
fp8.rs

1use cubecl_common::{e4m3, e5m2, ue8m0};
2use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType};
3
4use crate::{
5    Runtime,
6    compute::KernelLauncher,
7    prelude::{
8        CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, Numeric,
9        ScalarArgSettings, into_mut_expand_element, into_runtime_expand_element,
10    },
11};
12
13impl CubeType for e4m3 {
14    type ExpandType = ExpandElementTyped<e4m3>;
15}
16
17impl CubePrimitive for e4m3 {
18    /// Return the element type to use on GPU
19    fn as_type_native() -> Option<StorageType> {
20        Some(ElemType::Float(FloatKind::E4M3).into())
21    }
22
23    fn from_const_value(value: ConstantScalarValue) -> Self {
24        let ConstantScalarValue::Float(value, _) = value else {
25            unreachable!()
26        };
27        e4m3::from_f64(value)
28    }
29}
30
31impl IntoRuntime for e4m3 {
32    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
33        let elem: ExpandElementTyped<Self> = self.into();
34        into_runtime_expand_element(scope, elem).into()
35    }
36}
37
38impl Numeric for e4m3 {
39    fn min_value() -> Self {
40        Self::from_f64(Self::MIN)
41    }
42    fn max_value() -> Self {
43        Self::from_f64(Self::MAX)
44    }
45}
46
47impl ExpandElementIntoMut for e4m3 {
48    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
49        into_mut_expand_element(scope, elem)
50    }
51}
52
53impl ScalarArgSettings for e4m3 {
54    fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
55        todo!("Not yet supported for scalars")
56    }
57}
58
59impl CubeType for e5m2 {
60    type ExpandType = ExpandElementTyped<e5m2>;
61}
62
63impl CubePrimitive for e5m2 {
64    /// Return the element type to use on GPU
65    fn as_type_native() -> Option<StorageType> {
66        Some(ElemType::Float(FloatKind::E5M2).into())
67    }
68
69    fn from_const_value(value: ConstantScalarValue) -> Self {
70        let ConstantScalarValue::Float(value, _) = value else {
71            unreachable!()
72        };
73        e5m2::from_f64(value)
74    }
75}
76
77impl IntoRuntime for e5m2 {
78    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
79        let elem: ExpandElementTyped<Self> = self.into();
80        into_runtime_expand_element(scope, elem).into()
81    }
82}
83
84impl Numeric for e5m2 {
85    fn min_value() -> Self {
86        Self::from_f64(Self::MIN)
87    }
88    fn max_value() -> Self {
89        Self::from_f64(Self::MAX)
90    }
91}
92
93impl ExpandElementIntoMut for e5m2 {
94    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
95        into_mut_expand_element(scope, elem)
96    }
97}
98
99impl ScalarArgSettings for e5m2 {
100    fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
101        todo!("Not yet supported for scalars")
102    }
103}
104
105impl CubeType for ue8m0 {
106    type ExpandType = ExpandElementTyped<ue8m0>;
107}
108
109impl CubePrimitive for ue8m0 {
110    /// Return the element type to use on GPU
111    fn as_type_native() -> Option<StorageType> {
112        Some(ElemType::Float(FloatKind::UE8M0).into())
113    }
114
115    fn from_const_value(value: ConstantScalarValue) -> Self {
116        let ConstantScalarValue::Float(value, _) = value else {
117            unreachable!()
118        };
119        ue8m0::from_f64(value)
120    }
121}
122
123impl IntoRuntime for ue8m0 {
124    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
125        let elem: ExpandElementTyped<Self> = self.into();
126        into_runtime_expand_element(scope, elem).into()
127    }
128}
129
130impl Numeric for ue8m0 {
131    fn min_value() -> Self {
132        Self::from_f64(Self::MIN)
133    }
134    fn max_value() -> Self {
135        Self::from_f64(Self::MAX)
136    }
137}
138
139impl ExpandElementIntoMut for ue8m0 {
140    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
141        into_mut_expand_element(scope, elem)
142    }
143}
144
145impl ScalarArgSettings for ue8m0 {
146    fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
147        todo!("Not yet supported for scalars")
148    }
149}