1pub const CODEBOOK_2BIT: [f32; 4] = [
17 -1.5104176, -0.4527800, 0.4527800, 1.5104176,
18];
19
20pub const CODEBOOK_3BIT: [f32; 8] = [
22 -2.1519457, -1.3439093, -0.7560053, -0.2450942,
23 0.2450942, 0.7560053, 1.3439093, 2.1519457,
24];
25
26pub const CODEBOOK_4BIT: [f32; 16] = [
28 -2.7325896, -2.0690172, -1.6180464, -1.2562312,
29 -0.9423405, -0.6567591, -0.3880483, -0.1283950,
30 0.1283950, 0.3880483, 0.6567591, 0.9423405,
31 1.2562312, 1.6180464, 2.0690172, 2.7325896,
32];
33
34pub const CODEBOOK_HB_5BIT: [f32; 32] = [
47 -3.2606790, -2.6910589, -2.3176743, -2.0286608,
48 -1.7871646, -1.5761599, -1.3862739, -1.2117410,
49 -1.0487242, -0.8945114, -0.7470884, -0.6048936,
50 -0.4666676, -0.3313550, -0.1980377, -0.0658849,
51 0.0658849, 0.1980377, 0.3313550, 0.4666676,
52 0.6048936, 0.7470884, 0.8945114, 1.0487242,
53 1.2117410, 1.3862739, 1.5761599, 1.7871646,
54 2.0286608, 2.3176743, 2.6910589, 3.2606790,
55];
56
57pub const CODEBOOK_HB_6BIT: [f32; 64] = [
60 -3.6996161, -3.1907215, -2.8640626, -2.6161277,
61 -2.4129324, -2.2388464, -2.0853192, -1.9471373,
62 -1.8208742, -1.7041502, -1.5952401, -1.4928497,
63 -1.3959804, -1.3038428, -1.2157998, -1.1313277,
64 -1.0499889, -0.9714118, -0.8952766, -0.8213046,
65 -0.7492492, -0.6788902, -0.6100285, -0.5424819,
66 -0.4760822, -0.4106724, -0.3461048, -0.2822386,
67 -0.2189392, -0.1560761, -0.0935225, -0.0311537,
68 0.0311537, 0.0935225, 0.1560761, 0.2189392,
69 0.2822386, 0.3461048, 0.4106724, 0.4760822,
70 0.5424819, 0.6100285, 0.6788902, 0.7492492,
71 0.8213046, 0.8952766, 0.9714118, 1.0499889,
72 1.1313277, 1.2157998, 1.3038428, 1.3959804,
73 1.4928497, 1.5952401, 1.7041502, 1.8208742,
74 1.9471373, 2.0853192, 2.2388464, 2.4129324,
75 2.6161277, 2.8640626, 3.1907215, 3.6996161,
76];
77
78pub const CODEBOOK_HB_8BIT: [f32; 256] = [
83 -5.0652659, -4.6836997, -4.4467193, -4.2715508,
84 -4.1311907, -4.0132856, -3.9111092, -3.8205780,
85 -3.7390194, -3.6645851, -3.5959415, -3.5320936,
86 -3.4722785, -3.4158977, -3.3624729, -3.3116156,
87 -3.2630056, -3.2163758, -3.1715011, -3.1281899,
88 -3.0862780, -3.0456229, -3.0061011, -2.9676040,
89 -2.9300362, -2.8933131, -2.8573596, -2.8221086,
90 -2.7874999, -2.7534795, -2.7199985, -2.6870129,
91 -2.6544825, -2.6223710, -2.5906452, -2.5592748,
92 -2.5282321, -2.4974918, -2.4670306, -2.4368270,
93 -2.4068614, -2.3771157, -2.3475732, -2.3182184,
94 -2.2890372, -2.2600165, -2.2311440, -2.2024086,
95 -2.1737998, -2.1453081, -2.1169245, -2.0886408,
96 -2.0604493, -2.0323430, -2.0043154, -1.9763603,
97 -1.9484722, -1.9206458, -1.8928763, -1.8651592,
98 -1.8374904, -1.8098662, -1.7822828, -1.7547372,
99 -1.7272261, -1.6997469, -1.6722970, -1.6448739,
100 -1.6174755, -1.5900996, -1.5627445, -1.5354084,
101 -1.5080897, -1.4807869, -1.4534986, -1.4262237,
102 -1.3989610, -1.3717093, -1.3444678, -1.3172356,
103 -1.2900118, -1.2627956, -1.2355865, -1.2083838,
104 -1.1811868, -1.1539951, -1.1268081, -1.0996255,
105 -1.0724469, -1.0452718, -1.0180999, -0.9909310,
106 -0.9637647, -0.9366008, -0.9094390, -0.8822793,
107 -0.8551212, -0.8279648, -0.8008098, -0.7736561,
108 -0.7465035, -0.7193520, -0.6922014, -0.6650517,
109 -0.6379027, -0.6107544, -0.5836067, -0.5564596,
110 -0.5293129, -0.5021667, -0.4750208, -0.4478753,
111 -0.4207301, -0.3935852, -0.3664405, -0.3392960,
112 -0.3121517, -0.2850076, -0.2578636, -0.2307198,
113 -0.2035761, -0.1764324, -0.1492888, -0.1221453,
114 -0.0950019, -0.0678584, -0.0407151, -0.0135717,
115 0.0135717, 0.0407151, 0.0678584, 0.0950019,
116 0.1221453, 0.1492888, 0.1764324, 0.2035761,
117 0.2307198, 0.2578636, 0.2850076, 0.3121517,
118 0.3392960, 0.3664405, 0.3935852, 0.4207301,
119 0.4478753, 0.4750208, 0.5021667, 0.5293129,
120 0.5564596, 0.5836067, 0.6107544, 0.6379027,
121 0.6650517, 0.6922014, 0.7193520, 0.7465035,
122 0.7736561, 0.8008098, 0.8279648, 0.8551212,
123 0.8822793, 0.9094390, 0.9366008, 0.9637647,
124 0.9909310, 1.0180999, 1.0452718, 1.0724469,
125 1.0996255, 1.1268081, 1.1539951, 1.1811868,
126 1.2083838, 1.2355865, 1.2627956, 1.2900118,
127 1.3172356, 1.3444678, 1.3717093, 1.3989610,
128 1.4262237, 1.4534986, 1.4807869, 1.5080897,
129 1.5354084, 1.5627445, 1.5900996, 1.6174755,
130 1.6448739, 1.6722970, 1.6997469, 1.7272261,
131 1.7547372, 1.7822828, 1.8098662, 1.8374904,
132 1.8651592, 1.8928763, 1.9206458, 1.9484722,
133 1.9763603, 2.0043154, 2.0323430, 2.0604493,
134 2.0886408, 2.1169245, 2.1453081, 2.1737998,
135 2.2024086, 2.2311440, 2.2600165, 2.2890372,
136 2.3182184, 2.3475732, 2.3771157, 2.4068614,
137 2.4368270, 2.4670306, 2.4974918, 2.5282321,
138 2.5592748, 2.5906452, 2.6223710, 2.6544825,
139 2.6870129, 2.7199985, 2.7534795, 2.7874999,
140 2.8221086, 2.8573596, 2.8933131, 2.9300362,
141 2.9676040, 3.0061011, 3.0456229, 3.0862780,
142 3.1281899, 3.1715011, 3.2163758, 3.2630056,
143 3.3116156, 3.3624729, 3.4158977, 3.4722785,
144 3.5320936, 3.5959415, 3.6645851, 3.7390194,
145 3.8205780, 3.9111092, 4.0132856, 4.1311907,
146 4.2715508, 4.4467193, 4.6836997, 5.0652659,
147];
148
149#[inline]
158pub fn hb_centroid(idx: u8, bits: u32) -> f32 {
159 match bits {
160 5 => CODEBOOK_HB_5BIT[(idx & 0x1F) as usize],
161 6 => CODEBOOK_HB_6BIT[(idx & 0x3F) as usize],
162 8 => CODEBOOK_HB_8BIT[idx as usize],
163 _ => 0.0,
164 }
165}
166
167pub const TBQ_SIGNS_256: [u8; 32] = [
176 0xa7, 0x3b, 0x91, 0xf4, 0x6d, 0xc2, 0x58, 0x0e,
177 0xb3, 0x7f, 0x24, 0xd6, 0x89, 0x45, 0xea, 0x1c,
178 0x63, 0xaf, 0xd8, 0x52, 0x97, 0x0b, 0xe1, 0x3d,
179 0x76, 0xc4, 0x19, 0xfe, 0x4a, 0x85, 0x2c, 0xdb,
180];
181
182pub const TBQ_SIGNS_512: [u8; 64] = [
188 0xa7, 0x3b, 0x91, 0xf4, 0x6d, 0xc2, 0x58, 0x0e,
189 0xb3, 0x7f, 0x24, 0xd6, 0x89, 0x45, 0xea, 0x1c,
190 0x63, 0xaf, 0xd8, 0x52, 0x97, 0x0b, 0xe1, 0x3d,
191 0x76, 0xc4, 0x19, 0xfe, 0x4a, 0x85, 0x2c, 0xdb,
192 0xd3, 0x4e, 0xa8, 0x17, 0x9c, 0x5b, 0xe6, 0x31,
193 0x72, 0xb9, 0x0d, 0xf5, 0x43, 0x8a, 0x6e, 0xc7,
194 0x58, 0x2f, 0x94, 0xe1, 0xb6, 0x3d, 0x0a, 0x7c,
195 0xc5, 0x61, 0xd8, 0x4f, 0xa3, 0x97, 0x1e, 0x85,
196];
197
198#[inline]
203pub fn apply_d1_sign_mask_inplace(x: &mut [f32], signs: &[u8]) {
204 for j in 0..x.len() {
205 let byte = signs[j >> 3];
206 let bit = (byte >> (j & 7)) & 1;
207 if bit == 1 {
208 x[j] = -x[j];
209 }
210 }
211}
212
213pub fn turboquant_hb_encode_d256(x: &[f32], bits: u32) -> Result<(Vec<u8>, f32), crate::MlxError> {
232 if x.len() != 256 {
233 return Err(crate::MlxError::InvalidArgument(format!(
234 "turboquant_hb_encode_d256 expects head_dim=256, got {}",
235 x.len()
236 )));
237 }
238 if !matches!(bits, 5 | 6 | 8) {
239 return Err(crate::MlxError::InvalidArgument(format!(
240 "turboquant_hb_encode_d256 bits must be 5, 6, or 8, got {bits}"
241 )));
242 }
243
244 let mut elems = x.to_vec();
246 apply_d1_sign_mask_inplace(&mut elems, &TBQ_SIGNS_256);
247
248 fwht_inplace(&mut elems)?;
250
251 let norm_sq: f32 = elems.iter().map(|&v| v * v).sum();
253 let norm = norm_sq.sqrt();
254
255 let scale: f32 = if norm > 1.0e-10_f32 {
260 (1.0_f32 / norm) * (256.0_f32).sqrt()
261 } else {
262 0.0_f32
263 };
264 for v in elems.iter_mut() {
265 *v *= scale;
266 }
267
268 let mut packed = Vec::with_capacity(256);
270 for &v in elems.iter() {
271 packed.push(hb_nearest_centroid(v, bits));
272 }
273
274 Ok((packed, norm))
275}
276
277pub fn hb_nearest_centroid(value: f32, bits: u32) -> u8 {
287 let cb: &[f32] = match bits {
288 5 => &CODEBOOK_HB_5BIT,
289 6 => &CODEBOOK_HB_6BIT,
290 8 => &CODEBOOK_HB_8BIT,
291 _ => return 0u8,
292 };
293 let mut best_idx: u32 = 0;
294 let mut best_dist: f32 = (value - cb[0]).abs();
295 for (i, &c) in cb.iter().enumerate().skip(1) {
296 let dist = (value - c).abs();
297 if dist < best_dist {
298 best_dist = dist;
299 best_idx = i as u32;
300 }
301 }
302 best_idx as u8
303}
304
305#[derive(Debug, Clone, Copy, PartialEq, Eq)]
309pub enum BitWidth {
310 Two,
312 Three,
314 Four,
316 TwoPointFive,
318}
319
320#[derive(Debug, Clone)]
322pub struct TurboQuantConfig {
323 pub bit_width: BitWidth,
325 pub head_dim: usize,
327}
328
329pub fn fwht_inplace(x: &mut [f32]) -> crate::Result<()> {
342 let n = x.len();
343 if n == 0 || !n.is_power_of_two() {
344 return Err(crate::MlxError::InvalidArgument(format!(
345 "FWHT requires power-of-two length, got {n}"
346 )));
347 }
348
349 let mut h = 1;
350 while h < n {
351 let step = h * 2;
352 let mut i = 0;
353 while i < n {
354 for j in i..i + h {
355 let a = x[j];
356 let b = x[j + h];
357 x[j] = a + b;
358 x[j + h] = a - b;
359 }
360 i += step;
361 }
362 h *= 2;
363 }
364
365 let scale = 1.0 / (n as f32).sqrt();
367 for v in x.iter_mut() {
368 *v *= scale;
369 }
370
371 Ok(())
372}
373
374#[inline]
378fn std_normal_pdf(x: f64) -> f64 {
379 const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7; INV_SQRT_2PI * (-0.5 * x * x).exp()
381}
382
383#[inline]
386fn std_normal_cdf(x: f64) -> f64 {
387 if x < -8.0 {
388 return 0.0;
389 }
390 if x > 8.0 {
391 return 1.0;
392 }
393
394 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
395 let x_abs = x.abs();
396
397 const P: f64 = 0.231_641_9;
399 const B1: f64 = 0.319_381_530;
400 const B2: f64 = -0.356_563_782;
401 const B3: f64 = 1.781_477_937;
402 const B4: f64 = -1.821_255_978;
403 const B5: f64 = 1.330_274_429;
404
405 let t = 1.0 / (1.0 + P * x_abs);
406 let t2 = t * t;
407 let t3 = t2 * t;
408 let t4 = t3 * t;
409 let t5 = t4 * t;
410
411 let poly = B1 * t + B2 * t2 + B3 * t3 + B4 * t4 + B5 * t5;
412 let phi = std_normal_pdf(x_abs);
413
414 let result = 1.0 - phi * poly;
415
416 if sign < 0.0 {
417 1.0 - result
418 } else {
419 result
420 }
421}
422
423#[inline]
427fn nearest_centroid(value: f32, codebook: &[f32]) -> u8 {
428 let n = codebook.len();
430 if n <= 1 {
431 return 0;
432 }
433
434 let mut best_idx = 0u8;
435 let mut best_dist = (value - codebook[0]).abs();
436
437 for (i, &c) in codebook.iter().enumerate().skip(1) {
438 let dist = (value - c).abs();
439 if dist < best_dist {
440 best_dist = dist;
441 best_idx = i as u8;
442 }
443 }
444 best_idx
445}
446
447#[inline]
449fn codebook_for_coord(coord_idx: usize, config: &TurboQuantConfig) -> &'static [f32] {
450 match config.bit_width {
451 BitWidth::Two => &CODEBOOK_2BIT,
452 BitWidth::Three => &CODEBOOK_3BIT,
453 BitWidth::Four => &CODEBOOK_4BIT,
454 BitWidth::TwoPointFive => {
455 let boundary = config.head_dim / 4;
456 if coord_idx < boundary {
457 &CODEBOOK_3BIT } else {
459 &CODEBOOK_2BIT }
461 }
462 }
463}
464
465#[inline]
467fn bits_for_coord(coord_idx: usize, config: &TurboQuantConfig) -> usize {
468 match config.bit_width {
469 BitWidth::Two => 2,
470 BitWidth::Three => 3,
471 BitWidth::Four => 4,
472 BitWidth::TwoPointFive => {
473 if coord_idx < config.head_dim / 4 {
474 3
475 } else {
476 2
477 }
478 }
479 }
480}
481
482fn pack_indices(indices: &[u8], config: &TurboQuantConfig) -> Vec<u8> {
488 let total_bits: usize = (0..indices.len())
489 .map(|i| bits_for_coord(i, config))
490 .sum();
491 let num_bytes = (total_bits + 7) / 8;
492 let mut packed = vec![0u8; num_bytes];
493
494 let mut bit_offset = 0usize;
495 for (i, &idx) in indices.iter().enumerate() {
496 let nbits = bits_for_coord(i, config);
497 for b in (0..nbits).rev() {
499 let bit_val = (idx >> b) & 1;
500 let byte_pos = bit_offset / 8;
501 let bit_pos = 7 - (bit_offset % 8);
502 if byte_pos < packed.len() {
503 packed[byte_pos] |= bit_val << bit_pos;
504 }
505 bit_offset += 1;
506 }
507 }
508
509 packed
510}
511
512fn unpack_indices(packed: &[u8], config: &TurboQuantConfig) -> Vec<u8> {
514 let d = config.head_dim;
515 let mut indices = Vec::with_capacity(d);
516
517 let mut bit_offset = 0usize;
518 for i in 0..d {
519 let nbits = bits_for_coord(i, config);
520 let mut val = 0u8;
521 for _ in 0..nbits {
522 let byte_pos = bit_offset / 8;
523 let bit_pos = 7 - (bit_offset % 8);
524 let bit_val = if byte_pos < packed.len() {
525 (packed[byte_pos] >> bit_pos) & 1
526 } else {
527 0
528 };
529 val = (val << 1) | bit_val;
530 bit_offset += 1;
531 }
532 indices.push(val);
533 }
534
535 indices
536}
537
538pub fn turboquant_quantize(
556 x: &[f32],
557 config: &TurboQuantConfig,
558) -> crate::Result<(Vec<u8>, f32)> {
559 let d = config.head_dim;
560 if x.len() != d {
561 return Err(crate::MlxError::InvalidArgument(format!(
562 "Expected vector of length {d}, got {}",
563 x.len()
564 )));
565 }
566 if !d.is_power_of_two() {
567 return Err(crate::MlxError::InvalidArgument(format!(
568 "head_dim must be power of 2, got {d}"
569 )));
570 }
571
572 let mut rotated = x.to_vec();
574 fwht_inplace(&mut rotated)?;
575
576 let norm_sq: f32 = rotated.iter().map(|&v| v * v).sum();
578 let norm = norm_sq.sqrt();
579
580 if norm < 1e-30 {
581 let indices = vec![0u8; d];
583 let packed = pack_indices(&indices, config);
584 return Ok((packed, 0.0));
585 }
586
587 let inv_norm = 1.0 / norm;
589 for v in rotated.iter_mut() {
590 *v *= inv_norm;
591 }
592
593 let scale = (d as f32).sqrt();
597 let mut indices = Vec::with_capacity(d);
598 for (i, &v) in rotated.iter().enumerate() {
599 let scaled = v * scale;
600 let cb = codebook_for_coord(i, config);
601 indices.push(nearest_centroid(scaled, cb));
602 }
603
604 let packed = pack_indices(&indices, config);
606
607 Ok((packed, norm))
608}
609
610pub fn turboquant_dequantize(
626 packed: &[u8],
627 norm: f32,
628 config: &TurboQuantConfig,
629) -> crate::Result<Vec<f32>> {
630 let d = config.head_dim;
631 if !d.is_power_of_two() {
632 return Err(crate::MlxError::InvalidArgument(format!(
633 "head_dim must be power of 2, got {d}"
634 )));
635 }
636
637 let indices = unpack_indices(packed, config);
639
640 let inv_scale = 1.0 / (d as f32).sqrt();
642 let mut reconstructed = Vec::with_capacity(d);
643 for (i, &idx) in indices.iter().enumerate() {
644 let cb = codebook_for_coord(i, config);
645 let idx_usize = idx as usize;
646 let centroid = if idx_usize < cb.len() {
647 cb[idx_usize]
648 } else {
649 0.0 };
651 reconstructed.push(centroid * inv_scale * norm);
652 }
653
654 fwht_inplace(&mut reconstructed)?;
656
657 Ok(reconstructed)
658}
659
660pub fn compute_lloyd_max_codebook(num_levels: usize) -> Vec<f64> {
667 let mut boundaries = Vec::with_capacity(num_levels + 1);
669 boundaries.push(-10.0_f64); for i in 1..num_levels {
671 let p = i as f64 / num_levels as f64;
672 boundaries.push(quantile_normal(p));
673 }
674 boundaries.push(10.0_f64); let mut centroids = vec![0.0_f64; num_levels];
678 for i in 0..num_levels {
679 let a = boundaries[i];
680 let b = boundaries[i + 1];
681 let prob = std_normal_cdf(b) - std_normal_cdf(a);
682 if prob > 1e-30 {
683 centroids[i] = (std_normal_pdf(a) - std_normal_pdf(b)) / prob;
684 }
685 }
686
687 for _iter in 0..50_000 {
689 let old = centroids.clone();
690
691 boundaries[0] = -10.0;
693 for i in 1..num_levels {
694 boundaries[i] = (centroids[i - 1] + centroids[i]) / 2.0;
695 }
696 *boundaries.last_mut().unwrap_or(&mut 0.0) = 10.0;
697
698 for i in 0..num_levels {
700 let a = boundaries[i];
701 let b = boundaries[i + 1];
702 let prob = std_normal_cdf(b) - std_normal_cdf(a);
703 if prob > 1e-30 {
704 centroids[i] = (std_normal_pdf(a) - std_normal_pdf(b)) / prob;
705 }
706 }
707
708 let max_change = centroids
710 .iter()
711 .zip(old.iter())
712 .map(|(a, b)| (a - b).abs())
713 .fold(0.0_f64, f64::max);
714 if max_change < 1e-12 {
715 break;
716 }
717 }
718
719 centroids
720}
721
722fn quantile_normal(p: f64) -> f64 {
726 if p <= 0.0 {
727 return -10.0;
728 }
729 if p >= 1.0 {
730 return 10.0;
731 }
732
733 const A: [f64; 6] = [
735 -3.969683028665376e1,
736 2.209460984245205e2,
737 -2.759285104469687e2,
738 1.383577518672690e2,
739 -3.066479806614716e1,
740 2.506628277459239e0,
741 ];
742 const B: [f64; 5] = [
743 -5.447609879822406e1,
744 1.615858368580409e2,
745 -1.556989798598866e2,
746 6.680131188771972e1,
747 -1.328068155288572e1,
748 ];
749 const C: [f64; 6] = [
750 -7.784894002430293e-3,
751 -3.223964580411365e-1,
752 -2.400758277161838e0,
753 -2.549732539343734e0,
754 4.374664141464968e0,
755 2.938163982698783e0,
756 ];
757 const D: [f64; 4] = [
758 7.784695709041462e-3,
759 3.224671290700398e-1,
760 2.445134137142996e0,
761 3.754408661907416e0,
762 ];
763
764 const P_LOW: f64 = 0.02425;
765 const P_HIGH: f64 = 1.0 - P_LOW;
766
767 if p < P_LOW {
768 let q = (-2.0 * p.ln()).sqrt();
769 (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
770 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
771 } else if p <= P_HIGH {
772 let q = p - 0.5;
773 let r = q * q;
774 (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
775 / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
776 } else {
777 let q = (-2.0 * (1.0 - p).ln()).sqrt();
778 -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
779 / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
780 }
781}
782
783pub fn compute_lloyd_max_beta_codebook(dim: usize, num_levels: usize) -> Vec<f64> {
790 let alpha = (dim as f64 - 1.0) / 2.0;
791
792 let log_norm = log_beta_norm_const(alpha);
798
799 let beta_pdf = |x: f64| -> f64 {
800 if x <= -1.0 || x >= 1.0 {
801 return 0.0;
802 }
803 let val = 1.0 - x * x;
804 if val <= 0.0 {
805 return 0.0;
806 }
807 (log_norm + (alpha - 1.0) * val.ln()).exp()
808 };
809
810 let n_grid = 10_000;
812 let grid_lo = -1.0_f64;
813 let grid_hi = 1.0_f64;
814 let dx = (grid_hi - grid_lo) / n_grid as f64;
815
816 let mut cdf_vals = vec![0.0_f64; n_grid + 1];
818 let mut pdf_vals = vec![0.0_f64; n_grid + 1];
819 for i in 0..=n_grid {
820 let x = grid_lo + i as f64 * dx;
821 pdf_vals[i] = beta_pdf(x);
822 }
823 for i in 1..=n_grid {
824 cdf_vals[i] = cdf_vals[i - 1] + 0.5 * (pdf_vals[i - 1] + pdf_vals[i]) * dx;
825 }
826 let cdf_total = cdf_vals[n_grid];
828 if cdf_total > 1e-30 {
829 for v in cdf_vals.iter_mut() {
830 *v /= cdf_total;
831 }
832 for v in pdf_vals.iter_mut() {
833 *v /= cdf_total;
834 }
835 }
836
837 let interp_cdf = |x: f64| -> f64 {
839 let frac = (x - grid_lo) / dx;
840 let idx = frac as usize;
841 if idx >= n_grid {
842 return 1.0;
843 }
844 let t = frac - idx as f64;
845 cdf_vals[idx] * (1.0 - t) + cdf_vals[idx + 1] * t
846 };
847
848 let conditional_expectation = |a: f64, b: f64| -> f64 {
849 let prob = interp_cdf(b) - interp_cdf(a);
851 if prob < 1e-30 {
852 return (a + b) / 2.0;
853 }
854
855 let n_sub = 500;
856 let sub_dx = (b - a) / n_sub as f64;
857 let mut integral = 0.0_f64;
858 for j in 0..=n_sub {
859 let x = a + j as f64 * sub_dx;
860 let w = if j == 0 || j == n_sub { 0.5 } else { 1.0 };
861 let frac = (x - grid_lo) / dx;
862 let idx = frac as usize;
863 let pdf_val = if idx >= n_grid {
864 0.0
865 } else {
866 let t = frac - idx as f64;
867 pdf_vals[idx] * (1.0 - t) + pdf_vals[idx + 1] * t
868 };
869 integral += w * x * pdf_val * sub_dx;
870 }
871 integral / prob
872 };
873
874 let mut boundaries = Vec::with_capacity(num_levels + 1);
876 boundaries.push(-1.0_f64);
877 for i in 1..num_levels {
878 let target_p = i as f64 / num_levels as f64;
879 let mut lo = -1.0_f64;
881 let mut hi = 1.0_f64;
882 for _ in 0..100 {
883 let mid = (lo + hi) / 2.0;
884 if interp_cdf(mid) < target_p {
885 lo = mid;
886 } else {
887 hi = mid;
888 }
889 }
890 boundaries.push((lo + hi) / 2.0);
891 }
892 boundaries.push(1.0_f64);
893
894 let mut centroids = vec![0.0_f64; num_levels];
896 for i in 0..num_levels {
897 centroids[i] = conditional_expectation(boundaries[i], boundaries[i + 1]);
898 }
899
900 for _iter in 0..5000 {
902 let old = centroids.clone();
903
904 boundaries[0] = -1.0;
906 for i in 1..num_levels {
907 boundaries[i] = (centroids[i - 1] + centroids[i]) / 2.0;
908 }
909 if let Some(last) = boundaries.last_mut() {
910 *last = 1.0;
911 }
912
913 for i in 0..num_levels {
915 centroids[i] = conditional_expectation(boundaries[i], boundaries[i + 1]);
916 }
917
918 let max_change = centroids
919 .iter()
920 .zip(old.iter())
921 .map(|(a, b)| (a - b).abs())
922 .fold(0.0_f64, f64::max);
923 if max_change < 1e-10 {
924 break;
925 }
926 }
927
928 centroids
929}
930
931fn log_beta_norm_const(alpha: f64) -> f64 {
933 ln_gamma(2.0 * alpha) - 2.0 * ln_gamma(alpha) - (2.0 * alpha - 1.0) * 2.0_f64.ln()
938}
939
940fn ln_gamma(x: f64) -> f64 {
942 const G: f64 = 7.0;
944 const COEFF: [f64; 9] = [
945 0.999_999_999_999_809_93,
946 676.520_368_121_885_1,
947 -1_259.139_216_722_402_9,
948 771.323_428_777_653_1,
949 -176.615_029_162_140_6,
950 12.507_343_278_686_905,
951 -0.138_571_095_265_720_12,
952 9.984_369_578_019_571_6e-6,
953 1.505_632_735_149_311_6e-7,
954 ];
955
956 if x < 0.5 {
957 let pi = std::f64::consts::PI;
959 return pi.ln() - (pi * x).sin().ln() - ln_gamma(1.0 - x);
960 }
961
962 let x = x - 1.0;
963 let mut ag = COEFF[0];
964 for i in 1..9 {
965 ag += COEFF[i] / (x + i as f64);
966 }
967
968 let tmp = x + G + 0.5;
969 0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * tmp.ln() - tmp + ag.ln()
970}
971
972#[cfg(test)]
973#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
974mod tests {
975 use super::*;
976
977 #[test]
978 fn test_codebook_symmetry() {
979 for (name, cb) in [
980 ("2-bit", &CODEBOOK_2BIT[..]),
981 ("3-bit", &CODEBOOK_3BIT[..]),
982 ("4-bit", &CODEBOOK_4BIT[..]),
983 ] {
984 let n = cb.len();
985 for i in 0..n / 2 {
986 let sum = cb[i] + cb[n - 1 - i];
987 assert!(
988 sum.abs() < 1e-5,
989 "{name} codebook not symmetric: c[{i}]={} + c[{}]={} = {sum}",
990 cb[i],
991 n - 1 - i,
992 cb[n - 1 - i]
993 );
994 }
995 }
996 }
997
998 #[test]
999 fn test_codebook_values_match_lloyd_max() {
1000 for (bits, hardcoded) in [
1001 (2, &CODEBOOK_2BIT[..]),
1002 (3, &CODEBOOK_3BIT[..]),
1003 (4, &CODEBOOK_4BIT[..]),
1004 ] {
1005 let computed = compute_lloyd_max_codebook(1 << bits);
1006 assert_eq!(computed.len(), hardcoded.len());
1007 for (i, (&h, &c)) in hardcoded.iter().zip(computed.iter()).enumerate() {
1008 let diff = (h as f64 - c).abs();
1009 assert!(
1010 diff < 1e-4,
1011 "{bits}-bit codebook mismatch at {i}: hardcoded={h}, computed={c}, diff={diff}"
1012 );
1013 }
1014 }
1015 }
1016
1017 #[test]
1018 fn test_fwht_roundtrip() {
1019 let original: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1 - 6.4).collect();
1020 let mut data = original.clone();
1021 fwht_inplace(&mut data).unwrap();
1022 fwht_inplace(&mut data).unwrap();
1023 for (i, (&a, &b)) in original.iter().zip(data.iter()).enumerate() {
1024 assert!(
1025 (a - b).abs() < 1e-4,
1026 "FWHT roundtrip mismatch at {i}: {a} vs {b}"
1027 );
1028 }
1029 }
1030
1031 fn deterministic_gaussian_test(seed: u64, n: usize) -> Vec<f32> {
1034 let mut state = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
1035 let next_u32 = |s: &mut u64| -> u32 {
1036 *s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
1037 (*s >> 32) as u32
1038 };
1039 let next_f32 = |s: &mut u64| -> f32 {
1040 let bits = next_u32(s);
1041 ((bits as f64 + 0.5) / (u32::MAX as f64 + 1.0)) as f32
1042 };
1043 let mut out = Vec::with_capacity(n);
1044 while out.len() < n {
1045 let u1 = next_f32(&mut state).max(1e-7).min(1.0 - 1e-7);
1046 let u2 = next_f32(&mut state);
1047 let r = (-2.0_f32 * u1.ln()).sqrt();
1048 let theta = 2.0_f32 * std::f32::consts::PI * u2;
1049 out.push(r * theta.cos());
1050 if out.len() < n {
1051 out.push(r * theta.sin());
1052 }
1053 }
1054 out
1055 }
1056
1057 fn decode_d256_via_kernel_formula(packed: &[u8], norm: f32, bits: u32) -> Vec<f32> {
1061 let inv_sqrt_dk = 1.0_f32 / (256.0_f32).sqrt();
1063 let mut decoded: Vec<f32> = packed.iter()
1064 .map(|&idx| hb_centroid(idx, bits) * norm * inv_sqrt_dk)
1065 .collect();
1066 fwht_inplace(&mut decoded).expect("fwht ok");
1068 apply_d1_sign_mask_inplace(&mut decoded, &TBQ_SIGNS_256);
1070 decoded
1071 }
1072
1073 fn nrmse(a: &[f32], b: &[f32]) -> f32 {
1074 let mut sse: f64 = 0.0;
1075 let mut sse_a: f64 = 0.0;
1076 for (&av, &bv) in a.iter().zip(b.iter()) {
1077 let d = (av - bv) as f64;
1078 sse += d * d;
1079 sse_a += (av as f64) * (av as f64);
1080 }
1081 if sse_a < 1e-30 {
1082 return 0.0;
1083 }
1084 (sse / sse_a).sqrt() as f32
1085 }
1086
1087 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1088 let mut dot: f64 = 0.0;
1089 let mut na: f64 = 0.0;
1090 let mut nb: f64 = 0.0;
1091 for (&av, &bv) in a.iter().zip(b.iter()) {
1092 dot += (av as f64) * (bv as f64);
1093 na += (av as f64) * (av as f64);
1094 nb += (bv as f64) * (bv as f64);
1095 }
1096 if na < 1e-30 || nb < 1e-30 {
1097 return 1.0;
1098 }
1099 (dot / (na.sqrt() * nb.sqrt())) as f32
1100 }
1101
1102 #[test]
1106 fn hb_encoder_d256_roundtrip_8bit_meets_gate_a() {
1107 let x = deterministic_gaussian_test(0xC25EED, 256);
1112 let (packed, norm) = turboquant_hb_encode_d256(&x, 8).expect("encode");
1113 let recon = decode_d256_via_kernel_formula(&packed, norm, 8);
1114 let cos = cosine_similarity(&x, &recon);
1115 let nrmse_v = nrmse(&x, &recon);
1116 assert!(cos >= 0.998, "8-bit roundtrip cosine {cos} < 0.998");
1117 assert!(nrmse_v <= 0.07, "8-bit roundtrip NRMSE {nrmse_v} > 0.07");
1118 }
1119
1120 #[test]
1121 fn hb_encoder_d256_roundtrip_5bit_within_band() {
1122 let x = deterministic_gaussian_test(0xC25EED, 256);
1124 let (packed, norm) = turboquant_hb_encode_d256(&x, 5).expect("encode");
1125 let recon = decode_d256_via_kernel_formula(&packed, norm, 5);
1126 let cos = cosine_similarity(&x, &recon);
1127 assert!(cos >= 0.985, "5-bit roundtrip cosine {cos} < 0.985");
1129 }
1130
1131 #[test]
1132 fn hb_encoder_d256_is_deterministic() {
1133 let x = deterministic_gaussian_test(0xBEEF, 256);
1134 let (p_a, n_a) = turboquant_hb_encode_d256(&x, 8).expect("a");
1135 let (p_b, n_b) = turboquant_hb_encode_d256(&x, 8).expect("b");
1136 assert_eq!(p_a, p_b);
1137 assert_eq!(n_a.to_bits(), n_b.to_bits());
1138 }
1139
1140 #[test]
1141 fn hb_encoder_d256_zero_vector() {
1142 let x = vec![0.0_f32; 256];
1146 let (packed, norm) = turboquant_hb_encode_d256(&x, 8).expect("encode");
1147 assert_eq!(norm, 0.0);
1148 for &b in packed.iter() {
1150 assert!(b == 127 || b == 128,
1151 "zero-vec encode produced non-near-zero centroid: {b}");
1152 }
1153 let recon = decode_d256_via_kernel_formula(&packed, 0.0, 8);
1155 for &v in recon.iter() {
1156 assert_eq!(v, 0.0);
1157 }
1158 }
1159
1160 #[test]
1161 fn hb_encoder_d256_validates_bits() {
1162 let x = vec![0.0_f32; 256];
1163 assert!(turboquant_hb_encode_d256(&x, 4).is_err()); assert!(turboquant_hb_encode_d256(&x, 7).is_err()); }
1166
1167 #[test]
1168 fn hb_encoder_d256_validates_size() {
1169 let x = vec![0.0_f32; 128]; assert!(turboquant_hb_encode_d256(&x, 8).is_err());
1171 }
1172
1173 #[test]
1174 fn d1_sign_mask_is_self_inverse() {
1175 let mut x = deterministic_gaussian_test(0x123, 256);
1176 let original = x.clone();
1177 apply_d1_sign_mask_inplace(&mut x, &TBQ_SIGNS_256);
1178 let differs = x.iter().zip(original.iter()).any(|(&a, &b)| (a - b).abs() > 1e-6);
1180 assert!(differs, "D1 sign mask had no effect");
1181 apply_d1_sign_mask_inplace(&mut x, &TBQ_SIGNS_256);
1183 for (i, (&a, &b)) in x.iter().zip(original.iter()).enumerate() {
1184 assert!((a - b).abs() < 1e-6, "D1 sign mask not self-inverse at {i}");
1185 }
1186 }
1187
1188 #[test]
1189 fn tbq_signs_first_32_bytes_match_512_prefix() {
1190 for i in 0..32 {
1194 assert_eq!(TBQ_SIGNS_256[i], TBQ_SIGNS_512[i],
1195 "TBQ_SIGNS_256/512 prefix mismatch at byte {i}");
1196 }
1197 }
1198}