use super::cpu::kernels::nf4::NF4_CODEBOOK;
pub const KVALUES_IQ4NL: [i8; 16] = [
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
];
pub fn quantize_nf4(x: f32) -> u8 {
let mut best_idx = 0u8;
let mut best_dist = f32::MAX;
for (i, &val) in NF4_CODEBOOK.iter().enumerate() {
let dist = (x - val).abs();
if dist < best_dist {
best_dist = dist;
best_idx = i as u8;
}
}
best_idx
}
pub fn dequantize_nf4(idx: u8) -> f32 {
NF4_CODEBOOK[(idx & 0x0F) as usize]
}
pub fn f32_to_fp8_e4m3(x: f32) -> u8 {
let bits = x.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xFF) as i32;
let mant = bits & 0x7FFFFF;
if exp == 0 {
return (sign << 7) as u8;
}
if exp == 0xFF {
return ((sign << 7) | 0x7E) as u8;
}
let new_exp = exp - 127 + 7;
if new_exp <= 0 {
return (sign << 7) as u8;
}
if new_exp >= 15 {
return ((sign << 7) | 0x7E) as u8;
}
let new_mant = (mant >> 20) & 0x7;
((sign << 7) | ((new_exp as u32) << 3) | new_mant) as u8
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nf4_round_trip() {
for i in 0..16u8 {
let val = dequantize_nf4(i);
let back = quantize_nf4(val);
assert_eq!(back, i, "NF4 round-trip failed for index {i}: val={val}");
}
}
#[test]
fn test_nf4_nearest() {
assert_eq!(quantize_nf4(0.0), 0);
assert_eq!(quantize_nf4(1.0), 15);
assert_eq!(quantize_nf4(-1.0), 1);
}
#[test]
fn test_fp8_e4m3_basic() {
assert_eq!(f32_to_fp8_e4m3(0.0), 0);
let fp8 = f32_to_fp8_e4m3(1.0);
assert_ne!(fp8, 0);
assert_eq!(fp8 & 0x80, 0); let fp8_neg = f32_to_fp8_e4m3(-1.0);
assert_ne!(fp8_neg & 0x80, 0); }
}