Skip to main content

cubecl_core/frontend/element/float/
fp4.rs

1use cubecl_common::{e2m1, e2m1x2};
2use cubecl_ir::{ConstantValue, ElemType, FloatKind, Scope, StorageType, Type};
3
4use crate::prelude::*;
5
6impl CubeType for e2m1 {
7    type ExpandType = NativeExpand<e2m1>;
8}
9
10impl Scalar for e2m1 {}
11impl CubePrimitive for e2m1 {
12    type Scalar = Self;
13    type Size = Const<1>;
14    type WithScalar<S: Scalar> = S;
15
16    /// Return the element type to use on GPU
17    fn as_type_native() -> Option<Type> {
18        Some(StorageType::Scalar(ElemType::Float(FloatKind::E2M1)).into())
19    }
20
21    fn from_const_value(value: ConstantValue) -> Self {
22        let ConstantValue::Float(value) = value else {
23            unreachable!()
24        };
25        e2m1::from_f64(value)
26    }
27}
28
29impl IntoRuntime for e2m1 {
30    fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
31        self.into()
32    }
33}
34
35impl NativeAssign for e2m1 {}
36
37impl CubeType for e2m1x2 {
38    type ExpandType = NativeExpand<e2m1x2>;
39}
40
41// Considered a scalar because it's really just a `u8` in a trenchcoat, and should be possible to
42// store in a `Vector`.
43impl Scalar for e2m1x2 {}
44impl CubePrimitive for e2m1x2 {
45    type Scalar = Self;
46    type Size = Const<1>;
47    type WithScalar<S: Scalar> = S;
48
49    /// Return the element type to use on GPU
50    fn as_type_native() -> Option<Type> {
51        Some(StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2).into())
52    }
53
54    fn from_const_value(value: ConstantValue) -> Self {
55        let ConstantValue::Float(value) = value else {
56            unreachable!()
57        };
58        let val = e2m1::from_f64(value).to_bits();
59        // Fill both values, not sure this is ever useful but it works
60        e2m1x2::from_bits(val | (val << 4))
61    }
62}
63
64impl IntoRuntime for e2m1x2 {
65    fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
66        self.into()
67    }
68}
69
70impl NativeAssign for e2m1x2 {}