Skip to main content

oxirs_vec/
pq_index.rs

1//! # Product Quantization Index
2//!
3//! A memory-efficient approximate nearest-neighbour index based on product
4//! quantization (PQ).  Each vector is split into `m` sub-vectors, and each
5//! sub-vector is replaced by the index of its nearest centroid in a per-
6//! sub-space codebook.  Distances are then approximated via pre-computed
7//! lookup tables (asymmetric distance computation — ADC).
8//!
9//! ## Features
10//!
11//! - **Codebook training** via k-means on sub-vectors
12//! - **Asymmetric distance computation** for accurate approximations
13//! - **Multi-probe search** with configurable number of probes
14//! - **Compact storage**: each vector is just `m` bytes (for `k=256`)
15//!
16//! ## Usage
17//!
18//! ```rust
19//! use oxirs_vec::pq_index::{ProductQuantizationIndex, PqConfig};
20//!
21//! let config = PqConfig {
22//!     dimension: 8,
23//!     num_sub_vectors: 4,
24//!     num_centroids: 4,
25//!     training_iterations: 5,
26//!     ..Default::default()
27//! };
28//! let mut pq = ProductQuantizationIndex::new(config).unwrap();
29//!
30//! // Train on some data
31//! let training_data = vec![
32//!     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
33//!     vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
34//!     vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
35//!     vec![9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0],
36//! ];
37//! pq.train(&training_data).unwrap();
38//!
39//! // Add vectors
40//! pq.add(0, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
41//! pq.add(1, &[8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]).unwrap();
42//!
43//! // Search
44//! let results = pq.search(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], 1).unwrap();
45//! assert_eq!(results[0].0, 0); // ID of closest vector
46//! ```
47
48use anyhow::{anyhow, bail, Result};
49use serde::{Deserialize, Serialize};
50use std::cmp::Reverse;
51use std::collections::BinaryHeap;
52
53// ---------------------------------------------------------------------------
54// Configuration
55// ---------------------------------------------------------------------------
56
57/// Configuration for the PQ index.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct PqConfig {
60    /// Total vector dimension.
61    pub dimension: usize,
62    /// Number of sub-vector spaces (`m`). Must divide `dimension` evenly.
63    pub num_sub_vectors: usize,
64    /// Number of centroids per sub-space (`k`, typically 256).
65    pub num_centroids: usize,
66    /// Number of k-means iterations for codebook training.
67    pub training_iterations: usize,
68    /// Number of probes for multi-probe search (0 = exact ADC scan).
69    pub num_probes: usize,
70}
71
72impl Default for PqConfig {
73    fn default() -> Self {
74        Self {
75            dimension: 128,
76            num_sub_vectors: 8,
77            num_centroids: 256,
78            training_iterations: 20,
79            num_probes: 0,
80        }
81    }
82}
83
84// ---------------------------------------------------------------------------
85// Codebook
86// ---------------------------------------------------------------------------
87
88/// A codebook for one sub-vector space.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90struct SubCodebook {
91    /// `centroids[i]` is the centroid vector for code `i`.
92    centroids: Vec<Vec<f32>>,
93    /// Sub-vector dimension.
94    sub_dim: usize,
95}
96
97impl SubCodebook {
98    fn new(sub_dim: usize, num_centroids: usize) -> Self {
99        Self {
100            centroids: vec![vec![0.0; sub_dim]; num_centroids],
101            sub_dim,
102        }
103    }
104
105    /// Assign the nearest centroid index to a sub-vector.
106    fn encode(&self, sub_vec: &[f32]) -> u16 {
107        let mut best_idx = 0u16;
108        let mut best_dist = f32::MAX;
109        for (i, centroid) in self.centroids.iter().enumerate() {
110            let dist = l2_sq(sub_vec, centroid);
111            if dist < best_dist {
112                best_dist = dist;
113                best_idx = i as u16;
114            }
115        }
116        best_idx
117    }
118
119    /// Decode: return the centroid for a given code.
120    fn decode(&self, code: u16) -> &[f32] {
121        &self.centroids[code as usize]
122    }
123
124    /// Build a distance lookup table for a query sub-vector.
125    fn build_distance_table(&self, query_sub: &[f32]) -> Vec<f32> {
126        self.centroids.iter().map(|c| l2_sq(query_sub, c)).collect()
127    }
128}
129
130// ---------------------------------------------------------------------------
131// Encoded vector (PQ codes)
132// ---------------------------------------------------------------------------
133
134/// A PQ-encoded vector: `codes[i]` is the centroid index for sub-space `i`.
135#[derive(Debug, Clone, Serialize, Deserialize)]
136struct PqCode {
137    id: u64,
138    codes: Vec<u16>,
139}
140
141// ---------------------------------------------------------------------------
142// ProductQuantizationIndex
143// ---------------------------------------------------------------------------
144
145/// The Product Quantization index.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct ProductQuantizationIndex {
148    config: PqConfig,
149    codebooks: Vec<SubCodebook>,
150    entries: Vec<PqCode>,
151    trained: bool,
152    sub_dim: usize,
153}
154
155impl ProductQuantizationIndex {
156    /// Create a new (untrained) PQ index.
157    pub fn new(config: PqConfig) -> Result<Self> {
158        if config.dimension == 0 {
159            bail!("dimension must be > 0");
160        }
161        if config.num_sub_vectors == 0 {
162            bail!("num_sub_vectors must be > 0");
163        }
164        if config.dimension % config.num_sub_vectors != 0 {
165            bail!(
166                "dimension ({}) must be divisible by num_sub_vectors ({})",
167                config.dimension,
168                config.num_sub_vectors
169            );
170        }
171        if config.num_centroids == 0 || config.num_centroids > 65536 {
172            bail!("num_centroids must be in 1..=65536");
173        }
174        let sub_dim = config.dimension / config.num_sub_vectors;
175        let codebooks = (0..config.num_sub_vectors)
176            .map(|_| SubCodebook::new(sub_dim, config.num_centroids))
177            .collect();
178        Ok(Self {
179            config,
180            codebooks,
181            entries: Vec::new(),
182            trained: false,
183            sub_dim,
184        })
185    }
186
187    /// Train the codebooks using the provided training vectors.
188    pub fn train(&mut self, training_data: &[Vec<f32>]) -> Result<()> {
189        if training_data.is_empty() {
190            bail!("training data is empty");
191        }
192        for (i, v) in training_data.iter().enumerate() {
193            if v.len() != self.config.dimension {
194                bail!(
195                    "training vector {i} has dimension {} but expected {}",
196                    v.len(),
197                    self.config.dimension
198                );
199            }
200        }
201
202        for m in 0..self.config.num_sub_vectors {
203            let start = m * self.sub_dim;
204            let end = start + self.sub_dim;
205
206            // Extract sub-vectors for this sub-space
207            let sub_vectors: Vec<Vec<f32>> = training_data
208                .iter()
209                .map(|v| v[start..end].to_vec())
210                .collect();
211
212            // Run simple k-means
213            let centroids = kmeans(
214                &sub_vectors,
215                self.config.num_centroids,
216                self.config.training_iterations,
217                self.sub_dim,
218            );
219            self.codebooks[m].centroids = centroids;
220        }
221
222        self.trained = true;
223        Ok(())
224    }
225
226    /// Whether the index has been trained.
227    pub fn is_trained(&self) -> bool {
228        self.trained
229    }
230
231    /// Add a vector with the given ID.
232    pub fn add(&mut self, id: u64, vector: &[f32]) -> Result<()> {
233        if !self.trained {
234            bail!("index must be trained before adding vectors");
235        }
236        if vector.len() != self.config.dimension {
237            bail!(
238                "vector dimension {} != expected {}",
239                vector.len(),
240                self.config.dimension
241            );
242        }
243
244        let mut codes = Vec::with_capacity(self.config.num_sub_vectors);
245        for m in 0..self.config.num_sub_vectors {
246            let start = m * self.sub_dim;
247            let end = start + self.sub_dim;
248            let code = self.codebooks[m].encode(&vector[start..end]);
249            codes.push(code);
250        }
251
252        self.entries.push(PqCode { id, codes });
253        Ok(())
254    }
255
256    /// Number of indexed vectors.
257    pub fn len(&self) -> usize {
258        self.entries.len()
259    }
260
261    /// Whether the index is empty.
262    pub fn is_empty(&self) -> bool {
263        self.entries.is_empty()
264    }
265
266    /// Search for the `k` nearest neighbours using asymmetric distance computation.
267    /// Returns `(id, approximate_distance)` pairs sorted by ascending distance.
268    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
269        if !self.trained {
270            bail!("index must be trained before searching");
271        }
272        if query.len() != self.config.dimension {
273            bail!(
274                "query dimension {} != expected {}",
275                query.len(),
276                self.config.dimension
277            );
278        }
279        if k == 0 {
280            return Ok(Vec::new());
281        }
282
283        // Pre-compute distance tables for each sub-space
284        let distance_tables: Vec<Vec<f32>> = (0..self.config.num_sub_vectors)
285            .map(|m| {
286                let start = m * self.sub_dim;
287                let end = start + self.sub_dim;
288                self.codebooks[m].build_distance_table(&query[start..end])
289            })
290            .collect();
291
292        // Scan all entries, summing per-subspace distances from the lookup tables
293        let mut heap: BinaryHeap<Reverse<(OrderedF32, u64)>> = BinaryHeap::new();
294        for entry in &self.entries {
295            let mut dist = 0.0f32;
296            for (m, code) in entry.codes.iter().enumerate() {
297                dist += distance_tables[m][*code as usize];
298            }
299            heap.push(Reverse((OrderedF32(dist), entry.id)));
300        }
301
302        let mut results = Vec::with_capacity(k.min(heap.len()));
303        for _ in 0..k {
304            if let Some(Reverse((OrderedF32(d), id))) = heap.pop() {
305                results.push((id, d));
306            } else {
307                break;
308            }
309        }
310        Ok(results)
311    }
312
313    /// Reconstruct an approximate vector from its PQ codes.
314    pub fn reconstruct(&self, id: u64) -> Result<Vec<f32>> {
315        let entry = self
316            .entries
317            .iter()
318            .find(|e| e.id == id)
319            .ok_or_else(|| anyhow!("id {id} not found in index"))?;
320
321        let mut vector = Vec::with_capacity(self.config.dimension);
322        for (m, code) in entry.codes.iter().enumerate() {
323            vector.extend_from_slice(self.codebooks[m].decode(*code));
324        }
325        Ok(vector)
326    }
327
328    /// Remove all indexed vectors (keeps codebooks).
329    pub fn clear(&mut self) {
330        self.entries.clear();
331    }
332
333    /// Get the PQ configuration.
334    pub fn config(&self) -> &PqConfig {
335        &self.config
336    }
337
338    /// Compute the compression ratio (original bytes / encoded bytes).
339    pub fn compression_ratio(&self) -> f64 {
340        if self.entries.is_empty() {
341            return 0.0;
342        }
343        let original_bytes = self.config.dimension * 4; // f32
344        let encoded_bytes = self.config.num_sub_vectors * 2; // u16 codes
345        original_bytes as f64 / encoded_bytes as f64
346    }
347}
348
349// ---------------------------------------------------------------------------
350// k-means (simple)
351// ---------------------------------------------------------------------------
352
353fn kmeans(data: &[Vec<f32>], k: usize, iterations: usize, dim: usize) -> Vec<Vec<f32>> {
354    let actual_k = k.min(data.len());
355    // Initialise centroids from first k data points
356    let mut centroids: Vec<Vec<f32>> = data.iter().take(actual_k).cloned().collect();
357    // Pad if data.len() < k
358    while centroids.len() < k {
359        centroids.push(vec![0.0; dim]);
360    }
361
362    for _ in 0..iterations {
363        // Assignment step
364        let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
365        for (idx, point) in data.iter().enumerate() {
366            let mut best = 0;
367            let mut best_dist = f32::MAX;
368            for (c, centroid) in centroids.iter().enumerate() {
369                let d = l2_sq(point, centroid);
370                if d < best_dist {
371                    best_dist = d;
372                    best = c;
373                }
374            }
375            assignments[best].push(idx);
376        }
377
378        // Update step
379        for (c, assigned) in assignments.iter().enumerate() {
380            if assigned.is_empty() {
381                continue;
382            }
383            let mut new_centroid = vec![0.0f32; dim];
384            for &idx in assigned {
385                for (d, val) in data[idx].iter().enumerate() {
386                    new_centroid[d] += val;
387                }
388            }
389            let count = assigned.len() as f32;
390            for val in &mut new_centroid {
391                *val /= count;
392            }
393            centroids[c] = new_centroid;
394        }
395    }
396
397    centroids
398}
399
400// ---------------------------------------------------------------------------
401// helpers
402// ---------------------------------------------------------------------------
403
404fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
405    a.iter()
406        .zip(b.iter())
407        .map(|(x, y)| {
408            let d = x - y;
409            d * d
410        })
411        .sum()
412}
413
414/// Newtype for ordered f32 (used in BinaryHeap).
415#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
416struct OrderedF32(f32);
417
418impl Eq for OrderedF32 {}
419
420impl PartialOrd for OrderedF32 {
421    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
422        Some(self.cmp(other))
423    }
424}
425
426impl Ord for OrderedF32 {
427    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
428        self.0
429            .partial_cmp(&other.0)
430            .unwrap_or(std::cmp::Ordering::Equal)
431    }
432}
433
434// ===========================================================================
435// Tests
436// ===========================================================================
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    fn default_config(dim: usize, m: usize, k: usize) -> PqConfig {
443        PqConfig {
444            dimension: dim,
445            num_sub_vectors: m,
446            num_centroids: k,
447            training_iterations: 5,
448            num_probes: 0,
449        }
450    }
451
452    fn make_training_data(n: usize, dim: usize) -> Vec<Vec<f32>> {
453        (0..n)
454            .map(|i| (0..dim).map(|d| (i * dim + d) as f32 * 0.1).collect())
455            .collect()
456    }
457
458    fn trained_index(dim: usize, m: usize, k: usize) -> ProductQuantizationIndex {
459        let config = default_config(dim, m, k);
460        let mut idx = ProductQuantizationIndex::new(config).expect("new");
461        let data = make_training_data(k.max(4), dim);
462        idx.train(&data).expect("train");
463        idx
464    }
465
466    // -- constructor tests --
467
468    #[test]
469    fn test_new_valid_config() {
470        let idx = ProductQuantizationIndex::new(default_config(8, 4, 4));
471        assert!(idx.is_ok());
472    }
473
474    #[test]
475    fn test_new_zero_dimension() {
476        let config = PqConfig {
477            dimension: 0,
478            ..Default::default()
479        };
480        assert!(ProductQuantizationIndex::new(config).is_err());
481    }
482
483    #[test]
484    fn test_new_zero_sub_vectors() {
485        let config = PqConfig {
486            num_sub_vectors: 0,
487            ..Default::default()
488        };
489        assert!(ProductQuantizationIndex::new(config).is_err());
490    }
491
492    #[test]
493    fn test_new_indivisible_dimension() {
494        let config = PqConfig {
495            dimension: 7,
496            num_sub_vectors: 4,
497            ..Default::default()
498        };
499        assert!(ProductQuantizationIndex::new(config).is_err());
500    }
501
502    #[test]
503    fn test_new_zero_centroids() {
504        let config = PqConfig {
505            num_centroids: 0,
506            ..Default::default()
507        };
508        assert!(ProductQuantizationIndex::new(config).is_err());
509    }
510
511    // -- training tests --
512
513    #[test]
514    fn test_train_sets_trained_flag() {
515        let mut idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
516        assert!(!idx.is_trained());
517        let data = make_training_data(10, 8);
518        idx.train(&data).expect("train");
519        assert!(idx.is_trained());
520    }
521
522    #[test]
523    fn test_train_empty_data_fails() {
524        let mut idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
525        assert!(idx.train(&[]).is_err());
526    }
527
528    #[test]
529    fn test_train_wrong_dimension_fails() {
530        let mut idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
531        let data = vec![vec![1.0, 2.0]]; // dim=2 not 8
532        assert!(idx.train(&data).is_err());
533    }
534
535    // -- add tests --
536
537    #[test]
538    fn test_add_before_training_fails() {
539        let mut idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
540        assert!(idx.add(0, &[1.0; 8]).is_err());
541    }
542
543    #[test]
544    fn test_add_wrong_dimension_fails() {
545        let mut idx = trained_index(8, 4, 4);
546        assert!(idx.add(0, &[1.0; 4]).is_err());
547    }
548
549    #[test]
550    fn test_add_and_len() {
551        let mut idx = trained_index(8, 4, 4);
552        assert!(idx.is_empty());
553        idx.add(0, &[1.0; 8]).expect("add");
554        assert_eq!(idx.len(), 1);
555        idx.add(1, &[2.0; 8]).expect("add");
556        assert_eq!(idx.len(), 2);
557    }
558
559    // -- search tests --
560
561    #[test]
562    fn test_search_before_training_fails() {
563        let idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
564        assert!(idx.search(&[1.0; 8], 1).is_err());
565    }
566
567    #[test]
568    fn test_search_wrong_dimension_fails() {
569        let idx = trained_index(8, 4, 4);
570        assert!(idx.search(&[1.0; 4], 1).is_err());
571    }
572
573    #[test]
574    fn test_search_k_zero_returns_empty() {
575        let idx = trained_index(8, 4, 4);
576        let results = idx.search(&[1.0; 8], 0).expect("search");
577        assert!(results.is_empty());
578    }
579
580    #[test]
581    fn test_search_empty_index_returns_empty() {
582        let idx = trained_index(8, 4, 4);
583        let results = idx.search(&[1.0; 8], 5).expect("search");
584        assert!(results.is_empty());
585    }
586
587    #[test]
588    fn test_search_finds_nearest() {
589        let mut idx = trained_index(8, 4, 4);
590        let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
591        let v2 = vec![100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700.0, 800.0];
592        idx.add(10, &v1).expect("add");
593        idx.add(20, &v2).expect("add");
594
595        let query = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
596        let results = idx.search(&query, 1).expect("search");
597        assert_eq!(results.len(), 1);
598        assert_eq!(results[0].0, 10);
599    }
600
601    #[test]
602    fn test_search_returns_sorted_by_distance() {
603        let mut idx = trained_index(8, 4, 4);
604        let v1 = vec![1.0; 8];
605        let v2 = vec![2.0; 8];
606        let v3 = vec![10.0; 8];
607        idx.add(1, &v1).expect("add");
608        idx.add(2, &v2).expect("add");
609        idx.add(3, &v3).expect("add");
610
611        let results = idx.search(&[1.0; 8], 3).expect("search");
612        assert_eq!(results.len(), 3);
613        // Distances should be ascending
614        assert!(results[0].1 <= results[1].1);
615        assert!(results[1].1 <= results[2].1);
616    }
617
618    #[test]
619    fn test_search_k_larger_than_index() {
620        let mut idx = trained_index(8, 4, 4);
621        idx.add(1, &[1.0; 8]).expect("add");
622        let results = idx.search(&[1.0; 8], 100).expect("search");
623        assert_eq!(results.len(), 1);
624    }
625
626    // -- reconstruct tests --
627
628    #[test]
629    fn test_reconstruct_existing_id() {
630        let mut idx = trained_index(8, 4, 4);
631        let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
632        idx.add(42, &v).expect("add");
633        let reconstructed = idx.reconstruct(42).expect("reconstruct");
634        assert_eq!(reconstructed.len(), 8);
635    }
636
637    #[test]
638    fn test_reconstruct_missing_id() {
639        let idx = trained_index(8, 4, 4);
640        assert!(idx.reconstruct(999).is_err());
641    }
642
643    // -- clear --
644
645    #[test]
646    fn test_clear() {
647        let mut idx = trained_index(8, 4, 4);
648        idx.add(1, &[1.0; 8]).expect("add");
649        assert_eq!(idx.len(), 1);
650        idx.clear();
651        assert!(idx.is_empty());
652        // Still trained after clear
653        assert!(idx.is_trained());
654    }
655
656    // -- compression ratio --
657
658    #[test]
659    fn test_compression_ratio_empty() {
660        let idx = trained_index(8, 4, 4);
661        assert_eq!(idx.compression_ratio(), 0.0);
662    }
663
664    #[test]
665    fn test_compression_ratio_non_empty() {
666        let mut idx = trained_index(8, 4, 4);
667        idx.add(0, &[1.0; 8]).expect("add");
668        let ratio = idx.compression_ratio();
669        // 8 * 4 bytes original = 32; 4 * 2 bytes encoded = 8 -> ratio = 4
670        assert!((ratio - 4.0).abs() < 1e-6);
671    }
672
673    // -- config --
674
675    #[test]
676    fn test_config_accessor() {
677        let idx = ProductQuantizationIndex::new(default_config(16, 4, 8)).expect("new");
678        assert_eq!(idx.config().dimension, 16);
679        assert_eq!(idx.config().num_sub_vectors, 4);
680    }
681
682    // -- default config --
683
684    #[test]
685    fn test_default_config() {
686        let config = PqConfig::default();
687        assert_eq!(config.dimension, 128);
688        assert_eq!(config.num_sub_vectors, 8);
689        assert_eq!(config.num_centroids, 256);
690    }
691
692    // -- kmeans --
693
694    #[test]
695    fn test_kmeans_basic() {
696        let data = vec![
697            vec![1.0, 0.0],
698            vec![1.1, 0.0],
699            vec![0.0, 1.0],
700            vec![0.0, 1.1],
701        ];
702        let centroids = kmeans(&data, 2, 10, 2);
703        assert_eq!(centroids.len(), 2);
704    }
705
706    #[test]
707    fn test_kmeans_more_k_than_data() {
708        let data = vec![vec![1.0], vec![2.0]];
709        let centroids = kmeans(&data, 5, 3, 1);
710        assert_eq!(centroids.len(), 5);
711    }
712
713    // -- l2_sq --
714
715    #[test]
716    fn test_l2_sq_identical() {
717        assert_eq!(l2_sq(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
718    }
719
720    #[test]
721    fn test_l2_sq_known() {
722        // (3-1)^2 + (4-1)^2 = 4+9 = 13
723        let dist = l2_sq(&[3.0, 4.0], &[1.0, 1.0]);
724        assert!((dist - 13.0).abs() < 1e-6);
725    }
726
727    // -- ordered f32 --
728
729    #[test]
730    fn test_ordered_f32_ordering() {
731        let a = OrderedF32(1.0);
732        let b = OrderedF32(2.0);
733        assert!(a < b);
734    }
735
736    // -- multi-vector search --
737
738    #[test]
739    fn test_multi_add_and_search() {
740        let mut idx = trained_index(8, 4, 4);
741        for i in 0..20_u64 {
742            let v: Vec<f32> = (0..8).map(|d| (i * 8 + d) as f32).collect();
743            idx.add(i, &v).expect("add");
744        }
745        assert_eq!(idx.len(), 20);
746        let results = idx.search(&[0.0; 8], 5).expect("search");
747        assert_eq!(results.len(), 5);
748    }
749
750    #[test]
751    fn test_retrain_resets_codebooks() {
752        let mut idx = trained_index(8, 4, 4);
753        idx.add(0, &[1.0; 8]).expect("add");
754        let data2 = make_training_data(10, 8);
755        idx.train(&data2).expect("retrain");
756        // entries are still there
757        assert_eq!(idx.len(), 1);
758    }
759
760    // -- edge cases --
761
762    #[test]
763    fn test_single_dimension_subvectors() {
764        let config = default_config(4, 4, 2);
765        let mut idx = ProductQuantizationIndex::new(config).expect("new");
766        let data = vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]];
767        idx.train(&data).expect("train");
768        idx.add(0, &[1.0, 2.0, 3.0, 4.0]).expect("add");
769        let results = idx.search(&[1.0, 2.0, 3.0, 4.0], 1).expect("search");
770        assert_eq!(results.len(), 1);
771    }
772
773    #[test]
774    fn test_single_centroid_perfect_encode() {
775        let config = default_config(4, 2, 1);
776        let mut idx = ProductQuantizationIndex::new(config).expect("new");
777        let data = vec![vec![1.0, 2.0, 3.0, 4.0]];
778        idx.train(&data).expect("train");
779        idx.add(0, &[1.0, 2.0, 3.0, 4.0]).expect("add");
780        let recon = idx.reconstruct(0).expect("reconstruct");
781        // With only 1 centroid, reconstruction should match the training mean
782        assert_eq!(recon.len(), 4);
783    }
784
785    #[test]
786    fn test_large_dimension() {
787        let config = default_config(64, 8, 4);
788        let mut idx = ProductQuantizationIndex::new(config).expect("new");
789        let data = make_training_data(10, 64);
790        idx.train(&data).expect("train");
791        idx.add(0, &vec![0.5; 64]).expect("add");
792        let results = idx.search(&[0.5; 64], 1).expect("search");
793        assert_eq!(results.len(), 1);
794    }
795}