oxirs_vec/
sq.rs

1//! Scalar Quantization (SQ) for efficient vector compression
2//!
3//! This module implements Scalar Quantization, a simpler and faster alternative to
4//! Product Quantization (PQ) that quantizes each vector dimension independently.
5//!
6//! SQ is particularly useful when:
7//! - Training time needs to be minimal
8//! - Simple, predictable compression is preferred
9//! - Memory reduction is more important than extreme accuracy
10//! - Real-time index updates are required (no retraining needed)
11
12use anyhow::{anyhow, Result};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16/// Scalar quantization configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SqConfig {
19    /// Number of bits per scalar (4 or 8)
20    pub bits: u8,
21    /// Quantization mode
22    pub mode: QuantizationMode,
23    /// Whether to normalize vectors before quantization
24    pub normalize: bool,
25    /// Number of training vectors to use for range estimation
26    pub training_samples: usize,
27}
28
29impl Default for SqConfig {
30    fn default() -> Self {
31        Self {
32            bits: 8,
33            mode: QuantizationMode::Uniform,
34            normalize: false,
35            training_samples: 10_000,
36        }
37    }
38}
39
40/// Quantization mode
41#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
42pub enum QuantizationMode {
43    /// Uniform quantization across global min/max
44    Uniform,
45    /// Per-dimension quantization with individual min/max
46    PerDimension,
47    /// Quantization using mean and standard deviation (more robust to outliers)
48    MeanStd,
49}
50
51/// Quantization parameters for a single dimension
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct QuantizationParams {
54    /// Minimum value in the dimension
55    pub min: f32,
56    /// Maximum value in the dimension
57    pub max: f32,
58    /// Scale factor for quantization
59    pub scale: f32,
60    /// Offset for quantization
61    pub offset: f32,
62}
63
64impl QuantizationParams {
65    /// Create quantization parameters from min/max values
66    pub fn from_range(min: f32, max: f32, bits: u8) -> Self {
67        let levels = (1 << bits) - 1;
68        let range = max - min;
69        let scale = if range > 1e-8 {
70            levels as f32 / range
71        } else {
72            1.0
73        };
74
75        Self {
76            min,
77            max,
78            scale,
79            offset: min,
80        }
81    }
82
83    /// Create parameters from mean and standard deviation (3-sigma range)
84    pub fn from_mean_std(mean: f32, std: f32, bits: u8) -> Self {
85        let min = mean - 3.0 * std;
86        let max = mean + 3.0 * std;
87        Self::from_range(min, max, bits)
88    }
89
90    /// Quantize a value
91    pub fn quantize(&self, value: f32) -> u8 {
92        let normalized = (value - self.offset) * self.scale;
93        normalized.clamp(0.0, 255.0) as u8
94    }
95
96    /// Dequantize a value
97    pub fn dequantize(&self, quantized: u8) -> f32 {
98        (quantized as f32 / self.scale) + self.offset
99    }
100}
101
102/// Scalar quantization index statistics
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct SqStats {
105    /// Number of vectors in index
106    pub vector_count: usize,
107    /// Vector dimensionality
108    pub dimensions: usize,
109    /// Number of bits per scalar
110    pub bits: u8,
111    /// Compression ratio achieved
112    pub compression_ratio: f32,
113    /// Memory usage in bytes
114    pub memory_bytes: usize,
115    /// Average quantization error
116    pub avg_quantization_error: f32,
117}
118
119/// Scalar quantization vector index
120pub struct SqIndex {
121    config: SqConfig,
122    dimensions: usize,
123    quantization_params: Vec<QuantizationParams>,
124    quantized_vectors: Vec<Vec<u8>>,
125    uri_to_id: HashMap<String, usize>,
126    id_to_uri: Vec<String>,
127}
128
129impl SqIndex {
130    /// Create a new SQ index
131    pub fn new(config: SqConfig, dimensions: usize) -> Self {
132        Self {
133            config,
134            dimensions,
135            quantization_params: Vec::new(),
136            quantized_vectors: Vec::new(),
137            uri_to_id: HashMap::new(),
138            id_to_uri: Vec::new(),
139        }
140    }
141
142    /// Train quantization parameters from training vectors
143    pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
144        if training_vectors.is_empty() {
145            return Err(anyhow!("No training vectors provided"));
146        }
147
148        let dim = training_vectors[0].len();
149        if dim != self.dimensions {
150            return Err(anyhow!(
151                "Training vector dimensions ({}) don't match index dimensions ({})",
152                dim,
153                self.dimensions
154            ));
155        }
156
157        // Limit training samples
158        let sample_count = training_vectors.len().min(self.config.training_samples);
159        let samples = &training_vectors[..sample_count];
160
161        match self.config.mode {
162            QuantizationMode::Uniform => {
163                self.train_uniform(samples)?;
164            }
165            QuantizationMode::PerDimension => {
166                self.train_per_dimension(samples)?;
167            }
168            QuantizationMode::MeanStd => {
169                self.train_mean_std(samples)?;
170            }
171        }
172
173        tracing::info!(
174            "Trained SQ index: mode={:?}, bits={}, samples={}, dimensions={}",
175            self.config.mode,
176            self.config.bits,
177            sample_count,
178            self.dimensions
179        );
180
181        Ok(())
182    }
183
184    /// Train uniform quantization (single global range)
185    fn train_uniform(&mut self, samples: &[Vec<f32>]) -> Result<()> {
186        let mut global_min = f32::INFINITY;
187        let mut global_max = f32::NEG_INFINITY;
188
189        for vector in samples {
190            for &value in vector {
191                global_min = global_min.min(value);
192                global_max = global_max.max(value);
193            }
194        }
195
196        let params = QuantizationParams::from_range(global_min, global_max, self.config.bits);
197        self.quantization_params = vec![params; self.dimensions];
198
199        Ok(())
200    }
201
202    /// Train per-dimension quantization
203    fn train_per_dimension(&mut self, samples: &[Vec<f32>]) -> Result<()> {
204        let mut dim_mins = vec![f32::INFINITY; self.dimensions];
205        let mut dim_maxs = vec![f32::NEG_INFINITY; self.dimensions];
206
207        for vector in samples {
208            for (d, &value) in vector.iter().enumerate() {
209                dim_mins[d] = dim_mins[d].min(value);
210                dim_maxs[d] = dim_maxs[d].max(value);
211            }
212        }
213
214        self.quantization_params = dim_mins
215            .into_iter()
216            .zip(dim_maxs)
217            .map(|(min, max)| QuantizationParams::from_range(min, max, self.config.bits))
218            .collect();
219
220        Ok(())
221    }
222
223    /// Train using mean and standard deviation
224    fn train_mean_std(&mut self, samples: &[Vec<f32>]) -> Result<()> {
225        let n = samples.len() as f32;
226        let mut dim_means = vec![0.0; self.dimensions];
227        let mut dim_stds = vec![0.0; self.dimensions];
228
229        // Calculate means
230        for vector in samples {
231            for (d, &value) in vector.iter().enumerate() {
232                dim_means[d] += value;
233            }
234        }
235        for mean in &mut dim_means {
236            *mean /= n;
237        }
238
239        // Calculate standard deviations
240        for vector in samples {
241            for (d, &value) in vector.iter().enumerate() {
242                let diff = value - dim_means[d];
243                dim_stds[d] += diff * diff;
244            }
245        }
246        for std in &mut dim_stds {
247            *std = (*std / n).sqrt();
248        }
249
250        self.quantization_params = dim_means
251            .into_iter()
252            .zip(dim_stds)
253            .map(|(mean, std)| QuantizationParams::from_mean_std(mean, std, self.config.bits))
254            .collect();
255
256        Ok(())
257    }
258
259    /// Add a vector to the index
260    pub fn add(&mut self, uri: String, vector: Vec<f32>) -> Result<()> {
261        if vector.len() != self.dimensions {
262            return Err(anyhow!(
263                "Vector dimensions ({}) don't match index dimensions ({})",
264                vector.len(),
265                self.dimensions
266            ));
267        }
268
269        if self.quantization_params.is_empty() {
270            return Err(anyhow!(
271                "Index not trained. Call train() before adding vectors."
272            ));
273        }
274
275        let quantized = self.quantize_vector(&vector);
276        let id = self.quantized_vectors.len();
277
278        self.uri_to_id.insert(uri.clone(), id);
279        self.id_to_uri.push(uri);
280        self.quantized_vectors.push(quantized);
281
282        Ok(())
283    }
284
285    /// Quantize a vector
286    fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
287        vector
288            .iter()
289            .zip(&self.quantization_params)
290            .map(|(&value, params)| params.quantize(value))
291            .collect()
292    }
293
294    /// Dequantize a vector
295    fn dequantize_vector(&self, quantized: &[u8]) -> Vec<f32> {
296        quantized
297            .iter()
298            .zip(&self.quantization_params)
299            .map(|(&q, params)| params.dequantize(q))
300            .collect()
301    }
302
303    /// Search for k nearest neighbors
304    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
305        if query.len() != self.dimensions {
306            return Err(anyhow!(
307                "Query dimensions ({}) don't match index dimensions ({})",
308                query.len(),
309                self.dimensions
310            ));
311        }
312
313        if self.quantized_vectors.is_empty() {
314            return Ok(Vec::new());
315        }
316
317        // Quantize query
318        let query_quantized = self.quantize_vector(query);
319
320        // Compute distances
321        let mut distances: Vec<(usize, f32)> = self
322            .quantized_vectors
323            .iter()
324            .enumerate()
325            .map(|(id, vec)| {
326                let dist = self.asymmetric_distance(&query_quantized, vec);
327                (id, dist)
328            })
329            .collect();
330
331        // Sort by distance
332        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
333
334        // Return top k with URIs
335        Ok(distances
336            .into_iter()
337            .take(k)
338            .map(|(id, dist)| (self.id_to_uri[id].clone(), dist))
339            .collect())
340    }
341
342    /// Compute asymmetric distance (query vs quantized vector)
343    /// This provides better accuracy than symmetric distance
344    fn asymmetric_distance(&self, query_quantized: &[u8], db_quantized: &[u8]) -> f32 {
345        query_quantized
346            .iter()
347            .zip(db_quantized)
348            .zip(&self.quantization_params)
349            .map(|((&q1, &q2), params)| {
350                let v1 = params.dequantize(q1);
351                let v2 = params.dequantize(q2);
352                let diff = v1 - v2;
353                diff * diff
354            })
355            .sum::<f32>()
356            .sqrt()
357    }
358
359    /// Get statistics about the index
360    pub fn stats(&self) -> SqStats {
361        let vector_count = self.quantized_vectors.len();
362        let bits_per_vector = self.dimensions * self.config.bits as usize;
363        let bytes_per_vector = (bits_per_vector + 7) / 8;
364        let memory_bytes = vector_count * bytes_per_vector;
365
366        let original_bytes = vector_count * self.dimensions * 4; // f32 = 4 bytes
367        let compression_ratio = if memory_bytes > 0 {
368            original_bytes as f32 / memory_bytes as f32
369        } else {
370            0.0
371        };
372
373        SqStats {
374            vector_count,
375            dimensions: self.dimensions,
376            bits: self.config.bits,
377            compression_ratio,
378            memory_bytes,
379            avg_quantization_error: self.estimate_quantization_error(),
380        }
381    }
382
383    /// Estimate average quantization error
384    fn estimate_quantization_error(&self) -> f32 {
385        if self.quantized_vectors.is_empty() {
386            return 0.0;
387        }
388
389        let sample_size = self.quantized_vectors.len().min(100);
390        let mut total_error = 0.0;
391
392        for quantized in self.quantized_vectors.iter().take(sample_size) {
393            let dequantized = self.dequantize_vector(quantized);
394            let reconstructed_quantized = self.quantize_vector(&dequantized);
395
396            // Error is difference between original quantized and re-quantized
397            let error: f32 = quantized
398                .iter()
399                .zip(&reconstructed_quantized)
400                .map(|(&a, &b)| (a as f32 - b as f32).abs())
401                .sum();
402
403            total_error += error / self.dimensions as f32;
404        }
405
406        total_error / sample_size as f32
407    }
408
409    /// Get vector by URI
410    pub fn get(&self, uri: &str) -> Option<Vec<f32>> {
411        self.uri_to_id
412            .get(uri)
413            .and_then(|&id| self.quantized_vectors.get(id))
414            .map(|q| self.dequantize_vector(q))
415    }
416
417    /// Get number of vectors
418    pub fn len(&self) -> usize {
419        self.quantized_vectors.len()
420    }
421
422    /// Check if empty
423    pub fn is_empty(&self) -> bool {
424        self.quantized_vectors.is_empty()
425    }
426
427    /// Get configuration
428    pub fn config(&self) -> &SqConfig {
429        &self.config
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_quantization_params() {
439        let params = QuantizationParams::from_range(0.0, 1.0, 8);
440        assert_eq!(params.quantize(0.0), 0);
441        assert_eq!(params.quantize(1.0), 255);
442        assert_eq!(params.quantize(0.5), 127);
443
444        let dequantized = params.dequantize(127);
445        assert!((dequantized - 0.5).abs() < 0.01);
446    }
447
448    #[test]
449    fn test_sq_index_creation() {
450        let config = SqConfig::default();
451        let index = SqIndex::new(config, 128);
452        assert_eq!(index.dimensions, 128);
453        assert!(index.is_empty());
454    }
455
456    #[test]
457    fn test_sq_training() {
458        let config = SqConfig {
459            bits: 8,
460            mode: QuantizationMode::PerDimension,
461            ..Default::default()
462        };
463
464        let mut index = SqIndex::new(config, 4);
465
466        let training_data = vec![
467            vec![0.0, 1.0, 2.0, 3.0],
468            vec![1.0, 2.0, 3.0, 4.0],
469            vec![2.0, 3.0, 4.0, 5.0],
470        ];
471
472        assert!(index.train(&training_data).is_ok());
473        assert_eq!(index.quantization_params.len(), 4);
474    }
475
476    #[test]
477    fn test_sq_add_and_search() {
478        let config = SqConfig::default();
479        let mut index = SqIndex::new(config, 4);
480
481        let training_data = vec![
482            vec![0.0, 0.0, 0.0, 0.0],
483            vec![1.0, 1.0, 1.0, 1.0],
484            vec![2.0, 2.0, 2.0, 2.0],
485        ];
486
487        index.train(&training_data).unwrap();
488
489        index
490            .add("vec1".to_string(), vec![0.1, 0.1, 0.1, 0.1])
491            .unwrap();
492        index
493            .add("vec2".to_string(), vec![0.9, 0.9, 0.9, 0.9])
494            .unwrap();
495        index
496            .add("vec3".to_string(), vec![1.8, 1.8, 1.8, 1.8])
497            .unwrap();
498
499        let query = vec![0.0, 0.0, 0.0, 0.0];
500        let results = index.search(&query, 2).unwrap();
501
502        assert_eq!(results.len(), 2);
503        assert_eq!(results[0].0, "vec1");
504    }
505
506    #[test]
507    fn test_sq_stats() {
508        let config = SqConfig {
509            bits: 4,
510            ..Default::default()
511        };
512        let mut index = SqIndex::new(config, 128);
513
514        let training_data: Vec<Vec<f32>> =
515            (0..100).map(|_| (0..128).map(|_| 0.5).collect()).collect();
516
517        index.train(&training_data).unwrap();
518
519        for i in 0..10 {
520            index.add(format!("vec{}", i), vec![0.5; 128]).unwrap();
521        }
522
523        let stats = index.stats();
524        assert_eq!(stats.vector_count, 10);
525        assert_eq!(stats.dimensions, 128);
526        assert_eq!(stats.bits, 4);
527        assert!(stats.compression_ratio > 1.0);
528    }
529
530    #[test]
531    fn test_different_quantization_modes() {
532        let dimensions = 4;
533        let training_data = vec![
534            vec![0.0, 1.0, 2.0, 3.0],
535            vec![1.0, 2.0, 3.0, 4.0],
536            vec![2.0, 3.0, 4.0, 5.0],
537        ];
538
539        // Test Uniform mode
540        let mut index_uniform = SqIndex::new(
541            SqConfig {
542                mode: QuantizationMode::Uniform,
543                ..Default::default()
544            },
545            dimensions,
546        );
547        assert!(index_uniform.train(&training_data).is_ok());
548
549        // Test PerDimension mode
550        let mut index_per_dim = SqIndex::new(
551            SqConfig {
552                mode: QuantizationMode::PerDimension,
553                ..Default::default()
554            },
555            dimensions,
556        );
557        assert!(index_per_dim.train(&training_data).is_ok());
558
559        // Test MeanStd mode
560        let mut index_mean_std = SqIndex::new(
561            SqConfig {
562                mode: QuantizationMode::MeanStd,
563                ..Default::default()
564            },
565            dimensions,
566        );
567        assert!(index_mean_std.train(&training_data).is_ok());
568    }
569}