#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct E8M0(u8);
const _: () = assert!(std::mem::size_of::<E8M0>() == 1);
impl E8M0 {
pub const NAN: Self = Self(0xFF);
#[inline(always)]
pub fn from_f64(value: f64) -> Self {
if value.is_nan() {
return Self::NAN;
}
if value <= 0.0 {
return Self(0x00);
}
let exp = (value.log2()).ceil() as i32;
let biased = exp.saturating_add(127);
match biased {
n if n < 0 => Self(0x00), n if n > 254 => Self(0xFE), n => Self(n as u8),
}
}
#[inline(always)]
pub fn to_f64(self) -> f64 {
match self.0 {
0xFF => f64::NAN,
b => 2f64.powi(b as i32 - 127),
}
}
#[inline(always)]
pub fn from_f32_slice(values: &[f32]) -> Self {
let max_abs = values.iter().map(|&x| x.abs()).fold(0.0f32, |a, b| {
if b.is_nan() || a.is_nan() {
f32::NAN
} else if b > a {
b
} else {
a
}
});
Self::from_f64(max_abs as f64)
}
#[inline(always)]
pub const fn from_bits(bits: u8) -> Self {
Self(bits)
}
#[inline(always)]
pub const fn to_bits(&self) -> u8 {
self.0
}
}
impl From<f32> for E8M0 {
#[inline(always)]
fn from(value: f32) -> Self {
Self::from_f64(value as f64)
}
}
impl From<f64> for E8M0 {
#[inline(always)]
fn from(value: f64) -> Self {
Self::from_f64(value)
}
}
impl From<E8M0> for f64 {
#[inline(always)]
fn from(v: E8M0) -> Self {
v.to_f64()
}
}
impl From<u8> for E8M0 {
#[inline(always)]
fn from(b: u8) -> Self {
Self::from_bits(b)
}
}
impl From<E8M0> for u8 {
#[inline(always)]
fn from(v: E8M0) -> Self {
v.to_bits()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn e8m0_conversion_roundtrip_and_edges() {
let fixed = [
(-1.0, 0x00u8), (0.0, 0x00),
(2f64.powi(-127), 0x00), (0.75, 0x7F),
(1.0, 0x7F),
(1.5, 0x80),
(2.0, 0x80),
(2f64.powi(127), 0xFE), (f64::NAN, 0xFF),
];
for &(x, bits) in &fixed {
assert_eq!(E8M0::from(x).to_bits(), bits);
}
for exp in -127..=127 {
let x = 2f64.powi(exp);
let code = E8M0::from(x);
let y: f64 = code.into();
assert_eq!(x, y);
}
let x = 1.3_f64;
let y: f64 = E8M0::from(x).into();
assert!(y >= x && y <= 2.0);
}
#[test]
fn test_from_f32_slice() {
let values = [1.0f32, -2.0, 0.5];
let scale = E8M0::from_f32_slice(&values);
assert_eq!(scale.to_f64(), 2.0);
let values = [1.5f32, -2.5, 0.5];
let scale = E8M0::from_f32_slice(&values);
assert_eq!(scale.to_f64(), 4.0);
let values = [0.1f32, -0.2, 0.15];
let scale = E8M0::from_f32_slice(&values);
assert_eq!(scale.to_f64(), 0.25);
let values = [100.0f32, -50.0, 75.0];
let scale = E8M0::from_f32_slice(&values);
assert_eq!(scale.to_f64(), 128.0);
let values = [1.0f32, f32::NAN, -2.0];
let scale = E8M0::from_f32_slice(&values);
assert!(scale.to_f64().is_nan());
let values = [1.0f32, f32::INFINITY, -2.0];
let scale = E8M0::from_f32_slice(&values);
assert_eq!(scale.to_f64(), 2f64.powi(127));
let scale = E8M0::from_f32_slice(&[]);
assert_eq!(scale.to_f64(), 2f64.powi(-127));
let values = [0.0f32, -0.0, 0.0];
let scale = E8M0::from_f32_slice(&values);
assert_eq!(scale.to_f64(), 2f64.powi(-127));
let values = [3.0f32];
let scale = E8M0::from_f32_slice(&values);
assert_eq!(scale.to_f64(), 4.0); }
}