1use crate::DiskAnnError;
29use half::f16;
30use serde::{Deserialize, Serialize};
31use std::fs::File;
32use std::io::{BufReader, BufWriter};
33
34pub trait VectorQuantizer: Send + Sync {
36 fn encode(&self, vector: &[f32]) -> Vec<u8>;
38
39 fn decode(&self, codes: &[u8]) -> Vec<f32>;
41
42 fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32;
44
45 fn compression_ratio(&self, dim: usize) -> f32;
47}
48
49#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct F16Quantizer {
59 dim: usize,
60}
61
62impl F16Quantizer {
63 pub fn new(dim: usize) -> Self {
65 Self { dim }
66 }
67
68 pub fn dim(&self) -> usize {
70 self.dim
71 }
72
73 pub fn save(&self, path: &str) -> Result<(), DiskAnnError> {
75 let file = File::create(path)?;
76 let writer = BufWriter::new(file);
77 bincode::serialize_into(writer, self)?;
78 Ok(())
79 }
80
81 pub fn load(path: &str) -> Result<Self, DiskAnnError> {
83 let file = File::open(path)?;
84 let reader = BufReader::new(file);
85 let q: Self = bincode::deserialize_from(reader)?;
86 Ok(q)
87 }
88
89 pub fn stats(&self) -> SQStats {
91 SQStats {
92 kind: "F16".to_string(),
93 dim: self.dim,
94 code_size_bytes: self.dim * 2,
95 compression_ratio: 2.0,
96 trained: true, }
98 }
99}
100
101impl VectorQuantizer for F16Quantizer {
102 fn encode(&self, vector: &[f32]) -> Vec<u8> {
103 assert_eq!(vector.len(), self.dim, "Vector dimension mismatch");
104 let mut codes = Vec::with_capacity(self.dim * 2);
105 for &val in vector {
106 codes.extend_from_slice(&f16::from_f32(val).to_le_bytes());
107 }
108 codes
109 }
110
111 fn decode(&self, codes: &[u8]) -> Vec<f32> {
112 assert_eq!(codes.len(), self.dim * 2, "Code length mismatch");
113 let u16_slice: &[u16] = bytemuck::cast_slice(codes);
114 let mut output = vec![0.0f32; self.dim];
115 crate::simd::f16_to_f32_bulk(u16_slice, &mut output);
116 output
117 }
118
119 fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
120 assert_eq!(query.len(), self.dim, "Query dimension mismatch");
121 assert_eq!(codes.len(), self.dim * 2, "Code length mismatch");
122 let u16_slice: &[u16] = bytemuck::cast_slice(codes);
123 crate::simd::l2_f16_vs_f32(u16_slice, query)
124 }
125
126 fn compression_ratio(&self, dim: usize) -> f32 {
127 (dim * 4) as f32 / (dim * 2) as f32
128 }
129}
130
131#[derive(Clone, Debug, Serialize, Deserialize)]
142pub struct Int8Quantizer {
143 dim: usize,
144 scales: Vec<f32>,
146 offsets: Vec<f32>,
148}
149
150impl Int8Quantizer {
151 pub fn train(vectors: &[Vec<f32>]) -> Result<Self, DiskAnnError> {
155 if vectors.is_empty() {
156 return Err(DiskAnnError::IndexError("No vectors to train on".into()));
157 }
158
159 let dim = vectors[0].len();
160 let mut mins = vec![f32::MAX; dim];
161 let mut maxs = vec![f32::MIN; dim];
162
163 for v in vectors {
164 if v.len() != dim {
165 return Err(DiskAnnError::IndexError(format!(
166 "Dimension mismatch: expected {}, got {}", dim, v.len()
167 )));
168 }
169 for (i, &val) in v.iter().enumerate() {
170 if val < mins[i] { mins[i] = val; }
171 if val > maxs[i] { maxs[i] = val; }
172 }
173 }
174
175 let mut scales = Vec::with_capacity(dim);
176 let mut offsets = Vec::with_capacity(dim);
177
178 for i in 0..dim {
179 let range = maxs[i] - mins[i];
180 let scale = if range.abs() < f32::EPSILON { 1.0 } else { range / 255.0 };
182 scales.push(scale);
183 offsets.push(mins[i]);
184 }
185
186 Ok(Self { dim, scales, offsets })
187 }
188
189 pub fn from_params(dim: usize, scales: Vec<f32>, offsets: Vec<f32>) -> Self {
191 assert_eq!(scales.len(), dim);
192 assert_eq!(offsets.len(), dim);
193 Self { dim, scales, offsets }
194 }
195
196 pub fn dim(&self) -> usize {
198 self.dim
199 }
200
201 pub fn scales(&self) -> &[f32] {
203 &self.scales
204 }
205
206 pub fn offsets(&self) -> &[f32] {
208 &self.offsets
209 }
210
211 pub fn save(&self, path: &str) -> Result<(), DiskAnnError> {
213 let file = File::create(path)?;
214 let writer = BufWriter::new(file);
215 bincode::serialize_into(writer, self)?;
216 Ok(())
217 }
218
219 pub fn load(path: &str) -> Result<Self, DiskAnnError> {
221 let file = File::open(path)?;
222 let reader = BufReader::new(file);
223 let q: Self = bincode::deserialize_from(reader)?;
224 Ok(q)
225 }
226
227 pub fn stats(&self) -> SQStats {
229 SQStats {
230 kind: "Int8".to_string(),
231 dim: self.dim,
232 code_size_bytes: self.dim,
233 compression_ratio: 4.0,
234 trained: true,
235 }
236 }
237}
238
239impl VectorQuantizer for Int8Quantizer {
240 fn encode(&self, vector: &[f32]) -> Vec<u8> {
241 assert_eq!(vector.len(), self.dim, "Vector dimension mismatch");
242 let mut codes = Vec::with_capacity(self.dim);
243 for i in 0..self.dim {
244 let normalized = (vector[i] - self.offsets[i]) / self.scales[i];
245 let clamped = normalized.clamp(0.0, 255.0);
246 codes.push(clamped.round() as u8);
247 }
248 codes
249 }
250
251 fn decode(&self, codes: &[u8]) -> Vec<f32> {
252 assert_eq!(codes.len(), self.dim, "Code length mismatch");
253 let mut output = Vec::with_capacity(self.dim);
254 for i in 0..self.dim {
255 output.push(codes[i] as f32 * self.scales[i] + self.offsets[i]);
256 }
257 output
258 }
259
260 fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
261 assert_eq!(query.len(), self.dim, "Query dimension mismatch");
262 assert_eq!(codes.len(), self.dim, "Code length mismatch");
263 crate::simd::l2_u8_scaled_vs_f32(codes, query, &self.scales, &self.offsets)
264 }
265
266 fn compression_ratio(&self, dim: usize) -> f32 {
267 (dim * 4) as f32 / dim as f32
268 }
269}
270
271impl VectorQuantizer for crate::pq::ProductQuantizer {
276 fn encode(&self, vector: &[f32]) -> Vec<u8> {
277 self.encode(vector)
278 }
279
280 fn decode(&self, codes: &[u8]) -> Vec<f32> {
281 self.decode(codes)
282 }
283
284 fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
285 self.asymmetric_distance(query, codes)
286 }
287
288 fn compression_ratio(&self, _dim: usize) -> f32 {
289 self.stats().compression_ratio
290 }
291}
292
293#[derive(Debug, Clone)]
299pub struct SQStats {
300 pub kind: String,
301 pub dim: usize,
302 pub code_size_bytes: usize,
303 pub compression_ratio: f32,
304 pub trained: bool,
305}
306
307impl std::fmt::Display for SQStats {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 writeln!(f, "{} Quantizer Stats:", self.kind)?;
310 writeln!(f, " Dimension: {}", self.dim)?;
311 writeln!(f, " Code size: {} bytes", self.code_size_bytes)?;
312 writeln!(f, " Compression ratio: {:.1}x", self.compression_ratio)?;
313 writeln!(f, " Trained: {}", self.trained)
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
322 use rand::prelude::*;
323 use rand::SeedableRng;
324 let mut rng = StdRng::seed_from_u64(seed);
325 (0..n)
326 .map(|_| (0..dim).map(|_| rng.r#gen::<f32>() * 10.0 - 5.0).collect())
327 .collect()
328 }
329
330 #[test]
333 fn test_f16_encode_decode_round_trip() {
334 let q = F16Quantizer::new(4);
335 let vec = vec![1.0f32, -2.5, 0.0, 3.14];
336 let codes = q.encode(&vec);
337 assert_eq!(codes.len(), 8); let decoded = q.decode(&codes);
339 assert_eq!(decoded.len(), 4);
340 for (orig, dec) in vec.iter().zip(&decoded) {
341 assert!((orig - dec).abs() < 0.01, "orig={orig}, dec={dec}");
342 }
343 }
344
345 #[test]
346 fn test_f16_asymmetric_distance() {
347 let q = F16Quantizer::new(4);
348 let query = vec![1.0f32, 2.0, 3.0, 4.0];
349 let target = vec![5.0f32, 6.0, 7.0, 8.0];
350 let codes = q.encode(&target);
351
352 let dist = q.asymmetric_distance(&query, &codes);
353 let decoded = q.decode(&codes);
354 let expected: f32 = query.iter().zip(&decoded).map(|(a, b)| (a - b) * (a - b)).sum();
355
356 assert!((dist - expected).abs() < 0.1, "dist={dist}, expected={expected}");
357 }
358
359 #[test]
360 fn test_f16_large_vectors() {
361 let q = F16Quantizer::new(128);
362 let vectors = random_vectors(100, 128, 42);
363 for v in &vectors {
364 let codes = q.encode(v);
365 let decoded = q.decode(&codes);
366 let max_err: f32 = v.iter().zip(&decoded).map(|(a, b)| (a - b).abs()).fold(0.0, f32::max);
367 assert!(max_err < 0.05, "Max f16 error too high: {max_err}");
368 }
369 }
370
371 #[test]
372 fn test_f16_save_load() {
373 let path = "test_f16q.bin";
374 let q = F16Quantizer::new(64);
375 q.save(path).unwrap();
376 let loaded = F16Quantizer::load(path).unwrap();
377 assert_eq!(q.dim(), loaded.dim());
378 std::fs::remove_file(path).ok();
379 }
380
381 #[test]
382 fn test_f16_compression_ratio() {
383 let q = F16Quantizer::new(128);
384 assert!((q.compression_ratio(128) - 2.0).abs() < 0.01);
385 }
386
387 #[test]
388 fn test_f16_stats() {
389 let q = F16Quantizer::new(128);
390 let stats = q.stats();
391 assert_eq!(stats.dim, 128);
392 assert_eq!(stats.code_size_bytes, 256);
393 assert!((stats.compression_ratio - 2.0).abs() < 0.01);
394 }
395
396 #[test]
399 fn test_int8_train_encode_decode() {
400 let vectors = random_vectors(500, 32, 42);
401 let q = Int8Quantizer::train(&vectors).unwrap();
402
403 let original = &vectors[0];
404 let codes = q.encode(original);
405 assert_eq!(codes.len(), 32);
406
407 let decoded = q.decode(&codes);
408 assert_eq!(decoded.len(), 32);
409
410 let max_err: f32 = original.iter().zip(&decoded).map(|(a, b)| (a - b).abs()).fold(0.0, f32::max);
412 assert!(max_err < 0.1, "Max int8 error too high: {max_err}");
415 }
416
417 #[test]
418 fn test_int8_asymmetric_distance() {
419 let vectors = random_vectors(500, 32, 123);
420 let q = Int8Quantizer::train(&vectors).unwrap();
421
422 let query = &vectors[0];
423 let target = &vectors[100];
424 let codes = q.encode(target);
425
426 let asym_dist = q.asymmetric_distance(query, &codes);
427 let decoded = q.decode(&codes);
428 let expected: f32 = query.iter().zip(&decoded).map(|(a, b)| (a - b) * (a - b)).sum();
429
430 assert!((asym_dist - expected).abs() < 0.1, "asym={asym_dist}, expected={expected}");
432 }
433
434 #[test]
435 fn test_int8_constant_dimension() {
436 let vectors = vec![
438 vec![1.0, 5.0, 5.0],
439 vec![2.0, 5.0, 5.0],
440 vec![3.0, 5.0, 5.0],
441 ];
442 let q = Int8Quantizer::train(&vectors).unwrap();
443 let codes = q.encode(&vectors[0]);
444 let decoded = q.decode(&codes);
445 assert!((decoded[1] - 5.0).abs() < 0.1);
447 assert!((decoded[2] - 5.0).abs() < 0.1);
448 }
449
450 #[test]
451 fn test_int8_save_load() {
452 let path = "test_int8q.bin";
453 let vectors = random_vectors(200, 16, 42);
454 let q = Int8Quantizer::train(&vectors).unwrap();
455
456 let codes_before = q.encode(&vectors[0]);
457 q.save(path).unwrap();
458
459 let loaded = Int8Quantizer::load(path).unwrap();
460 let codes_after = loaded.encode(&vectors[0]);
461
462 assert_eq!(codes_before, codes_after);
463 std::fs::remove_file(path).ok();
464 }
465
466 #[test]
467 fn test_int8_compression_ratio() {
468 let vectors = random_vectors(100, 128, 42);
469 let q = Int8Quantizer::train(&vectors).unwrap();
470 assert!((q.compression_ratio(128) - 4.0).abs() < 0.01);
471 }
472
473 #[test]
474 fn test_int8_stats() {
475 let vectors = random_vectors(100, 64, 42);
476 let q = Int8Quantizer::train(&vectors).unwrap();
477 let stats = q.stats();
478 assert_eq!(stats.dim, 64);
479 assert_eq!(stats.code_size_bytes, 64);
480 assert!((stats.compression_ratio - 4.0).abs() < 0.01);
481 }
482
483 #[test]
484 fn test_int8_preserves_ordering() {
485 let vectors = random_vectors(200, 32, 456);
486 let q = Int8Quantizer::train(&vectors).unwrap();
487
488 let query = &vectors[0];
489
490 let mut true_dists: Vec<(usize, f32)> = vectors.iter()
492 .enumerate()
493 .skip(1)
494 .map(|(i, v)| {
495 let d: f32 = query.iter().zip(v).map(|(a, b)| (a - b) * (a - b)).sum();
496 (i, d)
497 })
498 .collect();
499 true_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
500
501 let codes: Vec<Vec<u8>> = vectors.iter().map(|v| q.encode(v)).collect();
503 let mut quant_dists: Vec<(usize, f32)> = codes.iter()
504 .enumerate()
505 .skip(1)
506 .map(|(i, c)| (i, q.asymmetric_distance(query, c)))
507 .collect();
508 quant_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
509
510 let true_top10: std::collections::HashSet<_> = true_dists.iter().take(10).map(|(i, _)| *i).collect();
512 let quant_top10: std::collections::HashSet<_> = quant_dists.iter().take(10).map(|(i, _)| *i).collect();
513 let recall = true_top10.intersection(&quant_top10).count() as f32 / 10.0;
514 assert!(recall >= 0.6, "Int8 recall@10 too low: {recall}");
515 }
516
517 #[test]
520 fn test_trait_object_dispatch() {
521 let f16q: Box<dyn VectorQuantizer> = Box::new(F16Quantizer::new(4));
522 let vec = vec![1.0f32, 2.0, 3.0, 4.0];
523 let codes = f16q.encode(&vec);
524 let decoded = f16q.decode(&codes);
525 assert_eq!(decoded.len(), 4);
526
527 let vectors = random_vectors(50, 4, 42);
528 let int8q: Box<dyn VectorQuantizer> = Box::new(Int8Quantizer::train(&vectors).unwrap());
529 let codes2 = int8q.encode(&vec);
530 let decoded2 = int8q.decode(&codes2);
531 assert_eq!(decoded2.len(), 4);
532 }
533}