Skip to main content

engine/
sq.rs

1//! Scalar Quantization (SQ) for memory-efficient vector storage
2//!
3//! Scalar Quantization reduces memory usage by quantizing each float dimension
4//! to a smaller integer type (typically uint8). This provides ~4x memory reduction
5//! with minimal accuracy loss for most use cases.
6//!
7//! # Supported Quantization Types
8//! - SQ4: 4-bit quantization (8x compression, lower accuracy)
9//! - SQ8: 8-bit quantization (4x compression, good accuracy)
10//! - SQ16: 16-bit quantization (2x compression, high accuracy)
11
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15use common::DistanceMetric;
16
17/// Quantization bit depth
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
19pub enum QuantizationType {
20    /// 4-bit quantization (16 levels per dimension)
21    SQ4,
22    /// 8-bit quantization (256 levels per dimension)
23    #[default]
24    SQ8,
25    /// 16-bit quantization (65536 levels per dimension)
26    SQ16,
27}
28
29/// Configuration for Scalar Quantization
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct SQConfig {
32    /// Quantization type (SQ4, SQ8, SQ16)
33    pub quantization_type: QuantizationType,
34    /// Number of dimensions in vectors
35    pub dimensions: usize,
36    /// Distance metric to use
37    pub metric: DistanceMetric,
38    /// Whether to store original vectors for rescoring
39    pub store_originals: bool,
40}
41
42impl Default for SQConfig {
43    fn default() -> Self {
44        Self {
45            quantization_type: QuantizationType::SQ8,
46            dimensions: 0,
47            metric: DistanceMetric::Cosine,
48            store_originals: false,
49        }
50    }
51}
52
53/// Statistics for quantization
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SQStats {
56    /// Number of vectors indexed
57    pub num_vectors: usize,
58    /// Original memory usage in bytes
59    pub original_memory_bytes: usize,
60    /// Quantized memory usage in bytes
61    pub quantized_memory_bytes: usize,
62    /// Compression ratio
63    pub compression_ratio: f32,
64    /// Quantization type used
65    pub quantization_type: QuantizationType,
66}
67
68/// Per-dimension quantization parameters
69#[derive(Debug, Clone, Serialize, Deserialize)]
70struct DimensionParams {
71    min_val: f32,
72    max_val: f32,
73    scale: f32,
74}
75
76/// Scalar Quantization Index
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct SQIndex {
79    config: SQConfig,
80    /// Per-dimension quantization parameters
81    dimension_params: Vec<DimensionParams>,
82    /// Quantized vectors stored as bytes (SQ8) or packed (SQ4)
83    quantized_vectors: Vec<Vec<u8>>,
84    /// Vector IDs
85    ids: Vec<String>,
86    /// Optional original vectors for rescoring
87    original_vectors: Option<Vec<Vec<f32>>>,
88    /// ID to index mapping
89    id_to_index: HashMap<String, usize>,
90    /// Whether the index has been trained
91    trained: bool,
92}
93
94/// Search result from SQ index
95#[derive(Debug, Clone)]
96pub struct SQSearchResult {
97    pub id: String,
98    pub score: f32,
99    pub quantized_score: f32,
100}
101
102impl SQIndex {
103    /// Create a new SQ index with the given configuration
104    pub fn new(config: SQConfig) -> Self {
105        Self {
106            dimension_params: Vec::new(),
107            quantized_vectors: Vec::new(),
108            ids: Vec::new(),
109            original_vectors: if config.store_originals {
110                Some(Vec::new())
111            } else {
112                None
113            },
114            id_to_index: HashMap::new(),
115            trained: false,
116            config,
117        }
118    }
119
120    /// Train the quantizer on a set of vectors to determine min/max ranges
121    pub fn train(&mut self, vectors: &[Vec<f32>]) -> Result<(), String> {
122        if vectors.is_empty() {
123            return Err("Cannot train on empty vector set".to_string());
124        }
125
126        let dimensions = vectors[0].len();
127        if self.config.dimensions == 0 {
128            self.config.dimensions = dimensions;
129        } else if self.config.dimensions != dimensions {
130            return Err(format!(
131                "Dimension mismatch: expected {}, got {}",
132                self.config.dimensions, dimensions
133            ));
134        }
135
136        // Calculate min/max for each dimension
137        let mut dimension_params = Vec::with_capacity(dimensions);
138
139        for dim in 0..dimensions {
140            let mut min_val = f32::MAX;
141            let mut max_val = f32::MIN;
142
143            for vector in vectors {
144                let val = vector[dim];
145                min_val = min_val.min(val);
146                max_val = max_val.max(val);
147            }
148
149            // Add small epsilon to prevent division by zero
150            let range = (max_val - min_val).max(1e-10);
151            let scale = self.get_max_quantized_value() / range;
152
153            dimension_params.push(DimensionParams {
154                min_val,
155                max_val,
156                scale,
157            });
158        }
159
160        self.dimension_params = dimension_params;
161        self.trained = true;
162        Ok(())
163    }
164
165    /// Get the maximum quantized value based on quantization type
166    fn get_max_quantized_value(&self) -> f32 {
167        match self.config.quantization_type {
168            QuantizationType::SQ4 => 15.0,
169            QuantizationType::SQ8 => 255.0,
170            QuantizationType::SQ16 => 65535.0,
171        }
172    }
173
174    /// Quantize a single vector
175    fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
176        match self.config.quantization_type {
177            QuantizationType::SQ8 => self.quantize_sq8(vector),
178            QuantizationType::SQ4 => self.quantize_sq4(vector),
179            QuantizationType::SQ16 => self.quantize_sq16(vector),
180        }
181    }
182
183    /// SQ8: Each dimension becomes one byte
184    fn quantize_sq8(&self, vector: &[f32]) -> Vec<u8> {
185        vector
186            .iter()
187            .enumerate()
188            .map(|(i, &val)| {
189                let params = &self.dimension_params[i];
190                let normalized = (val - params.min_val) * params.scale;
191                normalized.clamp(0.0, 255.0) as u8
192            })
193            .collect()
194    }
195
196    /// SQ4: Two dimensions packed into one byte
197    fn quantize_sq4(&self, vector: &[f32]) -> Vec<u8> {
198        let mut result = Vec::with_capacity(vector.len().div_ceil(2));
199
200        for chunk in vector.chunks(2) {
201            let low = {
202                let params = &self.dimension_params[result.len() * 2];
203                let normalized = (chunk[0] - params.min_val) * params.scale;
204                (normalized.clamp(0.0, 15.0) as u8) & 0x0F
205            };
206
207            let high = if chunk.len() > 1 {
208                let params = &self.dimension_params[result.len() * 2 + 1];
209                let normalized = (chunk[1] - params.min_val) * params.scale;
210                ((normalized.clamp(0.0, 15.0) as u8) & 0x0F) << 4
211            } else {
212                0
213            };
214
215            result.push(low | high);
216        }
217
218        result
219    }
220
221    /// SQ16: Each dimension becomes two bytes (little-endian)
222    fn quantize_sq16(&self, vector: &[f32]) -> Vec<u8> {
223        let mut result = Vec::with_capacity(vector.len() * 2);
224
225        for (i, &val) in vector.iter().enumerate() {
226            let params = &self.dimension_params[i];
227            let normalized = (val - params.min_val) * params.scale;
228            let quantized = normalized.clamp(0.0, 65535.0) as u16;
229            result.extend_from_slice(&quantized.to_le_bytes());
230        }
231
232        result
233    }
234
235    /// Dequantize a vector back to floats (approximate)
236    pub fn dequantize_vector(&self, quantized: &[u8]) -> Vec<f32> {
237        match self.config.quantization_type {
238            QuantizationType::SQ8 => self.dequantize_sq8(quantized),
239            QuantizationType::SQ4 => self.dequantize_sq4(quantized),
240            QuantizationType::SQ16 => self.dequantize_sq16(quantized),
241        }
242    }
243
244    fn dequantize_sq8(&self, quantized: &[u8]) -> Vec<f32> {
245        quantized
246            .iter()
247            .enumerate()
248            .map(|(i, &val)| {
249                let params = &self.dimension_params[i];
250                params.min_val + (val as f32 / params.scale)
251            })
252            .collect()
253    }
254
255    fn dequantize_sq4(&self, quantized: &[u8]) -> Vec<f32> {
256        let mut result = Vec::with_capacity(self.config.dimensions);
257
258        for (byte_idx, &byte) in quantized.iter().enumerate() {
259            let dim_idx = byte_idx * 2;
260            if dim_idx < self.config.dimensions {
261                let low = byte & 0x0F;
262                let params = &self.dimension_params[dim_idx];
263                result.push(params.min_val + (low as f32 / params.scale));
264            }
265
266            if dim_idx + 1 < self.config.dimensions {
267                let high = (byte >> 4) & 0x0F;
268                let params = &self.dimension_params[dim_idx + 1];
269                result.push(params.min_val + (high as f32 / params.scale));
270            }
271        }
272
273        result
274    }
275
276    fn dequantize_sq16(&self, quantized: &[u8]) -> Vec<f32> {
277        quantized
278            .chunks(2)
279            .enumerate()
280            .map(|(i, bytes)| {
281                let val = u16::from_le_bytes([bytes[0], bytes[1]]);
282                let params = &self.dimension_params[i];
283                params.min_val + (val as f32 / params.scale)
284            })
285            .collect()
286    }
287
288    /// Add vectors to the index
289    pub fn add(&mut self, ids: &[String], vectors: &[Vec<f32>]) -> Result<(), String> {
290        if !self.trained {
291            // Auto-train on first batch
292            self.train(vectors)?;
293        }
294
295        for (id, vector) in ids.iter().zip(vectors.iter()) {
296            if vector.len() != self.config.dimensions {
297                return Err(format!(
298                    "Dimension mismatch for {}: expected {}, got {}",
299                    id,
300                    self.config.dimensions,
301                    vector.len()
302                ));
303            }
304
305            // Check for duplicate
306            if let Some(&existing_idx) = self.id_to_index.get(id) {
307                // Update existing vector
308                self.quantized_vectors[existing_idx] = self.quantize_vector(vector);
309                if let Some(ref mut originals) = self.original_vectors {
310                    originals[existing_idx] = vector.clone();
311                }
312            } else {
313                // Add new vector
314                let idx = self.quantized_vectors.len();
315                self.quantized_vectors.push(self.quantize_vector(vector));
316                self.ids.push(id.clone());
317                self.id_to_index.insert(id.clone(), idx);
318
319                if let Some(ref mut originals) = self.original_vectors {
320                    originals.push(vector.clone());
321                }
322            }
323        }
324
325        Ok(())
326    }
327
328    /// Search for similar vectors
329    pub fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<SQSearchResult>, String> {
330        if !self.trained {
331            return Err("Index not trained".to_string());
332        }
333
334        if query.len() != self.config.dimensions {
335            return Err(format!(
336                "Query dimension mismatch: expected {}, got {}",
337                self.config.dimensions,
338                query.len()
339            ));
340        }
341
342        // Quantize query
343        let quantized_query = self.quantize_vector(query);
344
345        // Calculate distances to all vectors
346        let mut scores: Vec<(usize, f32)> = self
347            .quantized_vectors
348            .iter()
349            .enumerate()
350            .map(|(idx, qv)| {
351                let score = self.quantized_distance(&quantized_query, qv);
352                (idx, score)
353            })
354            .collect();
355
356        // Sort by score (higher is better for similarity)
357        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
358
359        // Take top_k and optionally rescore with original vectors
360        let results: Vec<SQSearchResult> = scores
361            .into_iter()
362            .take(top_k)
363            .map(|(idx, quantized_score)| {
364                let final_score = if let Some(ref originals) = self.original_vectors {
365                    self.float_similarity(query, &originals[idx])
366                } else {
367                    quantized_score
368                };
369
370                SQSearchResult {
371                    id: self.ids[idx].clone(),
372                    score: final_score,
373                    quantized_score,
374                }
375            })
376            .collect();
377
378        Ok(results)
379    }
380
381    /// Calculate distance between two quantized vectors
382    fn quantized_distance(&self, a: &[u8], b: &[u8]) -> f32 {
383        match self.config.quantization_type {
384            QuantizationType::SQ8 => self.sq8_distance(a, b),
385            QuantizationType::SQ4 => self.sq4_distance(a, b),
386            QuantizationType::SQ16 => self.sq16_distance(a, b),
387        }
388    }
389
390    /// SQ8 distance computation (optimized for SIMD)
391    fn sq8_distance(&self, a: &[u8], b: &[u8]) -> f32 {
392        match self.config.metric {
393            DistanceMetric::Cosine | DistanceMetric::DotProduct => {
394                // Dot product on quantized values
395                let dot: i32 = a
396                    .iter()
397                    .zip(b.iter())
398                    .map(|(&x, &y)| x as i32 * y as i32)
399                    .sum();
400
401                // Normalize for cosine
402                let norm_a: i32 = a.iter().map(|&x| x as i32 * x as i32).sum();
403                let norm_b: i32 = b.iter().map(|&x| x as i32 * x as i32).sum();
404
405                let denom = ((norm_a as f32).sqrt() * (norm_b as f32).sqrt()).max(1e-10);
406                dot as f32 / denom
407            }
408            DistanceMetric::Euclidean => {
409                // Negative euclidean (so higher = more similar)
410                let dist_sq: i32 = a
411                    .iter()
412                    .zip(b.iter())
413                    .map(|(&x, &y)| {
414                        let diff = x as i32 - y as i32;
415                        diff * diff
416                    })
417                    .sum();
418                -(dist_sq as f32).sqrt()
419            }
420        }
421    }
422
423    /// SQ4 distance computation
424    fn sq4_distance(&self, a: &[u8], b: &[u8]) -> f32 {
425        // Unpack and compute
426        let a_unpacked = self.unpack_sq4(a);
427        let b_unpacked = self.unpack_sq4(b);
428        self.sq8_distance(&a_unpacked, &b_unpacked)
429    }
430
431    fn unpack_sq4(&self, packed: &[u8]) -> Vec<u8> {
432        let mut result = Vec::with_capacity(self.config.dimensions);
433        for &byte in packed {
434            result.push(byte & 0x0F);
435            if result.len() < self.config.dimensions {
436                result.push((byte >> 4) & 0x0F);
437            }
438        }
439        result
440    }
441
442    /// SQ16 distance computation
443    fn sq16_distance(&self, a: &[u8], b: &[u8]) -> f32 {
444        match self.config.metric {
445            DistanceMetric::Cosine | DistanceMetric::DotProduct => {
446                let mut dot: i64 = 0;
447                let mut norm_a: i64 = 0;
448                let mut norm_b: i64 = 0;
449
450                for i in (0..a.len()).step_by(2) {
451                    let va = u16::from_le_bytes([a[i], a[i + 1]]) as i64;
452                    let vb = u16::from_le_bytes([b[i], b[i + 1]]) as i64;
453                    dot += va * vb;
454                    norm_a += va * va;
455                    norm_b += vb * vb;
456                }
457
458                let denom = ((norm_a as f64).sqrt() * (norm_b as f64).sqrt()).max(1e-10);
459                (dot as f64 / denom) as f32
460            }
461            DistanceMetric::Euclidean => {
462                let mut dist_sq: i64 = 0;
463                for i in (0..a.len()).step_by(2) {
464                    let va = u16::from_le_bytes([a[i], a[i + 1]]) as i64;
465                    let vb = u16::from_le_bytes([b[i], b[i + 1]]) as i64;
466                    let diff = va - vb;
467                    dist_sq += diff * diff;
468                }
469                -((dist_sq as f64).sqrt() as f32)
470            }
471        }
472    }
473
474    /// Calculate similarity using original float vectors
475    fn float_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
476        match self.config.metric {
477            DistanceMetric::Cosine => {
478                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
479                let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
480                let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
481                dot / (norm_a * norm_b).max(1e-10)
482            }
483            DistanceMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
484            DistanceMetric::Euclidean => {
485                let dist_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
486                -dist_sq.sqrt()
487            }
488        }
489    }
490
491    /// Delete vectors by ID
492    pub fn delete(&mut self, ids: &[String]) -> usize {
493        let mut deleted = 0;
494
495        for id in ids {
496            if let Some(idx) = self.id_to_index.remove(id) {
497                // Mark as deleted (we'll compact later)
498                // For now, we swap-remove for efficiency
499                let last_idx = self.quantized_vectors.len() - 1;
500
501                if idx != last_idx {
502                    // Swap with last element
503                    self.quantized_vectors.swap(idx, last_idx);
504                    self.ids.swap(idx, last_idx);
505                    if let Some(ref mut originals) = self.original_vectors {
506                        originals.swap(idx, last_idx);
507                    }
508                    // Update index of swapped element
509                    self.id_to_index.insert(self.ids[idx].clone(), idx);
510                }
511
512                self.quantized_vectors.pop();
513                self.ids.pop();
514                if let Some(ref mut originals) = self.original_vectors {
515                    originals.pop();
516                }
517
518                deleted += 1;
519            }
520        }
521
522        deleted
523    }
524
525    /// Get index statistics
526    pub fn stats(&self) -> SQStats {
527        let bytes_per_quantized = match self.config.quantization_type {
528            QuantizationType::SQ4 => self.config.dimensions.div_ceil(2),
529            QuantizationType::SQ8 => self.config.dimensions,
530            QuantizationType::SQ16 => self.config.dimensions * 2,
531        };
532
533        let original_memory = self.quantized_vectors.len() * self.config.dimensions * 4;
534        let quantized_memory = self.quantized_vectors.len() * bytes_per_quantized;
535
536        SQStats {
537            num_vectors: self.quantized_vectors.len(),
538            original_memory_bytes: original_memory,
539            quantized_memory_bytes: quantized_memory,
540            compression_ratio: if quantized_memory > 0 {
541                original_memory as f32 / quantized_memory as f32
542            } else {
543                0.0
544            },
545            quantization_type: self.config.quantization_type,
546        }
547    }
548
549    /// Get number of vectors in the index
550    pub fn len(&self) -> usize {
551        self.quantized_vectors.len()
552    }
553
554    /// Check if the index is empty
555    pub fn is_empty(&self) -> bool {
556        self.quantized_vectors.is_empty()
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    fn create_test_vectors() -> Vec<Vec<f32>> {
565        vec![
566            vec![1.0, 0.0, 0.0, 0.0],
567            vec![0.0, 1.0, 0.0, 0.0],
568            vec![0.0, 0.0, 1.0, 0.0],
569            vec![0.5, 0.5, 0.0, 0.0],
570            vec![0.0, 0.5, 0.5, 0.0],
571        ]
572    }
573
574    #[test]
575    fn test_sq8_basic() {
576        let config = SQConfig {
577            quantization_type: QuantizationType::SQ8,
578            dimensions: 4,
579            metric: DistanceMetric::Cosine,
580            store_originals: false,
581        };
582
583        let mut index = SQIndex::new(config);
584        let vectors = create_test_vectors();
585        let ids: Vec<String> = (0..vectors.len()).map(|i| format!("v{}", i)).collect();
586
587        index.add(&ids, &vectors).unwrap();
588
589        assert_eq!(index.len(), 5);
590
591        // Search for similar to first vector
592        let results = index.search(&vectors[0], 3).unwrap();
593        assert_eq!(results.len(), 3);
594        assert_eq!(results[0].id, "v0"); // Most similar to itself
595    }
596
597    #[test]
598    fn test_sq4_compression() {
599        let config = SQConfig {
600            quantization_type: QuantizationType::SQ4,
601            dimensions: 8,
602            metric: DistanceMetric::Cosine,
603            store_originals: false,
604        };
605
606        let mut index = SQIndex::new(config);
607        let vectors = vec![
608            vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
609            vec![0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
610        ];
611        let ids = vec!["a".to_string(), "b".to_string()];
612
613        index.add(&ids, &vectors).unwrap();
614
615        let stats = index.stats();
616        // SQ4 should give ~8x compression (4 bits vs 32 bits)
617        assert!(stats.compression_ratio > 6.0);
618    }
619
620    #[test]
621    fn test_sq16_accuracy() {
622        let config = SQConfig {
623            quantization_type: QuantizationType::SQ16,
624            dimensions: 4,
625            metric: DistanceMetric::Cosine,
626            store_originals: true,
627        };
628
629        let mut index = SQIndex::new(config);
630        let vectors = create_test_vectors();
631        let ids: Vec<String> = (0..vectors.len()).map(|i| format!("v{}", i)).collect();
632
633        index.add(&ids, &vectors).unwrap();
634
635        // With SQ16 and originals stored, accuracy should be high
636        let results = index.search(&vectors[0], 2).unwrap();
637        assert!(results[0].score > 0.99); // Should be very close to 1.0
638    }
639
640    #[test]
641    fn test_delete() {
642        let config = SQConfig {
643            quantization_type: QuantizationType::SQ8,
644            dimensions: 4,
645            metric: DistanceMetric::Cosine,
646            store_originals: false,
647        };
648
649        let mut index = SQIndex::new(config);
650        let vectors = create_test_vectors();
651        let ids: Vec<String> = (0..vectors.len()).map(|i| format!("v{}", i)).collect();
652
653        index.add(&ids, &vectors).unwrap();
654        assert_eq!(index.len(), 5);
655
656        let deleted = index.delete(&["v0".to_string(), "v2".to_string()]);
657        assert_eq!(deleted, 2);
658        assert_eq!(index.len(), 3);
659    }
660
661    #[test]
662    fn test_dequantize_roundtrip() {
663        let config = SQConfig {
664            quantization_type: QuantizationType::SQ8,
665            dimensions: 4,
666            metric: DistanceMetric::Cosine,
667            store_originals: false,
668        };
669
670        let mut index = SQIndex::new(config);
671        let vectors = vec![vec![0.1, 0.5, 0.3, 0.9]];
672        let _ids = vec!["test".to_string()];
673
674        index.train(&vectors).unwrap();
675        let quantized = index.quantize_vector(&vectors[0]);
676        let dequantized = index.dequantize_vector(&quantized);
677
678        // Check values are approximately equal (within quantization error)
679        for (orig, deq) in vectors[0].iter().zip(dequantized.iter()) {
680            assert!((orig - deq).abs() < 0.05, "Dequantized value too different");
681        }
682    }
683
684    #[test]
685    fn test_update_existing() {
686        let config = SQConfig {
687            quantization_type: QuantizationType::SQ8,
688            dimensions: 4,
689            metric: DistanceMetric::Cosine,
690            store_originals: false,
691        };
692
693        let mut index = SQIndex::new(config);
694        let vectors = vec![vec![1.0, 0.0, 0.0, 0.0]];
695        let ids = vec!["v1".to_string()];
696
697        index.add(&ids, &vectors).unwrap();
698        assert_eq!(index.len(), 1);
699
700        // Update with same ID
701        let new_vectors = vec![vec![0.0, 1.0, 0.0, 0.0]];
702        index.add(&ids, &new_vectors).unwrap();
703        assert_eq!(index.len(), 1); // Should still be 1
704
705        // Search should now find the updated vector
706        let results = index.search(&[0.0, 1.0, 0.0, 0.0], 1).unwrap();
707        assert_eq!(results[0].id, "v1");
708    }
709}