Skip to main content

diskann_rs/
sq.rs

1//! # Scalar Quantization (F16 and Int8)
2//!
3//! Standalone composable quantizers following the same pattern as [`ProductQuantizer`](crate::pq::ProductQuantizer).
4//!
5//! ## F16Quantizer
6//!
7//! Lossless-ish compression: f32 -> f16 (2 bytes per dimension).
8//! Uses hardware F16C / NEON conversion when available.
9//!
10//! ## Int8Quantizer
11//!
12//! Affine per-dimension quantization: f32 -> u8 (1 byte per dimension).
13//! Trained on sample data to learn per-dimension min/max scales.
14//!
15//! ## VectorQuantizer trait
16//!
17//! Shared interface for PQ, F16, and Int8 quantizers:
18//!
19//! ```ignore
20//! use diskann_rs::sq::{VectorQuantizer, F16Quantizer, Int8Quantizer};
21//!
22//! let f16q = F16Quantizer::new(128);
23//! let codes = f16q.encode(&my_vector);
24//! let decoded = f16q.decode(&codes);
25//! let dist = f16q.asymmetric_distance(&query, &codes);
26//! ```
27
28use crate::DiskAnnError;
29use half::f16;
30use serde::{Deserialize, Serialize};
31use std::fs::File;
32use std::io::{BufReader, BufWriter};
33
34/// Shared interface for vector quantizers (PQ, F16, Int8).
35pub trait VectorQuantizer: Send + Sync {
36    /// Encode a float vector into compressed bytes.
37    fn encode(&self, vector: &[f32]) -> Vec<u8>;
38
39    /// Decode compressed bytes back to an approximate float vector.
40    fn decode(&self, codes: &[u8]) -> Vec<f32>;
41
42    /// Compute asymmetric distance: exact query vs quantized database vector.
43    fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32;
44
45    /// Compression ratio for a given dimension (original bytes / compressed bytes).
46    fn compression_ratio(&self, dim: usize) -> f32;
47}
48
49// =============================================================================
50// F16 Quantizer
51// =============================================================================
52
53/// Half-precision (f16) quantizer.
54///
55/// Each f32 dimension is stored as f16 (2 bytes), giving 2x compression.
56/// Uses SIMD-accelerated conversion and fused distance when available.
57#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct F16Quantizer {
59    dim: usize,
60}
61
62impl F16Quantizer {
63    /// Create a new F16 quantizer for vectors of the given dimension.
64    pub fn new(dim: usize) -> Self {
65        Self { dim }
66    }
67
68    /// Get the vector dimension.
69    pub fn dim(&self) -> usize {
70        self.dim
71    }
72
73    /// Save to file.
74    pub fn save(&self, path: &str) -> Result<(), DiskAnnError> {
75        let file = File::create(path)?;
76        let writer = BufWriter::new(file);
77        bincode::serialize_into(writer, self)?;
78        Ok(())
79    }
80
81    /// Load from file.
82    pub fn load(path: &str) -> Result<Self, DiskAnnError> {
83        let file = File::open(path)?;
84        let reader = BufReader::new(file);
85        let q: Self = bincode::deserialize_from(reader)?;
86        Ok(q)
87    }
88
89    /// Get stats.
90    pub fn stats(&self) -> SQStats {
91        SQStats {
92            kind: "F16".to_string(),
93            dim: self.dim,
94            code_size_bytes: self.dim * 2,
95            compression_ratio: 2.0,
96            trained: true, // F16 needs no training
97        }
98    }
99}
100
101impl VectorQuantizer for F16Quantizer {
102    fn encode(&self, vector: &[f32]) -> Vec<u8> {
103        assert_eq!(vector.len(), self.dim, "Vector dimension mismatch");
104        let mut codes = Vec::with_capacity(self.dim * 2);
105        for &val in vector {
106            codes.extend_from_slice(&f16::from_f32(val).to_le_bytes());
107        }
108        codes
109    }
110
111    fn decode(&self, codes: &[u8]) -> Vec<f32> {
112        assert_eq!(codes.len(), self.dim * 2, "Code length mismatch");
113        let u16_slice: &[u16] = bytemuck::cast_slice(codes);
114        let mut output = vec![0.0f32; self.dim];
115        crate::simd::f16_to_f32_bulk(u16_slice, &mut output);
116        output
117    }
118
119    fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
120        assert_eq!(query.len(), self.dim, "Query dimension mismatch");
121        assert_eq!(codes.len(), self.dim * 2, "Code length mismatch");
122        let u16_slice: &[u16] = bytemuck::cast_slice(codes);
123        crate::simd::l2_f16_vs_f32(u16_slice, query)
124    }
125
126    fn compression_ratio(&self, dim: usize) -> f32 {
127        (dim * 4) as f32 / (dim * 2) as f32
128    }
129}
130
131// =============================================================================
132// Int8 Quantizer
133// =============================================================================
134
135/// Per-dimension affine Int8 quantizer.
136///
137/// Maps each dimension independently: `u8_val = round((f32_val - offset) / scale * 255)`
138/// where `offset = min` and `scale = max - min` per dimension.
139///
140/// Trained from sample vectors to learn the per-dimension ranges.
141#[derive(Clone, Debug, Serialize, Deserialize)]
142pub struct Int8Quantizer {
143    dim: usize,
144    /// Per-dimension scale: (max - min) / 255.0
145    scales: Vec<f32>,
146    /// Per-dimension offset: min value
147    offsets: Vec<f32>,
148}
149
150impl Int8Quantizer {
151    /// Train an Int8 quantizer from sample vectors.
152    ///
153    /// Computes per-dimension min/max to establish the affine mapping.
154    pub fn train(vectors: &[Vec<f32>]) -> Result<Self, DiskAnnError> {
155        if vectors.is_empty() {
156            return Err(DiskAnnError::IndexError("No vectors to train on".into()));
157        }
158
159        let dim = vectors[0].len();
160        let mut mins = vec![f32::MAX; dim];
161        let mut maxs = vec![f32::MIN; dim];
162
163        for v in vectors {
164            if v.len() != dim {
165                return Err(DiskAnnError::IndexError(format!(
166                    "Dimension mismatch: expected {}, got {}", dim, v.len()
167                )));
168            }
169            for (i, &val) in v.iter().enumerate() {
170                if val < mins[i] { mins[i] = val; }
171                if val > maxs[i] { maxs[i] = val; }
172            }
173        }
174
175        let mut scales = Vec::with_capacity(dim);
176        let mut offsets = Vec::with_capacity(dim);
177
178        for i in 0..dim {
179            let range = maxs[i] - mins[i];
180            // Avoid division by zero for constant dimensions
181            let scale = if range.abs() < f32::EPSILON { 1.0 } else { range / 255.0 };
182            scales.push(scale);
183            offsets.push(mins[i]);
184        }
185
186        Ok(Self { dim, scales, offsets })
187    }
188
189    /// Create from pre-computed scales and offsets.
190    pub fn from_params(dim: usize, scales: Vec<f32>, offsets: Vec<f32>) -> Self {
191        assert_eq!(scales.len(), dim);
192        assert_eq!(offsets.len(), dim);
193        Self { dim, scales, offsets }
194    }
195
196    /// Get the vector dimension.
197    pub fn dim(&self) -> usize {
198        self.dim
199    }
200
201    /// Get the per-dimension scales.
202    pub fn scales(&self) -> &[f32] {
203        &self.scales
204    }
205
206    /// Get the per-dimension offsets (min values).
207    pub fn offsets(&self) -> &[f32] {
208        &self.offsets
209    }
210
211    /// Save to file.
212    pub fn save(&self, path: &str) -> Result<(), DiskAnnError> {
213        let file = File::create(path)?;
214        let writer = BufWriter::new(file);
215        bincode::serialize_into(writer, self)?;
216        Ok(())
217    }
218
219    /// Load from file.
220    pub fn load(path: &str) -> Result<Self, DiskAnnError> {
221        let file = File::open(path)?;
222        let reader = BufReader::new(file);
223        let q: Self = bincode::deserialize_from(reader)?;
224        Ok(q)
225    }
226
227    /// Get stats.
228    pub fn stats(&self) -> SQStats {
229        SQStats {
230            kind: "Int8".to_string(),
231            dim: self.dim,
232            code_size_bytes: self.dim,
233            compression_ratio: 4.0,
234            trained: true,
235        }
236    }
237}
238
239impl VectorQuantizer for Int8Quantizer {
240    fn encode(&self, vector: &[f32]) -> Vec<u8> {
241        assert_eq!(vector.len(), self.dim, "Vector dimension mismatch");
242        let mut codes = Vec::with_capacity(self.dim);
243        for i in 0..self.dim {
244            let normalized = (vector[i] - self.offsets[i]) / self.scales[i];
245            let clamped = normalized.clamp(0.0, 255.0);
246            codes.push(clamped.round() as u8);
247        }
248        codes
249    }
250
251    fn decode(&self, codes: &[u8]) -> Vec<f32> {
252        assert_eq!(codes.len(), self.dim, "Code length mismatch");
253        let mut output = Vec::with_capacity(self.dim);
254        for i in 0..self.dim {
255            output.push(codes[i] as f32 * self.scales[i] + self.offsets[i]);
256        }
257        output
258    }
259
260    fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
261        assert_eq!(query.len(), self.dim, "Query dimension mismatch");
262        assert_eq!(codes.len(), self.dim, "Code length mismatch");
263        crate::simd::l2_u8_scaled_vs_f32(codes, query, &self.scales, &self.offsets)
264    }
265
266    fn compression_ratio(&self, dim: usize) -> f32 {
267        (dim * 4) as f32 / dim as f32
268    }
269}
270
271// =============================================================================
272// VectorQuantizer impl for ProductQuantizer
273// =============================================================================
274
275impl VectorQuantizer for crate::pq::ProductQuantizer {
276    fn encode(&self, vector: &[f32]) -> Vec<u8> {
277        self.encode(vector)
278    }
279
280    fn decode(&self, codes: &[u8]) -> Vec<f32> {
281        self.decode(codes)
282    }
283
284    fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
285        self.asymmetric_distance(query, codes)
286    }
287
288    fn compression_ratio(&self, _dim: usize) -> f32 {
289        self.stats().compression_ratio
290    }
291}
292
293// =============================================================================
294// Stats
295// =============================================================================
296
297/// Statistics for a scalar quantizer.
298#[derive(Debug, Clone)]
299pub struct SQStats {
300    pub kind: String,
301    pub dim: usize,
302    pub code_size_bytes: usize,
303    pub compression_ratio: f32,
304    pub trained: bool,
305}
306
307impl std::fmt::Display for SQStats {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        writeln!(f, "{} Quantizer Stats:", self.kind)?;
310        writeln!(f, "  Dimension: {}", self.dim)?;
311        writeln!(f, "  Code size: {} bytes", self.code_size_bytes)?;
312        writeln!(f, "  Compression ratio: {:.1}x", self.compression_ratio)?;
313        writeln!(f, "  Trained: {}", self.trained)
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
322        use rand::prelude::*;
323        use rand::SeedableRng;
324        let mut rng = StdRng::seed_from_u64(seed);
325        (0..n)
326            .map(|_| (0..dim).map(|_| rng.r#gen::<f32>() * 10.0 - 5.0).collect())
327            .collect()
328    }
329
330    // ---- F16 tests ----
331
332    #[test]
333    fn test_f16_encode_decode_round_trip() {
334        let q = F16Quantizer::new(4);
335        let vec = vec![1.0f32, -2.5, 0.0, 3.14];
336        let codes = q.encode(&vec);
337        assert_eq!(codes.len(), 8); // 4 dims * 2 bytes
338        let decoded = q.decode(&codes);
339        assert_eq!(decoded.len(), 4);
340        for (orig, dec) in vec.iter().zip(&decoded) {
341            assert!((orig - dec).abs() < 0.01, "orig={orig}, dec={dec}");
342        }
343    }
344
345    #[test]
346    fn test_f16_asymmetric_distance() {
347        let q = F16Quantizer::new(4);
348        let query = vec![1.0f32, 2.0, 3.0, 4.0];
349        let target = vec![5.0f32, 6.0, 7.0, 8.0];
350        let codes = q.encode(&target);
351
352        let dist = q.asymmetric_distance(&query, &codes);
353        let decoded = q.decode(&codes);
354        let expected: f32 = query.iter().zip(&decoded).map(|(a, b)| (a - b) * (a - b)).sum();
355
356        assert!((dist - expected).abs() < 0.1, "dist={dist}, expected={expected}");
357    }
358
359    #[test]
360    fn test_f16_large_vectors() {
361        let q = F16Quantizer::new(128);
362        let vectors = random_vectors(100, 128, 42);
363        for v in &vectors {
364            let codes = q.encode(v);
365            let decoded = q.decode(&codes);
366            let max_err: f32 = v.iter().zip(&decoded).map(|(a, b)| (a - b).abs()).fold(0.0, f32::max);
367            assert!(max_err < 0.05, "Max f16 error too high: {max_err}");
368        }
369    }
370
371    #[test]
372    fn test_f16_save_load() {
373        let path = "test_f16q.bin";
374        let q = F16Quantizer::new(64);
375        q.save(path).unwrap();
376        let loaded = F16Quantizer::load(path).unwrap();
377        assert_eq!(q.dim(), loaded.dim());
378        std::fs::remove_file(path).ok();
379    }
380
381    #[test]
382    fn test_f16_compression_ratio() {
383        let q = F16Quantizer::new(128);
384        assert!((q.compression_ratio(128) - 2.0).abs() < 0.01);
385    }
386
387    #[test]
388    fn test_f16_stats() {
389        let q = F16Quantizer::new(128);
390        let stats = q.stats();
391        assert_eq!(stats.dim, 128);
392        assert_eq!(stats.code_size_bytes, 256);
393        assert!((stats.compression_ratio - 2.0).abs() < 0.01);
394    }
395
396    // ---- Int8 tests ----
397
398    #[test]
399    fn test_int8_train_encode_decode() {
400        let vectors = random_vectors(500, 32, 42);
401        let q = Int8Quantizer::train(&vectors).unwrap();
402
403        let original = &vectors[0];
404        let codes = q.encode(original);
405        assert_eq!(codes.len(), 32);
406
407        let decoded = q.decode(&codes);
408        assert_eq!(decoded.len(), 32);
409
410        // Reconstruction error should be small relative to range
411        let max_err: f32 = original.iter().zip(&decoded).map(|(a, b)| (a - b).abs()).fold(0.0, f32::max);
412        // Each dimension has range ~10 (from -5 to 5), quantized to 256 levels
413        // So max error should be ~10/255 ≈ 0.04
414        assert!(max_err < 0.1, "Max int8 error too high: {max_err}");
415    }
416
417    #[test]
418    fn test_int8_asymmetric_distance() {
419        let vectors = random_vectors(500, 32, 123);
420        let q = Int8Quantizer::train(&vectors).unwrap();
421
422        let query = &vectors[0];
423        let target = &vectors[100];
424        let codes = q.encode(target);
425
426        let asym_dist = q.asymmetric_distance(query, &codes);
427        let decoded = q.decode(&codes);
428        let expected: f32 = query.iter().zip(&decoded).map(|(a, b)| (a - b) * (a - b)).sum();
429
430        // Should be very close since both use same dequantization
431        assert!((asym_dist - expected).abs() < 0.1, "asym={asym_dist}, expected={expected}");
432    }
433
434    #[test]
435    fn test_int8_constant_dimension() {
436        // Edge case: a dimension with all same values
437        let vectors = vec![
438            vec![1.0, 5.0, 5.0],
439            vec![2.0, 5.0, 5.0],
440            vec![3.0, 5.0, 5.0],
441        ];
442        let q = Int8Quantizer::train(&vectors).unwrap();
443        let codes = q.encode(&vectors[0]);
444        let decoded = q.decode(&codes);
445        // Constant dim should decode back accurately
446        assert!((decoded[1] - 5.0).abs() < 0.1);
447        assert!((decoded[2] - 5.0).abs() < 0.1);
448    }
449
450    #[test]
451    fn test_int8_save_load() {
452        let path = "test_int8q.bin";
453        let vectors = random_vectors(200, 16, 42);
454        let q = Int8Quantizer::train(&vectors).unwrap();
455
456        let codes_before = q.encode(&vectors[0]);
457        q.save(path).unwrap();
458
459        let loaded = Int8Quantizer::load(path).unwrap();
460        let codes_after = loaded.encode(&vectors[0]);
461
462        assert_eq!(codes_before, codes_after);
463        std::fs::remove_file(path).ok();
464    }
465
466    #[test]
467    fn test_int8_compression_ratio() {
468        let vectors = random_vectors(100, 128, 42);
469        let q = Int8Quantizer::train(&vectors).unwrap();
470        assert!((q.compression_ratio(128) - 4.0).abs() < 0.01);
471    }
472
473    #[test]
474    fn test_int8_stats() {
475        let vectors = random_vectors(100, 64, 42);
476        let q = Int8Quantizer::train(&vectors).unwrap();
477        let stats = q.stats();
478        assert_eq!(stats.dim, 64);
479        assert_eq!(stats.code_size_bytes, 64);
480        assert!((stats.compression_ratio - 4.0).abs() < 0.01);
481    }
482
483    #[test]
484    fn test_int8_preserves_ordering() {
485        let vectors = random_vectors(200, 32, 456);
486        let q = Int8Quantizer::train(&vectors).unwrap();
487
488        let query = &vectors[0];
489
490        // True distances
491        let mut true_dists: Vec<(usize, f32)> = vectors.iter()
492            .enumerate()
493            .skip(1)
494            .map(|(i, v)| {
495                let d: f32 = query.iter().zip(v).map(|(a, b)| (a - b) * (a - b)).sum();
496                (i, d)
497            })
498            .collect();
499        true_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
500
501        // Quantized distances
502        let codes: Vec<Vec<u8>> = vectors.iter().map(|v| q.encode(v)).collect();
503        let mut quant_dists: Vec<(usize, f32)> = codes.iter()
504            .enumerate()
505            .skip(1)
506            .map(|(i, c)| (i, q.asymmetric_distance(query, c)))
507            .collect();
508        quant_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
509
510        // Check recall@10
511        let true_top10: std::collections::HashSet<_> = true_dists.iter().take(10).map(|(i, _)| *i).collect();
512        let quant_top10: std::collections::HashSet<_> = quant_dists.iter().take(10).map(|(i, _)| *i).collect();
513        let recall = true_top10.intersection(&quant_top10).count() as f32 / 10.0;
514        assert!(recall >= 0.6, "Int8 recall@10 too low: {recall}");
515    }
516
517    // ---- VectorQuantizer trait usage ----
518
519    #[test]
520    fn test_trait_object_dispatch() {
521        let f16q: Box<dyn VectorQuantizer> = Box::new(F16Quantizer::new(4));
522        let vec = vec![1.0f32, 2.0, 3.0, 4.0];
523        let codes = f16q.encode(&vec);
524        let decoded = f16q.decode(&codes);
525        assert_eq!(decoded.len(), 4);
526
527        let vectors = random_vectors(50, 4, 42);
528        let int8q: Box<dyn VectorQuantizer> = Box::new(Int8Quantizer::train(&vectors).unwrap());
529        let codes2 = int8q.encode(&vec);
530        let decoded2 = int8q.decode(&codes2);
531        assert_eq!(decoded2.len(), 4);
532    }
533}