cubecl_common/float/
fp4.rs

1use bytemuck::{Pod, Zeroable};
2
3/// A 4-bit floating point type with 2 exponent bits and 1 mantissa bit.
4///
5/// [`Minifloat`]: https://en.wikipedia.org/wiki/Minifloat
6#[allow(non_camel_case_types)]
7#[repr(transparent)]
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Clone, Copy, Default, Zeroable, PartialEq, PartialOrd)]
10pub struct e2m1(u8);
11
12/// A 4-bit floating point type with 2 exponent bits and 1 mantissa bit. Packed with two elements
13/// per value, to allow for conversion to/from bytes. Care must be taken to ensure the shape is
14/// adjusted appropriately.
15///
16/// [`Minifloat`]: https://en.wikipedia.org/wiki/Minifloat
17#[allow(non_camel_case_types)]
18#[repr(transparent)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20#[derive(Clone, Copy, Default, Zeroable, Pod, PartialEq, PartialOrd)]
21pub struct e2m1x2(u8);
22
23impl e2m1 {
24    /// Maximum representable value
25    pub const MAX: f64 = 3.0;
26    /// Minimum representable value
27    pub const MIN: f64 = -3.0;
28}