1pub const CODEBOOK_2BIT: [f32; 4] = [
17 -1.5104176, -0.4527800, 0.4527800, 1.5104176,
18];
19
20pub const CODEBOOK_3BIT: [f32; 8] = [
22 -2.1519457, -1.3439093, -0.7560053, -0.2450942,
23 0.2450942, 0.7560053, 1.3439093, 2.1519457,
24];
25
26pub const CODEBOOK_4BIT: [f32; 16] = [
28 -2.7325896, -2.0690172, -1.6180464, -1.2562312,
29 -0.9423405, -0.6567591, -0.3880483, -0.1283950,
30 0.1283950, 0.3880483, 0.6567591, 0.9423405,
31 1.2562312, 1.6180464, 2.0690172, 2.7325896,
32];
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum BitWidth {
39 Two,
41 Three,
43 Four,
45 TwoPointFive,
47}
48
49#[derive(Debug, Clone)]
51pub struct TurboQuantConfig {
52 pub bit_width: BitWidth,
54 pub head_dim: usize,
56}
57
58pub fn fwht_inplace(x: &mut [f32]) -> crate::Result<()> {
71 let n = x.len();
72 if n == 0 || !n.is_power_of_two() {
73 return Err(crate::MlxError::InvalidArgument(format!(
74 "FWHT requires power-of-two length, got {n}"
75 )));
76 }
77
78 let mut h = 1;
79 while h < n {
80 let step = h * 2;
81 let mut i = 0;
82 while i < n {
83 for j in i..i + h {
84 let a = x[j];
85 let b = x[j + h];
86 x[j] = a + b;
87 x[j + h] = a - b;
88 }
89 i += step;
90 }
91 h *= 2;
92 }
93
94 let scale = 1.0 / (n as f32).sqrt();
96 for v in x.iter_mut() {
97 *v *= scale;
98 }
99
100 Ok(())
101}
102
103#[inline]
107fn std_normal_pdf(x: f64) -> f64 {
108 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7; INV_SQRT_2PI * (-0.5 * x * x).exp()
110}
111
112#[inline]
115fn std_normal_cdf(x: f64) -> f64 {
116 if x < -8.0 {
117 return 0.0;
118 }
119 if x > 8.0 {
120 return 1.0;
121 }
122
123 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
124 let x_abs = x.abs();
125
126 const P: f64 = 0.231_641_9;
128 const B1: f64 = 0.319_381_530;
129 const B2: f64 = -0.356_563_782;
130 const B3: f64 = 1.781_477_937;
131 const B4: f64 = -1.821_255_978;
132 const B5: f64 = 1.330_274_429;
133
134 let t = 1.0 / (1.0 + P * x_abs);
135 let t2 = t * t;
136 let t3 = t2 * t;
137 let t4 = t3 * t;
138 let t5 = t4 * t;
139
140 let poly = B1 * t + B2 * t2 + B3 * t3 + B4 * t4 + B5 * t5;
141 let phi = std_normal_pdf(x_abs);
142
143 let result = 1.0 - phi * poly;
144
145 if sign < 0.0 {
146 1.0 - result
147 } else {
148 result
149 }
150}
151
152#[inline]
156fn nearest_centroid(value: f32, codebook: &[f32]) -> u8 {
157 let n = codebook.len();
159 if n <= 1 {
160 return 0;
161 }
162
163 let mut best_idx = 0u8;
164 let mut best_dist = (value - codebook[0]).abs();
165
166 for (i, &c) in codebook.iter().enumerate().skip(1) {
167 let dist = (value - c).abs();
168 if dist < best_dist {
169 best_dist = dist;
170 best_idx = i as u8;
171 }
172 }
173 best_idx
174}
175
176#[inline]
178fn codebook_for_coord(coord_idx: usize, config: &TurboQuantConfig) -> &'static [f32] {
179 match config.bit_width {
180 BitWidth::Two => &CODEBOOK_2BIT,
181 BitWidth::Three => &CODEBOOK_3BIT,
182 BitWidth::Four => &CODEBOOK_4BIT,
183 BitWidth::TwoPointFive => {
184 let boundary = config.head_dim / 4;
185 if coord_idx < boundary {
186 &CODEBOOK_3BIT } else {
188 &CODEBOOK_2BIT }
190 }
191 }
192}
193
194#[inline]
196fn bits_for_coord(coord_idx: usize, config: &TurboQuantConfig) -> usize {
197 match config.bit_width {
198 BitWidth::Two => 2,
199 BitWidth::Three => 3,
200 BitWidth::Four => 4,
201 BitWidth::TwoPointFive => {
202 if coord_idx < config.head_dim / 4 {
203 3
204 } else {
205 2
206 }
207 }
208 }
209}
210
211fn pack_indices(indices: &[u8], config: &TurboQuantConfig) -> Vec<u8> {
217 let total_bits: usize = (0..indices.len())
218 .map(|i| bits_for_coord(i, config))
219 .sum();
220 let num_bytes = (total_bits + 7) / 8;
221 let mut packed = vec![0u8; num_bytes];
222
223 let mut bit_offset = 0usize;
224 for (i, &idx) in indices.iter().enumerate() {
225 let nbits = bits_for_coord(i, config);
226 for b in (0..nbits).rev() {
228 let bit_val = (idx >> b) & 1;
229 let byte_pos = bit_offset / 8;
230 let bit_pos = 7 - (bit_offset % 8);
231 if byte_pos < packed.len() {
232 packed[byte_pos] |= bit_val << bit_pos;
233 }
234 bit_offset += 1;
235 }
236 }
237
238 packed
239}
240
241fn unpack_indices(packed: &[u8], config: &TurboQuantConfig) -> Vec<u8> {
243 let d = config.head_dim;
244 let mut indices = Vec::with_capacity(d);
245
246 let mut bit_offset = 0usize;
247 for i in 0..d {
248 let nbits = bits_for_coord(i, config);
249 let mut val = 0u8;
250 for _ in 0..nbits {
251 let byte_pos = bit_offset / 8;
252 let bit_pos = 7 - (bit_offset % 8);
253 let bit_val = if byte_pos < packed.len() {
254 (packed[byte_pos] >> bit_pos) & 1
255 } else {
256 0
257 };
258 val = (val << 1) | bit_val;
259 bit_offset += 1;
260 }
261 indices.push(val);
262 }
263
264 indices
265}
266
267pub fn turboquant_quantize(
285 x: &[f32],
286 config: &TurboQuantConfig,
287) -> crate::Result<(Vec<u8>, f32)> {
288 let d = config.head_dim;
289 if x.len() != d {
290 return Err(crate::MlxError::InvalidArgument(format!(
291 "Expected vector of length {d}, got {}",
292 x.len()
293 )));
294 }
295 if !d.is_power_of_two() {
296 return Err(crate::MlxError::InvalidArgument(format!(
297 "head_dim must be power of 2, got {d}"
298 )));
299 }
300
301 let mut rotated = x.to_vec();
303 fwht_inplace(&mut rotated)?;
304
305 let norm_sq: f32 = rotated.iter().map(|&v| v * v).sum();
307 let norm = norm_sq.sqrt();
308
309 if norm < 1e-30 {
310 let indices = vec![0u8; d];
312 let packed = pack_indices(&indices, config);
313 return Ok((packed, 0.0));
314 }
315
316 let inv_norm = 1.0 / norm;
318 for v in rotated.iter_mut() {
319 *v *= inv_norm;
320 }
321
322 let scale = (d as f32).sqrt();
326 let mut indices = Vec::with_capacity(d);
327 for (i, &v) in rotated.iter().enumerate() {
328 let scaled = v * scale;
329 let cb = codebook_for_coord(i, config);
330 indices.push(nearest_centroid(scaled, cb));
331 }
332
333 let packed = pack_indices(&indices, config);
335
336 Ok((packed, norm))
337}
338
339pub fn turboquant_dequantize(
355 packed: &[u8],
356 norm: f32,
357 config: &TurboQuantConfig,
358) -> crate::Result<Vec<f32>> {
359 let d = config.head_dim;
360 if !d.is_power_of_two() {
361 return Err(crate::MlxError::InvalidArgument(format!(
362 "head_dim must be power of 2, got {d}"
363 )));
364 }
365
366 let indices = unpack_indices(packed, config);
368
369 let inv_scale = 1.0 / (d as f32).sqrt();
371 let mut reconstructed = Vec::with_capacity(d);
372 for (i, &idx) in indices.iter().enumerate() {
373 let cb = codebook_for_coord(i, config);
374 let idx_usize = idx as usize;
375 let centroid = if idx_usize < cb.len() {
376 cb[idx_usize]
377 } else {
378 0.0 };
380 reconstructed.push(centroid * inv_scale * norm);
381 }
382
383 fwht_inplace(&mut reconstructed)?;
385
386 Ok(reconstructed)
387}
388
389pub fn compute_lloyd_max_codebook(num_levels: usize) -> Vec<f64> {
396 let mut boundaries = Vec::with_capacity(num_levels + 1);
398 boundaries.push(-10.0_f64); for i in 1..num_levels {
400 let p = i as f64 / num_levels as f64;
401 boundaries.push(quantile_normal(p));
402 }
403 boundaries.push(10.0_f64); let mut centroids = vec![0.0_f64; num_levels];
407 for i in 0..num_levels {
408 let a = boundaries[i];
409 let b = boundaries[i + 1];
410 let prob = std_normal_cdf(b) - std_normal_cdf(a);
411 if prob > 1e-30 {
412 centroids[i] = (std_normal_pdf(a) - std_normal_pdf(b)) / prob;
413 }
414 }
415
416 for _iter in 0..50_000 {
418 let old = centroids.clone();
419
420 boundaries[0] = -10.0;
422 for i in 1..num_levels {
423 boundaries[i] = (centroids[i - 1] + centroids[i]) / 2.0;
424 }
425 *boundaries.last_mut().unwrap_or(&mut 0.0) = 10.0;
426
427 for i in 0..num_levels {
429 let a = boundaries[i];
430 let b = boundaries[i + 1];
431 let prob = std_normal_cdf(b) - std_normal_cdf(a);
432 if prob > 1e-30 {
433 centroids[i] = (std_normal_pdf(a) - std_normal_pdf(b)) / prob;
434 }
435 }
436
437 let max_change = centroids
439 .iter()
440 .zip(old.iter())
441 .map(|(a, b)| (a - b).abs())
442 .fold(0.0_f64, f64::max);
443 if max_change < 1e-12 {
444 break;
445 }
446 }
447
448 centroids
449}
450
451fn quantile_normal(p: f64) -> f64 {
455 if p <= 0.0 {
456 return -10.0;
457 }
458 if p >= 1.0 {
459 return 10.0;
460 }
461
462 const A: [f64; 6] = [
464 -3.969683028665376e1,
465 2.209460984245205e2,
466 -2.759285104469687e2,
467 1.383577518672690e2,
468 -3.066479806614716e1,
469 2.506628277459239e0,
470 ];
471 const B: [f64; 5] = [
472 -5.447609879822406e1,
473 1.615858368580409e2,
474 -1.556989798598866e2,
475 6.680131188771972e1,
476 -1.328068155288572e1,
477 ];
478 const C: [f64; 6] = [
479 -7.784894002430293e-3,
480 -3.223964580411365e-1,
481 -2.400758277161838e0,
482 -2.549732539343734e0,
483 4.374664141464968e0,
484 2.938163982698783e0,
485 ];
486 const D: [f64; 4] = [
487 7.784695709041462e-3,
488 3.224671290700398e-1,
489 2.445134137142996e0,
490 3.754408661907416e0,
491 ];
492
493 const P_LOW: f64 = 0.02425;
494 const P_HIGH: f64 = 1.0 - P_LOW;
495
496 if p < P_LOW {
497 let q = (-2.0 * p.ln()).sqrt();
498 (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
499 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
500 } else if p <= P_HIGH {
501 let q = p - 0.5;
502 let r = q * q;
503 (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
504 / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
505 } else {
506 let q = (-2.0 * (1.0 - p).ln()).sqrt();
507 -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
508 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
509 }
510}
511
512pub fn compute_lloyd_max_beta_codebook(dim: usize, num_levels: usize) -> Vec<f64> {
519 let alpha = (dim as f64 - 1.0) / 2.0;
520
521 let log_norm = log_beta_norm_const(alpha);
527
528 let beta_pdf = |x: f64| -> f64 {
529 if x <= -1.0 || x >= 1.0 {
530 return 0.0;
531 }
532 let val = 1.0 - x * x;
533 if val <= 0.0 {
534 return 0.0;
535 }
536 (log_norm + (alpha - 1.0) * val.ln()).exp()
537 };
538
539 let n_grid = 10_000;
541 let grid_lo = -1.0_f64;
542 let grid_hi = 1.0_f64;
543 let dx = (grid_hi - grid_lo) / n_grid as f64;
544
545 let mut cdf_vals = vec![0.0_f64; n_grid + 1];
547 let mut pdf_vals = vec![0.0_f64; n_grid + 1];
548 for i in 0..=n_grid {
549 let x = grid_lo + i as f64 * dx;
550 pdf_vals[i] = beta_pdf(x);
551 }
552 for i in 1..=n_grid {
553 cdf_vals[i] = cdf_vals[i - 1] + 0.5 * (pdf_vals[i - 1] + pdf_vals[i]) * dx;
554 }
555 let cdf_total = cdf_vals[n_grid];
557 if cdf_total > 1e-30 {
558 for v in cdf_vals.iter_mut() {
559 *v /= cdf_total;
560 }
561 for v in pdf_vals.iter_mut() {
562 *v /= cdf_total;
563 }
564 }
565
566 let interp_cdf = |x: f64| -> f64 {
568 let frac = (x - grid_lo) / dx;
569 let idx = frac as usize;
570 if idx >= n_grid {
571 return 1.0;
572 }
573 let t = frac - idx as f64;
574 cdf_vals[idx] * (1.0 - t) + cdf_vals[idx + 1] * t
575 };
576
577 let conditional_expectation = |a: f64, b: f64| -> f64 {
578 let prob = interp_cdf(b) - interp_cdf(a);
580 if prob < 1e-30 {
581 return (a + b) / 2.0;
582 }
583
584 let n_sub = 500;
585 let sub_dx = (b - a) / n_sub as f64;
586 let mut integral = 0.0_f64;
587 for j in 0..=n_sub {
588 let x = a + j as f64 * sub_dx;
589 let w = if j == 0 || j == n_sub { 0.5 } else { 1.0 };
590 let frac = (x - grid_lo) / dx;
591 let idx = frac as usize;
592 let pdf_val = if idx >= n_grid {
593 0.0
594 } else {
595 let t = frac - idx as f64;
596 pdf_vals[idx] * (1.0 - t) + pdf_vals[idx + 1] * t
597 };
598 integral += w * x * pdf_val * sub_dx;
599 }
600 integral / prob
601 };
602
603 let mut boundaries = Vec::with_capacity(num_levels + 1);
605 boundaries.push(-1.0_f64);
606 for i in 1..num_levels {
607 let target_p = i as f64 / num_levels as f64;
608 let mut lo = -1.0_f64;
610 let mut hi = 1.0_f64;
611 for _ in 0..100 {
612 let mid = (lo + hi) / 2.0;
613 if interp_cdf(mid) < target_p {
614 lo = mid;
615 } else {
616 hi = mid;
617 }
618 }
619 boundaries.push((lo + hi) / 2.0);
620 }
621 boundaries.push(1.0_f64);
622
623 let mut centroids = vec![0.0_f64; num_levels];
625 for i in 0..num_levels {
626 centroids[i] = conditional_expectation(boundaries[i], boundaries[i + 1]);
627 }
628
629 for _iter in 0..5000 {
631 let old = centroids.clone();
632
633 boundaries[0] = -1.0;
635 for i in 1..num_levels {
636 boundaries[i] = (centroids[i - 1] + centroids[i]) / 2.0;
637 }
638 if let Some(last) = boundaries.last_mut() {
639 *last = 1.0;
640 }
641
642 for i in 0..num_levels {
644 centroids[i] = conditional_expectation(boundaries[i], boundaries[i + 1]);
645 }
646
647 let max_change = centroids
648 .iter()
649 .zip(old.iter())
650 .map(|(a, b)| (a - b).abs())
651 .fold(0.0_f64, f64::max);
652 if max_change < 1e-10 {
653 break;
654 }
655 }
656
657 centroids
658}
659
660fn log_beta_norm_const(alpha: f64) -> f64 {
662 ln_gamma(2.0 * alpha) - 2.0 * ln_gamma(alpha) - (2.0 * alpha - 1.0) * 2.0_f64.ln()
667}
668
669fn ln_gamma(x: f64) -> f64 {
671 const G: f64 = 7.0;
673 const COEFF: [f64; 9] = [
674 0.999_999_999_999_809_93,
675 676.520_368_121_885_1,
676 -1_259.139_216_722_402_9,
677 771.323_428_777_653_1,
678 -176.615_029_162_140_6,
679 12.507_343_278_686_905,
680 -0.138_571_095_265_720_12,
681 9.984_369_578_019_571_6e-6,
682 1.505_632_735_149_311_6e-7,
683 ];
684
685 if x < 0.5 {
686 let pi = std::f64::consts::PI;
688 return pi.ln() - (pi * x).sin().ln() - ln_gamma(1.0 - x);
689 }
690
691 let x = x - 1.0;
692 let mut ag = COEFF[0];
693 for i in 1..9 {
694 ag += COEFF[i] / (x + i as f64);
695 }
696
697 let tmp = x + G + 0.5;
698 0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * tmp.ln() - tmp + ag.ln()
699}
700
701#[cfg(test)]
702#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
703mod tests {
704 use super::*;
705
706 #[test]
707 fn test_codebook_symmetry() {
708 for (name, cb) in [
709 ("2-bit", &CODEBOOK_2BIT[..]),
710 ("3-bit", &CODEBOOK_3BIT[..]),
711 ("4-bit", &CODEBOOK_4BIT[..]),
712 ] {
713 let n = cb.len();
714 for i in 0..n / 2 {
715 let sum = cb[i] + cb[n - 1 - i];
716 assert!(
717 sum.abs() < 1e-5,
718 "{name} codebook not symmetric: c[{i}]={} + c[{}]={} = {sum}",
719 cb[i],
720 n - 1 - i,
721 cb[n - 1 - i]
722 );
723 }
724 }
725 }
726
727 #[test]
728 fn test_codebook_values_match_lloyd_max() {
729 for (bits, hardcoded) in [
730 (2, &CODEBOOK_2BIT[..]),
731 (3, &CODEBOOK_3BIT[..]),
732 (4, &CODEBOOK_4BIT[..]),
733 ] {
734 let computed = compute_lloyd_max_codebook(1 << bits);
735 assert_eq!(computed.len(), hardcoded.len());
736 for (i, (&h, &c)) in hardcoded.iter().zip(computed.iter()).enumerate() {
737 let diff = (h as f64 - c).abs();
738 assert!(
739 diff < 1e-4,
740 "{bits}-bit codebook mismatch at {i}: hardcoded={h}, computed={c}, diff={diff}"
741 );
742 }
743 }
744 }
745
746 #[test]
747 fn test_fwht_roundtrip() {
748 let original: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1 - 6.4).collect();
749 let mut data = original.clone();
750 fwht_inplace(&mut data).unwrap();
751 fwht_inplace(&mut data).unwrap();
752 for (i, (&a, &b)) in original.iter().zip(data.iter()).enumerate() {
753 assert!(
754 (a - b).abs() < 1e-4,
755 "FWHT roundtrip mismatch at {i}: {a} vs {b}"
756 );
757 }
758 }
759}