Skip to main content

mnemonist_quant/
prod.rs

1//! TurboQuant_prod: unbiased inner-product quantizer (Algorithm 2).
2//!
3//! Combines TurboQuant_mse (at b-1 bits) with a 1-bit QJL transform on the
4//! residual to produce an unbiased inner product estimator at b total bits.
5//!
6//! Properties (Theorem 2):
7//! - Unbiased: E[⟨y, x̃⟩] = ⟨y, x⟩
8//! - Inner-product distortion: D_prod ≤ (√3π²·||y||²/d) · 1/4^b
9//!
10//! The total bit budget is b bits per coordinate:
11//! - b-1 bits for the MSE component
12//! - 1 bit for the QJL residual sign
13
14use crate::QuantError;
15use crate::mse::{QuantizedVector, TurboQuantMse};
16use crate::qjl::{QjlResult, QjlTransform};
17
18/// Unbiased inner-product TurboQuant quantizer.
19pub struct TurboQuantProd {
20    mse: TurboQuantMse,
21    qjl: QjlTransform,
22    /// Total bit-width (mse uses bits-1, qjl uses 1).
23    bits: u8,
24}
25
26/// A quantized vector produced by TurboQuant_prod.
27#[derive(Debug, Clone)]
28pub struct QuantizedProdVector {
29    /// The MSE-quantized component (b-1 bits per coordinate).
30    pub mse_part: QuantizedVector,
31    /// QJL sign bits of the residual (1 bit per coordinate).
32    pub qjl_part: QjlResult,
33    /// L2 norm of the residual vector (γ in the paper).
34    pub residual_norm: f32,
35}
36
37impl TurboQuantProd {
38    /// Create a new inner-product quantizer.
39    ///
40    /// - `dimension`: vector dimensionality
41    /// - `bits`: total bit-width per coordinate (must be ≥ 2)
42    /// - `mse_seed`: seed for the MSE rotation matrix
43    /// - `qjl_seed`: seed for the QJL projection (must differ from mse_seed)
44    pub fn new(
45        dimension: usize,
46        bits: u8,
47        mse_seed: u64,
48        qjl_seed: u64,
49    ) -> Result<Self, QuantError> {
50        if bits < 2 {
51            return Err(QuantError::UnsupportedBitWidth(bits));
52        }
53
54        let mse = TurboQuantMse::new(dimension, bits - 1, mse_seed)?;
55        let qjl = QjlTransform::new(dimension, qjl_seed);
56
57        Ok(Self { mse, qjl, bits })
58    }
59
60    /// The dimension this quantizer operates on.
61    pub fn dimension(&self) -> usize {
62        self.mse.dimension()
63    }
64
65    /// The total bit-width per coordinate.
66    pub fn bits(&self) -> u8 {
67        self.bits
68    }
69
70    /// Quantize a vector.
71    pub fn quantize(&self, x: &[f32]) -> Result<QuantizedProdVector, QuantError> {
72        // Step 1: MSE quantize at b-1 bits
73        let mse_part = self.mse.quantize(x)?;
74
75        // Step 2: Dequantize to get MSE approximation
76        let x_mse = self.mse.dequantize(&mse_part)?;
77
78        // Step 3: Compute residual r = x - x̃_mse
79        let residual: Vec<f32> = x.iter().zip(x_mse.iter()).map(|(a, b)| a - b).collect();
80
81        // Step 4: Compute residual norm γ = ||r||
82        let residual_norm: f32 = residual.iter().map(|v| v * v).sum::<f32>().sqrt();
83
84        // Step 5: Apply QJL to residual
85        let qjl_part = self.qjl.quantize(&residual);
86
87        Ok(QuantizedProdVector {
88            mse_part,
89            qjl_part,
90            residual_norm,
91        })
92    }
93
94    /// Dequantize a vector.
95    ///
96    /// Returns x̃ = x̃_mse + x̃_qjl where x̃_qjl = (√(π/2)/d) · γ · S^T · qjl
97    pub fn dequantize(&self, q: &QuantizedProdVector) -> Result<Vec<f32>, QuantError> {
98        // Dequantize MSE component
99        let x_mse = self.mse.dequantize(&q.mse_part)?;
100
101        // Dequantize QJL component
102        let x_qjl = self.qjl.dequantize(&q.qjl_part, q.residual_norm);
103
104        // Sum
105        let result: Vec<f32> = x_mse.iter().zip(x_qjl.iter()).map(|(a, b)| a + b).collect();
106
107        Ok(result)
108    }
109
110    /// Estimate ⟨query, quantized_x⟩ without full dequantization.
111    ///
112    /// Computes: ⟨query, x̃_mse⟩ + QJL_estimate(query, qjl_bits, γ)
113    pub fn inner_product_estimate(
114        &self,
115        query: &[f32],
116        q: &QuantizedProdVector,
117    ) -> Result<f32, QuantError> {
118        // MSE component inner product (requires dequantization)
119        let x_mse = self.mse.dequantize(&q.mse_part)?;
120        let ip_mse: f32 = query.iter().zip(x_mse.iter()).map(|(a, b)| a * b).sum();
121
122        // QJL component inner product (fast estimate)
123        let ip_qjl = self
124            .qjl
125            .inner_product_estimate(query, &q.qjl_part, q.residual_norm);
126
127        Ok(ip_mse + ip_qjl)
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    fn random_unit_vector(dim: usize, seed: u64) -> Vec<f32> {
136        use rand::SeedableRng;
137        use rand::rngs::StdRng;
138        use rand_distr::{Distribution, StandardNormal};
139
140        let mut rng = StdRng::seed_from_u64(seed);
141        let normal = StandardNormal;
142        let mut v: Vec<f32> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
143        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
144        for x in &mut v {
145            *x /= norm;
146        }
147        v
148    }
149
150    #[test]
151    fn prod_requires_minimum_2_bits() {
152        assert!(TurboQuantProd::new(32, 1, 1, 2).is_err());
153        assert!(TurboQuantProd::new(32, 2, 1, 2).is_ok());
154    }
155
156    #[test]
157    fn prod_quantize_dequantize() {
158        let dim = 128;
159        let quant = TurboQuantProd::new(dim, 3, 42, 99).unwrap();
160        let x = random_unit_vector(dim, 7);
161
162        let q = quant.quantize(&x).unwrap();
163        let x_hat = quant.dequantize(&q).unwrap();
164
165        assert_eq!(x_hat.len(), dim);
166
167        // MSE should be reasonable
168        let mse: f32 = x
169            .iter()
170            .zip(x_hat.iter())
171            .map(|(a, b)| (a - b) * (a - b))
172            .sum::<f32>();
173        assert!(mse < 1.0, "MSE too high: {mse}");
174    }
175
176    #[test]
177    fn unbiased_inner_product() {
178        // E[⟨y, x̃⟩] ≈ ⟨y, x⟩ over many QJL seeds
179        let dim = 128;
180        let x = random_unit_vector(dim, 1);
181        let y = random_unit_vector(dim, 2);
182        let true_ip: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
183
184        let n_trials = 100;
185        let mut total_estimated = 0.0f32;
186
187        for trial in 0..n_trials {
188            let quant = TurboQuantProd::new(dim, 3, 42, trial as u64 + 100).unwrap();
189            let q = quant.quantize(&x).unwrap();
190            let x_hat = quant.dequantize(&q).unwrap();
191            let estimated: f32 = y.iter().zip(x_hat.iter()).map(|(a, b)| a * b).sum();
192            total_estimated += estimated;
193        }
194
195        let avg_estimated = total_estimated / n_trials as f32;
196        let bias = (avg_estimated - true_ip).abs();
197        assert!(
198            bias < 0.1,
199            "prod bias too large: avg={avg_estimated}, true={true_ip}, bias={bias}"
200        );
201    }
202
203    #[test]
204    fn inner_product_estimate_close() {
205        let dim = 64;
206        let quant = TurboQuantProd::new(dim, 3, 42, 99).unwrap();
207        let x = random_unit_vector(dim, 1);
208        let query = random_unit_vector(dim, 2);
209
210        let q = quant.quantize(&x).unwrap();
211
212        // Full dequantize + dot
213        let x_hat = quant.dequantize(&q).unwrap();
214        let explicit: f32 = query.iter().zip(x_hat.iter()).map(|(a, b)| a * b).sum();
215
216        // Fast estimate
217        let estimated = quant.inner_product_estimate(&query, &q).unwrap();
218
219        assert!(
220            (explicit - estimated).abs() < 1e-3,
221            "explicit={explicit}, estimated={estimated}"
222        );
223    }
224
225    #[test]
226    fn residual_norm_positive() {
227        let dim = 64;
228        let quant = TurboQuantProd::new(dim, 3, 42, 99).unwrap();
229        let x = random_unit_vector(dim, 1);
230
231        let q = quant.quantize(&x).unwrap();
232        assert!(q.residual_norm > 0.0);
233        assert!(q.residual_norm < 1.0); // residual of unit vector should be < 1
234    }
235}