oxify_vector/
ivf.rs

1//! IVF-PQ (Inverted File Index with Product Quantization)
2//!
3//! Memory-efficient ANN search for large-scale datasets (1M+ vectors).
4//!
5//! ## Algorithm Overview
6//!
7//! 1. **Clustering (IVF)**: Partition vectors into clusters using k-means
8//! 2. **Product Quantization (PQ)**: Compress vectors by quantizing sub-vectors
9//! 3. **Search**: Query nearest clusters, then search compressed vectors
10//!
11//! ## Benefits
12//!
13//! - **Memory**: 8-16x compression (768D → 64-96 bytes)
14//! - **Speed**: Search only relevant partitions (nprobe parameter)
15//! - **Scalability**: Handles 1M+ vectors efficiently
16//!
17//! ## Example
18//!
19//! ```rust
20//! use oxify_vector::ivf::{IvfPqIndex, IvfPqConfig};
21//! use std::collections::HashMap;
22//!
23//! # fn example() -> anyhow::Result<()> {
24//! let config = IvfPqConfig::default()
25//!     .with_nclusters(256)
26//!     .with_nsubvectors(64)
27//!     .with_nprobe(16);
28//!
29//! let mut index = IvfPqIndex::new(config);
30//!
31//! let mut vectors = HashMap::new();
32//! vectors.insert("doc1".to_string(), vec![0.1; 768]);
33//! vectors.insert("doc2".to_string(), vec![0.2; 768]);
34//!
35//! index.build(&vectors)?;
36//!
37//! let query = vec![0.15; 768];
38//! let results = index.search(&query, 10)?;
39//! # Ok(())
40//! # }
41//! ```
42
43use anyhow::{Context, Result};
44use rand::Rng;
45use rayon::prelude::*;
46use serde::{Deserialize, Serialize};
47use std::collections::HashMap;
48
49use crate::simd;
50use crate::types::{DistanceMetric, SearchResult};
51
52/// IVF-PQ configuration
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct IvfPqConfig {
55    /// Number of clusters (partitions) for IVF
56    /// Typical values: 256, 1024, 4096
57    /// More clusters = faster search but more memory
58    pub nclusters: usize,
59
60    /// Number of sub-vectors for product quantization
61    /// Typical values: 8, 16, 32, 64
62    /// More sub-vectors = better accuracy but more memory
63    pub nsubvectors: usize,
64
65    /// Number of bits per sub-vector quantizer
66    /// Typical value: 8 (256 centroids per sub-quantizer)
67    pub nbits: usize,
68
69    /// Number of clusters to probe during search
70    /// Typical values: 1, 4, 16, 64
71    /// More probes = better recall but slower search
72    pub nprobe: usize,
73
74    /// Distance metric for clustering and search
75    pub metric: DistanceMetric,
76
77    /// Max iterations for k-means clustering
78    pub max_kmeans_iterations: usize,
79
80    /// Convergence threshold for k-means
81    pub kmeans_tolerance: f32,
82}
83
84impl Default for IvfPqConfig {
85    fn default() -> Self {
86        Self {
87            nclusters: 256,
88            nsubvectors: 64,
89            nbits: 8,
90            nprobe: 16,
91            metric: DistanceMetric::Cosine,
92            max_kmeans_iterations: 100,
93            kmeans_tolerance: 1e-4,
94        }
95    }
96}
97
98impl IvfPqConfig {
99    pub fn with_nclusters(mut self, nclusters: usize) -> Self {
100        self.nclusters = nclusters;
101        self
102    }
103
104    pub fn with_nsubvectors(mut self, nsubvectors: usize) -> Self {
105        self.nsubvectors = nsubvectors;
106        self
107    }
108
109    pub fn with_nbits(mut self, nbits: usize) -> Self {
110        self.nbits = nbits;
111        self
112    }
113
114    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
115        self.nprobe = nprobe;
116        self
117    }
118
119    pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
120        self.metric = metric;
121        self
122    }
123}
124
125/// Product quantizer for compressing vectors
126#[derive(Debug, Clone, Serialize, Deserialize)]
127struct ProductQuantizer {
128    /// Number of sub-vectors
129    nsubvectors: usize,
130    /// Dimension of each sub-vector
131    subvector_dim: usize,
132    /// Codebooks for each sub-vector (nsubvectors x ncentroids x subvector_dim)
133    codebooks: Vec<Vec<Vec<f32>>>,
134    /// Number of centroids per sub-quantizer (2^nbits)
135    ncentroids: usize,
136}
137
138impl ProductQuantizer {
139    fn new(dim: usize, nsubvectors: usize, nbits: usize) -> Result<Self> {
140        if !dim.is_multiple_of(nsubvectors) {
141            anyhow::bail!(
142                "Vector dimension {} must be divisible by number of sub-vectors {}",
143                dim,
144                nsubvectors
145            );
146        }
147
148        let subvector_dim = dim / nsubvectors;
149        let ncentroids = 1 << nbits; // 2^nbits
150
151        Ok(Self {
152            nsubvectors,
153            subvector_dim,
154            codebooks: vec![],
155            ncentroids,
156        })
157    }
158
159    /// Train product quantizer on a set of vectors
160    fn train(&mut self, vectors: &[Vec<f32>], iterations: usize) -> Result<()> {
161        self.codebooks.clear();
162
163        for subvec_idx in 0..self.nsubvectors {
164            let start = subvec_idx * self.subvector_dim;
165            let end = start + self.subvector_dim;
166
167            // Extract sub-vectors for this dimension
168            let subvectors: Vec<Vec<f32>> =
169                vectors.iter().map(|v| v[start..end].to_vec()).collect();
170
171            // Run k-means clustering on sub-vectors
172            let centroids = kmeans(&subvectors, self.ncentroids, iterations)?;
173            self.codebooks.push(centroids);
174        }
175
176        Ok(())
177    }
178
179    /// Encode a vector into quantized codes
180    fn encode(&self, vector: &[f32]) -> Vec<u8> {
181        let mut codes = Vec::with_capacity(self.nsubvectors);
182
183        for subvec_idx in 0..self.nsubvectors {
184            let start = subvec_idx * self.subvector_dim;
185            let end = start + self.subvector_dim;
186            let subvector = &vector[start..end];
187
188            // Find nearest centroid
189            let mut best_idx = 0;
190            let mut best_dist = f32::MAX;
191
192            for (centroid_idx, centroid) in self.codebooks[subvec_idx].iter().enumerate() {
193                let dist = euclidean_distance(subvector, centroid);
194                if dist < best_dist {
195                    best_dist = dist;
196                    best_idx = centroid_idx;
197                }
198            }
199
200            codes.push(best_idx as u8);
201        }
202
203        codes
204    }
205
206    /// Compute asymmetric distance between query vector and quantized vector
207    fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
208        let mut total_dist = 0.0;
209
210        #[allow(clippy::needless_range_loop)]
211        for subvec_idx in 0..self.nsubvectors {
212            let start = subvec_idx * self.subvector_dim;
213            let end = start + self.subvector_dim;
214            let query_subvector = &query[start..end];
215
216            let code = codes[subvec_idx] as usize;
217            let centroid = &self.codebooks[subvec_idx][code];
218
219            total_dist += euclidean_distance(query_subvector, centroid);
220        }
221
222        total_dist
223    }
224}
225
226/// IVF-PQ index for memory-efficient ANN search
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct IvfPqIndex {
229    config: IvfPqConfig,
230    /// Cluster centroids (coarse quantizer)
231    centroids: Vec<Vec<f32>>,
232    /// Inverted lists: cluster_id -> list of (entity_id, quantized_codes)
233    inverted_lists: Vec<Vec<(String, Vec<u8>)>>,
234    /// Product quantizer for fine quantization
235    pq: Option<ProductQuantizer>,
236    /// Original vector dimension
237    dim: Option<usize>,
238    /// Total number of indexed vectors
239    size: usize,
240}
241
242impl IvfPqIndex {
243    pub fn new(config: IvfPqConfig) -> Self {
244        Self {
245            config,
246            centroids: Vec::new(),
247            inverted_lists: Vec::new(),
248            pq: None,
249            dim: None,
250            size: 0,
251        }
252    }
253
254    /// Build the index from a collection of vectors
255    pub fn build(&mut self, vectors: &HashMap<String, Vec<f32>>) -> Result<()> {
256        if vectors.is_empty() {
257            anyhow::bail!("Cannot build index with empty vector collection");
258        }
259
260        // Get dimension from first vector
261        let dim = vectors.values().next().unwrap().len();
262        self.dim = Some(dim);
263
264        let vec_list: Vec<Vec<f32>> = vectors.values().cloned().collect();
265
266        // Step 1: Train coarse quantizer (IVF)
267        println!(
268            "Training coarse quantizer ({} clusters)...",
269            self.config.nclusters
270        );
271        self.centroids = kmeans(
272            &vec_list,
273            self.config.nclusters,
274            self.config.max_kmeans_iterations,
275        )
276        .context("Failed to train coarse quantizer")?;
277
278        // Step 2: Train product quantizer (PQ)
279        println!(
280            "Training product quantizer ({} sub-vectors)...",
281            self.config.nsubvectors
282        );
283        let mut pq = ProductQuantizer::new(dim, self.config.nsubvectors, self.config.nbits)?;
284        pq.train(&vec_list, 50)?; // Fewer iterations for PQ
285        self.pq = Some(pq);
286
287        // Step 3: Assign vectors to clusters and quantize
288        println!("Assigning vectors to clusters and quantizing...");
289        self.inverted_lists = vec![Vec::new(); self.config.nclusters];
290
291        for (entity_id, vector) in vectors {
292            // Find nearest cluster
293            let cluster_id = self.assign_to_cluster(vector);
294
295            // Quantize vector with PQ
296            let codes = self.pq.as_ref().unwrap().encode(vector);
297
298            // Add to inverted list
299            self.inverted_lists[cluster_id].push((entity_id.clone(), codes));
300        }
301
302        self.size = vectors.len();
303
304        println!(
305            "Index built: {} vectors in {} clusters",
306            self.size, self.config.nclusters
307        );
308
309        Ok(())
310    }
311
312    /// Assign a vector to the nearest cluster
313    fn assign_to_cluster(&self, vector: &[f32]) -> usize {
314        let mut best_idx = 0;
315        let mut best_dist = f32::MAX;
316
317        for (idx, centroid) in self.centroids.iter().enumerate() {
318            let dist = compute_distance(&self.config.metric, vector, centroid);
319            if dist < best_dist {
320                best_dist = dist;
321                best_idx = idx;
322            }
323        }
324
325        best_idx
326    }
327
328    /// Search for k nearest neighbors
329    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
330        if self.pq.is_none() {
331            anyhow::bail!("Index not built yet");
332        }
333
334        // Step 1: Find nprobe nearest clusters
335        let mut cluster_distances: Vec<(usize, f32)> = self
336            .centroids
337            .iter()
338            .enumerate()
339            .map(|(idx, centroid)| (idx, compute_distance(&self.config.metric, query, centroid)))
340            .collect();
341
342        cluster_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
343
344        let probe_clusters: Vec<usize> = cluster_distances
345            .iter()
346            .take(self.config.nprobe.min(self.centroids.len()))
347            .map(|(idx, _)| *idx)
348            .collect();
349
350        // Step 2: Search within probed clusters using asymmetric distance
351        let pq = self.pq.as_ref().unwrap();
352        let mut candidates = Vec::new();
353
354        for cluster_id in probe_clusters {
355            for (entity_id, codes) in &self.inverted_lists[cluster_id] {
356                let dist = pq.asymmetric_distance(query, codes);
357                candidates.push(SearchResult {
358                    entity_id: entity_id.clone(),
359                    score: dist,
360                    distance: dist,
361                    rank: 0, // Will be set below
362                });
363            }
364        }
365
366        // Step 3: Sort and return top-k
367        candidates.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
368
369        let results: Vec<SearchResult> = candidates
370            .into_iter()
371            .take(k)
372            .enumerate()
373            .map(|(rank, mut r)| {
374                r.distance = r.score;
375                r.rank = rank + 1;
376                r
377            })
378            .collect();
379
380        Ok(results)
381    }
382
383    /// Get index statistics
384    pub fn stats(&self) -> IvfPqStats {
385        let avg_list_size = if self.centroids.is_empty() {
386            0.0
387        } else {
388            self.size as f32 / self.centroids.len() as f32
389        };
390
391        let memory_bytes = self.estimate_memory();
392
393        IvfPqStats {
394            nclusters: self.centroids.len(),
395            nvectors: self.size,
396            dimension: self.dim.unwrap_or(0),
397            avg_list_size,
398            memory_bytes,
399            compression_ratio: self.compression_ratio(),
400        }
401    }
402
403    fn estimate_memory(&self) -> usize {
404        // Centroids: nclusters * dim * 4 bytes
405        let centroids_mem = self.centroids.len() * self.dim.unwrap_or(0) * 4;
406
407        // Inverted lists: nvectors * nsubvectors * 1 byte (u8 codes)
408        let inverted_mem = self.size * self.config.nsubvectors;
409
410        // PQ codebooks: nsubvectors * ncentroids * subvector_dim * 4 bytes
411        let pq_mem = if let Some(pq) = &self.pq {
412            pq.nsubvectors * pq.ncentroids * pq.subvector_dim * 4
413        } else {
414            0
415        };
416
417        centroids_mem + inverted_mem + pq_mem
418    }
419
420    fn compression_ratio(&self) -> f32 {
421        if self.size == 0 || self.dim.is_none() {
422            return 0.0;
423        }
424
425        let original_size = self.size * self.dim.unwrap() * 4; // f32 = 4 bytes
426        let compressed_size = self.estimate_memory();
427
428        original_size as f32 / compressed_size as f32
429    }
430}
431
432/// IVF-PQ index statistics
433#[derive(Debug, Clone)]
434pub struct IvfPqStats {
435    pub nclusters: usize,
436    pub nvectors: usize,
437    pub dimension: usize,
438    pub avg_list_size: f32,
439    pub memory_bytes: usize,
440    pub compression_ratio: f32,
441}
442
443/// K-means clustering algorithm
444fn kmeans(vectors: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
445    if vectors.is_empty() {
446        anyhow::bail!("Cannot run k-means on empty vector set");
447    }
448
449    let dim = vectors[0].len();
450    let n = vectors.len();
451
452    if k > n {
453        anyhow::bail!("Number of clusters {} exceeds number of vectors {}", k, n);
454    }
455
456    let mut rng = rand::rng();
457
458    // Initialize centroids randomly (k-means++)
459    let mut centroids = Vec::with_capacity(k);
460    let first_idx = rng.random_range(0..n);
461    centroids.push(vectors[first_idx].clone());
462
463    for _ in 1..k {
464        // Compute distance to nearest centroid for each vector
465        let distances: Vec<f32> = vectors
466            .iter()
467            .map(|v| {
468                centroids
469                    .iter()
470                    .map(|c| euclidean_distance(v, c))
471                    .fold(f32::MAX, f32::min)
472            })
473            .collect();
474
475        // Select next centroid with probability proportional to distance^2
476        let total: f32 = distances.iter().map(|d| d * d).sum();
477        let mut threshold = rng.random_range(0.0..total);
478
479        for (idx, &dist) in distances.iter().enumerate() {
480            threshold -= dist * dist;
481            if threshold <= 0.0 {
482                centroids.push(vectors[idx].clone());
483                break;
484            }
485        }
486    }
487
488    // Run k-means iterations
489    for _iter in 0..max_iterations {
490        // Assign vectors to nearest centroid
491        let assignments: Vec<usize> = vectors
492            .par_iter()
493            .map(|v| {
494                centroids
495                    .iter()
496                    .enumerate()
497                    .map(|(idx, c)| (idx, euclidean_distance(v, c)))
498                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
499                    .unwrap()
500                    .0
501            })
502            .collect();
503
504        // Update centroids
505        let mut new_centroids = vec![vec![0.0; dim]; k];
506        let mut counts = vec![0; k];
507
508        for (vec, &cluster_id) in vectors.iter().zip(&assignments) {
509            for (i, &val) in vec.iter().enumerate() {
510                new_centroids[cluster_id][i] += val;
511            }
512            counts[cluster_id] += 1;
513        }
514
515        // Average to get new centroids
516        for (centroid, count) in new_centroids.iter_mut().zip(&counts) {
517            if *count > 0 {
518                for val in centroid.iter_mut() {
519                    *val /= *count as f32;
520                }
521            }
522        }
523
524        // Check for convergence (centroid movement)
525        let mut total_movement = 0.0;
526        for (old, new) in centroids.iter().zip(&new_centroids) {
527            total_movement += euclidean_distance(old, new);
528        }
529
530        centroids = new_centroids;
531
532        if total_movement < 0.001 {
533            break;
534        }
535    }
536
537    Ok(centroids)
538}
539
540/// Euclidean distance between two vectors
541///
542/// Uses SIMD-optimized calculation for better performance.
543#[inline]
544fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
545    simd::euclidean_distance_simd(a, b)
546}
547
548/// Compute distance based on metric
549///
550/// Uses SIMD-optimized distance calculations for better performance.
551#[inline]
552fn compute_distance(metric: &DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
553    // Use SIMD-optimized implementations for hot path performance
554    simd::compute_distance_lower_is_better_simd(*metric, a, b)
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_ivf_pq_creation() {
563        let config = IvfPqConfig::default()
564            .with_nclusters(16)
565            .with_nsubvectors(8);
566
567        let index = IvfPqIndex::new(config);
568        assert_eq!(index.config.nclusters, 16);
569        assert_eq!(index.config.nsubvectors, 8);
570    }
571
572    #[test]
573    fn test_product_quantizer() {
574        let dim = 64;
575        let nsubvectors = 8;
576        let nbits = 8;
577
578        let pq = ProductQuantizer::new(dim, nsubvectors, nbits);
579        assert!(pq.is_ok());
580
581        let pq = pq.unwrap();
582        assert_eq!(pq.subvector_dim, 8);
583        assert_eq!(pq.ncentroids, 256);
584    }
585
586    #[test]
587    fn test_kmeans_basic() {
588        let vectors = vec![
589            vec![1.0, 0.0],
590            vec![1.1, 0.1],
591            vec![0.0, 1.0],
592            vec![0.1, 1.1],
593        ];
594
595        let centroids = kmeans(&vectors, 2, 10);
596        assert!(centroids.is_ok());
597
598        let centroids = centroids.unwrap();
599        assert_eq!(centroids.len(), 2);
600    }
601
602    #[test]
603    fn test_ivf_pq_build_and_search() {
604        // Create a dataset with 300 vectors (more than 256 needed for 8-bit quantization)
605        let mut vectors = HashMap::new();
606        for i in 0..300 {
607            let vec: Vec<f32> = (0..64).map(|j| (i + j) as f32 / 300.0).collect();
608            vectors.insert(format!("doc{}", i), vec);
609        }
610
611        // Build index with small cluster count for testing
612        let config = IvfPqConfig::default()
613            .with_nclusters(8)
614            .with_nsubvectors(8)
615            .with_nbits(4) // 4-bit = 16 centroids per sub-quantizer
616            .with_nprobe(2);
617
618        let mut index = IvfPqIndex::new(config);
619        let build_result = index.build(&vectors);
620        if let Err(e) = &build_result {
621            panic!("Build failed: {}", e);
622        }
623
624        // Search for a vector similar to doc150
625        let query = vectors.get("doc150").unwrap().clone();
626        let results = index.search(&query, 5);
627        assert!(results.is_ok());
628
629        let results = results.unwrap();
630        assert_eq!(results.len(), 5);
631
632        // The nearest neighbor should be doc150 itself or very close to it
633        assert!(results[0].entity_id.starts_with("doc"));
634    }
635
636    #[test]
637    fn test_ivf_pq_nprobe_effect() {
638        // Create a dataset with 300 vectors
639        let mut vectors = HashMap::new();
640        for i in 0..300 {
641            let vec: Vec<f32> = (0..64).map(|j| (i + j) as f32 / 300.0).collect();
642            vectors.insert(format!("doc{}", i), vec);
643        }
644
645        // Build with nprobe=1
646        let config1 = IvfPqConfig::default()
647            .with_nclusters(4)
648            .with_nsubvectors(8)
649            .with_nbits(4) // 4-bit quantization
650            .with_nprobe(1);
651
652        let mut index1 = IvfPqIndex::new(config1);
653        assert!(index1.build(&vectors).is_ok());
654
655        // Build with nprobe=4 (search all clusters)
656        let config2 = IvfPqConfig::default()
657            .with_nclusters(4)
658            .with_nsubvectors(8)
659            .with_nbits(4) // 4-bit quantization
660            .with_nprobe(4);
661
662        let mut index2 = IvfPqIndex::new(config2);
663        assert!(index2.build(&vectors).is_ok());
664
665        // Search with same query
666        let query = vectors.get("doc150").unwrap().clone();
667        let results1 = index1.search(&query, 5).unwrap();
668        let results2 = index2.search(&query, 5).unwrap();
669
670        // Both should return results
671        assert_eq!(results1.len(), 5);
672        assert_eq!(results2.len(), 5);
673
674        // Higher nprobe should generally give better results (though not guaranteed in small dataset)
675        assert!(results1[0].score >= 0.0);
676        assert!(results2[0].score >= 0.0);
677    }
678
679    #[test]
680    fn test_ivf_pq_stats() {
681        let mut vectors = HashMap::new();
682        for i in 0..300 {
683            let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 300.0).collect();
684            vectors.insert(format!("doc{}", i), vec);
685        }
686
687        let config = IvfPqConfig::default()
688            .with_nclusters(10)
689            .with_nsubvectors(16)
690            .with_nbits(4); // 4-bit quantization
691
692        let mut index = IvfPqIndex::new(config);
693        assert!(index.build(&vectors).is_ok());
694
695        let stats = index.stats();
696        assert_eq!(stats.nclusters, 10);
697        assert_eq!(stats.nvectors, 300);
698        assert_eq!(stats.dimension, 128);
699        assert!(stats.avg_list_size > 0.0);
700        assert!(stats.memory_bytes > 0);
701        assert!(stats.compression_ratio > 1.0); // Should be compressed
702    }
703
704    #[test]
705    fn test_ivf_pq_compression_ratio() {
706        // Optimized test with reduced parameters for fast execution
707        let mut vectors = HashMap::new();
708        for i in 0..200 {
709            let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 200.0).collect();
710            vectors.insert(format!("doc{}", i), vec);
711        }
712
713        let config = IvfPqConfig {
714            nclusters: 8,
715            nsubvectors: 8,            // Reduced from 64 to 8 (8x fewer k-means!)
716            nbits: 4,                  // 4-bit = 16 centroids (vs 64)
717            max_kmeans_iterations: 20, // Reduced from 100
718            ..IvfPqConfig::default()
719        };
720
721        let mut index = IvfPqIndex::new(config);
722        assert!(index.build(&vectors).is_ok());
723
724        let stats = index.stats();
725
726        // Original size: 200 vectors * 128 dims * 4 bytes = 102,400 bytes
727        // Compressed should be significantly smaller
728        let original_size = 200 * 128 * 4;
729        assert!(stats.memory_bytes < original_size);
730
731        // Compression ratio should be > 1
732        assert!(stats.compression_ratio > 1.0);
733
734        println!(
735            "Compression: {:.2}x (original: {} bytes, compressed: {} bytes)",
736            stats.compression_ratio, original_size, stats.memory_bytes
737        );
738    }
739
740    #[test]
741    #[ignore]
742    fn test_ivf_pq_compression_ratio_full() {
743        // Slow comprehensive test with production-scale parameters (75s+)
744        // Run with: cargo test test_ivf_pq_compression_ratio_full -- --ignored
745        let mut vectors = HashMap::new();
746        for i in 0..500 {
747            let vec: Vec<f32> = (0..768).map(|j| (i + j) as f32 / 500.0).collect();
748            vectors.insert(format!("doc{}", i), vec);
749        }
750
751        let config = IvfPqConfig::default()
752            .with_nclusters(16)
753            .with_nsubvectors(64)
754            .with_nbits(6); // 6-bit = 64 centroids per sub-quantizer
755
756        let mut index = IvfPqIndex::new(config);
757        assert!(index.build(&vectors).is_ok());
758
759        let stats = index.stats();
760
761        // Original size: 500 vectors * 768 dims * 4 bytes = 1,536,000 bytes
762        // Compressed should be significantly smaller
763        let original_size = 500 * 768 * 4;
764        assert!(stats.memory_bytes < original_size);
765
766        // Compression ratio should be > 1
767        assert!(stats.compression_ratio > 1.0);
768
769        println!(
770            "Compression: {:.2}x (original: {} bytes, compressed: {} bytes)",
771            stats.compression_ratio, original_size, stats.memory_bytes
772        );
773    }
774
775    #[test]
776    fn test_ivf_pq_empty_vectors_error() {
777        let vectors = HashMap::new();
778        let config = IvfPqConfig::default();
779        let mut index = IvfPqIndex::new(config);
780
781        let result = index.build(&vectors);
782        assert!(result.is_err());
783        assert!(result
784            .unwrap_err()
785            .to_string()
786            .contains("Cannot build index with empty vector collection"));
787    }
788
789    #[test]
790    fn test_ivf_pq_search_before_build_error() {
791        let config = IvfPqConfig::default();
792        let index = IvfPqIndex::new(config);
793
794        let query = vec![0.1; 64];
795        let result = index.search(&query, 10);
796
797        assert!(result.is_err());
798        assert!(result.unwrap_err().to_string().contains("Index not built"));
799    }
800
801    #[test]
802    fn test_ivf_pq_invalid_dimension_error() {
803        let _config = IvfPqConfig::default().with_nsubvectors(8);
804
805        // 65 is not divisible by 8
806        let pq = ProductQuantizer::new(65, 8, 8);
807        assert!(pq.is_err());
808        assert!(pq.unwrap_err().to_string().contains("must be divisible by"));
809    }
810
811    #[test]
812    fn test_ivf_pq_different_metrics() {
813        let mut vectors = HashMap::new();
814        for i in 0..300 {
815            let vec: Vec<f32> = (0..64).map(|j| (i + j) as f32 / 300.0).collect();
816            vectors.insert(format!("doc{}", i), vec);
817        }
818
819        let query = vectors.get("doc150").unwrap().clone();
820
821        // Test with different distance metrics
822        let metrics = vec![
823            DistanceMetric::Cosine,
824            DistanceMetric::Euclidean,
825            DistanceMetric::DotProduct,
826            DistanceMetric::Manhattan,
827        ];
828
829        for metric in metrics {
830            let config = IvfPqConfig::default()
831                .with_nclusters(4)
832                .with_nsubvectors(8)
833                .with_nbits(4) // 4-bit quantization
834                .with_metric(metric);
835
836            let mut index = IvfPqIndex::new(config);
837            assert!(index.build(&vectors).is_ok());
838
839            let results = index.search(&query, 3);
840            assert!(results.is_ok());
841
842            let results = results.unwrap();
843            assert_eq!(results.len(), 3);
844        }
845    }
846
847    #[test]
848    fn test_product_quantizer_encode_decode() {
849        let dim = 64;
850        let nsubvectors = 8;
851        let nbits = 4; // 4-bit = 16 centroids per sub-quantizer
852
853        let mut pq = ProductQuantizer::new(dim, nsubvectors, nbits).unwrap();
854
855        // Create training vectors (need at least 16 vectors for 4-bit quantization)
856        let mut train_vectors = Vec::new();
857        for i in 0..100 {
858            let vec: Vec<f32> = (0..dim).map(|j| (i + j) as f32 / 100.0).collect();
859            train_vectors.push(vec);
860        }
861
862        // Train the quantizer
863        let train_result = pq.train(&train_vectors, 20);
864        if let Err(e) = &train_result {
865            panic!("PQ training failed: {}", e);
866        }
867
868        // Encode a vector
869        let test_vector: Vec<f32> = (0..dim).map(|i| i as f32 / 64.0).collect();
870        let codes = pq.encode(&test_vector);
871
872        // Should have one code per sub-vector
873        assert_eq!(codes.len(), nsubvectors);
874
875        // All codes should be valid (< 16 for 4-bit)
876        for &code in &codes {
877            assert!((code as usize) < pq.ncentroids);
878        }
879
880        // Compute asymmetric distance
881        let distance = pq.asymmetric_distance(&test_vector, &codes);
882        assert!(distance >= 0.0);
883    }
884
885    #[test]
886    fn test_kmeans_convergence() {
887        // Create two well-separated clusters
888        let mut vectors = Vec::new();
889
890        // Cluster 1: around (1, 1)
891        for i in 0..20 {
892            vectors.push(vec![1.0 + (i as f32) * 0.01, 1.0 + (i as f32) * 0.01]);
893        }
894
895        // Cluster 2: around (10, 10)
896        for i in 0..20 {
897            vectors.push(vec![10.0 + (i as f32) * 0.01, 10.0 + (i as f32) * 0.01]);
898        }
899
900        let centroids = kmeans(&vectors, 2, 50).unwrap();
901        assert_eq!(centroids.len(), 2);
902
903        // Centroids should be roughly at (1, 1) and (10, 10)
904        let mut has_low_centroid = false;
905        let mut has_high_centroid = false;
906
907        for centroid in &centroids {
908            if centroid[0] < 5.0 {
909                has_low_centroid = true;
910                assert!(centroid[0] > 0.5 && centroid[0] < 1.5);
911            } else {
912                has_high_centroid = true;
913                assert!(centroid[0] > 9.5 && centroid[0] < 10.5);
914            }
915        }
916
917        assert!(has_low_centroid);
918        assert!(has_high_centroid);
919    }
920
921    #[test]
922    fn test_kmeans_error_cases() {
923        // Test empty vectors
924        let empty_vectors: Vec<Vec<f32>> = vec![];
925        let result = kmeans(&empty_vectors, 2, 10);
926        assert!(result.is_err());
927
928        // Test k > n
929        let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
930        let result = kmeans(&vectors, 5, 10);
931        assert!(result.is_err());
932        assert!(result.unwrap_err().to_string().contains("exceeds"));
933    }
934}