Skip to main content

mnemonist_quant/
mse.rs

1//! TurboQuant_mse: MSE-optimal vector quantizer (Algorithm 1).
2//!
3//! Quantizes d-dimensional vectors to b bits per coordinate by:
4//! 1. Normalizing to unit norm (storing the original norm separately)
5//! 2. Applying a random orthogonal rotation Π
6//! 3. Scalar-quantizing each coordinate with a precomputed Lloyd-Max codebook
7//!
8//! Achieves MSE distortion D_mse ≤ (√3π/2) · 1/4^b.
9
10use crate::QuantError;
11use crate::codebook::Codebook;
12use crate::pack;
13use crate::rotation::Rotation;
14
15/// MSE-optimal TurboQuant quantizer.
16pub struct TurboQuantMse {
17    rotation: Rotation,
18    codebook: &'static Codebook,
19    bits: u8,
20    /// Scaling factor: codebook centroids are for N(0,1); after rotation,
21    /// unit-sphere coordinates have variance ≈ 1/d. We scale coordinates
22    /// by sqrt(d) before quantization so the codebook applies directly.
23    scale: f32,
24}
25
26/// A quantized vector produced by TurboQuant_mse.
27#[derive(Debug, Clone)]
28pub struct QuantizedVector {
29    /// Packed b-bit indices, one per coordinate.
30    pub packed_indices: Vec<u8>,
31    /// Original vector norm (for rescaling on dequantization).
32    pub norm: f32,
33    /// Bit-width used.
34    pub bits: u8,
35    /// Number of coordinates (dimension).
36    pub dimension: usize,
37}
38
39impl TurboQuantMse {
40    /// Create a new MSE-optimal quantizer.
41    ///
42    /// - `dimension`: vector dimensionality
43    /// - `bits`: quantization bit-width (1-4)
44    /// - `seed`: RNG seed for the rotation matrix (must match for quant/dequant)
45    pub fn new(dimension: usize, bits: u8, seed: u64) -> Result<Self, QuantError> {
46        let codebook = Codebook::for_bits(bits)?;
47        let rotation = Rotation::new(dimension, seed);
48        let scale = (dimension as f32).sqrt();
49
50        Ok(Self {
51            rotation,
52            codebook,
53            bits,
54            scale,
55        })
56    }
57
58    /// The dimension this quantizer operates on.
59    pub fn dimension(&self) -> usize {
60        self.rotation.dimension()
61    }
62
63    /// The bit-width per coordinate.
64    pub fn bits(&self) -> u8 {
65        self.bits
66    }
67
68    /// The rotation seed.
69    pub fn seed(&self) -> u64 {
70        self.rotation.seed()
71    }
72
73    /// Quantize a vector.
74    pub fn quantize(&self, x: &[f32]) -> Result<QuantizedVector, QuantError> {
75        let dim = self.rotation.dimension();
76        if x.len() != dim {
77            return Err(QuantError::DimensionMismatch {
78                expected: dim,
79                got: x.len(),
80            });
81        }
82
83        // Compute norm and normalize
84        let norm: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
85        let mut y = if norm > 0.0 {
86            x.iter().map(|v| v / norm).collect::<Vec<_>>()
87        } else {
88            vec![0.0; dim]
89        };
90
91        // Apply rotation: y = Π · x_normalized
92        self.rotation.forward(&mut y);
93
94        // Scale by sqrt(d) to match codebook domain
95        for val in &mut y {
96            *val *= self.scale;
97        }
98
99        // Scalar-quantize each coordinate
100        let indices: Vec<u8> = y
101            .iter()
102            .map(|&v| self.codebook.quantize_scalar(v))
103            .collect();
104
105        let packed_indices = pack::pack_indices(&indices, self.bits)?;
106
107        Ok(QuantizedVector {
108            packed_indices,
109            norm,
110            bits: self.bits,
111            dimension: dim,
112        })
113    }
114
115    /// Dequantize a vector back to approximate floats.
116    pub fn dequantize(&self, q: &QuantizedVector) -> Result<Vec<f32>, QuantError> {
117        let dim = q.dimension;
118        let indices = pack::unpack_indices(&q.packed_indices, q.bits, dim)?;
119
120        // Look up centroids
121        let mut y: Vec<f32> = indices
122            .iter()
123            .map(|&idx| self.codebook.dequantize_scalar(idx))
124            .collect();
125
126        // Unscale
127        let inv_scale = 1.0 / self.scale;
128        for val in &mut y {
129            *val *= inv_scale;
130        }
131
132        // Apply inverse rotation: x = Π^T · y
133        self.rotation.inverse(&mut y);
134
135        // Rescale by original norm
136        for val in &mut y {
137            *val *= q.norm;
138        }
139
140        Ok(y)
141    }
142
143    /// Dequantize into a pre-allocated buffer (avoids allocation).
144    pub fn dequantize_into(&self, q: &QuantizedVector, out: &mut [f32]) -> Result<(), QuantError> {
145        let result = self.dequantize(q)?;
146        out.copy_from_slice(&result);
147        Ok(())
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    fn random_unit_vector(dim: usize, seed: u64) -> Vec<f32> {
156        use rand::SeedableRng;
157        use rand::rngs::StdRng;
158        use rand_distr::{Distribution, StandardNormal};
159
160        let mut rng = StdRng::seed_from_u64(seed);
161        let normal = StandardNormal;
162        let mut v: Vec<f32> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
163        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
164        for x in &mut v {
165            *x /= norm;
166        }
167        v
168    }
169
170    #[test]
171    fn quantize_dequantize_roundtrip() {
172        let dim = 128;
173        let quant = TurboQuantMse::new(dim, 2, 42).unwrap();
174
175        let x = random_unit_vector(dim, 7);
176        let q = quant.quantize(&x).unwrap();
177        let x_hat = quant.dequantize(&q).unwrap();
178
179        assert_eq!(x_hat.len(), dim);
180
181        // Compute MSE
182        let mse: f32 = x
183            .iter()
184            .zip(x_hat.iter())
185            .map(|(a, b)| (a - b) * (a - b))
186            .sum::<f32>();
187
188        // For 2-bit, theoretical bound is D_mse ≤ (√3π/2) · 1/4^2 ≈ 0.170
189        // But on unit vectors ||x||=1 so MSE = D_mse
190        assert!(mse < 0.5, "MSE too high: {mse} (expected < 0.5 for 2-bit)");
191    }
192
193    #[test]
194    fn mse_decreases_with_bits() {
195        let dim = 256;
196        let x = random_unit_vector(dim, 13);
197        let mut prev_mse = f32::MAX;
198
199        for bits in 1..=4 {
200            let quant = TurboQuantMse::new(dim, bits, 42).unwrap();
201            let q = quant.quantize(&x).unwrap();
202            let x_hat = quant.dequantize(&q).unwrap();
203
204            let mse: f32 = x
205                .iter()
206                .zip(x_hat.iter())
207                .map(|(a, b)| (a - b) * (a - b))
208                .sum::<f32>();
209
210            assert!(
211                mse < prev_mse,
212                "{bits}-bit MSE ({mse}) not less than {}-bit ({prev_mse})",
213                bits - 1
214            );
215            prev_mse = mse;
216        }
217    }
218
219    #[test]
220    fn preserves_norm() {
221        let dim = 64;
222        let quant = TurboQuantMse::new(dim, 3, 42).unwrap();
223
224        // Non-unit vector
225        let x: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) * 0.1).collect();
226        let norm_orig: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
227
228        let q = quant.quantize(&x).unwrap();
229        let x_hat = quant.dequantize(&q).unwrap();
230        let norm_hat: f32 = x_hat.iter().map(|v| v * v).sum::<f32>().sqrt();
231
232        // Norm should be approximately preserved (within quantization error)
233        assert!(
234            (norm_orig - norm_hat).abs() / norm_orig < 0.3,
235            "norm diverged: {norm_orig} → {norm_hat}"
236        );
237    }
238
239    #[test]
240    fn zero_vector() {
241        let dim = 32;
242        let quant = TurboQuantMse::new(dim, 2, 42).unwrap();
243
244        let x = vec![0.0f32; dim];
245        let q = quant.quantize(&x).unwrap();
246        assert_eq!(q.norm, 0.0);
247
248        let x_hat = quant.dequantize(&q).unwrap();
249        for v in &x_hat {
250            assert_eq!(*v, 0.0);
251        }
252    }
253
254    #[test]
255    fn dimension_mismatch() {
256        let quant = TurboQuantMse::new(32, 2, 42).unwrap();
257        let x = vec![1.0; 64];
258        assert!(quant.quantize(&x).is_err());
259    }
260
261    #[test]
262    fn average_mse_matches_theory() {
263        // Test over many random vectors to verify empirical MSE
264        // approaches theoretical bound D_mse ≈ 0.117 for b=2
265        let dim = 256;
266        let bits = 2;
267        let quant = TurboQuantMse::new(dim, bits, 42).unwrap();
268        let n_trials = 100;
269
270        let total_mse: f32 = (0..n_trials)
271            .map(|seed| {
272                let x = random_unit_vector(dim, seed + 1000);
273                let q = quant.quantize(&x).unwrap();
274                let x_hat = quant.dequantize(&q).unwrap();
275                x.iter()
276                    .zip(x_hat.iter())
277                    .map(|(a, b)| (a - b) * (a - b))
278                    .sum::<f32>()
279            })
280            .sum();
281
282        let avg_mse = total_mse / n_trials as f32;
283        // Paper: D_mse(b=2) ≈ 0.117 for unit vectors
284        // Allow generous margin for finite-d effects
285        assert!(
286            avg_mse < 0.35,
287            "average MSE = {avg_mse}, expected < 0.35 for 2-bit"
288        );
289    }
290}