turboquant/codebook/gen.rs
1//! Lloyd-Max codebook generation for the Beta distribution.
2//!
3//! This module contains the Lloyd-Max iterative algorithm and all generation
4//! helpers. It is separated from [`crate::codebook`] (which owns the
5//! [`Codebook`] struct, static lookup tables, and the Beta PDF) to respect
6//! the Single Responsibility Principle.
7
8use super::{centroid_count, Codebook, SUPPORT_MAX, SUPPORT_MIN};
9use crate::math::{converge, ln_gamma, simpsons_integrate, HALF};
10
11// ---------------------------------------------------------------------------
12// Constants — generation-specific
13// ---------------------------------------------------------------------------
14
15/// Maximum number of Lloyd-Max iterations before we declare convergence.
16const MAX_ITERATIONS: usize = 200;
17
18/// Convergence threshold: stop when the relative change in distortion drops
19/// below this value.
20const CONVERGENCE_EPS: f64 = 1e-12;
21
22/// Number of sub-intervals used for Simpson's rule integration.
23const INTEGRATION_STEPS: usize = 1024;
24
25/// Small epsilon to guard against division by near-zero values.
26const EPSILON_ZERO: f64 = 1e-30;
27
28/// Minimum dimension for which the Beta-type PDF is well-defined.
29/// For d < 3 the exponent (d-3)/2 is negative and the distribution degenerates.
30const MIN_DIMENSION_FOR_PDF: usize = 3;
31
32/// The exponent offset in the Beta-type kernel: (d - 3) / 2.
33const KERNEL_EXPONENT_OFFSET: f64 = 3.0;
34
35// ---------------------------------------------------------------------------
36// Pure Operation: Beta PDF
37// ---------------------------------------------------------------------------
38
39/// Evaluate the Beta-type PDF of a rotated unit-vector coordinate.
40///
41/// ```text
42/// f_X(x) = Gamma(d/2) / (sqrt(pi) * Gamma((d-1)/2)) * (1 - x^2)^((d-3)/2)
43/// ```
44///
45/// Pure Operation: all arithmetic (kernel + normalization) is computed
46/// inline without calls to other project functions.
47pub fn beta_pdf(x: f64, d: usize) -> f64 {
48 // Guard: dimension too low.
49 if d < MIN_DIMENSION_FOR_PDF {
50 return 0.0;
51 }
52 let df = d as f64;
53 let exponent = (df - KERNEL_EXPONENT_OFFSET) * HALF;
54 let one_minus_x2 = 1.0 - x * x;
55 if one_minus_x2 <= 0.0 {
56 return 0.0;
57 }
58 let kernel = one_minus_x2.powf(exponent);
59
60 // Normalization: ln(Gamma(d/2)) - 0.5*ln(pi) - ln(Gamma((d-1)/2))
61 let half_df = df * HALF;
62 let half_df_minus_one = (df - 1.0) * HALF;
63 let half_ln_pi = HALF * core::f64::consts::PI.ln();
64 let log_norm = ln_gamma(half_df) - half_ln_pi - ln_gamma(half_df_minus_one);
65
66 log_norm.exp() * kernel
67}
68
69// ---------------------------------------------------------------------------
70// Pure Operation: initialization
71// ---------------------------------------------------------------------------
72
73/// Place `k` centroids uniformly on `(SUPPORT_MIN, SUPPORT_MAX)` (excluding endpoints).
74fn initialize_centroids(k: usize) -> Vec<f64> {
75 let range = SUPPORT_MAX - SUPPORT_MIN; // 2.0
76 (0..k)
77 .map(|i| SUPPORT_MIN + (range * (i as f64 + HALF)) / k as f64)
78 .collect()
79}
80
81/// Compute midpoint boundaries between adjacent centroids.
82fn midpoint_boundaries(centroids: &[f64]) -> Vec<f64> {
83 centroids.windows(2).map(|w| (w[0] + w[1]) * HALF).collect()
84}
85
86// ---------------------------------------------------------------------------
87// Pure Operation: bin geometry
88// ---------------------------------------------------------------------------
89
90/// Determine the lower bound of the i-th bin given boundaries.
91fn bin_lower_bound(i: usize, boundaries: &[f64]) -> f64 {
92 if i == 0 {
93 SUPPORT_MIN
94 } else {
95 boundaries[i - 1]
96 }
97}
98
99/// Determine the upper bound of the i-th bin given boundaries and total
100/// number of centroids `k`.
101fn bin_upper_bound(i: usize, k: usize, boundaries: &[f64]) -> f64 {
102 if i == k - 1 {
103 SUPPORT_MAX
104 } else {
105 boundaries[i]
106 }
107}
108
109// ---------------------------------------------------------------------------
110// Pure Operation: convergence check & conditional selection
111// ---------------------------------------------------------------------------
112
113/// Check whether the Lloyd-Max iteration has converged by comparing the
114/// relative change in distortion against [`CONVERGENCE_EPS`].
115fn has_converged(prev_distortion: f64, distortion: f64) -> bool {
116 (prev_distortion - distortion).abs() < CONVERGENCE_EPS * prev_distortion.abs().max(EPSILON_ZERO)
117}
118
119/// Select the conditional expectation or the interval midpoint depending
120/// on whether the denominator is near zero.
121///
122/// Pure Operation: only arithmetic and comparison, no calls.
123fn select_conditional_or_midpoint(numerator: f64, denominator: f64, a: f64, b: f64) -> f64 {
124 if denominator.abs() < EPSILON_ZERO {
125 (a + b) * HALF
126 } else {
127 numerator / denominator
128 }
129}
130
131// ---------------------------------------------------------------------------
132// Pure Integration: numerical integration wrappers
133// ---------------------------------------------------------------------------
134
135/// Simpson's rule numerical integration of `f` over `[a, b]`, using the
136/// module-level [`INTEGRATION_STEPS`] constant.
137///
138/// Pure Integration: delegates to [`crate::math::simpsons_integrate`].
139fn integrate<F: Fn(f64) -> f64>(f: F, a: f64, b: f64) -> f64 {
140 simpsons_integrate(f, a, b, INTEGRATION_STEPS)
141}
142
143/// Compute `integral_a^b f(x) dx` where `f(x) = beta_pdf(x, d)`.
144///
145/// Pure Integration: delegates to `integrate` and `beta_pdf`.
146fn integrate_pdf(a: f64, b: f64, d: usize) -> f64 {
147 integrate(|x| beta_pdf(x, d), a, b)
148}
149
150/// Compute `integral_a^b x * f(x) dx` where `f(x) = beta_pdf(x, d)`.
151///
152/// Pure Integration: delegates to `integrate` and `beta_pdf`.
153fn integrate_x_pdf(a: f64, b: f64, d: usize) -> f64 {
154 integrate(|x| x * beta_pdf(x, d), a, b)
155}
156
157/// Conditional expectation `E[X | X in [a, b]]` under the Beta-type PDF.
158///
159/// Pure Integration: delegates computation to `integrate_pdf`,
160/// `integrate_x_pdf`, and `select_conditional_or_midpoint`.
161fn conditional_expectation(a: f64, b: f64, d: usize) -> f64 {
162 let denom = integrate_pdf(a, b, d);
163 let numer = integrate_x_pdf(a, b, d);
164 select_conditional_or_midpoint(numer, denom, a, b)
165}
166
167// ---------------------------------------------------------------------------
168// Pure Integration: distortion computation
169// ---------------------------------------------------------------------------
170
171/// Compute the MSE-distortion contribution of a single bin `[lo, hi]` with
172/// centroid `c` under the Beta PDF for dimension `d`.
173///
174/// Pure Integration: delegates to `integrate` and `beta_pdf`.
175fn bin_distortion(lo: f64, hi: f64, c: f64, d: usize) -> f64 {
176 integrate(|x| (x - c).powi(2) * beta_pdf(x, d), lo, hi)
177}
178
179/// Compute the MSE distortion of the current codebook under the Beta PDF.
180///
181/// Pure Integration: delegates bin bounds to `bin_lower_bound`/`bin_upper_bound`
182/// and per-bin distortion to `bin_distortion`. Uses an iterator chain instead
183/// of explicit loop logic.
184fn compute_distortion(centroids: &[f64], boundaries: &[f64], d: usize) -> f64 {
185 let k = centroids.len();
186 centroids
187 .iter()
188 .enumerate()
189 .map(|(i, ¢roid)| {
190 let lo = bin_lower_bound(i, boundaries);
191 let hi = bin_upper_bound(i, k, boundaries);
192 bin_distortion(lo, hi, centroid, d)
193 })
194 .sum()
195}
196
197// ---------------------------------------------------------------------------
198// Pure Integration: centroid update
199// ---------------------------------------------------------------------------
200
201/// Compute updated centroids for one Lloyd-Max iteration.
202///
203/// Pure Integration: delegates bin bounds to `bin_lower_bound`/`bin_upper_bound`
204/// and centroid updates to `conditional_expectation`. Uses an iterator chain
205/// instead of explicit loop logic.
206fn update_centroids(centroids_len: usize, boundaries: &[f64], d: usize) -> Vec<f64> {
207 (0..centroids_len)
208 .map(|i| {
209 let lo = bin_lower_bound(i, boundaries);
210 let hi = bin_upper_bound(i, centroids_len, boundaries);
211 conditional_expectation(lo, hi, d)
212 })
213 .collect()
214}
215
216// ---------------------------------------------------------------------------
217// Lloyd-Max core — Pure Integration (orchestrates operation helpers)
218// ---------------------------------------------------------------------------
219
220/// Perform one Lloyd-Max iteration step: compute boundaries, update centroids,
221/// measure distortion, and check convergence.
222///
223/// Pure Integration: delegates to `midpoint_boundaries`, `update_centroids`,
224/// `compute_distortion`, and `has_converged`. Returns the new centroids,
225/// the new distortion, and a convergence flag.
226fn lloyd_max_step(centroids: &[f64], prev_distortion: f64, d: usize) -> (Vec<f64>, f64, bool) {
227 let boundaries = midpoint_boundaries(centroids);
228 let new_centroids = update_centroids(centroids.len(), &boundaries, d);
229 let distortion = compute_distortion(&new_centroids, &boundaries, d);
230 let converged = has_converged(prev_distortion, distortion);
231 (new_centroids, distortion, converged)
232}
233
234/// Run Lloyd-Max iterations starting from the given initial `centroids` for
235/// dimension `d`. Returns the converged [`Codebook`].
236///
237/// Pure Integration: delegates each iteration to `lloyd_max_step` and
238/// final boundary computation to `midpoint_boundaries`.
239fn lloyd_max_iterate(mut centroids: Vec<f64>, d: usize) -> Codebook {
240 let mut prev_distortion = f64::MAX;
241
242 converge(MAX_ITERATIONS, || {
243 let (new_centroids, distortion, converged) = lloyd_max_step(¢roids, prev_distortion, d);
244 centroids = new_centroids;
245 prev_distortion = distortion;
246 converged
247 });
248
249 let boundaries = midpoint_boundaries(¢roids);
250 Codebook {
251 centroids,
252 boundaries,
253 }
254}
255
256// ---------------------------------------------------------------------------
257// Public API
258// ---------------------------------------------------------------------------
259
260/// Run the Lloyd-Max algorithm from scratch for arbitrary `(bits, dim)`.
261///
262/// Pure Integration: delegates centroid count to `centroid_count`,
263/// initialization to `initialize_centroids`, and iteration to
264/// `lloyd_max_iterate`.
265pub fn generate_codebook(bits: u8, dim: usize) -> Codebook {
266 let k = centroid_count(bits);
267 let centroids = initialize_centroids(k);
268 lloyd_max_iterate(centroids, dim)
269}
270
271// ---------------------------------------------------------------------------
272// Unit tests for generation helpers
273// ---------------------------------------------------------------------------
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use approx::assert_relative_eq;
279
280 // -- Named constants for test parameters --------------------------------
281
282 /// Dimension used in integration / beta-PDF tests.
283 const TEST_DIM: usize = 128;
284 /// Number of centroids when using 3-bit quantization (2^3).
285 const TEST_CENTROIDS_8: usize = 8;
286 /// Number of centroids when using 4-bit quantization (2^4).
287 const TEST_CENTROIDS_16: usize = 16;
288 /// Bit width for 3-bit quantization in tests.
289 const TEST_BITS_3: u8 = 3;
290 /// Dimension 64 used in generate_codebook tests.
291 const TEST_DIM_64: usize = 64;
292 /// Numerator used in select_conditional_or_midpoint tests.
293 const TEST_NUMERATOR: f64 = 3.0;
294 /// Normal-case denominator for select_conditional_or_midpoint test.
295 const TEST_DENOMINATOR: f64 = 2.0;
296 /// Near-zero denominator for select_conditional_or_midpoint fallback test.
297 const TEST_NEAR_ZERO_DENOM: f64 = 1e-31;
298
299 // -- initialize_centroids -----------------------------------------------
300
301 #[test]
302 fn initialize_centroids_correct_count() {
303 assert_eq!(
304 initialize_centroids(TEST_CENTROIDS_8).len(),
305 TEST_CENTROIDS_8
306 );
307 assert_eq!(
308 initialize_centroids(TEST_CENTROIDS_16).len(),
309 TEST_CENTROIDS_16
310 );
311 }
312
313 #[test]
314 fn initialize_centroids_sorted() {
315 let c = initialize_centroids(TEST_CENTROIDS_8);
316 for w in c.windows(2) {
317 assert!(w[0] < w[1]);
318 }
319 }
320
321 #[test]
322 fn initialize_centroids_symmetric() {
323 let c = initialize_centroids(TEST_CENTROIDS_8);
324 let half = TEST_CENTROIDS_8 / 2;
325 for i in 0..half {
326 assert_relative_eq!(c[i], -c[TEST_CENTROIDS_8 - 1 - i], epsilon = 1e-14);
327 }
328 }
329
330 #[test]
331 fn initialize_centroids_within_support() {
332 let c = initialize_centroids(TEST_CENTROIDS_16);
333 for &v in &c {
334 assert!(v > SUPPORT_MIN && v < SUPPORT_MAX);
335 }
336 }
337
338 // -- midpoint_boundaries ------------------------------------------------
339
340 #[test]
341 fn midpoint_boundaries_correct_values() {
342 let centroids = vec![-0.5, 0.0, 0.5];
343 let b = midpoint_boundaries(¢roids);
344 assert_eq!(b.len(), 2);
345 assert_relative_eq!(b[0], -0.25, epsilon = 1e-14);
346 assert_relative_eq!(b[1], 0.25, epsilon = 1e-14);
347 }
348
349 // -- bin_lower_bound / bin_upper_bound -----------------------------------
350
351 #[test]
352 fn bin_lower_bound_first() {
353 let boundaries = vec![0.0];
354 assert_relative_eq!(
355 bin_lower_bound(0, &boundaries),
356 SUPPORT_MIN,
357 epsilon = 1e-15
358 );
359 }
360
361 #[test]
362 fn bin_lower_bound_second() {
363 let boundaries = vec![0.0];
364 assert_relative_eq!(bin_lower_bound(1, &boundaries), 0.0, epsilon = 1e-15);
365 }
366
367 #[test]
368 fn bin_upper_bound_last() {
369 let boundaries = vec![0.0];
370 assert_relative_eq!(
371 bin_upper_bound(1, 2, &boundaries),
372 SUPPORT_MAX,
373 epsilon = 1e-15
374 );
375 }
376
377 #[test]
378 fn bin_upper_bound_first() {
379 let boundaries = vec![0.0];
380 assert_relative_eq!(bin_upper_bound(0, 2, &boundaries), 0.0, epsilon = 1e-15);
381 }
382
383 // -- has_converged ------------------------------------------------------
384
385 #[test]
386 fn has_converged_identical_values() {
387 assert!(has_converged(1.0, 1.0));
388 }
389
390 #[test]
391 fn has_converged_large_change() {
392 assert!(!has_converged(1.0, 0.5));
393 }
394
395 // -- select_conditional_or_midpoint -------------------------------------
396
397 #[test]
398 fn select_conditional_or_midpoint_normal_case() {
399 let result = select_conditional_or_midpoint(TEST_NUMERATOR, TEST_DENOMINATOR, 0.0, 1.0);
400 assert_relative_eq!(result, TEST_NUMERATOR / TEST_DENOMINATOR, epsilon = 1e-15);
401 }
402
403 #[test]
404 fn select_conditional_or_midpoint_near_zero_denom() {
405 let result = select_conditional_or_midpoint(TEST_NUMERATOR, TEST_NEAR_ZERO_DENOM, 0.0, 1.0);
406 assert_relative_eq!(result, 0.5, epsilon = 1e-15);
407 }
408
409 // -- conditional_expectation --------------------------------------------
410
411 #[test]
412 fn conditional_expectation_symmetric_interval() {
413 // E[X | X in [-1, 1]] should be 0 by symmetry.
414 let result = conditional_expectation(SUPPORT_MIN, SUPPORT_MAX, TEST_DIM);
415 assert_relative_eq!(result, 0.0, epsilon = 1e-8);
416 }
417
418 // -- compute_distortion -------------------------------------------------
419
420 #[test]
421 fn compute_distortion_nonnegative() {
422 let centroids = vec![-0.5, 0.0, 0.5];
423 let boundaries = vec![-0.25, 0.25];
424 let d = compute_distortion(¢roids, &boundaries, TEST_DIM);
425 assert!(d >= 0.0);
426 }
427
428 // -- update_centroids ---------------------------------------------------
429
430 #[test]
431 fn update_centroids_correct_count() {
432 let boundaries = midpoint_boundaries(&initialize_centroids(TEST_CENTROIDS_8));
433 let updated = update_centroids(TEST_CENTROIDS_8, &boundaries, TEST_DIM);
434 assert_eq!(updated.len(), TEST_CENTROIDS_8);
435 }
436
437 #[test]
438 fn update_centroids_within_support() {
439 let boundaries = midpoint_boundaries(&initialize_centroids(TEST_CENTROIDS_8));
440 let updated = update_centroids(TEST_CENTROIDS_8, &boundaries, TEST_DIM);
441 for &c in &updated {
442 assert!((SUPPORT_MIN..=SUPPORT_MAX).contains(&c));
443 }
444 }
445
446 // -- generate_codebook --------------------------------------------------
447
448 #[test]
449 fn generate_codebook_valid_structure() {
450 let cb = generate_codebook(TEST_BITS_3, TEST_DIM_64);
451 assert_eq!(cb.centroids.len(), TEST_CENTROIDS_8);
452 assert_eq!(cb.boundaries.len(), TEST_CENTROIDS_8 - 1);
453 for w in cb.centroids.windows(2) {
454 assert!(w[0] < w[1]);
455 }
456 }
457
458 // -- lloyd_max_step -----------------------------------------------------
459
460 #[test]
461 fn lloyd_max_step_reduces_distortion() {
462 let centroids = initialize_centroids(TEST_CENTROIDS_8);
463 let boundaries = midpoint_boundaries(¢roids);
464 let initial_dist = compute_distortion(¢roids, &boundaries, TEST_DIM);
465 let (new_centroids, new_dist, _) = lloyd_max_step(¢roids, f64::MAX, TEST_DIM);
466 // The new distortion should be <= initial (Lloyd-Max is monotonically improving).
467 assert!(new_dist <= initial_dist + 1e-15);
468 assert_eq!(new_centroids.len(), TEST_CENTROIDS_8);
469 }
470}