cubecl_common/float/
fp8.rs

1use bytemuck::{Pod, Zeroable};
2
3/// A 8-bit floating point type with 4 exponent bits and 3 mantissa bits.
4///
5/// [`Minifloat`]: https://en.wikipedia.org/wiki/Minifloat
6#[allow(non_camel_case_types)]
7#[repr(transparent)]
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Clone, Copy, Default, Zeroable, Pod, PartialEq, PartialOrd)]
10pub struct e4m3(u8);
11
12/// A 8-bit floating point type with 5 exponent bits and 2 mantissa bits.
13///
14/// [`Minifloat`]: https://en.wikipedia.org/wiki/Minifloat
15#[allow(non_camel_case_types)]
16#[repr(transparent)]
17#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
18#[derive(Clone, Copy, Default, Zeroable, Pod, PartialEq, PartialOrd)]
19pub struct e5m2(u8);
20
21/// An 8-bit unsigned floating point type with 8 exponent bits and no mantissa bits.
22/// Used for scaling factors.
23///
24/// [`Minifloat`]: https://en.wikipedia.org/wiki/Minifloat
25#[allow(non_camel_case_types)]
26#[repr(transparent)]
27#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28#[derive(Clone, Copy, Default, Zeroable, Pod, PartialEq, PartialOrd)]
29pub struct ue8m0(u8);
30
31impl e4m3 {
32    /// Maximum representable value
33    pub const MAX: f64 = 240.0;
34    /// Minimum representable value
35    pub const MIN: f64 = -240.0;
36}
37
38impl e5m2 {
39    /// Maximum representable value
40    pub const MAX: f64 = 57344.0;
41    /// Minimum representable value
42    pub const MIN: f64 = -57344.0;
43}
44
45impl ue8m0 {
46    /// Maximum representable value
47    pub const MAX: f64 = f64::from_bits(0x47E0000000000000);
48    /// Minimum representable value
49    pub const MIN: f64 = 0.0;
50}