Skip to main content

turboquant/codebook/
mod.rs

1//! Codebook lookup, Beta PDF, and pre-computed tables for Lloyd-Max quantization.
2//!
3//! After applying a random rotation to a *d*-dimensional unit vector, each
4//! coordinate follows a Beta-type distribution on [-1, 1]:
5//!
6//! ```text
7//! f_X(x) = Gamma(d/2) / (sqrt(pi) * Gamma((d-1)/2)) * (1 - x^2)^((d-3)/2)
8//! ```
9//!
10//! This module stores pre-computed Lloyd-Max codebooks for common `(bits, dim)`
11//! pairs and provides lookup / nearest-centroid queries.  The generation
12//! algorithm itself lives in the [`gen`] sub-module.
13
14mod gen;
15mod tables;
16
17use crate::error::{require, Result, TurboQuantError};
18use crate::packed::is_valid_bits;
19
20// Re-export `generate_codebook` and `beta_pdf` so that existing callers
21// (including integration tests) can continue to import them from `codebook`.
22pub use gen::{beta_pdf, generate_codebook};
23
24/// Lower bound of the support interval [-1, 1].
25pub(crate) const SUPPORT_MIN: f64 = -1.0;
26
27/// Upper bound of the support interval [-1, 1].
28pub(crate) const SUPPORT_MAX: f64 = 1.0;
29
30// ---------------------------------------------------------------------------
31// Codebook struct
32// ---------------------------------------------------------------------------
33
34/// A static codebook: sorted centroids and the decision boundaries between
35/// them.  For *k* centroids there are *k-1* interior boundaries; the outer
36/// boundaries are implicitly -1 and +1.
37#[derive(Debug, Clone)]
38pub struct Codebook {
39    /// Sorted centroid values (length = 2^bits).
40    pub centroids: Vec<f64>,
41    /// Interior decision boundaries (length = 2^bits - 1).
42    pub boundaries: Vec<f64>,
43}
44
45// ---------------------------------------------------------------------------
46// Pure Operation: Validation helpers (logic only, no calls)
47// ---------------------------------------------------------------------------
48
49/// Validate the bit width, returning an error if unsupported.
50///
51/// Pure Integration: only calls `require` and `is_valid_bits` (from `packed`).
52fn validate_bits(bits: u8) -> Result<()> {
53    require(is_valid_bits(bits), TurboQuantError::UnsupportedBits(bits))
54}
55
56/// Compute the number of centroids for a given bit width: 2^bits.
57pub(crate) fn centroid_count(bits: u8) -> usize {
58    1usize << bits
59}
60
61// ---------------------------------------------------------------------------
62// Pure Operation: nearest centroid lookup
63// ---------------------------------------------------------------------------
64
65/// Binary search over interior boundaries to find the bin index for `value`.
66fn boundary_binary_search(value: f64, boundaries: &[f64]) -> u8 {
67    let mut lo: usize = 0;
68    let mut hi: usize = boundaries.len();
69    while lo < hi {
70        let mid = lo + (hi - lo) / 2;
71        if value > boundaries[mid] {
72            lo = mid + 1;
73        } else {
74            hi = mid;
75        }
76    }
77    lo as u8
78}
79
80/// Find the index of the nearest centroid for a scalar `value`.
81///
82/// Uses the interior boundaries for an O(log k) binary search.
83///
84/// Pure Integration: delegates to `boundary_binary_search`.
85pub fn nearest_centroid(value: f64, codebook: &Codebook) -> u8 {
86    boundary_binary_search(value, &codebook.boundaries)
87}
88
89// ---------------------------------------------------------------------------
90// Pure Operation: static codebook lookup
91// ---------------------------------------------------------------------------
92
93/// Look up a pre-computed static codebook, returning `None` if the `(bits, dim)`
94/// pair is not in the table.
95///
96/// Pure Integration: delegates lookup to `tables::lookup_static_codebook_ref` and
97/// conversion to `StaticCodebook::to_codebook`.
98fn lookup_static_codebook(bits: u8, dim: usize) -> Option<Codebook> {
99    let sc = tables::lookup_static_codebook_ref(bits, dim)?;
100    Some(sc.to_codebook())
101}
102
103// ---------------------------------------------------------------------------
104// Public API — Pure Integration functions (orchestrate other fns)
105// ---------------------------------------------------------------------------
106
107/// Return a pre-computed [`Codebook`] for the given `(bits, dim)` pair.
108///
109/// Falls back to computing one on the fly if the pair is not in the
110/// pre-computed table.
111///
112/// # Errors
113///
114/// Returns [`TurboQuantError::UnsupportedBits`] if `bits` is not 2, 3, or 4.
115///
116/// Pure Integration: delegates validation to `validate_bits`, lookup to
117/// `lookup_static_codebook`, and generation to `gen::generate_codebook`.
118pub fn get_codebook(bits: u8, dim: usize) -> Result<Codebook> {
119    validate_bits(bits)?;
120    let maybe = lookup_static_codebook(bits, dim);
121    Ok(maybe.unwrap_or_else(|| gen::generate_codebook(bits, dim)))
122}
123
124// ---------------------------------------------------------------------------
125// Unit tests for codebook lookup and Beta PDF
126// ---------------------------------------------------------------------------
127
128#[cfg(test)]
129mod tests {
130    use super::tables::*;
131    use super::*;
132    use approx::assert_relative_eq;
133
134    // -- Named constants for test parameters --------------------------------
135
136    /// Dimension used in integration / beta-PDF tests.
137    const TEST_DIM: usize = 128;
138    /// Number of steps for Simpson's rule integration.
139    const INTEGRATION_STEPS: usize = 2048;
140    /// Number of centroids when using 3-bit quantization (2^3).
141    const TEST_CENTROIDS_8: usize = 8;
142    /// Number of centroids when using 4-bit quantization (2^4).
143    const TEST_CENTROIDS_16: usize = 16;
144    /// Bit width for 3-bit quantization in tests.
145    const TEST_BITS_3: u8 = 3;
146    /// Bit width for 4-bit quantization in tests.
147    const TEST_BITS_4: u8 = 4;
148    /// Dimensions used in beta_pdf symmetry and boundary tests.
149    const TEST_DIMS: [usize; 3] = [64, 128, 256];
150    /// X-values used in beta_pdf symmetry tests.
151    const TEST_X_VALUES: [f64; 5] = [0.0, 0.1, 0.3, 0.5, 0.9];
152    /// Number of centroids when using 2-bit quantization (2^2).
153    const TEST_CENTROIDS_4: usize = 4;
154    /// Bit width for 2-bit quantization in tests.
155    const TEST_BITS_2: u8 = 2;
156    /// Known `(bits, dim)` pairs used in lookup tests.
157    const KNOWN_CODEBOOK_CONFIGS: [(u8, usize); 12] = [
158        (2, 32),
159        (2, 64),
160        (2, 128),
161        (2, 256),
162        (3, 32),
163        (3, 64),
164        (3, 128),
165        (3, 256),
166        (4, 32),
167        (4, 64),
168        (4, 128),
169        (4, 256),
170    ];
171
172    // -- Test-only helper functions (inlined out of production code) ---------
173
174    use crate::math::{ln_gamma, simpsons_integrate, HALF};
175
176    /// Compute the log-normalization constant for the Beta-type PDF (test-only).
177    fn beta_pdf_log_normalization(df: f64) -> f64 {
178        let half_df = df * HALF;
179        let half_df_minus_one = (df - 1.0) * HALF;
180        let half_ln_pi = HALF * core::f64::consts::PI.ln();
181        ln_gamma(half_df) - half_ln_pi - ln_gamma(half_df_minus_one)
182    }
183
184    /// Compute the midpoint of an interval (test-only).
185    fn interval_midpoint(a: f64, b: f64) -> f64 {
186        (a + b) * HALF
187    }
188
189    /// Small epsilon to guard against division by near-zero values (test-only).
190    const EPSILON_ZERO: f64 = 1e-30;
191
192    /// Check whether a denominator is too small for safe division (test-only).
193    fn is_near_zero(value: f64) -> bool {
194        value.abs() < EPSILON_ZERO
195    }
196
197    // -- beta_pdf -----------------------------------------------------------
198
199    #[test]
200    fn beta_pdf_integrates_to_approximately_one() {
201        let d = TEST_DIM;
202        let integral = simpsons_integrate(
203            |x| beta_pdf(x, d),
204            SUPPORT_MIN,
205            SUPPORT_MAX,
206            INTEGRATION_STEPS,
207        );
208        assert_relative_eq!(integral, 1.0, epsilon = 1e-4);
209    }
210
211    #[test]
212    fn beta_pdf_is_symmetric() {
213        for d in TEST_DIMS {
214            for &x in &TEST_X_VALUES {
215                assert_relative_eq!(beta_pdf(x, d), beta_pdf(-x, d), epsilon = 1e-12);
216            }
217        }
218    }
219
220    #[test]
221    fn beta_pdf_zero_at_boundary() {
222        for d in TEST_DIMS {
223            assert_relative_eq!(beta_pdf(SUPPORT_MAX, d), 0.0, epsilon = 1e-15);
224            assert_relative_eq!(beta_pdf(SUPPORT_MIN, d), 0.0, epsilon = 1e-15);
225        }
226    }
227
228    #[test]
229    fn beta_pdf_zero_outside_support() {
230        assert_relative_eq!(beta_pdf(1.5, TEST_DIM), 0.0, epsilon = 1e-15);
231        assert_relative_eq!(beta_pdf(-2.0, TEST_DIM), 0.0, epsilon = 1e-15);
232    }
233
234    #[test]
235    fn beta_pdf_zero_for_low_dimension() {
236        assert_relative_eq!(beta_pdf(0.0, 2), 0.0, epsilon = 1e-15);
237        assert_relative_eq!(beta_pdf(0.0, 1), 0.0, epsilon = 1e-15);
238    }
239
240    // -- boundary_binary_search ---------------------------------------------
241
242    #[test]
243    fn boundary_binary_search_first_bin() {
244        let boundaries = vec![-0.5, 0.0, 0.5];
245        assert_eq!(boundary_binary_search(-0.9, &boundaries), 0);
246    }
247
248    #[test]
249    fn boundary_binary_search_last_bin() {
250        let boundaries = vec![-0.5, 0.0, 0.5];
251        assert_eq!(boundary_binary_search(0.9, &boundaries), 3);
252    }
253
254    #[test]
255    fn boundary_binary_search_middle() {
256        let boundaries = vec![-0.5, 0.0, 0.5];
257        assert_eq!(boundary_binary_search(-0.1, &boundaries), 1);
258        assert_eq!(boundary_binary_search(0.1, &boundaries), 2);
259    }
260
261    // -- interval_midpoint --------------------------------------------------
262
263    #[test]
264    fn interval_midpoint_basic() {
265        assert_relative_eq!(interval_midpoint(0.0, 1.0), 0.5, epsilon = 1e-15);
266        assert_relative_eq!(interval_midpoint(-1.0, 1.0), 0.0, epsilon = 1e-15);
267    }
268
269    // -- is_near_zero -------------------------------------------------------
270
271    #[test]
272    fn is_near_zero_true_for_tiny() {
273        assert!(is_near_zero(1e-31));
274        assert!(is_near_zero(-1e-31));
275        assert!(is_near_zero(0.0));
276    }
277
278    #[test]
279    fn is_near_zero_false_for_normal() {
280        assert!(!is_near_zero(1e-10));
281        assert!(!is_near_zero(-0.001));
282    }
283
284    // -- lookup_static_codebook / lookup_static_codebook_ref ----------------
285
286    #[test]
287    fn lookup_known_configs_return_some() {
288        for &(bits, dim) in &KNOWN_CODEBOOK_CONFIGS {
289            assert!(
290                lookup_static_codebook_ref(bits, dim).is_some(),
291                "expected Some for ({bits}, {dim})"
292            );
293            assert!(lookup_static_codebook(bits, dim).is_some());
294        }
295    }
296
297    #[test]
298    fn lookup_unknown_config_returns_none() {
299        assert!(lookup_static_codebook_ref(3, 512).is_none());
300        assert!(lookup_static_codebook(4, 16).is_none());
301    }
302
303    // -- centroid_count -----------------------------------------------------
304
305    #[test]
306    fn centroid_count_3bit() {
307        assert_eq!(centroid_count(TEST_BITS_3), TEST_CENTROIDS_8);
308    }
309
310    #[test]
311    fn centroid_count_4bit() {
312        assert_eq!(centroid_count(TEST_BITS_4), TEST_CENTROIDS_16);
313    }
314
315    // -- to_codebook --------------------------------------------------------
316
317    #[test]
318    fn to_codebook_copies_correctly() {
319        let sc = &CODEBOOK_3BIT_D64;
320        let cb = sc.to_codebook();
321        assert_eq!(cb.centroids.len(), sc.centroids.len());
322        assert_eq!(cb.boundaries.len(), sc.boundaries.len());
323        for (a, b) in cb.centroids.iter().zip(sc.centroids.iter()) {
324            assert_relative_eq!(a, b, epsilon = 1e-15);
325        }
326        for (a, b) in cb.boundaries.iter().zip(sc.boundaries.iter()) {
327            assert_relative_eq!(a, b, epsilon = 1e-15);
328        }
329    }
330
331    // -- validate_bits / is_valid_bits (from packed) -------------------------
332
333    #[test]
334    fn validate_bits_accepts_2_3_and_4() {
335        assert!(validate_bits(TEST_BITS_2).is_ok());
336        assert!(validate_bits(TEST_BITS_3).is_ok());
337        assert!(validate_bits(TEST_BITS_4).is_ok());
338    }
339
340    #[test]
341    fn validate_bits_rejects_others() {
342        assert!(validate_bits(0).is_err());
343        assert!(validate_bits(1).is_err());
344        assert!(validate_bits(5).is_err());
345    }
346
347    // -- centroid_count 2-bit -----------------------------------------------
348
349    #[test]
350    fn centroid_count_2bit() {
351        assert_eq!(centroid_count(TEST_BITS_2), TEST_CENTROIDS_4);
352    }
353
354    // -- beta_pdf_log_normalization -----------------------------------------
355
356    #[test]
357    fn beta_pdf_log_normalization_positive_for_high_d() {
358        // For d=128, the normalization constant should be > 1 (concentrated PDF),
359        // so the log should be positive.
360        let ln_norm = beta_pdf_log_normalization(TEST_DIM as f64);
361        assert!(ln_norm > 0.0);
362    }
363}