cubecl_core/frontend/element/float/
fp4.rs

1use cubecl_common::{e2m1, e2m1x2};
2use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType};
3
4use crate::{
5    Runtime,
6    compute::KernelLauncher,
7    prelude::{
8        CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime,
9        ScalarArgSettings, into_mut_expand_element, into_runtime_expand_element,
10    },
11};
12
13impl CubeType for e2m1 {
14    type ExpandType = ExpandElementTyped<e2m1>;
15}
16
17impl CubePrimitive for e2m1 {
18    /// Return the element type to use on GPU
19    fn as_type_native() -> Option<StorageType> {
20        Some(StorageType::Scalar(ElemType::Float(FloatKind::E2M1)))
21    }
22
23    fn from_const_value(value: ConstantScalarValue) -> Self {
24        let ConstantScalarValue::Float(value, _) = value else {
25            unreachable!()
26        };
27        e2m1::from_f64(value)
28    }
29}
30
31impl IntoRuntime for e2m1 {
32    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
33        let elem: ExpandElementTyped<Self> = self.into();
34        into_runtime_expand_element(scope, elem).into()
35    }
36}
37
38impl ExpandElementIntoMut for e2m1 {
39    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
40        into_mut_expand_element(scope, elem)
41    }
42}
43
44impl CubeType for e2m1x2 {
45    type ExpandType = ExpandElementTyped<e2m1x2>;
46}
47
48impl CubePrimitive for e2m1x2 {
49    /// Return the element type to use on GPU
50    fn as_type_native() -> Option<StorageType> {
51        Some(StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2))
52    }
53
54    fn from_const_value(value: ConstantScalarValue) -> Self {
55        let ConstantScalarValue::Float(value, _) = value else {
56            unreachable!()
57        };
58        let val = e2m1::from_f64(value).to_bits();
59        // Fill both values, not sure this is ever useful but it works
60        e2m1x2::from_bits(val | (val << 4))
61    }
62}
63
64impl IntoRuntime for e2m1x2 {
65    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
66        let elem: ExpandElementTyped<Self> = self.into();
67        into_runtime_expand_element(scope, elem).into()
68    }
69}
70
71impl ExpandElementIntoMut for e2m1x2 {
72    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
73        into_mut_expand_element(scope, elem)
74    }
75}
76
77impl ScalarArgSettings for e2m1x2 {
78    fn register<R: Runtime>(&self, _settings: &mut KernelLauncher<R>) {
79        todo!("Not yet supported for scalars")
80    }
81}