use crate::{E8M0, F4E2M1, F4E2M1x2};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MXFP4Block {
block: [F4E2M1x2; 16],
scale: E8M0,
}
const _: () = assert!(std::mem::size_of::<MXFP4Block>() == 17);
impl MXFP4Block {
#[inline(always)]
pub fn from_f32_slice(xs: [F4E2M1; 32], scale: E8M0) -> Self {
let mut block = [F4E2M1x2::ZERO; 16];
for (i, pair) in block.iter_mut().enumerate() {
*pair = F4E2M1x2::new(xs[i * 2], xs[i * 2 + 1]);
}
Self { block, scale }
}
#[inline(always)]
pub fn to_f4_array(&self) -> [F4E2M1; 32] {
let mut result = [F4E2M1::from_bits(0); 32];
for i in 0..16 {
result[i * 2] = self.block[i].lo();
result[i * 2 + 1] = self.block[i].hi();
}
result
}
#[inline(always)]
pub fn scale(&self) -> E8M0 {
self.scale
}
#[inline(always)]
pub fn to_f32_array(&self) -> [f32; 32] {
let scale = self.scale.to_f64();
let values = self.to_f4_array();
let mut result = [0.0f32; 32];
for i in 0..32 {
result[i] = (values[i].to_f64() * scale) as f32;
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pack_unpack_roundtrip() {
let mut values = [F4E2M1::from_bits(0); 32];
for (i, value) in values.iter_mut().enumerate() {
*value = F4E2M1::from_bits((i % 16) as u8);
}
let scale = E8M0::from_f64(4.0f64);
let block = MXFP4Block::from_f32_slice(values, scale);
assert_eq!(block.scale().to_f64(), 4.0);
let unpacked = block.to_f4_array();
for i in 0..32 {
assert_eq!(unpacked[i].to_bits(), values[i].to_bits());
}
}
#[test]
fn test_to_f32_array() {
let mut values = [F4E2M1::from_bits(0); 32];
values[0] = F4E2M1::from_f64(1.0); values[1] = F4E2M1::from_f64(2.0); values[2] = F4E2M1::from_f64(0.5);
let scale = E8M0::from_f64(2.0f64);
let block = MXFP4Block::from_f32_slice(values, scale);
let f32_array = block.to_f32_array();
assert_eq!(f32_array[0], 2.0); assert_eq!(f32_array[1], 4.0); assert_eq!(f32_array[2], 1.0);
for value in f32_array.iter().skip(3) {
assert_eq!(*value, 0.0);
}
}
#[test]
fn test_packing_layout() {
let mut values = [F4E2M1::from_bits(0); 32];
values[0] = F4E2M1::from_bits(0x5); values[1] = F4E2M1::from_bits(0xA); values[2] = F4E2M1::from_bits(0x3); values[3] = F4E2M1::from_bits(0xC);
let scale = E8M0::from_f64(1.0f64);
let block = MXFP4Block::from_f32_slice(values, scale);
assert_eq!(block.block[0].to_bits(), 0xA5); assert_eq!(block.block[1].to_bits(), 0xC3);
for i in 2..16 {
assert_eq!(block.block[i].to_bits(), 0x00);
}
}
}