Skip to main content

oxirs_vec/
quantizer.rs

1/// A single sub-space codebook with `n_clusters` centroids of size `sub_dimension`.
2#[derive(Debug, Clone)]
3pub struct Codebook {
4    /// Centroids stored as row-major vectors of length `sub_dimension`.
5    pub centroids: Vec<Vec<f32>>,
6    /// The number of coordinates per centroid.
7    pub sub_dimension: usize,
8}
9
10impl Codebook {
11    /// Create an empty codebook for a given sub-dimension.
12    pub fn new(sub_dimension: usize) -> Self {
13        Self {
14            centroids: Vec::new(),
15            sub_dimension,
16        }
17    }
18
19    /// Return the index (code) of the centroid nearest to `sub_vector`.
20    /// Uses squared Euclidean distance.
21    pub fn nearest_centroid(&self, sub_vector: &[f32]) -> u8 {
22        let mut best_idx = 0u8;
23        let mut best_dist = f32::INFINITY;
24        for (idx, centroid) in self.centroids.iter().enumerate() {
25            let dist: f32 = sub_vector
26                .iter()
27                .zip(centroid.iter())
28                .map(|(a, b)| (a - b) * (a - b))
29                .sum();
30            if dist < best_dist {
31                best_dist = dist;
32                best_idx = idx as u8;
33            }
34        }
35        best_idx
36    }
37
38    /// Return the centroid for a code, or `None` if the code is out of range.
39    pub fn centroid(&self, code: u8) -> Option<&Vec<f32>> {
40        self.centroids.get(code as usize)
41    }
42}
43
44/// A quantized representation of a vector.
45#[derive(Debug, Clone)]
46pub struct QuantizedVector {
47    /// One byte code per sub-space.
48    pub codes: Vec<u8>,
49    /// The original vector dimensionality.
50    pub original_dim: usize,
51}
52
53/// An approximately reconstructed vector with the quantization error.
54#[derive(Debug, Clone)]
55pub struct ReconstructedVector {
56    /// The reconstructed approximation.
57    pub vector: Vec<f32>,
58    /// Mean squared reconstruction error.
59    pub quantization_error: f32,
60}
61
62/// Configuration for the product quantizer.
63#[derive(Debug, Clone, Copy)]
64pub struct QuantizerConfig {
65    /// Number of sub-spaces (must divide the vector dimension evenly).
66    pub n_subspaces: usize,
67    /// Number of clusters per codebook (≤ 256 because codes are `u8`).
68    pub n_clusters: usize,
69}
70
71impl Default for QuantizerConfig {
72    fn default() -> Self {
73        Self {
74            n_subspaces: 4,
75            n_clusters: 256,
76        }
77    }
78}
79
80/// Errors from quantizer operations.
81#[derive(Debug)]
82pub enum QuantizerError {
83    /// Training or encoding was attempted before training.
84    NotTrained,
85    /// The vector dimension is incompatible with the codebooks.
86    DimensionMismatch,
87    /// Not enough training vectors (need at least `n_clusters` per sub-space).
88    InsufficientData(usize),
89    /// The configuration is invalid (e.g. n_clusters > 256).
90    InvalidConfig(String),
91}
92
93impl std::fmt::Display for QuantizerError {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            Self::NotTrained => write!(f, "Quantizer is not trained"),
97            Self::DimensionMismatch => write!(f, "Vector dimension mismatch"),
98            Self::InsufficientData(n) => {
99                write!(f, "Insufficient training data: {n} vectors")
100            }
101            Self::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
102        }
103    }
104}
105
106impl std::error::Error for QuantizerError {}
107
108/// Product Quantizer that compresses high-dimensional vectors.
109#[derive(Debug)]
110pub struct Quantizer {
111    config: QuantizerConfig,
112    codebooks: Vec<Codebook>,
113}
114
115impl Quantizer {
116    /// Create a new, un-trained quantizer with the given configuration.
117    pub fn new(config: QuantizerConfig) -> Self {
118        Self {
119            config,
120            codebooks: Vec::new(),
121        }
122    }
123
124    /// Train the quantizer using k-means on the provided data vectors.
125    ///
126    /// All vectors must have the same dimensionality, which must be divisible by `n_subspaces`.
127    pub fn train(&mut self, data: &[Vec<f32>]) -> Result<(), QuantizerError> {
128        // Validate configuration
129        if self.config.n_clusters > 256 {
130            return Err(QuantizerError::InvalidConfig(
131                "n_clusters must be ≤ 256".to_string(),
132            ));
133        }
134        if self.config.n_subspaces == 0 {
135            return Err(QuantizerError::InvalidConfig(
136                "n_subspaces must be > 0".to_string(),
137            ));
138        }
139        if data.is_empty() {
140            return Err(QuantizerError::InsufficientData(0));
141        }
142        if data.len() < self.config.n_clusters {
143            return Err(QuantizerError::InsufficientData(data.len()));
144        }
145
146        let dim = data[0].len();
147        if dim % self.config.n_subspaces != 0 {
148            return Err(QuantizerError::InvalidConfig(format!(
149                "Dimension {dim} is not divisible by n_subspaces {}",
150                self.config.n_subspaces
151            )));
152        }
153        let sub_dim = dim / self.config.n_subspaces;
154
155        // Validate all vectors have the same dimension
156        for v in data {
157            if v.len() != dim {
158                return Err(QuantizerError::DimensionMismatch);
159            }
160        }
161
162        // Train one codebook per sub-space
163        let actual_k = self.config.n_clusters.min(data.len());
164        let mut codebooks = Vec::with_capacity(self.config.n_subspaces);
165        for sub in 0..self.config.n_subspaces {
166            let start = sub * sub_dim;
167            let end = start + sub_dim;
168            // Extract sub-vectors for this sub-space
169            let sub_vecs: Vec<Vec<f32>> = data.iter().map(|v| v[start..end].to_vec()).collect();
170            let cb = kmeans_train(&sub_vecs, actual_k, sub_dim, 50)?;
171            codebooks.push(cb);
172        }
173        self.codebooks = codebooks;
174        Ok(())
175    }
176
177    /// Encode a single vector into a `QuantizedVector`.
178    pub fn encode(&self, vector: &[f32]) -> Result<QuantizedVector, QuantizerError> {
179        if self.codebooks.is_empty() {
180            return Err(QuantizerError::NotTrained);
181        }
182        let dim = vector.len();
183        let expected_dim = self.codebooks.len() * self.codebooks[0].sub_dimension;
184        if dim != expected_dim {
185            return Err(QuantizerError::DimensionMismatch);
186        }
187        let sub_dim = self.codebooks[0].sub_dimension;
188        let codes: Vec<u8> = self
189            .codebooks
190            .iter()
191            .enumerate()
192            .map(|(i, cb)| {
193                let start = i * sub_dim;
194                let end = start + sub_dim;
195                cb.nearest_centroid(&vector[start..end])
196            })
197            .collect();
198        Ok(QuantizedVector {
199            codes,
200            original_dim: dim,
201        })
202    }
203
204    /// Decode a `QuantizedVector` back to an approximate vector.
205    pub fn decode(&self, qv: &QuantizedVector) -> Result<ReconstructedVector, QuantizerError> {
206        if self.codebooks.is_empty() {
207            return Err(QuantizerError::NotTrained);
208        }
209        if qv.codes.len() != self.codebooks.len() {
210            return Err(QuantizerError::DimensionMismatch);
211        }
212        let mut result = Vec::with_capacity(qv.original_dim);
213        for (cb, &code) in self.codebooks.iter().zip(qv.codes.iter()) {
214            match cb.centroid(code) {
215                Some(c) => result.extend_from_slice(c),
216                None => return Err(QuantizerError::DimensionMismatch),
217            }
218        }
219        let error = 0.0_f32; // error not tracked at decode time without original
220        Ok(ReconstructedVector {
221            vector: result,
222            quantization_error: error,
223        })
224    }
225
226    /// Encode a batch of vectors.
227    pub fn encode_batch(
228        &self,
229        vectors: &[Vec<f32>],
230    ) -> Result<Vec<QuantizedVector>, QuantizerError> {
231        vectors.iter().map(|v| self.encode(v)).collect()
232    }
233
234    /// Return true if the quantizer has been trained.
235    pub fn is_trained(&self) -> bool {
236        !self.codebooks.is_empty()
237    }
238
239    /// Compute the compression ratio: bytes of original / bytes of compressed.
240    ///
241    /// Original: `original_dim * 4` bytes (f32).
242    /// Compressed: `n_subspaces` bytes (one u8 code per sub-space).
243    pub fn compression_ratio(&self, original_dim: usize) -> f32 {
244        let n_sub = self.config.n_subspaces;
245        if n_sub == 0 {
246            return 1.0;
247        }
248        (original_dim as f32 * 4.0) / n_sub as f32
249    }
250
251    /// Return the number of codebooks (equals `n_subspaces` after training).
252    pub fn codebook_count(&self) -> usize {
253        self.codebooks.len()
254    }
255}
256
257// ---- K-means implementation ----
258
259/// Run Lloyd's k-means algorithm on `sub_vecs` for `k` clusters and `max_iters` iterations.
260fn kmeans_train(
261    sub_vecs: &[Vec<f32>],
262    k: usize,
263    sub_dim: usize,
264    max_iters: usize,
265) -> Result<Codebook, QuantizerError> {
266    let n = sub_vecs.len();
267    if k == 0 || n == 0 {
268        return Err(QuantizerError::InvalidConfig(
269            "k and n must be > 0".to_string(),
270        ));
271    }
272
273    // Initialise centroids using k-means++ style seeded selection.
274    let mut centroids = kmeans_init(sub_vecs, k, sub_dim);
275
276    for _iter in 0..max_iters {
277        // Assignment step
278        let assignments: Vec<usize> = sub_vecs
279            .iter()
280            .map(|v| nearest_centroid_idx(&centroids, v))
281            .collect();
282
283        // Update step
284        let mut sums: Vec<Vec<f64>> = vec![vec![0.0_f64; sub_dim]; k];
285        let mut counts: Vec<usize> = vec![0; k];
286
287        for (v, &a) in sub_vecs.iter().zip(assignments.iter()) {
288            for (d, &x) in v.iter().enumerate() {
289                sums[a][d] += x as f64;
290            }
291            counts[a] += 1;
292        }
293
294        let mut converged = true;
295        for (ci, centroid) in centroids.iter_mut().enumerate() {
296            if counts[ci] == 0 {
297                continue;
298            }
299            for d in 0..sub_dim {
300                let new_val = (sums[ci][d] / counts[ci] as f64) as f32;
301                if (new_val - centroid[d]).abs() > 1e-6 {
302                    converged = false;
303                }
304                centroid[d] = new_val;
305            }
306        }
307        if converged {
308            break;
309        }
310    }
311
312    Ok(Codebook {
313        centroids,
314        sub_dimension: sub_dim,
315    })
316}
317
318/// K-means++ inspired initialisation: first centroid random, then D² sampling.
319fn kmeans_init(data: &[Vec<f32>], k: usize, sub_dim: usize) -> Vec<Vec<f32>> {
320    let n = data.len();
321    // Build a simple deterministic index map based on dimension sums to avoid
322    // requiring async/Random trait in tests. Use a simple hash for selection.
323    let mut chosen_indices: Vec<usize> = Vec::with_capacity(k);
324
325    // First centroid: use index derived from data dimensions (deterministic for tests)
326    let first_idx = (sub_dim * 7 + n * 3) % n;
327    chosen_indices.push(first_idx);
328
329    let mut distances: Vec<f32> = vec![f32::INFINITY; n];
330    for _ in 1..k {
331        // Update distances
332        for (i, v) in data.iter().enumerate() {
333            let last = &data[*chosen_indices.last().unwrap_or(&0)];
334            let dist: f32 = v
335                .iter()
336                .zip(last.iter())
337                .map(|(a, b)| (a - b) * (a - b))
338                .sum();
339            if dist < distances[i] {
340                distances[i] = dist;
341            }
342        }
343        // Choose next: the point farthest from its nearest chosen centroid
344        let next_idx = distances
345            .iter()
346            .enumerate()
347            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
348            .map(|(i, _)| i)
349            .unwrap_or(0);
350        chosen_indices.push(next_idx);
351    }
352
353    chosen_indices
354        .into_iter()
355        .map(|i| data[i % n].clone())
356        .collect()
357}
358
359/// Find index of nearest centroid using squared Euclidean distance.
360fn nearest_centroid_idx(centroids: &[Vec<f32>], v: &[f32]) -> usize {
361    centroids
362        .iter()
363        .enumerate()
364        .min_by(|(_, a), (_, b)| {
365            let da: f32 = a.iter().zip(v.iter()).map(|(x, y)| (x - y) * (x - y)).sum();
366            let db: f32 = b.iter().zip(v.iter()).map(|(x, y)| (x - y) * (x - y)).sum();
367            da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
368        })
369        .map(|(i, _)| i)
370        .unwrap_or(0)
371}
372
373#[cfg(test)]
374mod tests {
375    type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
376    use super::*;
377
378    fn make_config(n_subspaces: usize, n_clusters: usize) -> QuantizerConfig {
379        QuantizerConfig {
380            n_subspaces,
381            n_clusters,
382        }
383    }
384
385    /// Generate n vectors of dimension d with distinct cluster patterns.
386    fn make_data(n: usize, dim: usize) -> Vec<Vec<f32>> {
387        (0..n)
388            .map(|i| (0..dim).map(|d| (i as f32 * 0.1) + d as f32).collect())
389            .collect()
390    }
391
392    // --- is_trained ---
393
394    #[test]
395    fn test_not_trained_initially() {
396        let q = Quantizer::new(make_config(4, 8));
397        assert!(!q.is_trained());
398    }
399
400    #[test]
401    fn test_is_trained_after_train() -> Result<()> {
402        let mut q = Quantizer::new(make_config(4, 8));
403        let data = make_data(32, 8);
404        q.train(&data)?;
405        assert!(q.is_trained());
406        Ok(())
407    }
408
409    // --- train errors ---
410
411    #[test]
412    fn test_train_empty_data_error() {
413        let mut q = Quantizer::new(make_config(4, 8));
414        let err = q.train(&[]);
415        assert!(matches!(err, Err(QuantizerError::InsufficientData(0))));
416    }
417
418    #[test]
419    fn test_train_insufficient_data_error() {
420        let mut q = Quantizer::new(make_config(2, 10));
421        let data = make_data(5, 4); // 5 < 10 clusters
422        let err = q.train(&data);
423        assert!(matches!(err, Err(QuantizerError::InsufficientData(_))));
424    }
425
426    #[test]
427    fn test_train_n_clusters_over_256() {
428        let mut q = Quantizer::new(make_config(2, 300));
429        let data = make_data(400, 8);
430        let err = q.train(&data);
431        assert!(matches!(err, Err(QuantizerError::InvalidConfig(_))));
432    }
433
434    #[test]
435    fn test_train_dimension_not_divisible() {
436        let mut q = Quantizer::new(make_config(3, 4)); // 3 does not divide 8
437        let data = make_data(20, 8);
438        let err = q.train(&data);
439        assert!(matches!(err, Err(QuantizerError::InvalidConfig(_))));
440    }
441
442    // --- encode / decode ---
443
444    #[test]
445    fn test_encode_not_trained_error() {
446        let q = Quantizer::new(make_config(4, 8));
447        let v = vec![0.0f32; 8];
448        assert!(matches!(q.encode(&v), Err(QuantizerError::NotTrained)));
449    }
450
451    #[test]
452    fn test_encode_dimension_mismatch() -> Result<()> {
453        let mut q = Quantizer::new(make_config(2, 4));
454        let data = make_data(16, 8);
455        q.train(&data)?;
456        let v = vec![0.0f32; 4]; // wrong dim
457        assert!(matches!(
458            q.encode(&v),
459            Err(QuantizerError::DimensionMismatch)
460        ));
461        Ok(())
462    }
463
464    #[test]
465    fn test_encode_codes_length() -> Result<()> {
466        let mut q = Quantizer::new(make_config(4, 8));
467        let data = make_data(32, 8);
468        q.train(&data)?;
469        let v = vec![1.0f32; 8];
470        let qv = q.encode(&v)?;
471        assert_eq!(qv.codes.len(), 4); // one code per sub-space
472        Ok(())
473    }
474
475    #[test]
476    fn test_encode_original_dim_stored() -> Result<()> {
477        let mut q = Quantizer::new(make_config(2, 4));
478        let data = make_data(16, 8);
479        q.train(&data)?;
480        let v = vec![0.0f32; 8];
481        let qv = q.encode(&v)?;
482        assert_eq!(qv.original_dim, 8);
483        Ok(())
484    }
485
486    #[test]
487    fn test_decode_produces_correct_dim() -> Result<()> {
488        let mut q = Quantizer::new(make_config(4, 8));
489        let data = make_data(32, 8);
490        q.train(&data)?;
491        let v = vec![0.5f32; 8];
492        let qv = q.encode(&v)?;
493        let rv = q.decode(&qv)?;
494        assert_eq!(rv.vector.len(), 8);
495        Ok(())
496    }
497
498    #[test]
499    fn test_decode_not_trained_error() {
500        let q = Quantizer::new(make_config(4, 8));
501        let qv = QuantizedVector {
502            codes: vec![0; 4],
503            original_dim: 8,
504        };
505        assert!(matches!(q.decode(&qv), Err(QuantizerError::NotTrained)));
506    }
507
508    #[test]
509    fn test_encode_decode_approximates_original() -> Result<()> {
510        // A simple test: training on clustered data should reconstruct well
511        let mut q = Quantizer::new(make_config(2, 4));
512        // 3 clusters in 2D sub-spaces, each sub-space has 4 coords
513        let mut data: Vec<Vec<f32>> = Vec::new();
514        for c in 0..4 {
515            for _ in 0..8 {
516                let v: Vec<f32> = (0..8).map(|d| (c as f32 * 10.0) + d as f32 * 0.1).collect();
517                data.push(v);
518            }
519        }
520        q.train(&data)?;
521        let test = data[0].clone();
522        let qv = q.encode(&test)?;
523        let rv = q.decode(&qv)?;
524        // Reconstruction should be within 2.0 of original for each dim
525        for (&orig, &rec) in test.iter().zip(rv.vector.iter()) {
526            assert!((orig - rec).abs() < 5.0, "orig={orig}, rec={rec}");
527        }
528        Ok(())
529    }
530
531    // --- encode_batch ---
532
533    #[test]
534    fn test_encode_batch_empty() -> Result<()> {
535        let mut q = Quantizer::new(make_config(2, 4));
536        let data = make_data(16, 8);
537        q.train(&data)?;
538        let result = q.encode_batch(&[])?;
539        assert!(result.is_empty());
540        Ok(())
541    }
542
543    #[test]
544    fn test_encode_batch_multiple() -> Result<()> {
545        let mut q = Quantizer::new(make_config(2, 4));
546        let data = make_data(16, 8);
547        q.train(&data)?;
548        let batch = data.clone();
549        let result = q.encode_batch(&batch)?;
550        assert_eq!(result.len(), data.len());
551        Ok(())
552    }
553
554    // --- compression_ratio ---
555
556    #[test]
557    fn test_compression_ratio_basic() {
558        let q = Quantizer::new(make_config(4, 8));
559        // original: 128 * 4 = 512 bytes; compressed: 4 bytes
560        let ratio = q.compression_ratio(128);
561        assert!((ratio - 128.0).abs() < 0.001);
562    }
563
564    #[test]
565    fn test_compression_ratio_formula() {
566        let q = Quantizer::new(make_config(8, 256));
567        // 64-dim vector: 64*4=256 bytes original, 8 bytes compressed = ratio 32
568        let ratio = q.compression_ratio(64);
569        assert!((ratio - 32.0).abs() < 0.001);
570    }
571
572    // --- codebook_count ---
573
574    #[test]
575    fn test_codebook_count_before_training() {
576        let q = Quantizer::new(make_config(4, 8));
577        assert_eq!(q.codebook_count(), 0);
578    }
579
580    #[test]
581    fn test_codebook_count_after_training_matches_n_subspaces() -> Result<()> {
582        let mut q = Quantizer::new(make_config(4, 8));
583        let data = make_data(32, 8);
584        q.train(&data)?;
585        assert_eq!(q.codebook_count(), 4);
586        Ok(())
587    }
588
589    // --- Codebook ---
590
591    #[test]
592    fn test_codebook_new() {
593        let cb = Codebook::new(4);
594        assert_eq!(cb.sub_dimension, 4);
595        assert!(cb.centroids.is_empty());
596    }
597
598    #[test]
599    fn test_nearest_centroid_single() {
600        let mut cb = Codebook::new(2);
601        cb.centroids = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
602        let code = cb.nearest_centroid(&[1.0, 1.0]);
603        assert_eq!(code, 0); // closer to (0,0)
604    }
605
606    #[test]
607    fn test_nearest_centroid_second() {
608        let mut cb = Codebook::new(2);
609        cb.centroids = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
610        let code = cb.nearest_centroid(&[9.0, 9.0]);
611        assert_eq!(code, 1); // closer to (10,10)
612    }
613
614    #[test]
615    fn test_centroid_valid_code() -> Result<()> {
616        let mut cb = Codebook::new(2);
617        cb.centroids = vec![vec![1.0, 2.0]];
618        let c = cb.centroid(0).expect("centroid at index 0 should exist");
619        assert_eq!(c[0], 1.0);
620        Ok(())
621    }
622
623    #[test]
624    fn test_centroid_out_of_range() {
625        let cb = Codebook::new(2);
626        assert!(cb.centroid(5).is_none());
627    }
628
629    // --- error display ---
630
631    #[test]
632    fn test_not_trained_display() {
633        let e = QuantizerError::NotTrained;
634        assert!(format!("{e}").contains("trained"));
635    }
636
637    #[test]
638    fn test_dimension_mismatch_display() {
639        let e = QuantizerError::DimensionMismatch;
640        assert!(format!("{e}").contains("mismatch"));
641    }
642
643    #[test]
644    fn test_insufficient_data_display() {
645        let e = QuantizerError::InsufficientData(3);
646        assert!(format!("{e}").contains("3"));
647    }
648
649    #[test]
650    fn test_invalid_config_display() {
651        let e = QuantizerError::InvalidConfig("bad".to_string());
652        assert!(format!("{e}").contains("bad"));
653    }
654}