pub const NVFP4_GROUP_SIZE: usize = 16;
pub const FP4_E2M1_LUT: [f32; 16] = [
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
];
#[inline]
pub fn fp4_e2m1_to_f32(nibble: u8) -> f32 {
FP4_E2M1_LUT[(nibble & 0x0F) as usize]
}
#[inline]
pub fn fp8_e4m3_scale_to_f32(byte: u8) -> f32 {
let sign = if byte & 0x80 != 0 { -1.0 } else { 1.0 };
let exp = (byte >> 3) & 0x0F;
let mant = byte & 0x07;
let v = if exp == 0 {
if mant == 0 {
0.0
} else {
(mant as f32 / 8.0) * 2f32.powi(-6)
}
} else if exp == 0x0F && mant == 0x07 {
0.0 } else {
(1.0 + mant as f32 / 8.0) * 2f32.powi(exp as i32 - 7)
};
sign * v
}
#[inline]
pub const fn nvfp4_weight_bytes(k: usize, n: usize) -> usize {
(k * n).div_ceil(2)
}
#[inline]
pub const fn nvfp4_scale_bytes(k: usize, n: usize) -> usize {
k.div_ceil(NVFP4_GROUP_SIZE) * n
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fp4_lut_matches_ocp() {
assert_eq!(fp4_e2m1_to_f32(2), 1.0);
assert_eq!(fp4_e2m1_to_f32(14), -4.0);
}
#[test]
fn fp8_scale_one_is_unity() {
assert!((fp8_e4m3_scale_to_f32(0x38) - 1.0).abs() < 1e-6);
}
}