use std::sync::OnceLock;
use oxibonsai_core::{fp8_e4m3_decode, fp8_e5m2_decode};
static FP8_E4M3_LUT: OnceLock<[f32; 256]> = OnceLock::new();
#[inline]
pub fn fp8_e4m3_lut() -> &'static [f32; 256] {
FP8_E4M3_LUT.get_or_init(|| {
let mut lut = [0.0_f32; 256];
for (i, slot) in lut.iter_mut().enumerate() {
*slot = fp8_e4m3_decode(i as u8);
}
lut
})
}
static FP8_E5M2_LUT: OnceLock<[f32; 256]> = OnceLock::new();
#[inline]
pub fn fp8_e5m2_lut() -> &'static [f32; 256] {
FP8_E5M2_LUT.get_or_init(|| {
let mut lut = [0.0_f32; 256];
for (i, slot) in lut.iter_mut().enumerate() {
*slot = fp8_e5m2_decode(i as u8);
}
lut
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lut_zero_byte_is_zero() {
assert_eq!(fp8_e4m3_lut()[0], 0.0_f32);
assert_eq!(fp8_e5m2_lut()[0], 0.0_f32);
}
#[test]
fn lut_matches_scalar_decode_e4m3() {
let lut = fp8_e4m3_lut();
for (i, &lut_val) in lut.iter().enumerate() {
let scalar = fp8_e4m3_decode(i as u8);
if scalar.is_nan() {
assert!(
lut_val.is_nan(),
"byte {i:#04x}: expected NaN, got {lut_val}"
);
} else {
assert_eq!(
lut_val, scalar,
"byte {i:#04x}: lut={lut_val} vs scalar={scalar}"
);
}
}
}
#[test]
fn lut_matches_scalar_decode_e5m2() {
let lut = fp8_e5m2_lut();
for (i, &lut_val) in lut.iter().enumerate() {
let scalar = fp8_e5m2_decode(i as u8);
if scalar.is_nan() {
assert!(
lut_val.is_nan(),
"byte {i:#04x}: expected NaN, got {lut_val}"
);
} else if scalar.is_infinite() {
assert!(
lut_val.is_infinite(),
"byte {i:#04x}: expected Inf, got {lut_val}"
);
assert_eq!(
lut_val.is_sign_positive(),
scalar.is_sign_positive(),
"byte {i:#04x}: sign mismatch"
);
} else {
assert_eq!(
lut_val, scalar,
"byte {i:#04x}: lut={lut_val} vs scalar={scalar}"
);
}
}
}
#[test]
fn lut_is_singleton() {
let a = fp8_e4m3_lut() as *const _;
let b = fp8_e4m3_lut() as *const _;
assert_eq!(a, b, "multiple calls should return same static address");
}
}