#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
use mlx_native::turboquant::{fwht_inplace, CODEBOOK_4BIT};
struct Xoshiro256 {
s: [u64; 4],
}
impl Xoshiro256 {
fn new(seed: u64) -> Self {
let mut z = seed;
let mut s = [0u64; 4];
for si in s.iter_mut() {
z = z.wrapping_add(0x9E3779B97F4A7C15);
let mut x = z;
x = (x ^ (x >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
x = (x ^ (x >> 27)).wrapping_mul(0x94D049BB133111EB);
*si = x ^ (x >> 31);
}
Xoshiro256 { s }
}
fn next_u64(&mut self) -> u64 {
let result = (self.s[1].wrapping_mul(5)).rotate_left(7).wrapping_mul(9);
let t = self.s[1] << 17;
self.s[2] ^= self.s[0];
self.s[3] ^= self.s[1];
self.s[1] ^= self.s[2];
self.s[0] ^= self.s[3];
self.s[2] ^= t;
self.s[3] = self.s[3].rotate_left(45);
result
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
}
fn randn_pair(rng: &mut Xoshiro256) -> (f64, f64) {
loop {
let u1 = rng.next_f64();
let u2 = rng.next_f64();
if u1 > 1e-30 {
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * std::f64::consts::PI * u2;
return (r * theta.cos(), r * theta.sin());
}
}
}
fn random_f32_vec(rng: &mut Xoshiro256, n: usize) -> Vec<f32> {
let mut v = Vec::with_capacity(n);
while v.len() < n {
let (a, b) = randn_pair(rng);
v.push(a as f32);
if v.len() < n {
v.push(b as f32);
}
}
v
}
fn std_normal_pdf_f64(x: f64) -> f64 {
const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
INV_SQRT_2PI * (-0.5 * x * x).exp()
}
fn std_normal_cdf_f64(x: f64) -> f64 {
if x < -8.0 {
return 0.0;
}
if x > 8.0 {
return 1.0;
}
let sign = if x >= 0.0 { 1.0 } else { -1.0 };
let x_abs = x.abs();
const P: f64 = 0.231_641_9;
const B1: f64 = 0.319_381_530;
const B2: f64 = -0.356_563_782;
const B3: f64 = 1.781_477_937;
const B4: f64 = -1.821_255_978;
const B5: f64 = 1.330_274_429;
let t = 1.0 / (1.0 + P * x_abs);
let t2 = t * t;
let t3 = t2 * t;
let t4 = t3 * t;
let t5 = t4 * t;
let poly = B1 * t + B2 * t2 + B3 * t3 + B4 * t4 + B5 * t5;
let phi = std_normal_pdf_f64(x_abs);
let result = 1.0 - phi * poly;
if sign < 0.0 { 1.0 - result } else { result }
}
fn probit(p: f64) -> f64 {
let p = p.clamp(1e-10, 1.0 - 1e-10);
let q = p - 0.5;
if q.abs() <= 0.425 {
let r = 0.180_625 - q * q;
let num = (((((((2.509_080_928_730_123_e3_f64
* r + 3.343_057_558_358_813_e4)
* r + 6.726_577_092_700_870_e4)
* r + 4.592_195_393_154_987_e4)
* r + 1.373_169_376_550_946_e4)
* r + 1.971_590_950_306_551_e3)
* r + 1.330_285_827_340_910_e2)
* r + 3.387_132_872_796_366_e0)
* q;
let den = ((((((( 5.226_495_278_852_854_e3_f64
* r + 2.872_907_153_820_395_e4)
* r + 3.930_789_572_802_049_e4)
* r + 2.209_460_984_245_205_e4)
* r + 6.792_516_341_301_421_e3)
* r + 1.239_599_965_839_292_e3)
* r + 9.600_214_660_615_377_e1)
* r + 1.0);
return num / den;
}
let r = if q < 0.0 { p } else { 1.0 - p };
let r = (-r.ln()).sqrt();
let (num, den) = if r <= 5.0 {
let r = r - 1.6;
let num = (((((((7.745_450_142_783_414_e-4_f64
* r + 2.272_384_498_926_918_e-2)
* r + 1.676_384_830_183_804_e-1)
* r + 4.730_910_623_086_661_e-1)
* r + 5.974_930_791_05_e-1)
* r + 3.249_500_703_074_491_e-1)
* r + 5.202_650_510_717_898_e-2)
* r + 1.484_304_326_677_517_e-3);
let den = (((((((1.616_905_798_252_959_e-4_f64
* r + 5.943_246_777_970_694_e-3)
* r + 1.224_907_765_723_928_e-1)
* r + 6.997_278_757_900_750_e-1)
* r + 1.764_463_030_588_847_e0)
* r + 1.637_576_533_857_614_e0)
* r + 6.315_691_484_369_885_e-1)
* r + 1.0);
(num, den)
} else {
let r = r - 5.0;
let num = (((((((2.010_333_749_292_548_e-7_f64
* r + 2.711_555_568_743_488_e-5)
* r + 1.242_660_947_388_078_e-3)
* r + 1.826_131_040_741_387_e-2)
* r + 1.517_200_950_956_772_e-1)
* r + 5.353_579_800_793_019_e-1)
* r + 6.657_904_987_121_134_e-1)
* r + 1.274_615_176_740_628_e-1);
let den = (((((((2.044_263_103_389_939_e-7_f64
* r + 2.994_060_444_495_762_e-5)
* r + 1.426_277_767_232_788_e-3)
* r + 2.189_682_887_974_272_e-2)
* r + 1.846_318_317_510_054_e-1)
* r + 7.445_883_287_266_231_e-1)
* r + 1.0)
* r + 1.0);
(num, den)
};
let x = num / den;
if q < 0.0 { -x } else { x }
}
fn truncated_gaussian_mean(lo: f64, hi: f64) -> f64 {
let phi_lo = if lo == f64::NEG_INFINITY { 0.0 } else { std_normal_pdf_f64(lo) };
let phi_hi = if hi == f64::INFINITY { 0.0 } else { std_normal_pdf_f64(hi) };
let cdf_lo = if lo == f64::NEG_INFINITY { 0.0 } else { std_normal_cdf_f64(lo) };
let cdf_hi = if hi == f64::INFINITY { 1.0 } else { std_normal_cdf_f64(hi) };
let denom = cdf_hi - cdf_lo;
if denom < 1e-30 {
return (lo + hi) / 2.0; }
(phi_lo - phi_hi) / denom
}
struct LloydMaxResult {
centroids: Vec<f64>,
iteration_count: usize,
final_max_change: f64,
}
fn lloyd_max_codebook(n_levels: usize) -> LloydMaxResult {
assert!(n_levels >= 2 && n_levels.is_power_of_two(),
"n_levels must be >= 2 and a power of two, got {}", n_levels);
let mut centroids: Vec<f64> = (0..n_levels)
.map(|i| probit((i as f64 + 0.5) / n_levels as f64))
.collect();
const MAX_ITERS: usize = 10000;
const TOL: f64 = 1e-8;
let mut final_max_change = f64::MAX;
let mut iteration_count = 0usize;
for iter in 0..MAX_ITERS {
let mut boundaries: Vec<f64> = Vec::with_capacity(n_levels - 1);
for i in 0..n_levels - 1 {
boundaries.push((centroids[i] + centroids[i + 1]) / 2.0);
}
let mut new_centroids: Vec<f64> = Vec::with_capacity(n_levels);
for i in 0..n_levels {
let lo = if i == 0 { f64::NEG_INFINITY } else { boundaries[i - 1] };
let hi = if i == n_levels - 1 { f64::INFINITY } else { boundaries[i] };
new_centroids.push(truncated_gaussian_mean(lo, hi));
}
let max_change = centroids
.iter()
.zip(new_centroids.iter())
.map(|(old, new)| (old - new).abs())
.fold(0.0_f64, f64::max);
centroids = new_centroids;
final_max_change = max_change;
iteration_count = iter + 1;
if max_change < TOL {
break;
}
}
if final_max_change >= TOL {
panic!(
"Lloyd-Max did NOT converge for n_levels={}: \
exited after {} iterations with max_change={:.4e} (tolerance={:.1e}). \
Raise MAX_ITERS or debug the truncated-Gaussian numerics.",
n_levels, iteration_count, final_max_change, TOL
);
}
eprintln!(
"lloyd_max_codebook(n_levels={}): converged in {} iterations, \
final_max_change={:.4e}",
n_levels, iteration_count, final_max_change
);
LloydMaxResult { centroids, iteration_count, final_max_change }
}
fn codebook_f32(cb_f64: &[f64]) -> Vec<f32> {
cb_f64.iter().map(|&v| v as f32).collect()
}
fn boundaries(cb: &[f32]) -> Vec<f32> {
(0..cb.len() - 1)
.map(|i| (cb[i] + cb[i + 1]) / 2.0)
.collect()
}
fn nearest_centroid(value: f32, bounds: &[f32]) -> u8 {
let mut idx: u8 = 0;
for &b in bounds {
if value > b {
idx += 1;
}
}
idx
}
fn encode_with_fwht(x: &[f32], head_dim: usize, bounds: &[f32]) -> (Vec<u8>, f32) {
let mut rotated = x.to_vec();
fwht_inplace(&mut rotated).unwrap();
let norm: f32 = rotated.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm < 1e-30 {
return (vec![0u8; head_dim], 0.0);
}
let inv_norm = 1.0 / norm;
let scale = (head_dim as f32).sqrt();
let indices: Vec<u8> = (0..head_dim)
.map(|c| nearest_centroid(rotated[c] * inv_norm * scale, bounds))
.collect();
(indices, norm)
}
fn decode_with_fwht(indices: &[u8], norm: f32, head_dim: usize, cb: &[f32]) -> Vec<f32> {
let inv_scale = 1.0 / (head_dim as f32).sqrt();
let mut rotated: Vec<f32> = indices
.iter()
.map(|&idx| cb[idx as usize] * inv_scale * norm)
.collect();
fwht_inplace(&mut rotated).unwrap();
rotated
}
fn encode_quant_only(x: &[f32], head_dim: usize, bounds: &[f32]) -> (Vec<u8>, f32) {
let norm: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm < 1e-30 {
return (vec![0u8; head_dim], 0.0);
}
let inv_norm = 1.0 / norm;
let scale = (head_dim as f32).sqrt();
let indices: Vec<u8> = (0..head_dim)
.map(|c| nearest_centroid(x[c] * inv_norm * scale, bounds))
.collect();
(indices, norm)
}
fn decode_quant_only(indices: &[u8], norm: f32, head_dim: usize, cb: &[f32]) -> Vec<f32> {
let inv_scale = 1.0 / (head_dim as f32).sqrt();
indices
.iter()
.map(|&idx| cb[idx as usize] * inv_scale * norm)
.collect()
}
fn nrmse_f32(original: &[f32], reconstructed: &[f32]) -> f32 {
let num: f32 = original
.iter()
.zip(reconstructed.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
let den: f32 = original.iter().map(|a| a * a).sum();
if den < 1e-30 {
return 0.0;
}
(num / den).sqrt()
}
fn nrmse_f64(original: &[f32], reconstructed: &[f32]) -> f64 {
nrmse_f32(original, reconstructed) as f64
}
fn nrmse_vec_f64(original: &[f32], reconstructed: &[f32]) -> f64 {
let num: f64 = original
.iter()
.zip(reconstructed.iter())
.map(|(a, b)| {
let d = (*a as f64) - (*b as f64);
d * d
})
.sum();
let den: f64 = original.iter().map(|a| (*a as f64) * (*a as f64)).sum();
if den < 1e-30 {
return 0.0;
}
(num / den).sqrt()
}
fn rms_norm_vec(v: &[f32]) -> Vec<f32> {
const RMS_NORM_EPS: f32 = 1e-6;
let n = v.len() as f32;
let mean_sq: f32 = v.iter().map(|&x| x * x).sum::<f32>() / n;
let rms = (mean_sq + RMS_NORM_EPS).sqrt();
v.iter().map(|&x| x / rms).collect()
}
fn quant_dequant(v: &[f32], bounds: &[f32], cb: &[f32]) -> Vec<f32> {
let (indices, norm) = encode_quant_only(v, v.len(), bounds);
decode_quant_only(&indices, norm, v.len(), cb)
}
fn synthetic_sdpa(
q: &[f32], k: &[f32], v: &[f32], num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
kv_seq_len: usize,
) -> Vec<f32> {
let heads_per_kv = num_heads / num_kv_heads;
let mut output = vec![0.0f32; num_heads * head_dim];
for h in 0..num_heads {
let kv_h = h / heads_per_kv;
let q_off = h * head_dim;
let mut scores = Vec::<f32>::with_capacity(kv_seq_len);
for pos in 0..kv_seq_len {
let k_off = (kv_h * kv_seq_len + pos) * head_dim;
let mut dot = 0.0f32;
for c in 0..head_dim {
dot += q[q_off + c] * k[k_off + c];
}
scores.push(dot);
}
let mut running_max = f32::NEG_INFINITY;
let mut running_sum = 0.0f32;
let o_off = h * head_dim;
let mut out_h = vec![0.0f32; head_dim];
for pos in 0..kv_seq_len {
let s = scores[pos];
let old_max = running_max;
running_max = running_max.max(s);
let correction = (old_max - running_max).exp();
let weight = (s - running_max).exp();
running_sum = running_sum * correction + weight;
let v_off = (kv_h * kv_seq_len + pos) * head_dim;
for c in 0..head_dim {
out_h[c] = out_h[c] * correction + weight * v[v_off + c];
}
}
let inv_sum = if running_sum > 0.0 { 1.0 / running_sum } else { 0.0 };
for c in 0..head_dim {
output[o_off + c] = out_h[c] * inv_sum;
}
}
output
}
fn sdpa_nrmse(reference: &[f32], quantized: &[f32]) -> f64 {
let num: f64 = reference
.iter()
.zip(quantized.iter())
.map(|(a, b)| {
let d = (*a as f64) - (*b as f64);
d * d
})
.sum();
let den: f64 = reference.iter().map(|a| (*a as f64) * (*a as f64)).sum();
if den < 1e-30 {
return 0.0;
}
(num / den).sqrt()
}
fn mean(v: &[f64]) -> f64 {
v.iter().sum::<f64>() / v.len() as f64
}
fn std_dev(v: &[f64]) -> f64 {
let m = mean(v);
let var: f64 = v.iter().map(|&x| (x - m) * (x - m)).sum::<f64>() / v.len() as f64;
var.sqrt()
}
#[derive(Debug)]
struct RtCell {
bit_width: usize,
case: char,
head_dim: usize,
nrmse_mean: f64,
nrmse_std: f64,
}
#[derive(Debug)]
struct SdpaCell {
bit_width: usize,
kv_seq_len: usize,
n_trials: usize,
nrmse_mean: f64,
nrmse_std: f64,
}
#[test]
fn bitwidth_ab() {
const SEED: u64 = 0xC25EED;
const N_VECTORS: usize = 1000;
const N_TRIALS_SDPA: usize = 100;
const BIT_WIDTHS: [usize; 3] = [4, 5, 6];
const HEAD_DIMS: [usize; 3] = [128, 256, 512];
const KV_SEQ_LENS: [usize; 3] = [64, 512, 1024];
const NUM_HEADS: usize = 16;
const NUM_KV_HEADS: usize = 8;
const SDPA_HEAD_DIM: usize = 256;
const DECISION_THRESHOLD: f64 = 0.1;
const DECISION_THRESHOLD_ALT: f64 = 0.15;
println!();
println!("=== T1a: Lloyd-Max Codebook Generation ===");
println!();
let lm_4bit = lloyd_max_codebook(16);
let codebook_4bit_f64 = lm_4bit.centroids;
let iteration_count_4bit = lm_4bit.iteration_count;
let final_max_change_4bit = lm_4bit.final_max_change;
let codebook_4bit_f32 = codebook_f32(&codebook_4bit_f64);
println!(
" 4-bit Lloyd-Max: {} iterations, final_max_change={:.4e}",
iteration_count_4bit, final_max_change_4bit
);
{
let diffs: Vec<f32> = codebook_4bit_f32
.iter()
.zip(CODEBOOK_4BIT.iter())
.map(|(generated, prod)| (generated - prod).abs())
.collect();
let max_diff = diffs.iter().copied().fold(0.0f32, f32::max);
if max_diff >= 1e-4 {
panic!(
"SANITY GATE #1 FAILED: 4-bit generated codebook does not match production CODEBOOK_4BIT.\n\
max |diff| = {:.6e} (must be < 1e-4)\n\
diffs = {:?}\n\
generated = {:?}\n\
production = {:?}",
max_diff, diffs, codebook_4bit_f32, CODEBOOK_4BIT
);
}
println!("SANITY GATE #1 PASSED: 4-bit codebook max diff vs production = {:.2e}", max_diff);
let n = codebook_4bit_f64.len();
for i in 0..n {
let sym_err = (codebook_4bit_f64[i] + codebook_4bit_f64[n - 1 - i]).abs();
assert!(
sym_err < 1e-6,
"4-bit codebook not symmetric at i={}: c[{}]={:.6}, c[{}]={:.6}, sum={:.2e}",
i, i, codebook_4bit_f64[i], n - 1 - i, codebook_4bit_f64[n - 1 - i], sym_err
);
}
for i in 1..n {
assert!(
codebook_4bit_f64[i] > codebook_4bit_f64[i - 1],
"4-bit codebook not monotonically increasing at i={}", i
);
}
println!(" 4-bit codebook: symmetric and monotonically increasing [OK]");
}
let lm_5bit = lloyd_max_codebook(32);
let codebook_5bit_f64 = lm_5bit.centroids;
let iteration_count_5bit = lm_5bit.iteration_count;
let final_max_change_5bit = lm_5bit.final_max_change;
let lm_6bit = lloyd_max_codebook(64);
let codebook_6bit_f64 = lm_6bit.centroids;
let iteration_count_6bit = lm_6bit.iteration_count;
let final_max_change_6bit = lm_6bit.final_max_change;
let codebook_5bit_f32 = codebook_f32(&codebook_5bit_f64);
let codebook_6bit_f32 = codebook_f32(&codebook_6bit_f64);
println!(
" 5-bit Lloyd-Max: {} iterations, final_max_change={:.4e}",
iteration_count_5bit, final_max_change_5bit
);
println!(
" 6-bit Lloyd-Max: {} iterations, final_max_change={:.4e}",
iteration_count_6bit, final_max_change_6bit
);
println!();
for (label, cb) in [("5-bit", &codebook_5bit_f64), ("6-bit", &codebook_6bit_f64)] {
let n = cb.len();
for i in 0..n {
let sym_err = (cb[i] + cb[n - 1 - i]).abs();
assert!(
sym_err < 1e-6,
"{} codebook not symmetric at i={}: c[{}]={:.6}, c[{}]={:.6}, sum={:.2e}",
label, i, i, cb[i], n - 1 - i, cb[n - 1 - i], sym_err
);
}
for i in 1..n {
assert!(
cb[i] > cb[i - 1],
"{} codebook not monotonically increasing at i={}", label, i
);
}
println!(" {} codebook: {} centroids, symmetric and monotonically increasing [OK]", label, n);
}
println!();
println!(" 4-bit centroids (generated): {:?}", codebook_4bit_f32);
println!(" 4-bit centroids (production): {:?}", CODEBOOK_4BIT);
println!();
println!(" 5-bit centroids (first 8 / last 8): {:?} ... {:?}",
&codebook_5bit_f32[..8], &codebook_5bit_f32[24..]);
println!(" 6-bit centroids (first 8 / last 8): {:?} ... {:?}",
&codebook_6bit_f32[..8], &codebook_6bit_f32[56..]);
println!();
let codebook_4bit_max_diff: f64 = codebook_4bit_f32
.iter()
.zip(CODEBOOK_4BIT.iter())
.map(|(generated, prod)| (generated - prod).abs() as f64)
.fold(0.0f64, f64::max);
println!("=== T1b: Extended Round-Trip Triad (27 cells) ===");
println!();
let mut rt_cells: Vec<RtCell> = Vec::with_capacity(27);
for &bw in &BIT_WIDTHS {
let cb_f32 = match bw {
4 => &codebook_4bit_f32,
5 => &codebook_5bit_f32,
6 => &codebook_6bit_f32,
_ => unreachable!(),
};
let bounds = boundaries(cb_f32);
for &head_dim in &HEAD_DIMS {
let mut rng = Xoshiro256::new(SEED);
let mut nrmse_a = Vec::with_capacity(N_VECTORS);
let mut nrmse_b = Vec::with_capacity(N_VECTORS);
let mut nrmse_c = Vec::with_capacity(N_VECTORS);
for _ in 0..N_VECTORS {
let x = random_f32_vec(&mut rng, head_dim);
let (indices_a, norm_a) = encode_with_fwht(&x, head_dim, &bounds);
let x_recon_a = decode_with_fwht(&indices_a, norm_a, head_dim, cb_f32);
nrmse_a.push(nrmse_f64(&x, &x_recon_a));
let (indices_b, norm_b) = encode_quant_only(&x, head_dim, &bounds);
let x_recon_b = decode_quant_only(&indices_b, norm_b, head_dim, cb_f32);
nrmse_b.push(nrmse_f64(&x, &x_recon_b));
let mut x_fwht = x.clone();
fwht_inplace(&mut x_fwht).unwrap();
fwht_inplace(&mut x_fwht).unwrap();
nrmse_c.push(nrmse_f64(&x, &x_fwht));
}
rt_cells.push(RtCell { bit_width: bw, case: 'A', head_dim, nrmse_mean: mean(&nrmse_a), nrmse_std: std_dev(&nrmse_a) });
rt_cells.push(RtCell { bit_width: bw, case: 'B', head_dim, nrmse_mean: mean(&nrmse_b), nrmse_std: std_dev(&nrmse_b) });
rt_cells.push(RtCell { bit_width: bw, case: 'C', head_dim, nrmse_mean: mean(&nrmse_c), nrmse_std: std_dev(&nrmse_c) });
}
}
println!("| bit_width | case | head_dim | n_vectors | nrmse_mean | nrmse_std |");
println!("|-----------|------|----------|-----------|------------|-----------|");
for c in &rt_cells {
println!("| {} | {} | {} | {} | {:.6} | {:.6} |",
c.bit_width, c.case, c.head_dim, N_VECTORS, c.nrmse_mean, c.nrmse_std);
}
println!();
println!("=== T1c: Synthetic SDPA Amplification (9 cells) ===");
println!("Shape: num_heads={}, num_kv_heads={}, head_dim={}", NUM_HEADS, NUM_KV_HEADS, SDPA_HEAD_DIM);
println!("Seeds: Q=0xC25EED, K=0xC25EED^0x11, V=0xC25EED^0x22");
println!();
let cb_5bit_sdpa = &codebook_5bit_f32;
let cb_6bit_sdpa = &codebook_6bit_f32;
let bounds_4bit_sdpa = boundaries(&codebook_4bit_f32);
let bounds_5bit_sdpa = boundaries(cb_5bit_sdpa);
let bounds_6bit_sdpa = boundaries(cb_6bit_sdpa);
let mut sdpa_cells: Vec<SdpaCell> = Vec::with_capacity(9);
for &bw in &BIT_WIDTHS {
let (cb_sdpa, bounds_sdpa): (&[f32], &[f32]) = match bw {
4 => (&codebook_4bit_f32, &bounds_4bit_sdpa),
5 => (cb_5bit_sdpa, &bounds_5bit_sdpa),
6 => (cb_6bit_sdpa, &bounds_6bit_sdpa),
_ => unreachable!(),
};
for &kv_seq_len in &KV_SEQ_LENS {
let q_seed = SEED;
let k_seed = SEED ^ 0x11;
let v_seed = SEED ^ 0x22;
let mut rng_q = Xoshiro256::new(q_seed);
let mut rng_k = Xoshiro256::new(k_seed);
let mut rng_v = Xoshiro256::new(v_seed);
let mut trial_nrmses = Vec::with_capacity(N_TRIALS_SDPA);
for _trial in 0..N_TRIALS_SDPA {
let q_raw: Vec<f32> = random_f32_vec(&mut rng_q, NUM_HEADS * SDPA_HEAD_DIM);
let k_raw: Vec<f32> = random_f32_vec(&mut rng_k, NUM_KV_HEADS * kv_seq_len * SDPA_HEAD_DIM);
let v_raw: Vec<f32> = random_f32_vec(&mut rng_v, NUM_KV_HEADS * kv_seq_len * SDPA_HEAD_DIM);
let mut q_normed = Vec::with_capacity(NUM_HEADS * SDPA_HEAD_DIM);
for h in 0..NUM_HEADS {
let off = h * SDPA_HEAD_DIM;
let head_slice = &q_raw[off..off + SDPA_HEAD_DIM];
q_normed.extend_from_slice(&rms_norm_vec(head_slice));
}
let mut k_normed = Vec::with_capacity(NUM_KV_HEADS * kv_seq_len * SDPA_HEAD_DIM);
for kvh in 0..NUM_KV_HEADS {
for pos in 0..kv_seq_len {
let off = (kvh * kv_seq_len + pos) * SDPA_HEAD_DIM;
let tok_slice = &k_raw[off..off + SDPA_HEAD_DIM];
k_normed.extend_from_slice(&rms_norm_vec(tok_slice));
}
}
let mut v_normed = Vec::with_capacity(NUM_KV_HEADS * kv_seq_len * SDPA_HEAD_DIM);
for kvh in 0..NUM_KV_HEADS {
for pos in 0..kv_seq_len {
let off = (kvh * kv_seq_len + pos) * SDPA_HEAD_DIM;
let tok_slice = &v_raw[off..off + SDPA_HEAD_DIM];
v_normed.extend_from_slice(&rms_norm_vec(tok_slice));
}
}
let out_ref = synthetic_sdpa(
&q_normed, &k_normed, &v_normed,
NUM_HEADS, NUM_KV_HEADS, SDPA_HEAD_DIM, kv_seq_len,
);
for (i, &x) in out_ref.iter().enumerate() {
assert!(
x.is_finite(),
"Reference SDPA output NaN/Inf at index {} (bw={}, kvl={})",
i, bw, kv_seq_len
);
}
let mut k_quant = Vec::with_capacity(NUM_KV_HEADS * kv_seq_len * SDPA_HEAD_DIM);
for kvh in 0..NUM_KV_HEADS {
for pos in 0..kv_seq_len {
let off = (kvh * kv_seq_len + pos) * SDPA_HEAD_DIM;
let tok_slice = &k_normed[off..off + SDPA_HEAD_DIM];
let dq = quant_dequant(tok_slice, bounds_sdpa, cb_sdpa);
k_quant.extend_from_slice(&dq);
}
}
let mut v_quant = Vec::with_capacity(NUM_KV_HEADS * kv_seq_len * SDPA_HEAD_DIM);
for kvh in 0..NUM_KV_HEADS {
for pos in 0..kv_seq_len {
let off = (kvh * kv_seq_len + pos) * SDPA_HEAD_DIM;
let tok_slice = &v_normed[off..off + SDPA_HEAD_DIM];
let dq = quant_dequant(tok_slice, bounds_sdpa, cb_sdpa);
v_quant.extend_from_slice(&dq);
}
}
let out_q = synthetic_sdpa(
&q_normed, &k_quant, &v_quant,
NUM_HEADS, NUM_KV_HEADS, SDPA_HEAD_DIM, kv_seq_len,
);
trial_nrmses.push(sdpa_nrmse(&out_ref, &out_q));
}
let cell = SdpaCell {
bit_width: bw,
kv_seq_len,
n_trials: N_TRIALS_SDPA,
nrmse_mean: mean(&trial_nrmses),
nrmse_std: std_dev(&trial_nrmses),
};
sdpa_cells.push(cell);
}
}
println!("| bit_width | kv_seq_len | n_trials | nrmse_mean | nrmse_std |");
println!("|-----------|------------|----------|------------|-----------|");
for c in &sdpa_cells {
println!("| {} | {} | {} | {:.6} | {:.6} |",
c.bit_width, c.kv_seq_len, c.n_trials, c.nrmse_mean, c.nrmse_std);
}
println!();
let gate2_cell = sdpa_cells
.iter()
.find(|c| c.bit_width == 4 && c.kv_seq_len == 512)
.expect("4-bit kvl=512 cell not found");
let gate2_val = gate2_cell.nrmse_mean;
if gate2_val < 0.10 || gate2_val > 0.60 {
panic!(
"SANITY GATE #2 FAILED: synthetic SDPA 4-bit at kvl=512 nrmse = {:.5} is outside [0.10, 0.60].\n\
This means the synthetic SDPA does not match hf2q's amplification behavior.\n\
Check: scale=1.0, RMS-norm Q+K+V (iter 2: V now normed), no softcap, GQA 2:1, f32 softmax.\n\
Note: iter 1 range was [0.30, 0.60] with unnormed V; iter 2 V-norm may lower the floor.",
gate2_val
);
}
println!("SANITY GATE #2 PASSED: 4-bit at kvl=512 nrmse = {:.5} in [0.10, 0.60]", gate2_val);
println!();
println!("=== T1d: Verdict ===");
println!();
let worst_nrmse = |bw: usize| -> f64 {
sdpa_cells
.iter()
.filter(|c| c.bit_width == bw)
.map(|c| c.nrmse_mean)
.fold(f64::NEG_INFINITY, f64::max)
};
let worst_4bit = worst_nrmse(4);
let worst_5bit = worst_nrmse(5);
let worst_6bit = worst_nrmse(6);
let verdict = if worst_4bit <= DECISION_THRESHOLD {
"bit_width_4_sufficient"
} else if worst_5bit <= DECISION_THRESHOLD {
"bit_width_5_sufficient"
} else if worst_6bit <= DECISION_THRESHOLD {
"bit_width_6_sufficient"
} else {
"all_bit_widths_fail_pivot_to_mixed_precision"
};
let verdict_at_015 = if worst_4bit <= DECISION_THRESHOLD_ALT {
"bit_width_4_sufficient"
} else if worst_5bit <= DECISION_THRESHOLD_ALT {
"bit_width_5_sufficient"
} else if worst_6bit <= DECISION_THRESHOLD_ALT {
"bit_width_6_sufficient"
} else {
"all_bit_widths_fail_pivot_to_mixed_precision"
};
let verdict_rationale = format!(
"Decision threshold = {:.2} nrmse. Worst-case (max over kv_seq_lens): \
4-bit={:.4}, 5-bit={:.4}, 6-bit={:.4}. \
Verdict at 0.10 '{}' selected: {}. \
Verdict at 0.15 '{}'.",
DECISION_THRESHOLD,
worst_4bit, worst_5bit, worst_6bit,
verdict,
match verdict {
"bit_width_4_sufficient" =>
format!("4-bit worst nrmse {:.4} <= threshold {:.2}", worst_4bit, DECISION_THRESHOLD),
"bit_width_5_sufficient" =>
format!("4-bit ({:.4}) exceeds threshold; 5-bit worst nrmse {:.4} <= threshold {:.2}", worst_4bit, worst_5bit, DECISION_THRESHOLD),
"bit_width_6_sufficient" =>
format!("4-bit ({:.4}) and 5-bit ({:.4}) exceed threshold; 6-bit worst nrmse {:.4} <= threshold {:.2}", worst_4bit, worst_5bit, worst_6bit, DECISION_THRESHOLD),
_ =>
format!("All bit-widths exceed threshold: 4-bit={:.4}, 5-bit={:.4}, 6-bit={:.4}; pivot to mixed-precision", worst_4bit, worst_5bit, worst_6bit),
},
verdict_at_015,
);
println!("Verdict (threshold=0.10): {}", verdict);
println!("Verdict (threshold=0.15): {}", verdict_at_015);
println!("Rationale: {}", verdict_rationale);
println!();
let out_dir = "/tmp/cfa-20260422-C4t1-bitwidth-ab";
std::fs::create_dir_all(out_dir).expect("create output dir");
let codebook_4bit_json: String = codebook_4bit_f32.iter()
.map(|v| format!("{:.8}", v))
.collect::<Vec<_>>()
.join(", ");
let codebook_5bit_json: String = codebook_5bit_f32.iter()
.map(|v| format!("{:.8}", v))
.collect::<Vec<_>>()
.join(", ");
let codebook_6bit_json: String = codebook_6bit_f32.iter()
.map(|v| format!("{:.8}", v))
.collect::<Vec<_>>()
.join(", ");
let rt_cells_json: String = rt_cells.iter().enumerate().map(|(i, c)| {
let comma = if i > 0 { ",\n" } else { "\n" };
format!(
"{} {{\"bit_width\": {}, \"case\": \"{}\", \"head_dim\": {}, \"nrmse_mean\": {:.8}, \"nrmse_std\": {:.8}}}",
comma, c.bit_width, c.case, c.head_dim, c.nrmse_mean, c.nrmse_std
)
}).collect();
let sdpa_cells_json: String = sdpa_cells.iter().enumerate().map(|(i, c)| {
let comma = if i > 0 { ",\n" } else { "\n" };
format!(
"{} {{\"bit_width\": {}, \"kv_seq_len\": {}, \"n_trials\": {}, \"nrmse_mean\": {:.8}, \"nrmse_std\": {:.8}}}",
comma, c.bit_width, c.kv_seq_len, c.n_trials, c.nrmse_mean, c.nrmse_std
)
}).collect();
let result_json = format!(
r#"{{
"session": "cfa-20260422-C4t1-bitwidth-ab",
"iteration": 2,
"seed": "0xC25EED",
"fixes_applied": [
"1: lloyd_max MAX_ITERS raised to 10000 + panic-on-not-converged",
"2: V RMS-normalized per head per token (matches forward_mlx.rs:1167-1211)",
"3: softmax accumulators changed from f64 to f32 (matches flash_attn_vec.metal)",
"4: DECISION_THRESHOLD documented + verdict_at_threshold_015 added",
"5: numeric regression bands added",
"6: rerun with corrected numbers"
],
"lloyd_max_convergence": {{
"4bit": {{ "iteration_count": {iter4}, "final_max_change": {mc4:.4e} }},
"5bit": {{ "iteration_count": {iter5}, "final_max_change": {mc5:.4e} }},
"6bit": {{ "iteration_count": {iter6}, "final_max_change": {mc6:.4e} }}
}},
"codebooks": {{
"4_bit": [{cb4}],
"5_bit": [{cb5}],
"6_bit": [{cb6}]
}},
"codebook_4bit_sanity_gate": {{
"max_diff_vs_production": {gate1:.8},
"passed": true
}},
"round_trip_cells": [{rt_cells}
],
"sdpa_cells": [{sdpa_cells}
],
"sdpa_sanity_gate": {{
"4bit_at_kvl_512": {gate2:.8},
"passed": true
}},
"verdict": "{verdict}",
"verdict_at_threshold_015": "{verdict015}",
"verdict_rationale": "{rationale}",
"worst_nrmse": {{
"4bit": {w4:.8},
"5bit": {w5:.8},
"6bit": {w6:.8}
}},
"shape": {{
"num_heads": {num_heads},
"num_kv_heads": {num_kv_heads},
"head_dim": {head_dim}
}},
"kv_seq_lens": [64, 512, 1024],
"bit_widths": [4, 5, 6],
"n_trials_per_sdpa_cell": {n_trials},
"decision_threshold_nrmse": {threshold},
"decision_threshold_alt_nrmse": {threshold_alt}
}}"#,
iter4 = iteration_count_4bit,
mc4 = final_max_change_4bit,
iter5 = iteration_count_5bit,
mc5 = final_max_change_5bit,
iter6 = iteration_count_6bit,
mc6 = final_max_change_6bit,
cb4 = codebook_4bit_json,
cb5 = codebook_5bit_json,
cb6 = codebook_6bit_json,
gate1 = codebook_4bit_max_diff,
rt_cells = rt_cells_json,
sdpa_cells = sdpa_cells_json,
gate2 = gate2_val,
verdict = verdict,
verdict015 = verdict_at_015,
rationale = verdict_rationale.replace('"', "'"),
w4 = worst_4bit,
w5 = worst_5bit,
w6 = worst_6bit,
num_heads = NUM_HEADS,
num_kv_heads = NUM_KV_HEADS,
head_dim = SDPA_HEAD_DIM,
n_trials = N_TRIALS_SDPA,
threshold = DECISION_THRESHOLD,
threshold_alt = DECISION_THRESHOLD_ALT,
);
let mut rt_md_rows = String::new();
for c in &rt_cells {
rt_md_rows.push_str(&format!(
"| {} | {} | {} | {} | {:.7} | {:.7} |\n",
c.bit_width, c.case, c.head_dim, N_VECTORS, c.nrmse_mean, c.nrmse_std
));
}
let mut sdpa_md_rows = String::new();
for c in &sdpa_cells {
sdpa_md_rows.push_str(&format!(
"| {} | {} | {} | {:.7} | {:.7} |\n",
c.bit_width, c.kv_seq_len, c.n_trials, c.nrmse_mean, c.nrmse_std
));
}
let result_md = format!(
"# C-4 T1 Higher-Bit Codebook A/B Results (iter 2)\n\n\
Session: cfa-20260422-C4t1-bitwidth-ab \n\
Iteration: 2 (fixes: MAX_ITERS=10000 + panic, V RMS-norm, f32 softmax, dual-threshold, numeric bands) \n\
Seed: 0xC25EED \n\
Shape: num_heads={NUM_HEADS}, num_kv_heads={NUM_KV_HEADS}, head_dim={SDPA_HEAD_DIM} \n\
Decision threshold: {DECISION_THRESHOLD} nrmse (alt: {DECISION_THRESHOLD_ALT}) \n\n\
## Lloyd-Max Convergence\n\n\
| bit_width | iteration_count | final_max_change |\n\
|-----------|-----------------|------------------|\n\
| 4 | {iter4} | {mc4:.4e} |\n\
| 5 | {iter5} | {mc5:.4e} |\n\
| 6 | {iter6} | {mc6:.4e} |\n\n\
## Codebook Sanity Gate\n\n\
- 4-bit generated vs production max |diff|: {codebook_4bit_max_diff:.2e} \n\
- PASSED: max diff < 1e-4 \n\n\
## T1b: Round-Trip Triad (27 cells)\n\n\
| bit_width | case | head_dim | n_vectors | nrmse_mean | nrmse_std |\n\
|-----------|------|----------|-----------|------------|-----------|\n\
{rt_md_rows}\n\
## T1c: Synthetic SDPA Amplification (9 cells)\n\n\
Shape: GQA {NUM_HEADS}Q:{NUM_KV_HEADS}KV heads, head_dim={SDPA_HEAD_DIM}, scale=1.0, \n\
RMS-norm Q+K+V (iter 2: V now normed per forward_mlx.rs:1167-1211), f32 softmax \n\n\
| bit_width | kv_seq_len | n_trials | nrmse_mean | nrmse_std |\n\
|-----------|------------|----------|-----------|-----------|\n\
{sdpa_md_rows}\n\
SDPA Sanity Gate (4-bit, kvl=512): nrmse={gate2_val:.5} — PASSED \n\n\
## T1d: Verdict\n\n\
**Verdict (threshold=0.10): {verdict}** \n\
**Verdict (threshold=0.15): {verdict015}** \n\
Worst-case nrmse: 4-bit={worst_4bit:.4}, 5-bit={worst_5bit:.4}, 6-bit={worst_6bit:.4} \n\n\
{verdict_rationale}\n",
NUM_HEADS = NUM_HEADS,
NUM_KV_HEADS = NUM_KV_HEADS,
SDPA_HEAD_DIM = SDPA_HEAD_DIM,
DECISION_THRESHOLD = DECISION_THRESHOLD,
DECISION_THRESHOLD_ALT = DECISION_THRESHOLD_ALT,
iter4 = iteration_count_4bit,
mc4 = final_max_change_4bit,
iter5 = iteration_count_5bit,
mc5 = final_max_change_5bit,
iter6 = iteration_count_6bit,
mc6 = final_max_change_6bit,
codebook_4bit_max_diff = codebook_4bit_max_diff,
rt_md_rows = rt_md_rows,
sdpa_md_rows = sdpa_md_rows,
gate2_val = gate2_val,
verdict = verdict,
verdict015 = verdict_at_015,
worst_4bit = worst_4bit,
worst_5bit = worst_5bit,
worst_6bit = worst_6bit,
verdict_rationale = verdict_rationale,
);
std::fs::write(format!("{}/result.json", out_dir), &result_json)
.expect("write result.json");
std::fs::write(format!("{}/result.md", out_dir), &result_md)
.expect("write result.md");
println!("Written: {}/result.json", out_dir);
println!("Written: {}/result.md", out_dir);
println!();
assert!(
codebook_4bit_max_diff < 1e-4,
"REGRESSION: 4-bit Lloyd-Max max diff vs production = {:.2e} >= 1e-4",
codebook_4bit_max_diff
);
let cell_4b_b_256 = rt_cells.iter()
.find(|c| c.bit_width == 4 && c.case == 'B' && c.head_dim == 256)
.expect("4-bit Case B head_dim=256 cell");
assert!(
(cell_4b_b_256.nrmse_mean - 0.0975).abs() < 0.005,
"REGRESSION: 4-bit Case B head_dim=256 nrmse={:.5} not within 0.005 of C-3 reference 0.0975",
cell_4b_b_256.nrmse_mean
);
for &hd in &HEAD_DIMS {
let cell = rt_cells.iter()
.find(|c| c.bit_width == 5 && c.case == 'B' && c.head_dim == hd)
.expect("5-bit Case B cell");
assert!(
cell.nrmse_mean >= 0.047 && cell.nrmse_mean <= 0.056,
"REGRESSION: 5-bit Case B head_dim={} nrmse={:.5} outside [0.047, 0.056]",
hd, cell.nrmse_mean
);
}
for &hd in &HEAD_DIMS {
let cell = rt_cells.iter()
.find(|c| c.bit_width == 6 && c.case == 'B' && c.head_dim == hd)
.expect("6-bit Case B cell");
assert!(
cell.nrmse_mean >= 0.024 && cell.nrmse_mean <= 0.032,
"REGRESSION: 6-bit Case B head_dim={} nrmse={:.5} outside [0.024, 0.032]",
hd, cell.nrmse_mean
);
}
assert!(
iteration_count_4bit >= 200 && iteration_count_4bit <= 800,
"REGRESSION: 4-bit Lloyd-Max iteration_count={} outside [200, 800] (queen baseline ~415)",
iteration_count_4bit
);
assert!(
iteration_count_5bit >= 1000 && iteration_count_5bit <= 2500,
"REGRESSION: 5-bit Lloyd-Max iteration_count={} outside [1000, 2500] (queen baseline ~1454)",
iteration_count_5bit
);
assert!(
iteration_count_6bit >= 3500 && iteration_count_6bit <= 7000,
"REGRESSION: 6-bit Lloyd-Max iteration_count={} outside [3500, 7000] (queen baseline ~5110)",
iteration_count_6bit
);
assert!(
worst_4bit > 0.28 && worst_4bit < 0.50,
"REGRESSION: worst 4-bit SDPA nrmse={:.5} outside (0.28, 0.50) [iter2 measured: 0.3880]",
worst_4bit
);
assert!(
worst_5bit > 0.15 && worst_5bit < 0.28,
"REGRESSION: worst 5-bit SDPA nrmse={:.5} outside (0.15, 0.28) [iter2 measured: 0.2110]",
worst_5bit
);
assert!(
worst_6bit > 0.07 && worst_6bit < 0.14,
"REGRESSION: worst 6-bit SDPA nrmse={:.5} outside (0.07, 0.14) [iter2 measured: 0.1026]",
worst_6bit
);
assert!(
gate2_val >= 0.10 && gate2_val <= 0.60,
"REGRESSION: 4-bit SDPA kvl=512 nrmse={:.5} outside [0.10, 0.60]",
gate2_val
);
assert!(
matches!(verdict,
"bit_width_4_sufficient" |
"bit_width_5_sufficient" |
"bit_width_6_sufficient" |
"all_bit_widths_fail_pivot_to_mixed_precision"),
"REGRESSION: verdict '{}' is not one of the four allowed enum values", verdict
);
assert_eq!(
verdict, "all_bit_widths_fail_pivot_to_mixed_precision",
"REGRESSION: verdict changed from iter 2 locked value; re-examine if codebook or SDPA changed"
);
for &hd in &HEAD_DIMS {
let cell = rt_cells.iter()
.find(|c| c.bit_width == 4 && c.case == 'C' && c.head_dim == hd)
.expect("Case C cell");
assert!(
cell.nrmse_mean < 1e-5,
"REGRESSION: FWHT non-reversible at head_dim={} (nrmse={:.3e})",
hd, cell.nrmse_mean
);
}
for &hd in &HEAD_DIMS {
let b4 = rt_cells.iter().find(|c| c.bit_width == 4 && c.case == 'B' && c.head_dim == hd).unwrap().nrmse_mean;
let b5 = rt_cells.iter().find(|c| c.bit_width == 5 && c.case == 'B' && c.head_dim == hd).unwrap().nrmse_mean;
let b6 = rt_cells.iter().find(|c| c.bit_width == 6 && c.case == 'B' && c.head_dim == hd).unwrap().nrmse_mean;
assert!(
b6 < b5 && b5 < b4,
"REGRESSION: monotonicity violated at head_dim={}: 6-bit nrmse={:.5} 5-bit={:.5} 4-bit={:.5} (expected 6<5<4)",
hd, b6, b5, b4
);
}
for &kvl in &KV_SEQ_LENS {
let s4 = sdpa_cells.iter().find(|c| c.bit_width == 4 && c.kv_seq_len == kvl).unwrap().nrmse_mean;
let s5 = sdpa_cells.iter().find(|c| c.bit_width == 5 && c.kv_seq_len == kvl).unwrap().nrmse_mean;
let s6 = sdpa_cells.iter().find(|c| c.bit_width == 6 && c.kv_seq_len == kvl).unwrap().nrmse_mean;
assert!(
s6 < s5 && s5 < s4,
"REGRESSION: SDPA monotonicity violated at kvl={}: 6-bit={:.5} 5-bit={:.5} 4-bit={:.5} (expected 6<5<4)",
kvl, s6, s5, s4
);
}
println!("All regression asserts PASSED.");
println!("Verdict (threshold=0.10): {}", verdict);
println!("Verdict (threshold=0.15): {}", verdict_at_015);
}