1use crate::QuantError;
11use crate::codebook::Codebook;
12use crate::pack;
13use crate::rotation::Rotation;
14
15pub struct TurboQuantMse {
17 rotation: Rotation,
18 codebook: &'static Codebook,
19 bits: u8,
20 scale: f32,
24}
25
26#[derive(Debug, Clone)]
28pub struct QuantizedVector {
29 pub packed_indices: Vec<u8>,
31 pub norm: f32,
33 pub bits: u8,
35 pub dimension: usize,
37}
38
39impl TurboQuantMse {
40 pub fn new(dimension: usize, bits: u8, seed: u64) -> Result<Self, QuantError> {
46 let codebook = Codebook::for_bits(bits)?;
47 let rotation = Rotation::new(dimension, seed);
48 let scale = (dimension as f32).sqrt();
49
50 Ok(Self {
51 rotation,
52 codebook,
53 bits,
54 scale,
55 })
56 }
57
58 pub fn dimension(&self) -> usize {
60 self.rotation.dimension()
61 }
62
63 pub fn bits(&self) -> u8 {
65 self.bits
66 }
67
68 pub fn seed(&self) -> u64 {
70 self.rotation.seed()
71 }
72
73 pub fn quantize(&self, x: &[f32]) -> Result<QuantizedVector, QuantError> {
75 let dim = self.rotation.dimension();
76 if x.len() != dim {
77 return Err(QuantError::DimensionMismatch {
78 expected: dim,
79 got: x.len(),
80 });
81 }
82
83 let norm: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
85 let mut y = if norm > 0.0 {
86 x.iter().map(|v| v / norm).collect::<Vec<_>>()
87 } else {
88 vec![0.0; dim]
89 };
90
91 self.rotation.forward(&mut y);
93
94 for val in &mut y {
96 *val *= self.scale;
97 }
98
99 let indices: Vec<u8> = y
101 .iter()
102 .map(|&v| self.codebook.quantize_scalar(v))
103 .collect();
104
105 let packed_indices = pack::pack_indices(&indices, self.bits)?;
106
107 Ok(QuantizedVector {
108 packed_indices,
109 norm,
110 bits: self.bits,
111 dimension: dim,
112 })
113 }
114
115 pub fn dequantize(&self, q: &QuantizedVector) -> Result<Vec<f32>, QuantError> {
117 let dim = q.dimension;
118 let indices = pack::unpack_indices(&q.packed_indices, q.bits, dim)?;
119
120 let mut y: Vec<f32> = indices
122 .iter()
123 .map(|&idx| self.codebook.dequantize_scalar(idx))
124 .collect();
125
126 let inv_scale = 1.0 / self.scale;
128 for val in &mut y {
129 *val *= inv_scale;
130 }
131
132 self.rotation.inverse(&mut y);
134
135 for val in &mut y {
137 *val *= q.norm;
138 }
139
140 Ok(y)
141 }
142
143 pub fn dequantize_into(&self, q: &QuantizedVector, out: &mut [f32]) -> Result<(), QuantError> {
145 let result = self.dequantize(q)?;
146 out.copy_from_slice(&result);
147 Ok(())
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 fn random_unit_vector(dim: usize, seed: u64) -> Vec<f32> {
156 use rand::SeedableRng;
157 use rand::rngs::StdRng;
158 use rand_distr::{Distribution, StandardNormal};
159
160 let mut rng = StdRng::seed_from_u64(seed);
161 let normal = StandardNormal;
162 let mut v: Vec<f32> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
163 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
164 for x in &mut v {
165 *x /= norm;
166 }
167 v
168 }
169
170 #[test]
171 fn quantize_dequantize_roundtrip() {
172 let dim = 128;
173 let quant = TurboQuantMse::new(dim, 2, 42).unwrap();
174
175 let x = random_unit_vector(dim, 7);
176 let q = quant.quantize(&x).unwrap();
177 let x_hat = quant.dequantize(&q).unwrap();
178
179 assert_eq!(x_hat.len(), dim);
180
181 let mse: f32 = x
183 .iter()
184 .zip(x_hat.iter())
185 .map(|(a, b)| (a - b) * (a - b))
186 .sum::<f32>();
187
188 assert!(mse < 0.5, "MSE too high: {mse} (expected < 0.5 for 2-bit)");
191 }
192
193 #[test]
194 fn mse_decreases_with_bits() {
195 let dim = 256;
196 let x = random_unit_vector(dim, 13);
197 let mut prev_mse = f32::MAX;
198
199 for bits in 1..=4 {
200 let quant = TurboQuantMse::new(dim, bits, 42).unwrap();
201 let q = quant.quantize(&x).unwrap();
202 let x_hat = quant.dequantize(&q).unwrap();
203
204 let mse: f32 = x
205 .iter()
206 .zip(x_hat.iter())
207 .map(|(a, b)| (a - b) * (a - b))
208 .sum::<f32>();
209
210 assert!(
211 mse < prev_mse,
212 "{bits}-bit MSE ({mse}) not less than {}-bit ({prev_mse})",
213 bits - 1
214 );
215 prev_mse = mse;
216 }
217 }
218
219 #[test]
220 fn preserves_norm() {
221 let dim = 64;
222 let quant = TurboQuantMse::new(dim, 3, 42).unwrap();
223
224 let x: Vec<f32> = (0..dim).map(|i| (i as f32 + 1.0) * 0.1).collect();
226 let norm_orig: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
227
228 let q = quant.quantize(&x).unwrap();
229 let x_hat = quant.dequantize(&q).unwrap();
230 let norm_hat: f32 = x_hat.iter().map(|v| v * v).sum::<f32>().sqrt();
231
232 assert!(
234 (norm_orig - norm_hat).abs() / norm_orig < 0.3,
235 "norm diverged: {norm_orig} → {norm_hat}"
236 );
237 }
238
239 #[test]
240 fn zero_vector() {
241 let dim = 32;
242 let quant = TurboQuantMse::new(dim, 2, 42).unwrap();
243
244 let x = vec![0.0f32; dim];
245 let q = quant.quantize(&x).unwrap();
246 assert_eq!(q.norm, 0.0);
247
248 let x_hat = quant.dequantize(&q).unwrap();
249 for v in &x_hat {
250 assert_eq!(*v, 0.0);
251 }
252 }
253
254 #[test]
255 fn dimension_mismatch() {
256 let quant = TurboQuantMse::new(32, 2, 42).unwrap();
257 let x = vec![1.0; 64];
258 assert!(quant.quantize(&x).is_err());
259 }
260
261 #[test]
262 fn average_mse_matches_theory() {
263 let dim = 256;
266 let bits = 2;
267 let quant = TurboQuantMse::new(dim, bits, 42).unwrap();
268 let n_trials = 100;
269
270 let total_mse: f32 = (0..n_trials)
271 .map(|seed| {
272 let x = random_unit_vector(dim, seed + 1000);
273 let q = quant.quantize(&x).unwrap();
274 let x_hat = quant.dequantize(&q).unwrap();
275 x.iter()
276 .zip(x_hat.iter())
277 .map(|(a, b)| (a - b) * (a - b))
278 .sum::<f32>()
279 })
280 .sum();
281
282 let avg_mse = total_mse / n_trials as f32;
283 assert!(
286 avg_mse < 0.35,
287 "average MSE = {avg_mse}, expected < 0.35 for 2-bit"
288 );
289 }
290}