Skip to main content

oxirs_vec/
quantizer.rs

1/// A single sub-space codebook with `n_clusters` centroids of size `sub_dimension`.
2#[derive(Debug, Clone)]
3pub struct Codebook {
4    /// Centroids stored as row-major vectors of length `sub_dimension`.
5    pub centroids: Vec<Vec<f32>>,
6    /// The number of coordinates per centroid.
7    pub sub_dimension: usize,
8}
9
10impl Codebook {
11    /// Create an empty codebook for a given sub-dimension.
12    pub fn new(sub_dimension: usize) -> Self {
13        Self {
14            centroids: Vec::new(),
15            sub_dimension,
16        }
17    }
18
19    /// Return the index (code) of the centroid nearest to `sub_vector`.
20    /// Uses squared Euclidean distance.
21    pub fn nearest_centroid(&self, sub_vector: &[f32]) -> u8 {
22        let mut best_idx = 0u8;
23        let mut best_dist = f32::INFINITY;
24        for (idx, centroid) in self.centroids.iter().enumerate() {
25            let dist: f32 = sub_vector
26                .iter()
27                .zip(centroid.iter())
28                .map(|(a, b)| (a - b) * (a - b))
29                .sum();
30            if dist < best_dist {
31                best_dist = dist;
32                best_idx = idx as u8;
33            }
34        }
35        best_idx
36    }
37
38    /// Return the centroid for a code, or `None` if the code is out of range.
39    pub fn centroid(&self, code: u8) -> Option<&Vec<f32>> {
40        self.centroids.get(code as usize)
41    }
42}
43
44/// A quantized representation of a vector.
45#[derive(Debug, Clone)]
46pub struct QuantizedVector {
47    /// One byte code per sub-space.
48    pub codes: Vec<u8>,
49    /// The original vector dimensionality.
50    pub original_dim: usize,
51}
52
53/// An approximately reconstructed vector with the quantization error.
54#[derive(Debug, Clone)]
55pub struct ReconstructedVector {
56    /// The reconstructed approximation.
57    pub vector: Vec<f32>,
58    /// Mean squared reconstruction error.
59    pub quantization_error: f32,
60}
61
62/// Configuration for the product quantizer.
63#[derive(Debug, Clone, Copy)]
64pub struct QuantizerConfig {
65    /// Number of sub-spaces (must divide the vector dimension evenly).
66    pub n_subspaces: usize,
67    /// Number of clusters per codebook (≤ 256 because codes are `u8`).
68    pub n_clusters: usize,
69}
70
71impl Default for QuantizerConfig {
72    fn default() -> Self {
73        Self {
74            n_subspaces: 4,
75            n_clusters: 256,
76        }
77    }
78}
79
80/// Errors from quantizer operations.
81#[derive(Debug)]
82pub enum QuantizerError {
83    /// Training or encoding was attempted before training.
84    NotTrained,
85    /// The vector dimension is incompatible with the codebooks.
86    DimensionMismatch,
87    /// Not enough training vectors (need at least `n_clusters` per sub-space).
88    InsufficientData(usize),
89    /// The configuration is invalid (e.g. n_clusters > 256).
90    InvalidConfig(String),
91}
92
93impl std::fmt::Display for QuantizerError {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            Self::NotTrained => write!(f, "Quantizer is not trained"),
97            Self::DimensionMismatch => write!(f, "Vector dimension mismatch"),
98            Self::InsufficientData(n) => {
99                write!(f, "Insufficient training data: {n} vectors")
100            }
101            Self::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
102        }
103    }
104}
105
106impl std::error::Error for QuantizerError {}
107
108/// Product Quantizer that compresses high-dimensional vectors.
109#[derive(Debug)]
110pub struct Quantizer {
111    config: QuantizerConfig,
112    codebooks: Vec<Codebook>,
113}
114
115impl Quantizer {
116    /// Create a new, un-trained quantizer with the given configuration.
117    pub fn new(config: QuantizerConfig) -> Self {
118        Self {
119            config,
120            codebooks: Vec::new(),
121        }
122    }
123
124    /// Train the quantizer using k-means on the provided data vectors.
125    ///
126    /// All vectors must have the same dimensionality, which must be divisible by `n_subspaces`.
127    pub fn train(&mut self, data: &[Vec<f32>]) -> Result<(), QuantizerError> {
128        // Validate configuration
129        if self.config.n_clusters > 256 {
130            return Err(QuantizerError::InvalidConfig(
131                "n_clusters must be ≤ 256".to_string(),
132            ));
133        }
134        if self.config.n_subspaces == 0 {
135            return Err(QuantizerError::InvalidConfig(
136                "n_subspaces must be > 0".to_string(),
137            ));
138        }
139        if data.is_empty() {
140            return Err(QuantizerError::InsufficientData(0));
141        }
142        if data.len() < self.config.n_clusters {
143            return Err(QuantizerError::InsufficientData(data.len()));
144        }
145
146        let dim = data[0].len();
147        if dim % self.config.n_subspaces != 0 {
148            return Err(QuantizerError::InvalidConfig(format!(
149                "Dimension {dim} is not divisible by n_subspaces {}",
150                self.config.n_subspaces
151            )));
152        }
153        let sub_dim = dim / self.config.n_subspaces;
154
155        // Validate all vectors have the same dimension
156        for v in data {
157            if v.len() != dim {
158                return Err(QuantizerError::DimensionMismatch);
159            }
160        }
161
162        // Train one codebook per sub-space
163        let actual_k = self.config.n_clusters.min(data.len());
164        let mut codebooks = Vec::with_capacity(self.config.n_subspaces);
165        for sub in 0..self.config.n_subspaces {
166            let start = sub * sub_dim;
167            let end = start + sub_dim;
168            // Extract sub-vectors for this sub-space
169            let sub_vecs: Vec<Vec<f32>> = data.iter().map(|v| v[start..end].to_vec()).collect();
170            let cb = kmeans_train(&sub_vecs, actual_k, sub_dim, 50)?;
171            codebooks.push(cb);
172        }
173        self.codebooks = codebooks;
174        Ok(())
175    }
176
177    /// Encode a single vector into a `QuantizedVector`.
178    pub fn encode(&self, vector: &[f32]) -> Result<QuantizedVector, QuantizerError> {
179        if self.codebooks.is_empty() {
180            return Err(QuantizerError::NotTrained);
181        }
182        let dim = vector.len();
183        let expected_dim = self.codebooks.len() * self.codebooks[0].sub_dimension;
184        if dim != expected_dim {
185            return Err(QuantizerError::DimensionMismatch);
186        }
187        let sub_dim = self.codebooks[0].sub_dimension;
188        let codes: Vec<u8> = self
189            .codebooks
190            .iter()
191            .enumerate()
192            .map(|(i, cb)| {
193                let start = i * sub_dim;
194                let end = start + sub_dim;
195                cb.nearest_centroid(&vector[start..end])
196            })
197            .collect();
198        Ok(QuantizedVector {
199            codes,
200            original_dim: dim,
201        })
202    }
203
204    /// Decode a `QuantizedVector` back to an approximate vector.
205    pub fn decode(&self, qv: &QuantizedVector) -> Result<ReconstructedVector, QuantizerError> {
206        if self.codebooks.is_empty() {
207            return Err(QuantizerError::NotTrained);
208        }
209        if qv.codes.len() != self.codebooks.len() {
210            return Err(QuantizerError::DimensionMismatch);
211        }
212        let mut result = Vec::with_capacity(qv.original_dim);
213        for (cb, &code) in self.codebooks.iter().zip(qv.codes.iter()) {
214            match cb.centroid(code) {
215                Some(c) => result.extend_from_slice(c),
216                None => return Err(QuantizerError::DimensionMismatch),
217            }
218        }
219        let error = 0.0_f32; // error not tracked at decode time without original
220        Ok(ReconstructedVector {
221            vector: result,
222            quantization_error: error,
223        })
224    }
225
226    /// Encode a batch of vectors.
227    pub fn encode_batch(
228        &self,
229        vectors: &[Vec<f32>],
230    ) -> Result<Vec<QuantizedVector>, QuantizerError> {
231        vectors.iter().map(|v| self.encode(v)).collect()
232    }
233
234    /// Return true if the quantizer has been trained.
235    pub fn is_trained(&self) -> bool {
236        !self.codebooks.is_empty()
237    }
238
239    /// Compute the compression ratio: bytes of original / bytes of compressed.
240    ///
241    /// Original: `original_dim * 4` bytes (f32).
242    /// Compressed: `n_subspaces` bytes (one u8 code per sub-space).
243    pub fn compression_ratio(&self, original_dim: usize) -> f32 {
244        let n_sub = self.config.n_subspaces;
245        if n_sub == 0 {
246            return 1.0;
247        }
248        (original_dim as f32 * 4.0) / n_sub as f32
249    }
250
251    /// Return the number of codebooks (equals `n_subspaces` after training).
252    pub fn codebook_count(&self) -> usize {
253        self.codebooks.len()
254    }
255}
256
257// ---- K-means implementation ----
258
259/// Run Lloyd's k-means algorithm on `sub_vecs` for `k` clusters and `max_iters` iterations.
260fn kmeans_train(
261    sub_vecs: &[Vec<f32>],
262    k: usize,
263    sub_dim: usize,
264    max_iters: usize,
265) -> Result<Codebook, QuantizerError> {
266    let n = sub_vecs.len();
267    if k == 0 || n == 0 {
268        return Err(QuantizerError::InvalidConfig(
269            "k and n must be > 0".to_string(),
270        ));
271    }
272
273    // Initialise centroids using k-means++ style seeded selection.
274    let mut centroids = kmeans_init(sub_vecs, k, sub_dim);
275
276    for _iter in 0..max_iters {
277        // Assignment step
278        let assignments: Vec<usize> = sub_vecs
279            .iter()
280            .map(|v| nearest_centroid_idx(&centroids, v))
281            .collect();
282
283        // Update step
284        let mut sums: Vec<Vec<f64>> = vec![vec![0.0_f64; sub_dim]; k];
285        let mut counts: Vec<usize> = vec![0; k];
286
287        for (v, &a) in sub_vecs.iter().zip(assignments.iter()) {
288            for (d, &x) in v.iter().enumerate() {
289                sums[a][d] += x as f64;
290            }
291            counts[a] += 1;
292        }
293
294        let mut converged = true;
295        for (ci, centroid) in centroids.iter_mut().enumerate() {
296            if counts[ci] == 0 {
297                continue;
298            }
299            for d in 0..sub_dim {
300                let new_val = (sums[ci][d] / counts[ci] as f64) as f32;
301                if (new_val - centroid[d]).abs() > 1e-6 {
302                    converged = false;
303                }
304                centroid[d] = new_val;
305            }
306        }
307        if converged {
308            break;
309        }
310    }
311
312    Ok(Codebook {
313        centroids,
314        sub_dimension: sub_dim,
315    })
316}
317
318/// K-means++ inspired initialisation: first centroid random, then D² sampling.
319fn kmeans_init(data: &[Vec<f32>], k: usize, sub_dim: usize) -> Vec<Vec<f32>> {
320    let n = data.len();
321    // Build a simple deterministic index map based on dimension sums to avoid
322    // requiring async/Random trait in tests. Use a simple hash for selection.
323    let mut chosen_indices: Vec<usize> = Vec::with_capacity(k);
324
325    // First centroid: use index derived from data dimensions (deterministic for tests)
326    let first_idx = (sub_dim * 7 + n * 3) % n;
327    chosen_indices.push(first_idx);
328
329    let mut distances: Vec<f32> = vec![f32::INFINITY; n];
330    for _ in 1..k {
331        // Update distances
332        for (i, v) in data.iter().enumerate() {
333            let last = &data[*chosen_indices.last().unwrap_or(&0)];
334            let dist: f32 = v
335                .iter()
336                .zip(last.iter())
337                .map(|(a, b)| (a - b) * (a - b))
338                .sum();
339            if dist < distances[i] {
340                distances[i] = dist;
341            }
342        }
343        // Choose next: the point farthest from its nearest chosen centroid
344        let next_idx = distances
345            .iter()
346            .enumerate()
347            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
348            .map(|(i, _)| i)
349            .unwrap_or(0);
350        chosen_indices.push(next_idx);
351    }
352
353    chosen_indices
354        .into_iter()
355        .map(|i| data[i % n].clone())
356        .collect()
357}
358
359/// Find index of nearest centroid using squared Euclidean distance.
360fn nearest_centroid_idx(centroids: &[Vec<f32>], v: &[f32]) -> usize {
361    centroids
362        .iter()
363        .enumerate()
364        .min_by(|(_, a), (_, b)| {
365            let da: f32 = a.iter().zip(v.iter()).map(|(x, y)| (x - y) * (x - y)).sum();
366            let db: f32 = b.iter().zip(v.iter()).map(|(x, y)| (x - y) * (x - y)).sum();
367            da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
368        })
369        .map(|(i, _)| i)
370        .unwrap_or(0)
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    fn make_config(n_subspaces: usize, n_clusters: usize) -> QuantizerConfig {
378        QuantizerConfig {
379            n_subspaces,
380            n_clusters,
381        }
382    }
383
384    /// Generate n vectors of dimension d with distinct cluster patterns.
385    fn make_data(n: usize, dim: usize) -> Vec<Vec<f32>> {
386        (0..n)
387            .map(|i| (0..dim).map(|d| (i as f32 * 0.1) + d as f32).collect())
388            .collect()
389    }
390
391    // --- is_trained ---
392
393    #[test]
394    fn test_not_trained_initially() {
395        let q = Quantizer::new(make_config(4, 8));
396        assert!(!q.is_trained());
397    }
398
399    #[test]
400    fn test_is_trained_after_train() {
401        let mut q = Quantizer::new(make_config(4, 8));
402        let data = make_data(32, 8);
403        q.train(&data).unwrap();
404        assert!(q.is_trained());
405    }
406
407    // --- train errors ---
408
409    #[test]
410    fn test_train_empty_data_error() {
411        let mut q = Quantizer::new(make_config(4, 8));
412        let err = q.train(&[]);
413        assert!(matches!(err, Err(QuantizerError::InsufficientData(0))));
414    }
415
416    #[test]
417    fn test_train_insufficient_data_error() {
418        let mut q = Quantizer::new(make_config(2, 10));
419        let data = make_data(5, 4); // 5 < 10 clusters
420        let err = q.train(&data);
421        assert!(matches!(err, Err(QuantizerError::InsufficientData(_))));
422    }
423
424    #[test]
425    fn test_train_n_clusters_over_256() {
426        let mut q = Quantizer::new(make_config(2, 300));
427        let data = make_data(400, 8);
428        let err = q.train(&data);
429        assert!(matches!(err, Err(QuantizerError::InvalidConfig(_))));
430    }
431
432    #[test]
433    fn test_train_dimension_not_divisible() {
434        let mut q = Quantizer::new(make_config(3, 4)); // 3 does not divide 8
435        let data = make_data(20, 8);
436        let err = q.train(&data);
437        assert!(matches!(err, Err(QuantizerError::InvalidConfig(_))));
438    }
439
440    // --- encode / decode ---
441
442    #[test]
443    fn test_encode_not_trained_error() {
444        let q = Quantizer::new(make_config(4, 8));
445        let v = vec![0.0f32; 8];
446        assert!(matches!(q.encode(&v), Err(QuantizerError::NotTrained)));
447    }
448
449    #[test]
450    fn test_encode_dimension_mismatch() {
451        let mut q = Quantizer::new(make_config(2, 4));
452        let data = make_data(16, 8);
453        q.train(&data).unwrap();
454        let v = vec![0.0f32; 4]; // wrong dim
455        assert!(matches!(
456            q.encode(&v),
457            Err(QuantizerError::DimensionMismatch)
458        ));
459    }
460
461    #[test]
462    fn test_encode_codes_length() {
463        let mut q = Quantizer::new(make_config(4, 8));
464        let data = make_data(32, 8);
465        q.train(&data).unwrap();
466        let v = vec![1.0f32; 8];
467        let qv = q.encode(&v).unwrap();
468        assert_eq!(qv.codes.len(), 4); // one code per sub-space
469    }
470
471    #[test]
472    fn test_encode_original_dim_stored() {
473        let mut q = Quantizer::new(make_config(2, 4));
474        let data = make_data(16, 8);
475        q.train(&data).unwrap();
476        let v = vec![0.0f32; 8];
477        let qv = q.encode(&v).unwrap();
478        assert_eq!(qv.original_dim, 8);
479    }
480
481    #[test]
482    fn test_decode_produces_correct_dim() {
483        let mut q = Quantizer::new(make_config(4, 8));
484        let data = make_data(32, 8);
485        q.train(&data).unwrap();
486        let v = vec![0.5f32; 8];
487        let qv = q.encode(&v).unwrap();
488        let rv = q.decode(&qv).unwrap();
489        assert_eq!(rv.vector.len(), 8);
490    }
491
492    #[test]
493    fn test_decode_not_trained_error() {
494        let q = Quantizer::new(make_config(4, 8));
495        let qv = QuantizedVector {
496            codes: vec![0; 4],
497            original_dim: 8,
498        };
499        assert!(matches!(q.decode(&qv), Err(QuantizerError::NotTrained)));
500    }
501
502    #[test]
503    fn test_encode_decode_approximates_original() {
504        // A simple test: training on clustered data should reconstruct well
505        let mut q = Quantizer::new(make_config(2, 4));
506        // 3 clusters in 2D sub-spaces, each sub-space has 4 coords
507        let mut data: Vec<Vec<f32>> = Vec::new();
508        for c in 0..4 {
509            for _ in 0..8 {
510                let v: Vec<f32> = (0..8).map(|d| (c as f32 * 10.0) + d as f32 * 0.1).collect();
511                data.push(v);
512            }
513        }
514        q.train(&data).unwrap();
515        let test = data[0].clone();
516        let qv = q.encode(&test).unwrap();
517        let rv = q.decode(&qv).unwrap();
518        // Reconstruction should be within 2.0 of original for each dim
519        for (&orig, &rec) in test.iter().zip(rv.vector.iter()) {
520            assert!((orig - rec).abs() < 5.0, "orig={orig}, rec={rec}");
521        }
522    }
523
524    // --- encode_batch ---
525
526    #[test]
527    fn test_encode_batch_empty() {
528        let mut q = Quantizer::new(make_config(2, 4));
529        let data = make_data(16, 8);
530        q.train(&data).unwrap();
531        let result = q.encode_batch(&[]).unwrap();
532        assert!(result.is_empty());
533    }
534
535    #[test]
536    fn test_encode_batch_multiple() {
537        let mut q = Quantizer::new(make_config(2, 4));
538        let data = make_data(16, 8);
539        q.train(&data).unwrap();
540        let batch = data.clone();
541        let result = q.encode_batch(&batch).unwrap();
542        assert_eq!(result.len(), data.len());
543    }
544
545    // --- compression_ratio ---
546
547    #[test]
548    fn test_compression_ratio_basic() {
549        let q = Quantizer::new(make_config(4, 8));
550        // original: 128 * 4 = 512 bytes; compressed: 4 bytes
551        let ratio = q.compression_ratio(128);
552        assert!((ratio - 128.0).abs() < 0.001);
553    }
554
555    #[test]
556    fn test_compression_ratio_formula() {
557        let q = Quantizer::new(make_config(8, 256));
558        // 64-dim vector: 64*4=256 bytes original, 8 bytes compressed = ratio 32
559        let ratio = q.compression_ratio(64);
560        assert!((ratio - 32.0).abs() < 0.001);
561    }
562
563    // --- codebook_count ---
564
565    #[test]
566    fn test_codebook_count_before_training() {
567        let q = Quantizer::new(make_config(4, 8));
568        assert_eq!(q.codebook_count(), 0);
569    }
570
571    #[test]
572    fn test_codebook_count_after_training_matches_n_subspaces() {
573        let mut q = Quantizer::new(make_config(4, 8));
574        let data = make_data(32, 8);
575        q.train(&data).unwrap();
576        assert_eq!(q.codebook_count(), 4);
577    }
578
579    // --- Codebook ---
580
581    #[test]
582    fn test_codebook_new() {
583        let cb = Codebook::new(4);
584        assert_eq!(cb.sub_dimension, 4);
585        assert!(cb.centroids.is_empty());
586    }
587
588    #[test]
589    fn test_nearest_centroid_single() {
590        let mut cb = Codebook::new(2);
591        cb.centroids = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
592        let code = cb.nearest_centroid(&[1.0, 1.0]);
593        assert_eq!(code, 0); // closer to (0,0)
594    }
595
596    #[test]
597    fn test_nearest_centroid_second() {
598        let mut cb = Codebook::new(2);
599        cb.centroids = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
600        let code = cb.nearest_centroid(&[9.0, 9.0]);
601        assert_eq!(code, 1); // closer to (10,10)
602    }
603
604    #[test]
605    fn test_centroid_valid_code() {
606        let mut cb = Codebook::new(2);
607        cb.centroids = vec![vec![1.0, 2.0]];
608        let c = cb.centroid(0).unwrap();
609        assert_eq!(c[0], 1.0);
610    }
611
612    #[test]
613    fn test_centroid_out_of_range() {
614        let cb = Codebook::new(2);
615        assert!(cb.centroid(5).is_none());
616    }
617
618    // --- error display ---
619
620    #[test]
621    fn test_not_trained_display() {
622        let e = QuantizerError::NotTrained;
623        assert!(format!("{e}").contains("trained"));
624    }
625
626    #[test]
627    fn test_dimension_mismatch_display() {
628        let e = QuantizerError::DimensionMismatch;
629        assert!(format!("{e}").contains("mismatch"));
630    }
631
632    #[test]
633    fn test_insufficient_data_display() {
634        let e = QuantizerError::InsufficientData(3);
635        assert!(format!("{e}").contains("3"));
636    }
637
638    #[test]
639    fn test_invalid_config_display() {
640        let e = QuantizerError::InvalidConfig("bad".to_string());
641        assert!(format!("{e}").contains("bad"));
642    }
643}