mod block;
mod cvt;
mod fp4x2;
mod m8e0;
pub use block::MXFP4Block;
pub use fp4x2::F4E2M1x2;
pub use m8e0::E8M0;
pub fn pack(values: &[F4E2M1]) -> Vec<F4E2M1x2> {
values
.chunks(2)
.map(|chunk| {
let lo = chunk[0];
let hi = if chunk.len() == 2 {
chunk[1]
} else {
F4E2M1::ZERO
};
F4E2M1x2::new(lo, hi)
})
.collect()
}
pub fn unpack(packed: &[F4E2M1x2]) -> Vec<F4E2M1> {
let mut result = Vec::with_capacity(packed.len() * 2);
for &pair in packed {
result.push(pair.lo());
result.push(pair.hi());
}
result
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct F4E2M1(u8);
const _: () = assert!(std::mem::size_of::<F4E2M1>() == 1);
impl F4E2M1 {
#[inline(always)]
pub const fn from_f64(x: f64) -> Self {
Self(cvt::f64_to_fp4(x))
}
#[inline(always)]
pub fn to_f64(&self) -> f64 {
cvt::fp4_to_f64(self.0)
}
#[inline(always)]
pub const fn from_bits(bits: u8) -> Self {
Self(bits)
}
#[inline(always)]
pub const fn to_bits(&self) -> u8 {
self.0
}
}
impl F4E2M1 {
pub const MIN_POSITIVE_NORMAL: F4E2M1 = F4E2M1(0x2);
pub const MIN_POSITIVE: F4E2M1 = F4E2M1(0x1);
pub const MAX: F4E2M1 = F4E2M1(0x7);
pub const MIN: F4E2M1 = F4E2M1(0xF);
pub const ZERO: F4E2M1 = F4E2M1(0x0);
pub const NEG_ZERO: F4E2M1 = F4E2M1(0x8);
pub const ONE: F4E2M1 = F4E2M1(0x2);
pub const NEG_ONE: F4E2M1 = F4E2M1(0xA);
pub const EPSILON: F4E2M1 = F4E2M1(0x1);
}
impl Default for F4E2M1 {
#[inline]
fn default() -> Self {
F4E2M1::ZERO
}
}
impl From<f32> for F4E2M1 {
#[inline]
fn from(value: f32) -> Self {
F4E2M1::from_f64(value as f64)
}
}
impl From<F4E2M1> for f32 {
#[inline]
fn from(value: F4E2M1) -> Self {
value.to_f64() as f32
}
}
impl From<F4E2M1> for f64 {
#[inline]
fn from(value: F4E2M1) -> Self {
value.to_f64()
}
}
impl std::fmt::Display for F4E2M1 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_f64())
}
}
impl std::fmt::LowerExp for F4E2M1 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:e}", self.to_f64())
}
}
impl std::fmt::UpperExp for F4E2M1 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:E}", self.to_f64())
}
}
#[cfg(test)]
mod test {
use crate::F4E2M1;
#[test]
fn test_full_range() {
let expected_values = [
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ];
for (bits, expected) in (0u8..16).zip(expected_values.iter()) {
let converted = F4E2M1::from_bits(bits).to_f64();
assert_eq!(
converted, *expected,
"Failed for bits 0x{bits:X}: got {converted}, expected {expected}"
);
let fp4 = F4E2M1(bits);
assert_eq!(
fp4.to_f64(),
*expected,
"Failed for F4E2M1(0x{:X}): got {}, expected {}",
bits,
fp4.to_f64(),
expected
);
}
}
#[test]
fn test_roundtrip() {
let test_values = [
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
];
for &x in &test_values {
let mxfp4 = F4E2M1::from_f64(x);
let roundtrip = mxfp4.to_f64();
assert_eq!(roundtrip, x, "Roundtrip failed for {x}: got {roundtrip}");
}
}
#[test]
fn test_rounding() {
let test_cases = [
(0.75, 1.0), (1.25, 1.0), (1.75, 2.0), (2.25, 2.0), (2.5, 2.0), (2.75, 3.0), (3.25, 3.0), (3.5, 4.0), (4.5, 4.0), (5.0, 4.0), (5.5, 6.0), (7.0, 6.0), (10.0, 6.0), (-0.75, -1.0), (-1.25, -1.0), (-1.75, -2.0), (-2.25, -2.0), (-2.5, -2.0), (-2.75, -3.0), (-3.25, -3.0), (-3.5, -4.0), (-4.5, -4.0), (-5.0, -4.0), (-5.5, -6.0), (-7.0, -6.0), ];
for &(input, expected) in &test_cases {
let fp4 = F4E2M1::from_f64(input);
let result = fp4.to_f64();
assert_eq!(
result, expected,
"Rounding failed for {input}: got {result}, expected {expected}"
);
}
}
#[test]
fn test_special_values() {
use std::f64;
let fp4 = F4E2M1::from_f64(f64::INFINITY);
assert_eq!(fp4.to_f64(), 6.0);
let fp4 = F4E2M1::from_f64(f64::NEG_INFINITY);
assert_eq!(fp4.to_f64(), -6.0);
let fp4 = F4E2M1::from_f64(f64::NAN);
assert_eq!(fp4.to_f64(), 6.0);
}
}
#[cfg(test)]
mod pack_tests {
use crate::{F4E2M1, pack, unpack};
#[test]
fn empty_slice() {
let packed = pack(&[]);
assert!(packed.is_empty());
let unpacked = unpack(&packed);
assert!(unpacked.is_empty());
}
#[test]
fn even_length_roundtrip() {
let values: Vec<F4E2M1> = (0..8).map(F4E2M1::from_bits).collect();
let packed = pack(&values);
assert_eq!(packed.len(), 4);
let unpacked = unpack(&packed);
assert_eq!(unpacked.len(), 8);
for (i, (a, b)) in values.iter().zip(unpacked.iter()).enumerate() {
assert_eq!(a.to_bits(), b.to_bits(), "mismatch at index {i}");
}
}
#[test]
fn odd_length_pads_with_zero() {
let values = vec![
F4E2M1::from_f64(1.0),
F4E2M1::from_f64(2.0),
F4E2M1::from_f64(3.0),
];
let packed = pack(&values);
assert_eq!(packed.len(), 2);
let unpacked = unpack(&packed);
assert_eq!(unpacked.len(), 4); assert_eq!(unpacked[0].to_f64(), 1.0);
assert_eq!(unpacked[1].to_f64(), 2.0);
assert_eq!(unpacked[2].to_f64(), 3.0);
assert_eq!(unpacked[3].to_f64(), 0.0); }
#[test]
fn single_element() {
let values = vec![F4E2M1::from_f64(6.0)];
let packed = pack(&values);
assert_eq!(packed.len(), 1);
assert_eq!(packed[0].lo().to_f64(), 6.0);
assert_eq!(packed[0].hi().to_f64(), 0.0);
}
#[test]
fn all_values_roundtrip() {
let values: Vec<F4E2M1> = (0..16).map(F4E2M1::from_bits).collect();
let packed = pack(&values);
let unpacked = unpack(&packed);
for (i, (a, b)) in values.iter().zip(unpacked.iter()).enumerate() {
assert_eq!(a.to_bits(), b.to_bits(), "mismatch at index {i}");
}
}
}