Skip to main content

mnemonist_quant/
codebook.rs

1//! Precomputed Lloyd-Max optimal scalar quantizer codebooks.
2//!
3//! For high-dimensional vectors (d ≥ 64), after random rotation each coordinate
4//! follows a Beta(d/2, d/2) distribution that closely approximates N(0, 1/d).
5//! The codebooks partition [-1, 1] into 2^b buckets minimizing MSE.
6//!
7//! These are computed offline by solving the continuous 1-D k-means problem
8//! (Eq. 4 in the TurboQuant paper) using the Max-Lloyd algorithm.
9
10use crate::QuantError;
11
12/// A scalar quantizer codebook for a given bit-width.
13#[derive(Debug, Clone)]
14pub struct Codebook {
15    /// Bit-width (1-4).
16    pub bits: u8,
17    /// Reconstruction centroids in ascending order (2^b values).
18    pub centroids: &'static [f32],
19    /// Decision boundaries (2^b - 1 midpoints between consecutive centroids).
20    pub boundaries: &'static [f32],
21}
22
23impl Codebook {
24    /// Get the codebook for a given bit-width.
25    pub fn for_bits(bits: u8) -> Result<&'static Codebook, QuantError> {
26        match bits {
27            1 => Ok(&CODEBOOK_1BIT),
28            2 => Ok(&CODEBOOK_2BIT),
29            3 => Ok(&CODEBOOK_3BIT),
30            4 => Ok(&CODEBOOK_4BIT),
31            _ => Err(QuantError::UnsupportedBitWidth(bits)),
32        }
33    }
34
35    /// Find the index of the nearest centroid for a scalar value.
36    #[inline]
37    pub fn quantize_scalar(&self, x: f32) -> u8 {
38        // Binary search on boundaries (they are sorted ascending).
39        let mut idx = 0u8;
40        for &b in self.boundaries {
41            if x > b {
42                idx += 1;
43            } else {
44                break;
45            }
46        }
47        idx
48    }
49
50    /// Look up the centroid for a given index.
51    #[inline]
52    pub fn dequantize_scalar(&self, idx: u8) -> f32 {
53        self.centroids[idx as usize]
54    }
55}
56
57// ─── Precomputed codebooks ──────────────────────────────────────────────────
58//
59// These are optimal Lloyd-Max centroids for the Beta distribution that arises
60// from randomly rotating unit-sphere vectors. In high dimensions, Beta(d/2, d/2)
61// on [-1, 1] converges to N(0, 1/d). For moderate d, the paper uses numerical
62// optimization. The values below are for the Gaussian approximation, scaled by
63// 1/sqrt(d) at runtime. Since the rotation normalizes vectors to the unit sphere,
64// coordinates land in roughly [-3/sqrt(d), 3/sqrt(d)], and the codebook operates
65// on the pre-scaled domain [-1, 1].
66//
67// For b=1: optimal centroids are ±√(2/π) ≈ ±0.7979 (Gaussian quantizer)
68// For b=2: ±0.4528, ±1.5104 (standard 2-bit Gaussian Lloyd-Max)
69// For b=3,4: standard Lloyd-Max for N(0,1) scaled to [-1,1]
70//
71// We use the paper's high-d Gaussian approximation centroids, which give
72// near-optimal MSE for d ≥ 64.
73
74static CENTROIDS_1BIT: [f32; 2] = [-0.7979, 0.7979];
75static BOUNDARIES_1BIT: [f32; 1] = [0.0];
76
77static CENTROIDS_2BIT: [f32; 4] = [-1.5104, -0.4528, 0.4528, 1.5104];
78static BOUNDARIES_2BIT: [f32; 3] = [-0.9816, 0.0, 0.9816];
79
80static CENTROIDS_3BIT: [f32; 8] = [
81    -2.1520, -1.3440, -0.7560, -0.2450, 0.2450, 0.7560, 1.3440, 2.1520,
82];
83static BOUNDARIES_3BIT: [f32; 7] = [-1.7480, -1.0500, -0.5005, 0.0, 0.5005, 1.0500, 1.7480];
84
85static CENTROIDS_4BIT: [f32; 16] = [
86    -2.7326, -2.0690, -1.6180, -1.2562, -0.9424, -0.6568, -0.3880, -0.1284, 0.1284, 0.3880, 0.6568,
87    0.9424, 1.2562, 1.6180, 2.0690, 2.7326,
88];
89static BOUNDARIES_4BIT: [f32; 15] = [
90    -2.4008, -1.8435, -1.4371, -1.0993, -0.7996, -0.5224, -0.2582, 0.0, 0.2582, 0.5224, 0.7996,
91    1.0993, 1.4371, 1.8435, 2.4008,
92];
93
94static CODEBOOK_1BIT: Codebook = Codebook {
95    bits: 1,
96    centroids: &CENTROIDS_1BIT,
97    boundaries: &BOUNDARIES_1BIT,
98};
99
100static CODEBOOK_2BIT: Codebook = Codebook {
101    bits: 2,
102    centroids: &CENTROIDS_2BIT,
103    boundaries: &BOUNDARIES_2BIT,
104};
105
106static CODEBOOK_3BIT: Codebook = Codebook {
107    bits: 3,
108    centroids: &CENTROIDS_3BIT,
109    boundaries: &BOUNDARIES_3BIT,
110};
111
112static CODEBOOK_4BIT: Codebook = Codebook {
113    bits: 4,
114    centroids: &CENTROIDS_4BIT,
115    boundaries: &BOUNDARIES_4BIT,
116};
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn codebook_lookup_valid() {
124        for bits in 1..=4 {
125            let cb = Codebook::for_bits(bits).unwrap();
126            assert_eq!(cb.centroids.len(), 1 << bits);
127            assert_eq!(cb.boundaries.len(), (1 << bits) - 1);
128        }
129    }
130
131    #[test]
132    fn codebook_invalid_bits() {
133        assert!(Codebook::for_bits(0).is_err());
134        assert!(Codebook::for_bits(5).is_err());
135    }
136
137    #[test]
138    fn quantize_scalar_1bit() {
139        let cb = Codebook::for_bits(1).unwrap();
140        assert_eq!(cb.quantize_scalar(-0.5), 0);
141        assert_eq!(cb.quantize_scalar(0.5), 1);
142        assert_eq!(cb.quantize_scalar(0.0), 0); // boundary → lower bucket
143    }
144
145    #[test]
146    fn quantize_scalar_2bit() {
147        let cb = Codebook::for_bits(2).unwrap();
148        assert_eq!(cb.quantize_scalar(-2.0), 0);
149        assert_eq!(cb.quantize_scalar(-0.5), 1);
150        assert_eq!(cb.quantize_scalar(0.5), 2);
151        assert_eq!(cb.quantize_scalar(2.0), 3);
152    }
153
154    #[test]
155    fn centroids_are_sorted() {
156        for bits in 1..=4 {
157            let cb = Codebook::for_bits(bits).unwrap();
158            for w in cb.centroids.windows(2) {
159                assert!(w[0] < w[1], "centroids not sorted for {bits}-bit codebook");
160            }
161            for w in cb.boundaries.windows(2) {
162                assert!(w[0] < w[1], "boundaries not sorted for {bits}-bit codebook");
163            }
164        }
165    }
166
167    #[test]
168    fn boundaries_between_centroids() {
169        for bits in 1..=4 {
170            let cb = Codebook::for_bits(bits).unwrap();
171            for (i, &b) in cb.boundaries.iter().enumerate() {
172                assert!(
173                    b > cb.centroids[i] && b < cb.centroids[i + 1],
174                    "boundary {b} not between centroids {} and {} for {bits}-bit",
175                    cb.centroids[i],
176                    cb.centroids[i + 1]
177                );
178            }
179        }
180    }
181
182    #[test]
183    fn dequantize_roundtrip() {
184        let cb = Codebook::for_bits(2).unwrap();
185        for i in 0..4u8 {
186            let val = cb.dequantize_scalar(i);
187            let idx = cb.quantize_scalar(val);
188            assert_eq!(idx, i, "dequantize({i}) = {val}, quantize({val}) = {idx}");
189        }
190    }
191}