cubecl_core/frontend/element/float/
fp4.rs

1use cubecl_common::{e2m1, e2m1x2};
2use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType};
3
4use crate::prelude::{
5    CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime,
6    into_mut_expand_element, into_runtime_expand_element,
7};
8
9impl CubeType for e2m1 {
10    type ExpandType = ExpandElementTyped<e2m1>;
11}
12
13impl CubePrimitive for e2m1 {
14    /// Return the element type to use on GPU
15    fn as_type_native() -> Option<StorageType> {
16        Some(StorageType::Scalar(ElemType::Float(FloatKind::E2M1)))
17    }
18
19    fn from_const_value(value: ConstantScalarValue) -> Self {
20        let ConstantScalarValue::Float(value, _) = value else {
21            unreachable!()
22        };
23        e2m1::from_f64(value)
24    }
25}
26
27impl IntoRuntime for e2m1 {
28    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
29        let elem: ExpandElementTyped<Self> = self.into();
30        into_runtime_expand_element(scope, elem).into()
31    }
32}
33
34impl ExpandElementIntoMut for e2m1 {
35    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
36        into_mut_expand_element(scope, elem)
37    }
38}
39
40impl CubeType for e2m1x2 {
41    type ExpandType = ExpandElementTyped<e2m1x2>;
42}
43
44impl CubePrimitive for e2m1x2 {
45    /// Return the element type to use on GPU
46    fn as_type_native() -> Option<StorageType> {
47        Some(StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2))
48    }
49
50    fn from_const_value(value: ConstantScalarValue) -> Self {
51        let ConstantScalarValue::Float(value, _) = value else {
52            unreachable!()
53        };
54        let val = e2m1::from_f64(value).to_bits();
55        // Fill both values, not sure this is ever useful but it works
56        e2m1x2::from_bits(val | (val << 4))
57    }
58}
59
60impl IntoRuntime for e2m1x2 {
61    fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
62        let elem: ExpandElementTyped<Self> = self.into();
63        into_runtime_expand_element(scope, elem).into()
64    }
65}
66
67impl ExpandElementIntoMut for e2m1x2 {
68    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
69        into_mut_expand_element(scope, elem)
70    }
71}