1use crate::QuantError;
15use crate::mse::{QuantizedVector, TurboQuantMse};
16use crate::qjl::{QjlResult, QjlTransform};
17
18pub struct TurboQuantProd {
20 mse: TurboQuantMse,
21 qjl: QjlTransform,
22 bits: u8,
24}
25
26#[derive(Debug, Clone)]
28pub struct QuantizedProdVector {
29 pub mse_part: QuantizedVector,
31 pub qjl_part: QjlResult,
33 pub residual_norm: f32,
35}
36
37impl TurboQuantProd {
38 pub fn new(
45 dimension: usize,
46 bits: u8,
47 mse_seed: u64,
48 qjl_seed: u64,
49 ) -> Result<Self, QuantError> {
50 if bits < 2 {
51 return Err(QuantError::UnsupportedBitWidth(bits));
52 }
53
54 let mse = TurboQuantMse::new(dimension, bits - 1, mse_seed)?;
55 let qjl = QjlTransform::new(dimension, qjl_seed);
56
57 Ok(Self { mse, qjl, bits })
58 }
59
60 pub fn dimension(&self) -> usize {
62 self.mse.dimension()
63 }
64
65 pub fn bits(&self) -> u8 {
67 self.bits
68 }
69
70 pub fn quantize(&self, x: &[f32]) -> Result<QuantizedProdVector, QuantError> {
72 let mse_part = self.mse.quantize(x)?;
74
75 let x_mse = self.mse.dequantize(&mse_part)?;
77
78 let residual: Vec<f32> = x.iter().zip(x_mse.iter()).map(|(a, b)| a - b).collect();
80
81 let residual_norm: f32 = residual.iter().map(|v| v * v).sum::<f32>().sqrt();
83
84 let qjl_part = self.qjl.quantize(&residual);
86
87 Ok(QuantizedProdVector {
88 mse_part,
89 qjl_part,
90 residual_norm,
91 })
92 }
93
94 pub fn dequantize(&self, q: &QuantizedProdVector) -> Result<Vec<f32>, QuantError> {
98 let x_mse = self.mse.dequantize(&q.mse_part)?;
100
101 let x_qjl = self.qjl.dequantize(&q.qjl_part, q.residual_norm);
103
104 let result: Vec<f32> = x_mse.iter().zip(x_qjl.iter()).map(|(a, b)| a + b).collect();
106
107 Ok(result)
108 }
109
110 pub fn inner_product_estimate(
114 &self,
115 query: &[f32],
116 q: &QuantizedProdVector,
117 ) -> Result<f32, QuantError> {
118 let x_mse = self.mse.dequantize(&q.mse_part)?;
120 let ip_mse: f32 = query.iter().zip(x_mse.iter()).map(|(a, b)| a * b).sum();
121
122 let ip_qjl = self
124 .qjl
125 .inner_product_estimate(query, &q.qjl_part, q.residual_norm);
126
127 Ok(ip_mse + ip_qjl)
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 fn random_unit_vector(dim: usize, seed: u64) -> Vec<f32> {
136 use rand::SeedableRng;
137 use rand::rngs::StdRng;
138 use rand_distr::{Distribution, StandardNormal};
139
140 let mut rng = StdRng::seed_from_u64(seed);
141 let normal = StandardNormal;
142 let mut v: Vec<f32> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
143 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
144 for x in &mut v {
145 *x /= norm;
146 }
147 v
148 }
149
150 #[test]
151 fn prod_requires_minimum_2_bits() {
152 assert!(TurboQuantProd::new(32, 1, 1, 2).is_err());
153 assert!(TurboQuantProd::new(32, 2, 1, 2).is_ok());
154 }
155
156 #[test]
157 fn prod_quantize_dequantize() {
158 let dim = 128;
159 let quant = TurboQuantProd::new(dim, 3, 42, 99).unwrap();
160 let x = random_unit_vector(dim, 7);
161
162 let q = quant.quantize(&x).unwrap();
163 let x_hat = quant.dequantize(&q).unwrap();
164
165 assert_eq!(x_hat.len(), dim);
166
167 let mse: f32 = x
169 .iter()
170 .zip(x_hat.iter())
171 .map(|(a, b)| (a - b) * (a - b))
172 .sum::<f32>();
173 assert!(mse < 1.0, "MSE too high: {mse}");
174 }
175
176 #[test]
177 fn unbiased_inner_product() {
178 let dim = 128;
180 let x = random_unit_vector(dim, 1);
181 let y = random_unit_vector(dim, 2);
182 let true_ip: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
183
184 let n_trials = 100;
185 let mut total_estimated = 0.0f32;
186
187 for trial in 0..n_trials {
188 let quant = TurboQuantProd::new(dim, 3, 42, trial as u64 + 100).unwrap();
189 let q = quant.quantize(&x).unwrap();
190 let x_hat = quant.dequantize(&q).unwrap();
191 let estimated: f32 = y.iter().zip(x_hat.iter()).map(|(a, b)| a * b).sum();
192 total_estimated += estimated;
193 }
194
195 let avg_estimated = total_estimated / n_trials as f32;
196 let bias = (avg_estimated - true_ip).abs();
197 assert!(
198 bias < 0.1,
199 "prod bias too large: avg={avg_estimated}, true={true_ip}, bias={bias}"
200 );
201 }
202
203 #[test]
204 fn inner_product_estimate_close() {
205 let dim = 64;
206 let quant = TurboQuantProd::new(dim, 3, 42, 99).unwrap();
207 let x = random_unit_vector(dim, 1);
208 let query = random_unit_vector(dim, 2);
209
210 let q = quant.quantize(&x).unwrap();
211
212 let x_hat = quant.dequantize(&q).unwrap();
214 let explicit: f32 = query.iter().zip(x_hat.iter()).map(|(a, b)| a * b).sum();
215
216 let estimated = quant.inner_product_estimate(&query, &q).unwrap();
218
219 assert!(
220 (explicit - estimated).abs() < 1e-3,
221 "explicit={explicit}, estimated={estimated}"
222 );
223 }
224
225 #[test]
226 fn residual_norm_positive() {
227 let dim = 64;
228 let quant = TurboQuantProd::new(dim, 3, 42, 99).unwrap();
229 let x = random_unit_vector(dim, 1);
230
231 let q = quant.quantize(&x).unwrap();
232 assert!(q.residual_norm > 0.0);
233 assert!(q.residual_norm < 1.0); }
235}