Skip to main content

mlx_native/
turboquant.rs

1//! TurboQuant KV cache compression — CPU reference implementation.
2//!
3//! Implements the TurboQuant_mse algorithm:
4//! 1. Walsh-Hadamard rotation for incoherence
5//! 2. Per-head norm extraction
6//! 3. Lloyd-Max scalar quantization against N(0,1) codebooks
7//!
8//! This module is CPU-only math — no Metal GPU dispatch.
9
10// ---- Lloyd-Max Codebooks for N(0,1) ----
11//
12// Precomputed via iterative Lloyd-Max algorithm with convergence tolerance 1e-12.
13// Each codebook is symmetric around zero.
14
15/// 2-bit Lloyd-Max centroids for N(0,1): 4 reconstruction levels.
16pub const CODEBOOK_2BIT: [f32; 4] = [
17    -1.5104176, -0.4527800, 0.4527800, 1.5104176,
18];
19
20/// 3-bit Lloyd-Max centroids for N(0,1): 8 reconstruction levels.
21pub const CODEBOOK_3BIT: [f32; 8] = [
22    -2.1519457, -1.3439093, -0.7560053, -0.2450942,
23    0.2450942, 0.7560053, 1.3439093, 2.1519457,
24];
25
26/// 4-bit Lloyd-Max centroids for N(0,1): 16 reconstruction levels.
27pub const CODEBOOK_4BIT: [f32; 16] = [
28    -2.7325896, -2.0690172, -1.6180464, -1.2562312,
29    -0.9423405, -0.6567591, -0.3880483, -0.1283950,
30    0.1283950, 0.3880483, 0.6567591, 0.9423405,
31    1.2562312, 1.6180464, 2.0690172, 2.7325896,
32];
33
34// ---- BitWidth enum ----
35
36/// Quantization bit-width for TurboQuant.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum BitWidth {
39    /// 2-bit uniform: all coordinates use 4-level codebook.
40    Two,
41    /// 3-bit uniform: all coordinates use 8-level codebook.
42    Three,
43    /// 4-bit uniform: all coordinates use 16-level codebook.
44    Four,
45    /// 2.5-bit mixed: first d/4 coordinates at 3-bit, remaining 3d/4 at 2-bit.
46    TwoPointFive,
47}
48
49/// Configuration for TurboQuant quantization.
50#[derive(Debug, Clone)]
51pub struct TurboQuantConfig {
52    /// Quantization bit-width.
53    pub bit_width: BitWidth,
54    /// Head dimension (must be a power of 2: 128, 256, or 512).
55    pub head_dim: usize,
56}
57
58// ---- Fast Walsh-Hadamard Transform ----
59
60/// In-place normalized Fast Walsh-Hadamard Transform.
61///
62/// The normalization ensures H * H = I, so the inverse transform is the
63/// same function applied again.
64///
65/// # Arguments
66/// * `x` — mutable slice of length `n` where `n` is a power of 2.
67///
68/// # Returns
69/// `Ok(())` on success, or an error if the length is not a power of 2.
70pub fn fwht_inplace(x: &mut [f32]) -> crate::Result<()> {
71    let n = x.len();
72    if n == 0 || !n.is_power_of_two() {
73        return Err(crate::MlxError::InvalidArgument(format!(
74            "FWHT requires power-of-two length, got {n}"
75        )));
76    }
77
78    let mut h = 1;
79    while h < n {
80        let step = h * 2;
81        let mut i = 0;
82        while i < n {
83            for j in i..i + h {
84                let a = x[j];
85                let b = x[j + h];
86                x[j] = a + b;
87                x[j + h] = a - b;
88            }
89            i += step;
90        }
91        h *= 2;
92    }
93
94    // Normalize so that H * H = I
95    let scale = 1.0 / (n as f32).sqrt();
96    for v in x.iter_mut() {
97        *v *= scale;
98    }
99
100    Ok(())
101}
102
103// ---- Standard Normal PDF / CDF ----
104
105/// Standard normal probability density function: phi(x) = exp(-x^2/2) / sqrt(2*pi).
106#[inline]
107fn std_normal_pdf(x: f64) -> f64 {
108    const INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7; // 1/sqrt(2*pi)
109    INV_SQRT_2PI * (-0.5 * x * x).exp()
110}
111
112/// Standard normal CDF using the Abramowitz & Stegun rational approximation
113/// (formula 26.2.17, maximum error < 7.5e-8).
114#[inline]
115fn std_normal_cdf(x: f64) -> f64 {
116    if x < -8.0 {
117        return 0.0;
118    }
119    if x > 8.0 {
120        return 1.0;
121    }
122
123    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
124    let x_abs = x.abs();
125
126    // Horner form of the rational approximation
127    const P: f64 = 0.231_641_9;
128    const B1: f64 = 0.319_381_530;
129    const B2: f64 = -0.356_563_782;
130    const B3: f64 = 1.781_477_937;
131    const B4: f64 = -1.821_255_978;
132    const B5: f64 = 1.330_274_429;
133
134    let t = 1.0 / (1.0 + P * x_abs);
135    let t2 = t * t;
136    let t3 = t2 * t;
137    let t4 = t3 * t;
138    let t5 = t4 * t;
139
140    let poly = B1 * t + B2 * t2 + B3 * t3 + B4 * t4 + B5 * t5;
141    let phi = std_normal_pdf(x_abs);
142
143    let result = 1.0 - phi * poly;
144
145    if sign < 0.0 {
146        1.0 - result
147    } else {
148        result
149    }
150}
151
152// ---- Nearest centroid lookup ----
153
154/// Find the index of the nearest centroid in a sorted codebook.
155#[inline]
156fn nearest_centroid(value: f32, codebook: &[f32]) -> u8 {
157    // Binary-search style: codebook is sorted, find nearest by checking boundaries
158    let n = codebook.len();
159    if n <= 1 {
160        return 0;
161    }
162
163    let mut best_idx = 0u8;
164    let mut best_dist = (value - codebook[0]).abs();
165
166    for (i, &c) in codebook.iter().enumerate().skip(1) {
167        let dist = (value - c).abs();
168        if dist < best_dist {
169            best_dist = dist;
170            best_idx = i as u8;
171        }
172    }
173    best_idx
174}
175
176/// Get the codebook for a specific coordinate index under the given config.
177#[inline]
178fn codebook_for_coord(coord_idx: usize, config: &TurboQuantConfig) -> &'static [f32] {
179    match config.bit_width {
180        BitWidth::Two => &CODEBOOK_2BIT,
181        BitWidth::Three => &CODEBOOK_3BIT,
182        BitWidth::Four => &CODEBOOK_4BIT,
183        BitWidth::TwoPointFive => {
184            let boundary = config.head_dim / 4;
185            if coord_idx < boundary {
186                &CODEBOOK_3BIT // first d/4 channels at 3-bit
187            } else {
188                &CODEBOOK_2BIT // remaining 3d/4 at 2-bit
189            }
190        }
191    }
192}
193
194/// Bits per index for a coordinate under the given config.
195#[inline]
196fn bits_for_coord(coord_idx: usize, config: &TurboQuantConfig) -> usize {
197    match config.bit_width {
198        BitWidth::Two => 2,
199        BitWidth::Three => 3,
200        BitWidth::Four => 4,
201        BitWidth::TwoPointFive => {
202            if coord_idx < config.head_dim / 4 {
203                3
204            } else {
205                2
206            }
207        }
208    }
209}
210
211// ---- Pack / Unpack indices ----
212
213/// Pack variable-width indices into a byte vector using bit-packing.
214///
215/// Indices are packed MSB-first into consecutive bytes.
216fn pack_indices(indices: &[u8], config: &TurboQuantConfig) -> Vec<u8> {
217    let total_bits: usize = (0..indices.len())
218        .map(|i| bits_for_coord(i, config))
219        .sum();
220    let num_bytes = (total_bits + 7) / 8;
221    let mut packed = vec![0u8; num_bytes];
222
223    let mut bit_offset = 0usize;
224    for (i, &idx) in indices.iter().enumerate() {
225        let nbits = bits_for_coord(i, config);
226        // Write `nbits` bits of `idx` starting at `bit_offset`
227        for b in (0..nbits).rev() {
228            let bit_val = (idx >> b) & 1;
229            let byte_pos = bit_offset / 8;
230            let bit_pos = 7 - (bit_offset % 8);
231            if byte_pos < packed.len() {
232                packed[byte_pos] |= bit_val << bit_pos;
233            }
234            bit_offset += 1;
235        }
236    }
237
238    packed
239}
240
241/// Unpack variable-width indices from a packed byte vector.
242fn unpack_indices(packed: &[u8], config: &TurboQuantConfig) -> Vec<u8> {
243    let d = config.head_dim;
244    let mut indices = Vec::with_capacity(d);
245
246    let mut bit_offset = 0usize;
247    for i in 0..d {
248        let nbits = bits_for_coord(i, config);
249        let mut val = 0u8;
250        for _ in 0..nbits {
251            let byte_pos = bit_offset / 8;
252            let bit_pos = 7 - (bit_offset % 8);
253            let bit_val = if byte_pos < packed.len() {
254                (packed[byte_pos] >> bit_pos) & 1
255            } else {
256                0
257            };
258            val = (val << 1) | bit_val;
259            bit_offset += 1;
260        }
261        indices.push(val);
262    }
263
264    indices
265}
266
267// ---- Quantize / Dequantize ----
268
269/// Quantize a single head vector using TurboQuant_mse.
270///
271/// Steps:
272/// 1. Apply FWHT (Walsh-Hadamard rotation) for incoherence
273/// 2. Extract L2 norm
274/// 3. Normalize to unit vector
275/// 4. Quantize each coordinate against the appropriate Lloyd-Max codebook
276/// 5. Pack indices
277///
278/// # Arguments
279/// * `x` — input vector of length `config.head_dim`
280/// * `config` — quantization configuration
281///
282/// # Returns
283/// `(packed_indices, norm)` on success.
284pub fn turboquant_quantize(
285    x: &[f32],
286    config: &TurboQuantConfig,
287) -> crate::Result<(Vec<u8>, f32)> {
288    let d = config.head_dim;
289    if x.len() != d {
290        return Err(crate::MlxError::InvalidArgument(format!(
291            "Expected vector of length {d}, got {}",
292            x.len()
293        )));
294    }
295    if !d.is_power_of_two() {
296        return Err(crate::MlxError::InvalidArgument(format!(
297            "head_dim must be power of 2, got {d}"
298        )));
299    }
300
301    // 1. Copy and apply FWHT
302    let mut rotated = x.to_vec();
303    fwht_inplace(&mut rotated)?;
304
305    // 2. Compute L2 norm of rotated vector (same as original since Hadamard is orthogonal)
306    let norm_sq: f32 = rotated.iter().map(|&v| v * v).sum();
307    let norm = norm_sq.sqrt();
308
309    if norm < 1e-30 {
310        // Zero vector: all indices = 0, norm = 0
311        let indices = vec![0u8; d];
312        let packed = pack_indices(&indices, config);
313        return Ok((packed, 0.0));
314    }
315
316    // 3. Normalize to unit vector on S^{d-1}
317    let inv_norm = 1.0 / norm;
318    for v in rotated.iter_mut() {
319        *v *= inv_norm;
320    }
321
322    // 4. Quantize: each coordinate needs to be scaled to N(0,1) domain.
323    // A unit vector on S^{d-1} has coordinates ~ N(0, 1/d) for large d.
324    // Scale by sqrt(d) to map to N(0,1) for codebook lookup.
325    let scale = (d as f32).sqrt();
326    let mut indices = Vec::with_capacity(d);
327    for (i, &v) in rotated.iter().enumerate() {
328        let scaled = v * scale;
329        let cb = codebook_for_coord(i, config);
330        indices.push(nearest_centroid(scaled, cb));
331    }
332
333    // 5. Pack
334    let packed = pack_indices(&indices, config);
335
336    Ok((packed, norm))
337}
338
339/// Dequantize a TurboQuant-compressed head vector.
340///
341/// Steps:
342/// 1. Unpack indices
343/// 2. Look up centroid values, scale back from N(0,1) domain
344/// 3. Multiply by norm
345/// 4. Apply inverse FWHT (same as forward)
346///
347/// # Arguments
348/// * `packed` — packed index bytes
349/// * `norm` — the L2 norm stored during quantization
350/// * `config` — quantization configuration
351///
352/// # Returns
353/// Reconstructed vector of length `config.head_dim`.
354pub fn turboquant_dequantize(
355    packed: &[u8],
356    norm: f32,
357    config: &TurboQuantConfig,
358) -> crate::Result<Vec<f32>> {
359    let d = config.head_dim;
360    if !d.is_power_of_two() {
361        return Err(crate::MlxError::InvalidArgument(format!(
362            "head_dim must be power of 2, got {d}"
363        )));
364    }
365
366    // 1. Unpack indices
367    let indices = unpack_indices(packed, config);
368
369    // 2. Look up centroids and scale back from N(0,1) to unit-sphere scale
370    let inv_scale = 1.0 / (d as f32).sqrt();
371    let mut reconstructed = Vec::with_capacity(d);
372    for (i, &idx) in indices.iter().enumerate() {
373        let cb = codebook_for_coord(i, config);
374        let idx_usize = idx as usize;
375        let centroid = if idx_usize < cb.len() {
376            cb[idx_usize]
377        } else {
378            0.0 // fallback for out-of-range (shouldn't happen)
379        };
380        reconstructed.push(centroid * inv_scale * norm);
381    }
382
383    // 3. Apply inverse FWHT (same as forward since H^{-1} = H with normalization)
384    fwht_inplace(&mut reconstructed)?;
385
386    Ok(reconstructed)
387}
388
389// ---- Lloyd-Max computation utilities (used by tests for validation) ----
390
391/// Compute Lloyd-Max codebook for N(0,1) with the given number of levels.
392///
393/// Returns the sorted centroid array. This is used in tests to validate the
394/// hardcoded codebooks.
395pub fn compute_lloyd_max_codebook(num_levels: usize) -> Vec<f64> {
396    // Initialize with uniform quantile boundaries
397    let mut boundaries = Vec::with_capacity(num_levels + 1);
398    boundaries.push(-10.0_f64); // approx -inf
399    for i in 1..num_levels {
400        let p = i as f64 / num_levels as f64;
401        boundaries.push(quantile_normal(p));
402    }
403    boundaries.push(10.0_f64); // approx +inf
404
405    // Initial centroids from conditional expectations
406    let mut centroids = vec![0.0_f64; num_levels];
407    for i in 0..num_levels {
408        let a = boundaries[i];
409        let b = boundaries[i + 1];
410        let prob = std_normal_cdf(b) - std_normal_cdf(a);
411        if prob > 1e-30 {
412            centroids[i] = (std_normal_pdf(a) - std_normal_pdf(b)) / prob;
413        }
414    }
415
416    // Iterate
417    for _iter in 0..50_000 {
418        let old = centroids.clone();
419
420        // Update boundaries to midpoints
421        boundaries[0] = -10.0;
422        for i in 1..num_levels {
423            boundaries[i] = (centroids[i - 1] + centroids[i]) / 2.0;
424        }
425        *boundaries.last_mut().unwrap_or(&mut 0.0) = 10.0;
426
427        // Update centroids
428        for i in 0..num_levels {
429            let a = boundaries[i];
430            let b = boundaries[i + 1];
431            let prob = std_normal_cdf(b) - std_normal_cdf(a);
432            if prob > 1e-30 {
433                centroids[i] = (std_normal_pdf(a) - std_normal_pdf(b)) / prob;
434            }
435        }
436
437        // Check convergence
438        let max_change = centroids
439            .iter()
440            .zip(old.iter())
441            .map(|(a, b)| (a - b).abs())
442            .fold(0.0_f64, f64::max);
443        if max_change < 1e-12 {
444            break;
445        }
446    }
447
448    centroids
449}
450
451/// Approximate quantile (inverse CDF) of N(0,1) using rational approximation.
452///
453/// Uses the Beasley-Springer-Moro algorithm.
454fn quantile_normal(p: f64) -> f64 {
455    if p <= 0.0 {
456        return -10.0;
457    }
458    if p >= 1.0 {
459        return 10.0;
460    }
461
462    // Rational approximation (Peter Acklam's algorithm)
463    const A: [f64; 6] = [
464        -3.969683028665376e1,
465        2.209460984245205e2,
466        -2.759285104469687e2,
467        1.383577518672690e2,
468        -3.066479806614716e1,
469        2.506628277459239e0,
470    ];
471    const B: [f64; 5] = [
472        -5.447609879822406e1,
473        1.615858368580409e2,
474        -1.556989798598866e2,
475        6.680131188771972e1,
476        -1.328068155288572e1,
477    ];
478    const C: [f64; 6] = [
479        -7.784894002430293e-3,
480        -3.223964580411365e-1,
481        -2.400758277161838e0,
482        -2.549732539343734e0,
483        4.374664141464968e0,
484        2.938163982698783e0,
485    ];
486    const D: [f64; 4] = [
487        7.784695709041462e-3,
488        3.224671290700398e-1,
489        2.445134137142996e0,
490        3.754408661907416e0,
491    ];
492
493    const P_LOW: f64 = 0.02425;
494    const P_HIGH: f64 = 1.0 - P_LOW;
495
496    if p < P_LOW {
497        let q = (-2.0 * p.ln()).sqrt();
498        (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
499            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
500    } else if p <= P_HIGH {
501        let q = p - 0.5;
502        let r = q * q;
503        (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
504            / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
505    } else {
506        let q = (-2.0 * (1.0 - p).ln()).sqrt();
507        -(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
508            / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
509    }
510}
511
512/// Compute Lloyd-Max codebook for Beta((d-1)/2, (d-1)/2) scaled to [-1, 1].
513///
514/// The exact distribution of a coordinate of a unit vector uniform on S^{d-1}
515/// is Beta((d-1)/2, (d-1)/2) on [-1, 1]. For large d this converges to N(0, 1/d).
516///
517/// Uses numerical integration via trapezoidal rule for the conditional expectations.
518pub fn compute_lloyd_max_beta_codebook(dim: usize, num_levels: usize) -> Vec<f64> {
519    let alpha = (dim as f64 - 1.0) / 2.0;
520
521    // Beta PDF on [-1,1] with parameters (alpha, alpha) — symmetric
522    // f(x) = C * (1-x^2)^(alpha-1)  for x in [-1, 1]
523    // where C normalizes to 1.
524
525    // Use log-space for numerical stability
526    let log_norm = log_beta_norm_const(alpha);
527
528    let beta_pdf = |x: f64| -> f64 {
529        if x <= -1.0 || x >= 1.0 {
530            return 0.0;
531        }
532        let val = 1.0 - x * x;
533        if val <= 0.0 {
534            return 0.0;
535        }
536        (log_norm + (alpha - 1.0) * val.ln()).exp()
537    };
538
539    // Numerical CDF via cumulative trapezoidal integration
540    let n_grid = 10_000;
541    let grid_lo = -1.0_f64;
542    let grid_hi = 1.0_f64;
543    let dx = (grid_hi - grid_lo) / n_grid as f64;
544
545    // Build CDF table
546    let mut cdf_vals = vec![0.0_f64; n_grid + 1];
547    let mut pdf_vals = vec![0.0_f64; n_grid + 1];
548    for i in 0..=n_grid {
549        let x = grid_lo + i as f64 * dx;
550        pdf_vals[i] = beta_pdf(x);
551    }
552    for i in 1..=n_grid {
553        cdf_vals[i] = cdf_vals[i - 1] + 0.5 * (pdf_vals[i - 1] + pdf_vals[i]) * dx;
554    }
555    // Normalize CDF to [0, 1]
556    let cdf_total = cdf_vals[n_grid];
557    if cdf_total > 1e-30 {
558        for v in cdf_vals.iter_mut() {
559            *v /= cdf_total;
560        }
561        for v in pdf_vals.iter_mut() {
562            *v /= cdf_total;
563        }
564    }
565
566    // Helper: interpolated CDF and conditional expectation on [a, b]
567    let interp_cdf = |x: f64| -> f64 {
568        let frac = (x - grid_lo) / dx;
569        let idx = frac as usize;
570        if idx >= n_grid {
571            return 1.0;
572        }
573        let t = frac - idx as f64;
574        cdf_vals[idx] * (1.0 - t) + cdf_vals[idx + 1] * t
575    };
576
577    let conditional_expectation = |a: f64, b: f64| -> f64 {
578        // E[X | a <= X <= b] via numerical integration
579        let prob = interp_cdf(b) - interp_cdf(a);
580        if prob < 1e-30 {
581            return (a + b) / 2.0;
582        }
583
584        let n_sub = 500;
585        let sub_dx = (b - a) / n_sub as f64;
586        let mut integral = 0.0_f64;
587        for j in 0..=n_sub {
588            let x = a + j as f64 * sub_dx;
589            let w = if j == 0 || j == n_sub { 0.5 } else { 1.0 };
590            let frac = (x - grid_lo) / dx;
591            let idx = frac as usize;
592            let pdf_val = if idx >= n_grid {
593                0.0
594            } else {
595                let t = frac - idx as f64;
596                pdf_vals[idx] * (1.0 - t) + pdf_vals[idx + 1] * t
597            };
598            integral += w * x * pdf_val * sub_dx;
599        }
600        integral / prob
601    };
602
603    // Initialize with uniform quantile boundaries
604    let mut boundaries = Vec::with_capacity(num_levels + 1);
605    boundaries.push(-1.0_f64);
606    for i in 1..num_levels {
607        let target_p = i as f64 / num_levels as f64;
608        // Binary search for quantile
609        let mut lo = -1.0_f64;
610        let mut hi = 1.0_f64;
611        for _ in 0..100 {
612            let mid = (lo + hi) / 2.0;
613            if interp_cdf(mid) < target_p {
614                lo = mid;
615            } else {
616                hi = mid;
617            }
618        }
619        boundaries.push((lo + hi) / 2.0);
620    }
621    boundaries.push(1.0_f64);
622
623    // Initial centroids
624    let mut centroids = vec![0.0_f64; num_levels];
625    for i in 0..num_levels {
626        centroids[i] = conditional_expectation(boundaries[i], boundaries[i + 1]);
627    }
628
629    // Lloyd-Max iteration
630    for _iter in 0..5000 {
631        let old = centroids.clone();
632
633        // Update boundaries
634        boundaries[0] = -1.0;
635        for i in 1..num_levels {
636            boundaries[i] = (centroids[i - 1] + centroids[i]) / 2.0;
637        }
638        if let Some(last) = boundaries.last_mut() {
639            *last = 1.0;
640        }
641
642        // Update centroids
643        for i in 0..num_levels {
644            centroids[i] = conditional_expectation(boundaries[i], boundaries[i + 1]);
645        }
646
647        let max_change = centroids
648            .iter()
649            .zip(old.iter())
650            .map(|(a, b)| (a - b).abs())
651            .fold(0.0_f64, f64::max);
652        if max_change < 1e-10 {
653            break;
654        }
655    }
656
657    centroids
658}
659
660/// Log of the normalization constant for the symmetric Beta PDF on [-1, 1].
661fn log_beta_norm_const(alpha: f64) -> f64 {
662    // Beta(alpha, alpha) on [0,1] has norm B(alpha, alpha) = Gamma(alpha)^2 / Gamma(2*alpha)
663    // On [-1,1] we scale by 1/2, so norm = B(alpha,alpha) * 2^(2*alpha-1)
664    // log C = -log(B(alpha,alpha)) - (2*alpha-1)*log(2)
665    //       = log(Gamma(2*alpha)) - 2*log(Gamma(alpha)) - (2*alpha-1)*log(2)
666    ln_gamma(2.0 * alpha) - 2.0 * ln_gamma(alpha) - (2.0 * alpha - 1.0) * 2.0_f64.ln()
667}
668
669/// Lanczos approximation for ln(Gamma(x)), x > 0.
670fn ln_gamma(x: f64) -> f64 {
671    // Lanczos approximation with g=7, n=9
672    const G: f64 = 7.0;
673    const COEFF: [f64; 9] = [
674        0.999_999_999_999_809_93,
675        676.520_368_121_885_1,
676        -1_259.139_216_722_402_9,
677        771.323_428_777_653_1,
678        -176.615_029_162_140_6,
679        12.507_343_278_686_905,
680        -0.138_571_095_265_720_12,
681        9.984_369_578_019_571_6e-6,
682        1.505_632_735_149_311_6e-7,
683    ];
684
685    if x < 0.5 {
686        // Reflection formula
687        let pi = std::f64::consts::PI;
688        return pi.ln() - (pi * x).sin().ln() - ln_gamma(1.0 - x);
689    }
690
691    let x = x - 1.0;
692    let mut ag = COEFF[0];
693    for i in 1..9 {
694        ag += COEFF[i] / (x + i as f64);
695    }
696
697    let tmp = x + G + 0.5;
698    0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * tmp.ln() - tmp + ag.ln()
699}
700
701#[cfg(test)]
702#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
703mod tests {
704    use super::*;
705
706    #[test]
707    fn test_codebook_symmetry() {
708        for (name, cb) in [
709            ("2-bit", &CODEBOOK_2BIT[..]),
710            ("3-bit", &CODEBOOK_3BIT[..]),
711            ("4-bit", &CODEBOOK_4BIT[..]),
712        ] {
713            let n = cb.len();
714            for i in 0..n / 2 {
715                let sum = cb[i] + cb[n - 1 - i];
716                assert!(
717                    sum.abs() < 1e-5,
718                    "{name} codebook not symmetric: c[{i}]={} + c[{}]={} = {sum}",
719                    cb[i],
720                    n - 1 - i,
721                    cb[n - 1 - i]
722                );
723            }
724        }
725    }
726
727    #[test]
728    fn test_codebook_values_match_lloyd_max() {
729        for (bits, hardcoded) in [
730            (2, &CODEBOOK_2BIT[..]),
731            (3, &CODEBOOK_3BIT[..]),
732            (4, &CODEBOOK_4BIT[..]),
733        ] {
734            let computed = compute_lloyd_max_codebook(1 << bits);
735            assert_eq!(computed.len(), hardcoded.len());
736            for (i, (&h, &c)) in hardcoded.iter().zip(computed.iter()).enumerate() {
737                let diff = (h as f64 - c).abs();
738                assert!(
739                    diff < 1e-4,
740                    "{bits}-bit codebook mismatch at {i}: hardcoded={h}, computed={c}, diff={diff}"
741                );
742            }
743        }
744    }
745
746    #[test]
747    fn test_fwht_roundtrip() {
748        let original: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1 - 6.4).collect();
749        let mut data = original.clone();
750        fwht_inplace(&mut data).unwrap();
751        fwht_inplace(&mut data).unwrap();
752        for (i, (&a, &b)) in original.iter().zip(data.iter()).enumerate() {
753            assert!(
754                (a - b).abs() < 1e-4,
755                "FWHT roundtrip mismatch at {i}: {a} vs {b}"
756            );
757        }
758    }
759}