cubecl_core/frontend/element/float/
fp4.rs1use cubecl_common::{e2m1, e2m1x2};
2use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType};
3
4use crate::prelude::{
5 CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime,
6 into_mut_expand_element, into_runtime_expand_element,
7};
8
9impl CubeType for e2m1 {
10 type ExpandType = ExpandElementTyped<e2m1>;
11}
12
13impl CubePrimitive for e2m1 {
14 fn as_type_native() -> Option<StorageType> {
16 Some(StorageType::Scalar(ElemType::Float(FloatKind::E2M1)))
17 }
18
19 fn from_const_value(value: ConstantScalarValue) -> Self {
20 let ConstantScalarValue::Float(value, _) = value else {
21 unreachable!()
22 };
23 e2m1::from_f64(value)
24 }
25}
26
27impl IntoRuntime for e2m1 {
28 fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
29 let elem: ExpandElementTyped<Self> = self.into();
30 into_runtime_expand_element(scope, elem).into()
31 }
32}
33
34impl ExpandElementIntoMut for e2m1 {
35 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
36 into_mut_expand_element(scope, elem)
37 }
38}
39
40impl CubeType for e2m1x2 {
41 type ExpandType = ExpandElementTyped<e2m1x2>;
42}
43
44impl CubePrimitive for e2m1x2 {
45 fn as_type_native() -> Option<StorageType> {
47 Some(StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2))
48 }
49
50 fn from_const_value(value: ConstantScalarValue) -> Self {
51 let ConstantScalarValue::Float(value, _) = value else {
52 unreachable!()
53 };
54 let val = e2m1::from_f64(value).to_bits();
55 e2m1x2::from_bits(val | (val << 4))
57 }
58}
59
60impl IntoRuntime for e2m1x2 {
61 fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
62 let elem: ExpandElementTyped<Self> = self.into();
63 into_runtime_expand_element(scope, elem).into()
64 }
65}
66
67impl ExpandElementIntoMut for e2m1x2 {
68 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
69 into_mut_expand_element(scope, elem)
70 }
71}