Skip to main content

mlx_native/
turboquant.rs

1//! TurboQuant KV cache compression — CPU reference implementation.
2//!
3//! Implements the TurboQuant_mse algorithm:
4//! 1. Walsh-Hadamard rotation for incoherence
5//! 2. Per-head norm extraction
6//! 3. Lloyd-Max scalar quantization against N(0,1) codebooks
7//!
8//! This module is CPU-only math — no Metal GPU dispatch.
9
10// ---- Lloyd-Max Codebooks for N(0,1) ----
11//
12// Precomputed via iterative Lloyd-Max algorithm with convergence tolerance 1e-12.
13// Each codebook is symmetric around zero.
14
15/// 2-bit Lloyd-Max centroids for N(0,1): 4 reconstruction levels.
16pub const CODEBOOK_2BIT: [f32; 4] = [
17    -1.5104176, -0.4527800, 0.4527800, 1.5104176,
18];
19
20/// 3-bit Lloyd-Max centroids for N(0,1): 8 reconstruction levels.
21pub const CODEBOOK_3BIT: [f32; 8] = [
22    -2.1519457, -1.3439093, -0.7560053, -0.2450942,
23    0.2450942, 0.7560053, 1.3439093, 2.1519457,
24];
25
26/// 4-bit Lloyd-Max centroids for N(0,1): 16 reconstruction levels.
27pub const CODEBOOK_4BIT: [f32; 16] = [
28    -2.7325896, -2.0690172, -1.6180464, -1.2562312,
29    -0.9423405, -0.6567591, -0.3880483, -0.1283950,
30    0.1283950, 0.3880483, 0.6567591, 0.9423405,
31    1.2562312, 1.6180464, 2.0690172, 2.7325896,
32];
33
34// ---------------------------------------------------------------------------
35// ADR-007 Path C F-0.1: Higher-bit (HB) codebooks for the 5/6/8-bit byte-packed
36// production decode path. These mirror EXACTLY the constants in
37// `src/shaders/flash_attn_vec_tq_hb.metal` (lines 52-156) and
38// `src/shaders/hadamard_quantize_kv_fast.metal::CODEBOOK_8BIT`.
39//
40// IMPORTANT: any change here is an on-disk codec change per ADR-007 §F-7.1.
41// Bump the codec version in the TQ envelope before changing these bytes.
42// ---------------------------------------------------------------------------
43
44/// 5-bit Lloyd-Max centroids for N(0,1): 32 reconstruction levels.
45/// Mirrors `flash_attn_vec_tq_hb.metal::CODEBOOK_HB_5BIT`.
46pub const CODEBOOK_HB_5BIT: [f32; 32] = [
47    -3.2606790, -2.6910589, -2.3176743, -2.0286608,
48    -1.7871646, -1.5761599, -1.3862739, -1.2117410,
49    -1.0487242, -0.8945114, -0.7470884, -0.6048936,
50    -0.4666676, -0.3313550, -0.1980377, -0.0658849,
51     0.0658849,  0.1980377,  0.3313550,  0.4666676,
52     0.6048936,  0.7470884,  0.8945114,  1.0487242,
53     1.2117410,  1.3862739,  1.5761599,  1.7871646,
54     2.0286608,  2.3176743,  2.6910589,  3.2606790,
55];
56
57/// 6-bit Lloyd-Max centroids for N(0,1): 64 reconstruction levels.
58/// Mirrors `flash_attn_vec_tq_hb.metal::CODEBOOK_HB_6BIT`.
59pub const CODEBOOK_HB_6BIT: [f32; 64] = [
60    -3.6996161, -3.1907215, -2.8640626, -2.6161277,
61    -2.4129324, -2.2388464, -2.0853192, -1.9471373,
62    -1.8208742, -1.7041502, -1.5952401, -1.4928497,
63    -1.3959804, -1.3038428, -1.2157998, -1.1313277,
64    -1.0499889, -0.9714118, -0.8952766, -0.8213046,
65    -0.7492492, -0.6788902, -0.6100285, -0.5424819,
66    -0.4760822, -0.4106724, -0.3461048, -0.2822386,
67    -0.2189392, -0.1560761, -0.0935225, -0.0311537,
68     0.0311537,  0.0935225,  0.1560761,  0.2189392,
69     0.2822386,  0.3461048,  0.4106724,  0.4760822,
70     0.5424819,  0.6100285,  0.6788902,  0.7492492,
71     0.8213046,  0.8952766,  0.9714118,  1.0499889,
72     1.1313277,  1.2157998,  1.3038428,  1.3959804,
73     1.4928497,  1.5952401,  1.7041502,  1.8208742,
74     1.9471373,  2.0853192,  2.2388464,  2.4129324,
75     2.6161277,  2.8640626,  3.1907215,  3.6996161,
76];
77
78/// 8-bit Lloyd-Max centroids for N(0,1): 256 reconstruction levels.
79/// Range: [-5.0652659, +5.0652659]. Symmetry error: 3.41e-10.
80/// Mirrors `flash_attn_vec_tq_hb.metal::CODEBOOK_HB_8BIT` and
81/// `hadamard_quantize_kv_fast.metal::CODEBOOK_8BIT`.
82pub const CODEBOOK_HB_8BIT: [f32; 256] = [
83    -5.0652659, -4.6836997, -4.4467193, -4.2715508,
84    -4.1311907, -4.0132856, -3.9111092, -3.8205780,
85    -3.7390194, -3.6645851, -3.5959415, -3.5320936,
86    -3.4722785, -3.4158977, -3.3624729, -3.3116156,
87    -3.2630056, -3.2163758, -3.1715011, -3.1281899,
88    -3.0862780, -3.0456229, -3.0061011, -2.9676040,
89    -2.9300362, -2.8933131, -2.8573596, -2.8221086,
90    -2.7874999, -2.7534795, -2.7199985, -2.6870129,
91    -2.6544825, -2.6223710, -2.5906452, -2.5592748,
92    -2.5282321, -2.4974918, -2.4670306, -2.4368270,
93    -2.4068614, -2.3771157, -2.3475732, -2.3182184,
94    -2.2890372, -2.2600165, -2.2311440, -2.2024086,
95    -2.1737998, -2.1453081, -2.1169245, -2.0886408,
96    -2.0604493, -2.0323430, -2.0043154, -1.9763603,
97    -1.9484722, -1.9206458, -1.8928763, -1.8651592,
98    -1.8374904, -1.8098662, -1.7822828, -1.7547372,
99    -1.7272261, -1.6997469, -1.6722970, -1.6448739,
100    -1.6174755, -1.5900996, -1.5627445, -1.5354084,
101    -1.5080897, -1.4807869, -1.4534986, -1.4262237,
102    -1.3989610, -1.3717093, -1.3444678, -1.3172356,
103    -1.2900118, -1.2627956, -1.2355865, -1.2083838,
104    -1.1811868, -1.1539951, -1.1268081, -1.0996255,
105    -1.0724469, -1.0452718, -1.0180999, -0.9909310,
106    -0.9637647, -0.9366008, -0.9094390, -0.8822793,
107    -0.8551212, -0.8279648, -0.8008098, -0.7736561,
108    -0.7465035, -0.7193520, -0.6922014, -0.6650517,
109    -0.6379027, -0.6107544, -0.5836067, -0.5564596,
110    -0.5293129, -0.5021667, -0.4750208, -0.4478753,
111    -0.4207301, -0.3935852, -0.3664405, -0.3392960,
112    -0.3121517, -0.2850076, -0.2578636, -0.2307198,
113    -0.2035761, -0.1764324, -0.1492888, -0.1221453,
114    -0.0950019, -0.0678584, -0.0407151, -0.0135717,
115     0.0135717,  0.0407151,  0.0678584,  0.0950019,
116     0.1221453,  0.1492888,  0.1764324,  0.2035761,
117     0.2307198,  0.2578636,  0.2850076,  0.3121517,
118     0.3392960,  0.3664405,  0.3935852,  0.4207301,
119     0.4478753,  0.4750208,  0.5021667,  0.5293129,
120     0.5564596,  0.5836067,  0.6107544,  0.6379027,
121     0.6650517,  0.6922014,  0.7193520,  0.7465035,
122     0.7736561,  0.8008098,  0.8279648,  0.8551212,
123     0.8822793,  0.9094390,  0.9366008,  0.9637647,
124     0.9909310,  1.0180999,  1.0452718,  1.0724469,
125     1.0996255,  1.1268081,  1.1539951,  1.1811868,
126     1.2083838,  1.2355865,  1.2627956,  1.2900118,
127     1.3172356,  1.3444678,  1.3717093,  1.3989610,
128     1.4262237,  1.4534986,  1.4807869,  1.5080897,
129     1.5354084,  1.5627445,  1.5900996,  1.6174755,
130     1.6448739,  1.6722970,  1.6997469,  1.7272261,
131     1.7547372,  1.7822828,  1.8098662,  1.8374904,
132     1.8651592,  1.8928763,  1.9206458,  1.9484722,
133     1.9763603,  2.0043154,  2.0323430,  2.0604493,
134     2.0886408,  2.1169245,  2.1453081,  2.1737998,
135     2.2024086,  2.2311440,  2.2600165,  2.2890372,
136     2.3182184,  2.3475732,  2.3771157,  2.4068614,
137     2.4368270,  2.4670306,  2.4974918,  2.5282321,
138     2.5592748,  2.5906452,  2.6223710,  2.6544825,
139     2.6870129,  2.7199985,  2.7534795,  2.7874999,
140     2.8221086,  2.8573596,  2.8933131,  2.9300362,
141     2.9676040,  3.0061011,  3.0456229,  3.0862780,
142     3.1281899,  3.1715011,  3.2163758,  3.2630056,
143     3.3116156,  3.3624729,  3.4158977,  3.4722785,
144     3.5320936,  3.5959415,  3.6645851,  3.7390194,
145     3.8205780,  3.9111092,  4.0132856,  4.1311907,
146     4.2715508,  4.4467193,  4.6836997,  5.0652659,
147];
148
149/// HB codebook lookup helper.
150///
151/// Returns the centroid for the given byte index under the specified bit-width.
152/// Mirrors `flash_attn_vec_tq_hb.metal::dequant_hb_single`'s codebook switch.
153///
154/// `bits` must be 5, 6, or 8 (returns 0.0 for any other value — caller should
155/// validate). Index masking matches the kernel (`& 0x1F` for 5-bit, `& 0x3F` for
156/// 6-bit, full byte for 8-bit).
157#[inline]
158pub fn hb_centroid(idx: u8, bits: u32) -> f32 {
159    match bits {
160        5 => CODEBOOK_HB_5BIT[(idx & 0x1F) as usize],
161        6 => CODEBOOK_HB_6BIT[(idx & 0x3F) as usize],
162        8 => CODEBOOK_HB_8BIT[idx as usize],
163        _ => 0.0,
164    }
165}
166
167/// D1 sign mask for the SRHT pre-multiplication, D=256 path.
168///
169/// Verbatim mirror of `hadamard_quantize_kv_fast.metal::TBQ_SIGNS_256`
170/// (lines 25-30). Source: AmesianX `cpy-utils.cuh:158-163`,
171/// sha256=3ef1038e6c232e9519101daa2d6efd637d4c6bfdb29f4ee7101625c39d0ddc89.
172///
173/// Convention: `bit j = (table[j>>3] >> (j&7)) & 1`; bit=1 → sign = -1,
174/// bit=0 → sign = +1 (LSB-first within each byte).
175pub const TBQ_SIGNS_256: [u8; 32] = [
176    0xa7, 0x3b, 0x91, 0xf4, 0x6d, 0xc2, 0x58, 0x0e,
177    0xb3, 0x7f, 0x24, 0xd6, 0x89, 0x45, 0xea, 0x1c,
178    0x63, 0xaf, 0xd8, 0x52, 0x97, 0x0b, 0xe1, 0x3d,
179    0x76, 0xc4, 0x19, 0xfe, 0x4a, 0x85, 0x2c, 0xdb,
180];
181
182/// D1 sign mask for the SRHT pre-multiplication, D=512 path.
183///
184/// Verbatim mirror of `hadamard_quantize_kv_fast.metal::TBQ_SIGNS_512`
185/// (lines 35-44). Source: AmesianX `cpy-utils.cuh:211-220`,
186/// sha256=44f13ce9f6db1edac62f558ee054f9de29cd474fd051362cadcaa98a55745f17.
187pub const TBQ_SIGNS_512: [u8; 64] = [
188    0xa7, 0x3b, 0x91, 0xf4, 0x6d, 0xc2, 0x58, 0x0e,
189    0xb3, 0x7f, 0x24, 0xd6, 0x89, 0x45, 0xea, 0x1c,
190    0x63, 0xaf, 0xd8, 0x52, 0x97, 0x0b, 0xe1, 0x3d,
191    0x76, 0xc4, 0x19, 0xfe, 0x4a, 0x85, 0x2c, 0xdb,
192    0xd3, 0x4e, 0xa8, 0x17, 0x9c, 0x5b, 0xe6, 0x31,
193    0x72, 0xb9, 0x0d, 0xf5, 0x43, 0x8a, 0x6e, 0xc7,
194    0x58, 0x2f, 0x94, 0xe1, 0xb6, 0x3d, 0x0a, 0x7c,
195    0xc5, 0x61, 0xd8, 0x4f, 0xa3, 0x97, 0x1e, 0x85,
196];
197
198/// Apply D1 sign mask in-place per the SRHT convention.
199///
200/// `signs` must have at least `x.len() / 8` bytes (one bit per element).
201/// Sign flip: bit=1 → x[j] *= -1; bit=0 → x[j] unchanged.
202#[inline]
203pub fn apply_d1_sign_mask_inplace(x: &mut [f32], signs: &[u8]) {
204    for j in 0..x.len() {
205        let byte = signs[j >> 3];
206        let bit = (byte >> (j & 7)) & 1;
207        if bit == 1 {
208            x[j] = -x[j];
209        }
210    }
211}
212
213/// Higher-bit (5/6/8-bit) CPU encoder for D=256 — byte-equivalent mirror of
214/// `hadamard_quantize_kv_fast.metal::hadamard_quantize_kv_hb<256>`.
215///
216/// Path C F-0.2 deliverable: produces the exact byte layout that the GPU
217/// kernel writes given the same input vector, so divergence between the
218/// flash_attn_vec_tq_hb GPU kernel and the F-0.1 CPU oracle isolates the
219/// SDPA math (not the codec math).
220///
221/// Steps (mirroring the kernel byte-for-byte):
222/// 1. Apply D1 sign mask (`TBQ_SIGNS_256`).
223/// 2. Apply normalized FWHT (butterfly + 1/sqrt(d) — `fwht_inplace`).
224/// 3. Compute L2 norm of the rotated vector.
225/// 4. If norm > 1e-10: scale elems by `(1/norm) * sqrt(d)` (lift to N(0,1)).
226///    If norm ≤ 1e-10: scale = 0 (matches kernel `inv_norm = 0` branch).
227/// 5. Quantize each element to nearest centroid in the HB codebook for `bits`
228///    (5/6/8). Returns 1 byte per element (byte-packed).
229///
230/// Returns `(packed_indices, norm)`.
231pub fn turboquant_hb_encode_d256(x: &[f32], bits: u32) -> Result<(Vec<u8>, f32), crate::MlxError> {
232    if x.len() != 256 {
233        return Err(crate::MlxError::InvalidArgument(format!(
234            "turboquant_hb_encode_d256 expects head_dim=256, got {}",
235            x.len()
236        )));
237    }
238    if !matches!(bits, 5 | 6 | 8) {
239        return Err(crate::MlxError::InvalidArgument(format!(
240            "turboquant_hb_encode_d256 bits must be 5, 6, or 8, got {bits}"
241        )));
242    }
243
244    // Step 1: D1 sign pre-multiplication.
245    let mut elems = x.to_vec();
246    apply_d1_sign_mask_inplace(&mut elems, &TBQ_SIGNS_256);
247
248    // Step 2+3: normalized FWHT (butterfly + 1/sqrt(d)).
249    fwht_inplace(&mut elems)?;
250
251    // Step 4: L2 norm.
252    let norm_sq: f32 = elems.iter().map(|&v| v * v).sum();
253    let norm = norm_sq.sqrt();
254
255    // Step 5: scale to N(0,1). Kernel uses inv_norm * sqrt(d). Since fwht_inplace
256    // already applied 1/sqrt(d) scaling to elems, the "inv_norm * sqrt(d)" here
257    // means we multiply the post-fwht element by sqrt(d)/norm.
258    // ↳ Match kernel exactly: when norm ≤ 1e-10, scale := 0 (zeros out output).
259    let scale: f32 = if norm > 1.0e-10_f32 {
260        (1.0_f32 / norm) * (256.0_f32).sqrt()
261    } else {
262        0.0_f32
263    };
264    for v in elems.iter_mut() {
265        *v *= scale;
266    }
267
268    // Step 6+7: nearest centroid per element, byte-packed (1 byte per element).
269    let mut packed = Vec::with_capacity(256);
270    for &v in elems.iter() {
271        packed.push(hb_nearest_centroid(v, bits));
272    }
273
274    Ok((packed, norm))
275}
276
277/// HB nearest-centroid encoder (CPU-side mirror of the Metal encoder kernel).
278///
279/// Returns the byte index of the nearest centroid in the codebook for the given
280/// bit-width. Used only for the F-0.1 oracle's encode path and for codec
281/// roundtrip tests; production encode goes through `hadamard_quantize_kv_hb_d*`
282/// Metal kernels.
283///
284/// Returns `0u8` (closest-to-zero centroid) for unsupported bit-widths so the
285/// function stays no-panic; callers are expected to pre-validate `bits`.
286pub fn hb_nearest_centroid(value: f32, bits: u32) -> u8 {
287    let cb: &[f32] = match bits {
288        5 => &CODEBOOK_HB_5BIT,
289        6 => &CODEBOOK_HB_6BIT,
290        8 => &CODEBOOK_HB_8BIT,
291        _ => return 0u8,
292    };
293    let mut best_idx: u32 = 0;
294    let mut best_dist: f32 = (value - cb[0]).abs();
295    for (i, &c) in cb.iter().enumerate().skip(1) {
296        let dist = (value - c).abs();
297        if dist < best_dist {
298            best_dist = dist;
299            best_idx = i as u32;
300        }
301    }
302    best_idx as u8
303}
304
305// ---- BitWidth enum ----
306
307/// Quantization bit-width for TurboQuant.
308#[derive(Debug, Clone, Copy, PartialEq, Eq)]
309pub enum BitWidth {
310    /// 2-bit uniform: all coordinates use 4-level codebook.
311    Two,
312    /// 3-bit uniform: all coordinates use 8-level codebook.
313    Three,
314    /// 4-bit uniform: all coordinates use 16-level codebook.
315    Four,
316    /// 2.5-bit mixed: first d/4 coordinates at 3-bit, remaining 3d/4 at 2-bit.
317    TwoPointFive,
318}
319
320/// Configuration for TurboQuant quantization.
321#[derive(Debug, Clone)]
322pub struct TurboQuantConfig {
323    /// Quantization bit-width.
324    pub bit_width: BitWidth,
325    /// Head dimension (must be a power of 2: 128, 256, or 512).
326    pub head_dim: usize,
327}
328
329// ---- Fast Walsh-Hadamard Transform ----
330
331/// In-place normalized Fast Walsh-Hadamard Transform.
332///
333/// The normalization ensures H * H = I, so the inverse transform is the
334/// same function applied again.
335///
336/// # Arguments
337/// * `x` — mutable slice of length `n` where `n` is a power of 2.
338///
339/// # Returns
340/// `Ok(())` on success, or an error if the length is not a power of 2.
341pub fn fwht_inplace(x: &mut [f32]) -> crate::Result<()> {
342    let n = x.len();
343    if n == 0 || !n.is_power_of_two() {
344        return Err(crate::MlxError::InvalidArgument(format!(
345            "FWHT requires power-of-two length, got {n}"
346        )));
347    }
348
349    let mut h = 1;
350    while h < n {
351        let step = h * 2;
352        let mut i = 0;
353        while i < n {
354            for j in i..i + h {
355                let a = x[j];
356                let b = x[j + h];
357                x[j] = a + b;
358                x[j + h] = a - b;
359            }
360            i += step;
361        }
362        h *= 2;
363    }
364
365    // Normalize so that H * H = I
366    let scale = 1.0 / (n as f32).sqrt();
367    for v in x.iter_mut() {
368        *v *= scale;
369    }
370
371    Ok(())
372}
373
374// ---- Standard Normal PDF / CDF ----
375
376/// Standard normal probability density function: phi(x) = exp(-x^2/2) / sqrt(2*pi).
377#[inline]
378fn std_normal_pdf(x: f64) -> f64 {
379    const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7; // 1/sqrt(2*pi)
380    INV_SQRT_2PI * (-0.5 * x * x).exp()
381}
382
383/// Standard normal CDF using the Abramowitz & Stegun rational approximation
384/// (formula 26.2.17, maximum error < 7.5e-8).
385#[inline]
386fn std_normal_cdf(x: f64) -> f64 {
387    if x < -8.0 {
388        return 0.0;
389    }
390    if x > 8.0 {
391        return 1.0;
392    }
393
394    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
395    let x_abs = x.abs();
396
397    // Horner form of the rational approximation
398    const P: f64 = 0.231_641_9;
399    const B1: f64 = 0.319_381_530;
400    const B2: f64 = -0.356_563_782;
401    const B3: f64 = 1.781_477_937;
402    const B4: f64 = -1.821_255_978;
403    const B5: f64 = 1.330_274_429;
404
405    let t = 1.0 / (1.0 + P * x_abs);
406    let t2 = t * t;
407    let t3 = t2 * t;
408    let t4 = t3 * t;
409    let t5 = t4 * t;
410
411    let poly = B1 * t + B2 * t2 + B3 * t3 + B4 * t4 + B5 * t5;
412    let phi = std_normal_pdf(x_abs);
413
414    let result = 1.0 - phi * poly;
415
416    if sign < 0.0 {
417        1.0 - result
418    } else {
419        result
420    }
421}
422
423// ---- Nearest centroid lookup ----
424
425/// Find the index of the nearest centroid in a sorted codebook.
426#[inline]
427fn nearest_centroid(value: f32, codebook: &[f32]) -> u8 {
428    // Binary-search style: codebook is sorted, find nearest by checking boundaries
429    let n = codebook.len();
430    if n <= 1 {
431        return 0;
432    }
433
434    let mut best_idx = 0u8;
435    let mut best_dist = (value - codebook[0]).abs();
436
437    for (i, &c) in codebook.iter().enumerate().skip(1) {
438        let dist = (value - c).abs();
439        if dist < best_dist {
440            best_dist = dist;
441            best_idx = i as u8;
442        }
443    }
444    best_idx
445}
446
447/// Get the codebook for a specific coordinate index under the given config.
448#[inline]
449fn codebook_for_coord(coord_idx: usize, config: &TurboQuantConfig) -> &'static [f32] {
450    match config.bit_width {
451        BitWidth::Two => &CODEBOOK_2BIT,
452        BitWidth::Three => &CODEBOOK_3BIT,
453        BitWidth::Four => &CODEBOOK_4BIT,
454        BitWidth::TwoPointFive => {
455            let boundary = config.head_dim / 4;
456            if coord_idx < boundary {
457                &CODEBOOK_3BIT // first d/4 channels at 3-bit
458            } else {
459                &CODEBOOK_2BIT // remaining 3d/4 at 2-bit
460            }
461        }
462    }
463}
464
465/// Bits per index for a coordinate under the given config.
466#[inline]
467fn bits_for_coord(coord_idx: usize, config: &TurboQuantConfig) -> usize {
468    match config.bit_width {
469        BitWidth::Two => 2,
470        BitWidth::Three => 3,
471        BitWidth::Four => 4,
472        BitWidth::TwoPointFive => {
473            if coord_idx < config.head_dim / 4 {
474                3
475            } else {
476                2
477            }
478        }
479    }
480}
481
482// ---- Pack / Unpack indices ----
483
484/// Pack variable-width indices into a byte vector using bit-packing.
485///
486/// Indices are packed MSB-first into consecutive bytes.
487fn pack_indices(indices: &[u8], config: &TurboQuantConfig) -> Vec<u8> {
488    let total_bits: usize = (0..indices.len())
489        .map(|i| bits_for_coord(i, config))
490        .sum();
491    let num_bytes = (total_bits + 7) / 8;
492    let mut packed = vec![0u8; num_bytes];
493
494    let mut bit_offset = 0usize;
495    for (i, &idx) in indices.iter().enumerate() {
496        let nbits = bits_for_coord(i, config);
497        // Write `nbits` bits of `idx` starting at `bit_offset`
498        for b in (0..nbits).rev() {
499            let bit_val = (idx >> b) & 1;
500            let byte_pos = bit_offset / 8;
501            let bit_pos = 7 - (bit_offset % 8);
502            if byte_pos < packed.len() {
503                packed[byte_pos] |= bit_val << bit_pos;
504            }
505            bit_offset += 1;
506        }
507    }
508
509    packed
510}
511
512/// Unpack variable-width indices from a packed byte vector.
513fn unpack_indices(packed: &[u8], config: &TurboQuantConfig) -> Vec<u8> {
514    let d = config.head_dim;
515    let mut indices = Vec::with_capacity(d);
516
517    let mut bit_offset = 0usize;
518    for i in 0..d {
519        let nbits = bits_for_coord(i, config);
520        let mut val = 0u8;
521        for _ in 0..nbits {
522            let byte_pos = bit_offset / 8;
523            let bit_pos = 7 - (bit_offset % 8);
524            let bit_val = if byte_pos < packed.len() {
525                (packed[byte_pos] >> bit_pos) & 1
526            } else {
527                0
528            };
529            val = (val << 1) | bit_val;
530            bit_offset += 1;
531        }
532        indices.push(val);
533    }
534
535    indices
536}
537
538// ---- Quantize / Dequantize ----
539
540/// Quantize a single head vector using TurboQuant_mse.
541///
542/// Steps:
543/// 1. Apply FWHT (Walsh-Hadamard rotation) for incoherence
544/// 2. Extract L2 norm
545/// 3. Normalize to unit vector
546/// 4. Quantize each coordinate against the appropriate Lloyd-Max codebook
547/// 5. Pack indices
548///
549/// # Arguments
550/// * `x` — input vector of length `config.head_dim`
551/// * `config` — quantization configuration
552///
553/// # Returns
554/// `(packed_indices, norm)` on success.
555pub fn turboquant_quantize(
556    x: &[f32],
557    config: &TurboQuantConfig,
558) -> crate::Result<(Vec<u8>, f32)> {
559    let d = config.head_dim;
560    if x.len() != d {
561        return Err(crate::MlxError::InvalidArgument(format!(
562            "Expected vector of length {d}, got {}",
563            x.len()
564        )));
565    }
566    if !d.is_power_of_two() {
567        return Err(crate::MlxError::InvalidArgument(format!(
568            "head_dim must be power of 2, got {d}"
569        )));
570    }
571
572    // 1. Copy and apply FWHT
573    let mut rotated = x.to_vec();
574    fwht_inplace(&mut rotated)?;
575
576    // 2. Compute L2 norm of rotated vector (same as original since Hadamard is orthogonal)
577    let norm_sq: f32 = rotated.iter().map(|&v| v * v).sum();
578    let norm = norm_sq.sqrt();
579
580    if norm < 1e-30 {
581        // Zero vector: all indices = 0, norm = 0
582        let indices = vec![0u8; d];
583        let packed = pack_indices(&indices, config);
584        return Ok((packed, 0.0));
585    }
586
587    // 3. Normalize to unit vector on S^{d-1}
588    let inv_norm = 1.0 / norm;
589    for v in rotated.iter_mut() {
590        *v *= inv_norm;
591    }
592
593    // 4. Quantize: each coordinate needs to be scaled to N(0,1) domain.
594    // A unit vector on S^{d-1} has coordinates ~ N(0, 1/d) for large d.
595    // Scale by sqrt(d) to map to N(0,1) for codebook lookup.
596    let scale = (d as f32).sqrt();
597    let mut indices = Vec::with_capacity(d);
598    for (i, &v) in rotated.iter().enumerate() {
599        let scaled = v * scale;
600        let cb = codebook_for_coord(i, config);
601        indices.push(nearest_centroid(scaled, cb));
602    }
603
604    // 5. Pack
605    let packed = pack_indices(&indices, config);
606
607    Ok((packed, norm))
608}
609
610/// Dequantize a TurboQuant-compressed head vector.
611///
612/// Steps:
613/// 1. Unpack indices
614/// 2. Look up centroid values, scale back from N(0,1) domain
615/// 3. Multiply by norm
616/// 4. Apply inverse FWHT (same as forward)
617///
618/// # Arguments
619/// * `packed` — packed index bytes
620/// * `norm` — the L2 norm stored during quantization
621/// * `config` — quantization configuration
622///
623/// # Returns
624/// Reconstructed vector of length `config.head_dim`.
625pub fn turboquant_dequantize(
626    packed: &[u8],
627    norm: f32,
628    config: &TurboQuantConfig,
629) -> crate::Result<Vec<f32>> {
630    let d = config.head_dim;
631    if !d.is_power_of_two() {
632        return Err(crate::MlxError::InvalidArgument(format!(
633            "head_dim must be power of 2, got {d}"
634        )));
635    }
636
637    // 1. Unpack indices
638    let indices = unpack_indices(packed, config);
639
640    // 2. Look up centroids and scale back from N(0,1) to unit-sphere scale
641    let inv_scale = 1.0 / (d as f32).sqrt();
642    let mut reconstructed = Vec::with_capacity(d);
643    for (i, &idx) in indices.iter().enumerate() {
644        let cb = codebook_for_coord(i, config);
645        let idx_usize = idx as usize;
646        let centroid = if idx_usize < cb.len() {
647            cb[idx_usize]
648        } else {
649            0.0 // fallback for out-of-range (shouldn't happen)
650        };
651        reconstructed.push(centroid * inv_scale * norm);
652    }
653
654    // 3. Apply inverse FWHT (same as forward since H^{-1} = H with normalization)
655    fwht_inplace(&mut reconstructed)?;
656
657    Ok(reconstructed)
658}
659
660// ---- Lloyd-Max computation utilities (used by tests for validation) ----
661
662/// Compute Lloyd-Max codebook for N(0,1) with the given number of levels.
663///
664/// Returns the sorted centroid array. This is used in tests to validate the
665/// hardcoded codebooks.
666pub fn compute_lloyd_max_codebook(num_levels: usize) -> Vec<f64> {
667    // Initialize with uniform quantile boundaries
668    let mut boundaries = Vec::with_capacity(num_levels + 1);
669    boundaries.push(-10.0_f64); // approx -inf
670    for i in 1..num_levels {
671        let p = i as f64 / num_levels as f64;
672        boundaries.push(quantile_normal(p));
673    }
674    boundaries.push(10.0_f64); // approx +inf
675
676    // Initial centroids from conditional expectations
677    let mut centroids = vec![0.0_f64; num_levels];
678    for i in 0..num_levels {
679        let a = boundaries[i];
680        let b = boundaries[i + 1];
681        let prob = std_normal_cdf(b) - std_normal_cdf(a);
682        if prob > 1e-30 {
683            centroids[i] = (std_normal_pdf(a) - std_normal_pdf(b)) / prob;
684        }
685    }
686
687    // Iterate
688    for _iter in 0..50_000 {
689        let old = centroids.clone();
690
691        // Update boundaries to midpoints
692        boundaries[0] = -10.0;
693        for i in 1..num_levels {
694            boundaries[i] = (centroids[i - 1] + centroids[i]) / 2.0;
695        }
696        *boundaries.last_mut().unwrap_or(&mut 0.0) = 10.0;
697
698        // Update centroids
699        for i in 0..num_levels {
700            let a = boundaries[i];
701            let b = boundaries[i + 1];
702            let prob = std_normal_cdf(b) - std_normal_cdf(a);
703            if prob > 1e-30 {
704                centroids[i] = (std_normal_pdf(a) - std_normal_pdf(b)) / prob;
705            }
706        }
707
708        // Check convergence
709        let max_change = centroids
710            .iter()
711            .zip(old.iter())
712            .map(|(a, b)| (a - b).abs())
713            .fold(0.0_f64, f64::max);
714        if max_change < 1e-12 {
715            break;
716        }
717    }
718
719    centroids
720}
721
722/// Approximate quantile (inverse CDF) of N(0,1) using rational approximation.
723///
724/// Uses the Beasley-Springer-Moro algorithm.
725fn quantile_normal(p: f64) -> f64 {
726    if p <= 0.0 {
727        return -10.0;
728    }
729    if p >= 1.0 {
730        return 10.0;
731    }
732
733    // Rational approximation (Peter Acklam's algorithm)
734    const A: [f64; 6] = [
735        -3.969683028665376e1,
736        2.209460984245205e2,
737        -2.759285104469687e2,
738        1.383577518672690e2,
739        -3.066479806614716e1,
740        2.506628277459239e0,
741    ];
742    const B: [f64; 5] = [
743        -5.447609879822406e1,
744        1.615858368580409e2,
745        -1.556989798598866e2,
746        6.680131188771972e1,
747        -1.328068155288572e1,
748    ];
749    const C: [f64; 6] = [
750        -7.784894002430293e-3,
751        -3.223964580411365e-1,
752        -2.400758277161838e0,
753        -2.549732539343734e0,
754        4.374664141464968e0,
755        2.938163982698783e0,
756    ];
757    const D: [f64; 4] = [
758        7.784695709041462e-3,
759        3.224671290700398e-1,
760        2.445134137142996e0,
761        3.754408661907416e0,
762    ];
763
764    const P_LOW: f64 = 0.02425;
765    const P_HIGH: f64 = 1.0 - P_LOW;
766
767    if p < P_LOW {
768        let q = (-2.0 * p.ln()).sqrt();
769        (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
770            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
771    } else if p <= P_HIGH {
772        let q = p - 0.5;
773        let r = q * q;
774        (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
775            / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
776    } else {
777        let q = (-2.0 * (1.0 - p).ln()).sqrt();
778        -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
779            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
780    }
781}
782
783/// Compute Lloyd-Max codebook for Beta((d-1)/2, (d-1)/2) scaled to [-1, 1].
784///
785/// The exact distribution of a coordinate of a unit vector uniform on S^{d-1}
786/// is Beta((d-1)/2, (d-1)/2) on [-1, 1]. For large d this converges to N(0, 1/d).
787///
788/// Uses numerical integration via trapezoidal rule for the conditional expectations.
789pub fn compute_lloyd_max_beta_codebook(dim: usize, num_levels: usize) -> Vec<f64> {
790    let alpha = (dim as f64 - 1.0) / 2.0;
791
792    // Beta PDF on [-1,1] with parameters (alpha, alpha) — symmetric
793    // f(x) = C * (1-x^2)^(alpha-1)  for x in [-1, 1]
794    // where C normalizes to 1.
795
796    // Use log-space for numerical stability
797    let log_norm = log_beta_norm_const(alpha);
798
799    let beta_pdf = |x: f64| -> f64 {
800        if x <= -1.0 || x >= 1.0 {
801            return 0.0;
802        }
803        let val = 1.0 - x * x;
804        if val <= 0.0 {
805            return 0.0;
806        }
807        (log_norm + (alpha - 1.0) * val.ln()).exp()
808    };
809
810    // Numerical CDF via cumulative trapezoidal integration
811    let n_grid = 10_000;
812    let grid_lo = -1.0_f64;
813    let grid_hi = 1.0_f64;
814    let dx = (grid_hi - grid_lo) / n_grid as f64;
815
816    // Build CDF table
817    let mut cdf_vals = vec![0.0_f64; n_grid + 1];
818    let mut pdf_vals = vec![0.0_f64; n_grid + 1];
819    for i in 0..=n_grid {
820        let x = grid_lo + i as f64 * dx;
821        pdf_vals[i] = beta_pdf(x);
822    }
823    for i in 1..=n_grid {
824        cdf_vals[i] = cdf_vals[i - 1] + 0.5 * (pdf_vals[i - 1] + pdf_vals[i]) * dx;
825    }
826    // Normalize CDF to [0, 1]
827    let cdf_total = cdf_vals[n_grid];
828    if cdf_total > 1e-30 {
829        for v in cdf_vals.iter_mut() {
830            *v /= cdf_total;
831        }
832        for v in pdf_vals.iter_mut() {
833            *v /= cdf_total;
834        }
835    }
836
837    // Helper: interpolated CDF and conditional expectation on [a, b]
838    let interp_cdf = |x: f64| -> f64 {
839        let frac = (x - grid_lo) / dx;
840        let idx = frac as usize;
841        if idx >= n_grid {
842            return 1.0;
843        }
844        let t = frac - idx as f64;
845        cdf_vals[idx] * (1.0 - t) + cdf_vals[idx + 1] * t
846    };
847
848    let conditional_expectation = |a: f64, b: f64| -> f64 {
849        // E[X | a <= X <= b] via numerical integration
850        let prob = interp_cdf(b) - interp_cdf(a);
851        if prob < 1e-30 {
852            return (a + b) / 2.0;
853        }
854
855        let n_sub = 500;
856        let sub_dx = (b - a) / n_sub as f64;
857        let mut integral = 0.0_f64;
858        for j in 0..=n_sub {
859            let x = a + j as f64 * sub_dx;
860            let w = if j == 0 || j == n_sub { 0.5 } else { 1.0 };
861            let frac = (x - grid_lo) / dx;
862            let idx = frac as usize;
863            let pdf_val = if idx >= n_grid {
864                0.0
865            } else {
866                let t = frac - idx as f64;
867                pdf_vals[idx] * (1.0 - t) + pdf_vals[idx + 1] * t
868            };
869            integral += w * x * pdf_val * sub_dx;
870        }
871        integral / prob
872    };
873
874    // Initialize with uniform quantile boundaries
875    let mut boundaries = Vec::with_capacity(num_levels + 1);
876    boundaries.push(-1.0_f64);
877    for i in 1..num_levels {
878        let target_p = i as f64 / num_levels as f64;
879        // Binary search for quantile
880        let mut lo = -1.0_f64;
881        let mut hi = 1.0_f64;
882        for _ in 0..100 {
883            let mid = (lo + hi) / 2.0;
884            if interp_cdf(mid) < target_p {
885                lo = mid;
886            } else {
887                hi = mid;
888            }
889        }
890        boundaries.push((lo + hi) / 2.0);
891    }
892    boundaries.push(1.0_f64);
893
894    // Initial centroids
895    let mut centroids = vec![0.0_f64; num_levels];
896    for i in 0..num_levels {
897        centroids[i] = conditional_expectation(boundaries[i], boundaries[i + 1]);
898    }
899
900    // Lloyd-Max iteration
901    for _iter in 0..5000 {
902        let old = centroids.clone();
903
904        // Update boundaries
905        boundaries[0] = -1.0;
906        for i in 1..num_levels {
907            boundaries[i] = (centroids[i - 1] + centroids[i]) / 2.0;
908        }
909        if let Some(last) = boundaries.last_mut() {
910            *last = 1.0;
911        }
912
913        // Update centroids
914        for i in 0..num_levels {
915            centroids[i] = conditional_expectation(boundaries[i], boundaries[i + 1]);
916        }
917
918        let max_change = centroids
919            .iter()
920            .zip(old.iter())
921            .map(|(a, b)| (a - b).abs())
922            .fold(0.0_f64, f64::max);
923        if max_change < 1e-10 {
924            break;
925        }
926    }
927
928    centroids
929}
930
931/// Log of the normalization constant for the symmetric Beta PDF on [-1, 1].
932fn log_beta_norm_const(alpha: f64) -> f64 {
933    // Beta(alpha, alpha) on [0,1] has norm B(alpha, alpha) = Gamma(alpha)^2 / Gamma(2*alpha)
934    // On [-1,1] we scale by 1/2, so norm = B(alpha,alpha) * 2^(2*alpha-1)
935    // log C = -log(B(alpha,alpha)) - (2*alpha-1)*log(2)
936    //       = log(Gamma(2*alpha)) - 2*log(Gamma(alpha)) - (2*alpha-1)*log(2)
937    ln_gamma(2.0 * alpha) - 2.0 * ln_gamma(alpha) - (2.0 * alpha - 1.0) * 2.0_f64.ln()
938}
939
940/// Lanczos approximation for ln(Gamma(x)), x > 0.
941fn ln_gamma(x: f64) -> f64 {
942    // Lanczos approximation with g=7, n=9
943    const G: f64 = 7.0;
944    const COEFF: [f64; 9] = [
945        0.999_999_999_999_809_93,
946        676.520_368_121_885_1,
947        -1_259.139_216_722_402_9,
948        771.323_428_777_653_1,
949        -176.615_029_162_140_6,
950        12.507_343_278_686_905,
951        -0.138_571_095_265_720_12,
952        9.984_369_578_019_571_6e-6,
953        1.505_632_735_149_311_6e-7,
954    ];
955
956    if x < 0.5 {
957        // Reflection formula
958        let pi = std::f64::consts::PI;
959        return pi.ln() - (pi * x).sin().ln() - ln_gamma(1.0 - x);
960    }
961
962    let x = x - 1.0;
963    let mut ag = COEFF[0];
964    for i in 1..9 {
965        ag += COEFF[i] / (x + i as f64);
966    }
967
968    let tmp = x + G + 0.5;
969    0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * tmp.ln() - tmp + ag.ln()
970}
971
972#[cfg(test)]
973#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
974mod tests {
975    use super::*;
976
977    #[test]
978    fn test_codebook_symmetry() {
979        for (name, cb) in [
980            ("2-bit", &CODEBOOK_2BIT[..]),
981            ("3-bit", &CODEBOOK_3BIT[..]),
982            ("4-bit", &CODEBOOK_4BIT[..]),
983        ] {
984            let n = cb.len();
985            for i in 0..n / 2 {
986                let sum = cb[i] + cb[n - 1 - i];
987                assert!(
988                    sum.abs() < 1e-5,
989                    "{name} codebook not symmetric: c[{i}]={} + c[{}]={} = {sum}",
990                    cb[i],
991                    n - 1 - i,
992                    cb[n - 1 - i]
993                );
994            }
995        }
996    }
997
998    #[test]
999    fn test_codebook_values_match_lloyd_max() {
1000        for (bits, hardcoded) in [
1001            (2, &CODEBOOK_2BIT[..]),
1002            (3, &CODEBOOK_3BIT[..]),
1003            (4, &CODEBOOK_4BIT[..]),
1004        ] {
1005            let computed = compute_lloyd_max_codebook(1 << bits);
1006            assert_eq!(computed.len(), hardcoded.len());
1007            for (i, (&h, &c)) in hardcoded.iter().zip(computed.iter()).enumerate() {
1008                let diff = (h as f64 - c).abs();
1009                assert!(
1010                    diff < 1e-4,
1011                    "{bits}-bit codebook mismatch at {i}: hardcoded={h}, computed={c}, diff={diff}"
1012                );
1013            }
1014        }
1015    }
1016
1017    #[test]
1018    fn test_fwht_roundtrip() {
1019        let original: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1 - 6.4).collect();
1020        let mut data = original.clone();
1021        fwht_inplace(&mut data).unwrap();
1022        fwht_inplace(&mut data).unwrap();
1023        for (i, (&a, &b)) in original.iter().zip(data.iter()).enumerate() {
1024            assert!(
1025                (a - b).abs() < 1e-4,
1026                "FWHT roundtrip mismatch at {i}: {a} vs {b}"
1027            );
1028        }
1029    }
1030
1031    // ----- ADR-007 Path C F-0.2 tests: HB encoder mirror correctness -----
1032
1033    fn deterministic_gaussian_test(seed: u64, n: usize) -> Vec<f32> {
1034        let mut state = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
1035        let next_u32 = |s: &mut u64| -> u32 {
1036            *s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
1037            (*s >> 32) as u32
1038        };
1039        let next_f32 = |s: &mut u64| -> f32 {
1040            let bits = next_u32(s);
1041            ((bits as f64 + 0.5) / (u32::MAX as f64 + 1.0)) as f32
1042        };
1043        let mut out = Vec::with_capacity(n);
1044        while out.len() < n {
1045            let u1 = next_f32(&mut state).max(1e-7).min(1.0 - 1e-7);
1046            let u2 = next_f32(&mut state);
1047            let r = (-2.0_f32 * u1.ln()).sqrt();
1048            let theta = 2.0_f32 * std::f32::consts::PI * u2;
1049            out.push(r * theta.cos());
1050            if out.len() < n {
1051                out.push(r * theta.sin());
1052            }
1053        }
1054        out
1055    }
1056
1057    /// Decode a single packed-byte row back to F32 via the same dequant formula
1058    /// the GPU kernel uses on the read side, then invert SRHT (FWHT + sign mask).
1059    /// Used only by tests to verify encoder roundtrip.
1060    fn decode_d256_via_kernel_formula(packed: &[u8], norm: f32, bits: u32) -> Vec<f32> {
1061        // Step 1: codebook lookup × norm × inv_sqrt(256), per kernel decoder math.
1062        let inv_sqrt_dk = 1.0_f32 / (256.0_f32).sqrt();
1063        let mut decoded: Vec<f32> = packed.iter()
1064            .map(|&idx| hb_centroid(idx, bits) * norm * inv_sqrt_dk)
1065            .collect();
1066        // Step 2: inverse normalized FWHT = same FWHT (involution under H * H = I).
1067        fwht_inplace(&mut decoded).expect("fwht ok");
1068        // Step 3: invert D1 sign mask (sign flip is its own inverse).
1069        apply_d1_sign_mask_inplace(&mut decoded, &TBQ_SIGNS_256);
1070        decoded
1071    }
1072
1073    fn nrmse(a: &[f32], b: &[f32]) -> f32 {
1074        let mut sse: f64 = 0.0;
1075        let mut sse_a: f64 = 0.0;
1076        for (&av, &bv) in a.iter().zip(b.iter()) {
1077            let d = (av - bv) as f64;
1078            sse += d * d;
1079            sse_a += (av as f64) * (av as f64);
1080        }
1081        if sse_a < 1e-30 {
1082            return 0.0;
1083        }
1084        (sse / sse_a).sqrt() as f32
1085    }
1086
1087    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1088        let mut dot: f64 = 0.0;
1089        let mut na: f64 = 0.0;
1090        let mut nb: f64 = 0.0;
1091        for (&av, &bv) in a.iter().zip(b.iter()) {
1092            dot += (av as f64) * (bv as f64);
1093            na += (av as f64) * (av as f64);
1094            nb += (bv as f64) * (bv as f64);
1095        }
1096        if na < 1e-30 || nb < 1e-30 {
1097            return 1.0;
1098        }
1099        (dot / (na.sqrt() * nb.sqrt())) as f32
1100    }
1101
1102    /// Encoder roundtrip via the kernel's dequant formula: encode → dequant →
1103    /// compare to original. Cosine ≥ ADR-007 Gate A threshold (≥0.999).
1104    /// 8-bit close-section measured 0.9998 mean / 0.9986 p1.
1105    #[test]
1106    fn hb_encoder_d256_roundtrip_8bit_meets_gate_a() {
1107        // 8-bit Gate A close-section measurement: cosine mean 0.9998.
1108        // Synthetic Gaussian sample isn't the production distribution but should
1109        // clear the strict spec on a single vector. Threshold ≥0.998 leaves
1110        // headroom for sampling noise on a single 256-vector.
1111        let x = deterministic_gaussian_test(0xC25EED, 256);
1112        let (packed, norm) = turboquant_hb_encode_d256(&x, 8).expect("encode");
1113        let recon = decode_d256_via_kernel_formula(&packed, norm, 8);
1114        let cos = cosine_similarity(&x, &recon);
1115        let nrmse_v = nrmse(&x, &recon);
1116        assert!(cos >= 0.998, "8-bit roundtrip cosine {cos} < 0.998");
1117        assert!(nrmse_v <= 0.07, "8-bit roundtrip NRMSE {nrmse_v} > 0.07");
1118    }
1119
1120    #[test]
1121    fn hb_encoder_d256_roundtrip_5bit_within_band() {
1122        // 5-bit close-section: not shippable as default; expected wider gap.
1123        let x = deterministic_gaussian_test(0xC25EED, 256);
1124        let (packed, norm) = turboquant_hb_encode_d256(&x, 5).expect("encode");
1125        let recon = decode_d256_via_kernel_formula(&packed, norm, 5);
1126        let cos = cosine_similarity(&x, &recon);
1127        // 5-bit Lloyd-Max MSE ≈ 0.0095 → cosine ≈ 0.99. Allow small headroom.
1128        assert!(cos >= 0.985, "5-bit roundtrip cosine {cos} < 0.985");
1129    }
1130
1131    #[test]
1132    fn hb_encoder_d256_is_deterministic() {
1133        let x = deterministic_gaussian_test(0xBEEF, 256);
1134        let (p_a, n_a) = turboquant_hb_encode_d256(&x, 8).expect("a");
1135        let (p_b, n_b) = turboquant_hb_encode_d256(&x, 8).expect("b");
1136        assert_eq!(p_a, p_b);
1137        assert_eq!(n_a.to_bits(), n_b.to_bits());
1138    }
1139
1140    #[test]
1141    fn hb_encoder_d256_zero_vector() {
1142        // Mantra: if norm <= 1e-10 kernel sets scale = 0. Then every elem = 0,
1143        // which dequants to centroid index 127 or 128 (closest-to-zero 8-bit).
1144        // The norm written is 0. Decode should yield ~0 vector.
1145        let x = vec![0.0_f32; 256];
1146        let (packed, norm) = turboquant_hb_encode_d256(&x, 8).expect("encode");
1147        assert_eq!(norm, 0.0);
1148        // All packed bytes should be the centroid closest to zero (idx 127 or 128).
1149        for &b in packed.iter() {
1150            assert!(b == 127 || b == 128,
1151                "zero-vec encode produced non-near-zero centroid: {b}");
1152        }
1153        // Roundtrip: norm=0 means decoder produces all-zero output (× 0 = 0).
1154        let recon = decode_d256_via_kernel_formula(&packed, 0.0, 8);
1155        for &v in recon.iter() {
1156            assert_eq!(v, 0.0);
1157        }
1158    }
1159
1160    #[test]
1161    fn hb_encoder_d256_validates_bits() {
1162        let x = vec![0.0_f32; 256];
1163        assert!(turboquant_hb_encode_d256(&x, 4).is_err()); // 4-bit not HB
1164        assert!(turboquant_hb_encode_d256(&x, 7).is_err()); // invalid
1165    }
1166
1167    #[test]
1168    fn hb_encoder_d256_validates_size() {
1169        let x = vec![0.0_f32; 128]; // wrong size
1170        assert!(turboquant_hb_encode_d256(&x, 8).is_err());
1171    }
1172
1173    #[test]
1174    fn d1_sign_mask_is_self_inverse() {
1175        let mut x = deterministic_gaussian_test(0x123, 256);
1176        let original = x.clone();
1177        apply_d1_sign_mask_inplace(&mut x, &TBQ_SIGNS_256);
1178        // After one application, must differ.
1179        let differs = x.iter().zip(original.iter()).any(|(&a, &b)| (a - b).abs() > 1e-6);
1180        assert!(differs, "D1 sign mask had no effect");
1181        // After two applications, must equal original (sign flip is its own inverse).
1182        apply_d1_sign_mask_inplace(&mut x, &TBQ_SIGNS_256);
1183        for (i, (&a, &b)) in x.iter().zip(original.iter()).enumerate() {
1184            assert!((a - b).abs() < 1e-6, "D1 sign mask not self-inverse at {i}");
1185        }
1186    }
1187
1188    #[test]
1189    fn tbq_signs_first_32_bytes_match_512_prefix() {
1190        // The shader's two sign tables share their first 32 bytes (verified
1191        // visually in hadamard_quantize_kv_fast.metal:25-30 vs 35-44). This
1192        // is load-bearing for cross-D=256/D=512 codec equivalence proofs.
1193        for i in 0..32 {
1194            assert_eq!(TBQ_SIGNS_256[i], TBQ_SIGNS_512[i],
1195                "TBQ_SIGNS_256/512 prefix mismatch at byte {i}");
1196        }
1197    }
1198}