1use anyhow::{Context, Result};
44use rand::Rng;
45use rayon::prelude::*;
46use serde::{Deserialize, Serialize};
47use std::collections::HashMap;
48
49use crate::simd;
50use crate::types::{DistanceMetric, SearchResult};
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct IvfPqConfig {
55 pub nclusters: usize,
59
60 pub nsubvectors: usize,
64
65 pub nbits: usize,
68
69 pub nprobe: usize,
73
74 pub metric: DistanceMetric,
76
77 pub max_kmeans_iterations: usize,
79
80 pub kmeans_tolerance: f32,
82}
83
84impl Default for IvfPqConfig {
85 fn default() -> Self {
86 Self {
87 nclusters: 256,
88 nsubvectors: 64,
89 nbits: 8,
90 nprobe: 16,
91 metric: DistanceMetric::Cosine,
92 max_kmeans_iterations: 100,
93 kmeans_tolerance: 1e-4,
94 }
95 }
96}
97
98impl IvfPqConfig {
99 pub fn with_nclusters(mut self, nclusters: usize) -> Self {
100 self.nclusters = nclusters;
101 self
102 }
103
104 pub fn with_nsubvectors(mut self, nsubvectors: usize) -> Self {
105 self.nsubvectors = nsubvectors;
106 self
107 }
108
109 pub fn with_nbits(mut self, nbits: usize) -> Self {
110 self.nbits = nbits;
111 self
112 }
113
114 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
115 self.nprobe = nprobe;
116 self
117 }
118
119 pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
120 self.metric = metric;
121 self
122 }
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127struct ProductQuantizer {
128 nsubvectors: usize,
130 subvector_dim: usize,
132 codebooks: Vec<Vec<Vec<f32>>>,
134 ncentroids: usize,
136}
137
138impl ProductQuantizer {
139 fn new(dim: usize, nsubvectors: usize, nbits: usize) -> Result<Self> {
140 if !dim.is_multiple_of(nsubvectors) {
141 anyhow::bail!(
142 "Vector dimension {} must be divisible by number of sub-vectors {}",
143 dim,
144 nsubvectors
145 );
146 }
147
148 let subvector_dim = dim / nsubvectors;
149 let ncentroids = 1 << nbits; Ok(Self {
152 nsubvectors,
153 subvector_dim,
154 codebooks: vec![],
155 ncentroids,
156 })
157 }
158
159 fn train(&mut self, vectors: &[Vec<f32>], iterations: usize) -> Result<()> {
161 self.codebooks.clear();
162
163 for subvec_idx in 0..self.nsubvectors {
164 let start = subvec_idx * self.subvector_dim;
165 let end = start + self.subvector_dim;
166
167 let subvectors: Vec<Vec<f32>> =
169 vectors.iter().map(|v| v[start..end].to_vec()).collect();
170
171 let centroids = kmeans(&subvectors, self.ncentroids, iterations)?;
173 self.codebooks.push(centroids);
174 }
175
176 Ok(())
177 }
178
179 fn encode(&self, vector: &[f32]) -> Vec<u8> {
181 let mut codes = Vec::with_capacity(self.nsubvectors);
182
183 for subvec_idx in 0..self.nsubvectors {
184 let start = subvec_idx * self.subvector_dim;
185 let end = start + self.subvector_dim;
186 let subvector = &vector[start..end];
187
188 let mut best_idx = 0;
190 let mut best_dist = f32::MAX;
191
192 for (centroid_idx, centroid) in self.codebooks[subvec_idx].iter().enumerate() {
193 let dist = euclidean_distance(subvector, centroid);
194 if dist < best_dist {
195 best_dist = dist;
196 best_idx = centroid_idx;
197 }
198 }
199
200 codes.push(best_idx as u8);
201 }
202
203 codes
204 }
205
206 fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
208 let mut total_dist = 0.0;
209
210 #[allow(clippy::needless_range_loop)]
211 for subvec_idx in 0..self.nsubvectors {
212 let start = subvec_idx * self.subvector_dim;
213 let end = start + self.subvector_dim;
214 let query_subvector = &query[start..end];
215
216 let code = codes[subvec_idx] as usize;
217 let centroid = &self.codebooks[subvec_idx][code];
218
219 total_dist += euclidean_distance(query_subvector, centroid);
220 }
221
222 total_dist
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct IvfPqIndex {
229 config: IvfPqConfig,
230 centroids: Vec<Vec<f32>>,
232 inverted_lists: Vec<Vec<(String, Vec<u8>)>>,
234 pq: Option<ProductQuantizer>,
236 dim: Option<usize>,
238 size: usize,
240}
241
242impl IvfPqIndex {
243 pub fn new(config: IvfPqConfig) -> Self {
244 Self {
245 config,
246 centroids: Vec::new(),
247 inverted_lists: Vec::new(),
248 pq: None,
249 dim: None,
250 size: 0,
251 }
252 }
253
254 pub fn build(&mut self, vectors: &HashMap<String, Vec<f32>>) -> Result<()> {
256 if vectors.is_empty() {
257 anyhow::bail!("Cannot build index with empty vector collection");
258 }
259
260 let dim = vectors.values().next().unwrap().len();
262 self.dim = Some(dim);
263
264 let vec_list: Vec<Vec<f32>> = vectors.values().cloned().collect();
265
266 println!(
268 "Training coarse quantizer ({} clusters)...",
269 self.config.nclusters
270 );
271 self.centroids = kmeans(
272 &vec_list,
273 self.config.nclusters,
274 self.config.max_kmeans_iterations,
275 )
276 .context("Failed to train coarse quantizer")?;
277
278 println!(
280 "Training product quantizer ({} sub-vectors)...",
281 self.config.nsubvectors
282 );
283 let mut pq = ProductQuantizer::new(dim, self.config.nsubvectors, self.config.nbits)?;
284 pq.train(&vec_list, 50)?; self.pq = Some(pq);
286
287 println!("Assigning vectors to clusters and quantizing...");
289 self.inverted_lists = vec![Vec::new(); self.config.nclusters];
290
291 for (entity_id, vector) in vectors {
292 let cluster_id = self.assign_to_cluster(vector);
294
295 let codes = self.pq.as_ref().unwrap().encode(vector);
297
298 self.inverted_lists[cluster_id].push((entity_id.clone(), codes));
300 }
301
302 self.size = vectors.len();
303
304 println!(
305 "Index built: {} vectors in {} clusters",
306 self.size, self.config.nclusters
307 );
308
309 Ok(())
310 }
311
312 fn assign_to_cluster(&self, vector: &[f32]) -> usize {
314 let mut best_idx = 0;
315 let mut best_dist = f32::MAX;
316
317 for (idx, centroid) in self.centroids.iter().enumerate() {
318 let dist = compute_distance(&self.config.metric, vector, centroid);
319 if dist < best_dist {
320 best_dist = dist;
321 best_idx = idx;
322 }
323 }
324
325 best_idx
326 }
327
328 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
330 if self.pq.is_none() {
331 anyhow::bail!("Index not built yet");
332 }
333
334 let mut cluster_distances: Vec<(usize, f32)> = self
336 .centroids
337 .iter()
338 .enumerate()
339 .map(|(idx, centroid)| (idx, compute_distance(&self.config.metric, query, centroid)))
340 .collect();
341
342 cluster_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
343
344 let probe_clusters: Vec<usize> = cluster_distances
345 .iter()
346 .take(self.config.nprobe.min(self.centroids.len()))
347 .map(|(idx, _)| *idx)
348 .collect();
349
350 let pq = self.pq.as_ref().unwrap();
352 let mut candidates = Vec::new();
353
354 for cluster_id in probe_clusters {
355 for (entity_id, codes) in &self.inverted_lists[cluster_id] {
356 let dist = pq.asymmetric_distance(query, codes);
357 candidates.push(SearchResult {
358 entity_id: entity_id.clone(),
359 score: dist,
360 distance: dist,
361 rank: 0, });
363 }
364 }
365
366 candidates.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
368
369 let results: Vec<SearchResult> = candidates
370 .into_iter()
371 .take(k)
372 .enumerate()
373 .map(|(rank, mut r)| {
374 r.distance = r.score;
375 r.rank = rank + 1;
376 r
377 })
378 .collect();
379
380 Ok(results)
381 }
382
383 pub fn stats(&self) -> IvfPqStats {
385 let avg_list_size = if self.centroids.is_empty() {
386 0.0
387 } else {
388 self.size as f32 / self.centroids.len() as f32
389 };
390
391 let memory_bytes = self.estimate_memory();
392
393 IvfPqStats {
394 nclusters: self.centroids.len(),
395 nvectors: self.size,
396 dimension: self.dim.unwrap_or(0),
397 avg_list_size,
398 memory_bytes,
399 compression_ratio: self.compression_ratio(),
400 }
401 }
402
403 fn estimate_memory(&self) -> usize {
404 let centroids_mem = self.centroids.len() * self.dim.unwrap_or(0) * 4;
406
407 let inverted_mem = self.size * self.config.nsubvectors;
409
410 let pq_mem = if let Some(pq) = &self.pq {
412 pq.nsubvectors * pq.ncentroids * pq.subvector_dim * 4
413 } else {
414 0
415 };
416
417 centroids_mem + inverted_mem + pq_mem
418 }
419
420 fn compression_ratio(&self) -> f32 {
421 if self.size == 0 || self.dim.is_none() {
422 return 0.0;
423 }
424
425 let original_size = self.size * self.dim.unwrap() * 4; let compressed_size = self.estimate_memory();
427
428 original_size as f32 / compressed_size as f32
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct IvfPqStats {
435 pub nclusters: usize,
436 pub nvectors: usize,
437 pub dimension: usize,
438 pub avg_list_size: f32,
439 pub memory_bytes: usize,
440 pub compression_ratio: f32,
441}
442
443fn kmeans(vectors: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
445 if vectors.is_empty() {
446 anyhow::bail!("Cannot run k-means on empty vector set");
447 }
448
449 let dim = vectors[0].len();
450 let n = vectors.len();
451
452 if k > n {
453 anyhow::bail!("Number of clusters {} exceeds number of vectors {}", k, n);
454 }
455
456 let mut rng = rand::rng();
457
458 let mut centroids = Vec::with_capacity(k);
460 let first_idx = rng.random_range(0..n);
461 centroids.push(vectors[first_idx].clone());
462
463 for _ in 1..k {
464 let distances: Vec<f32> = vectors
466 .iter()
467 .map(|v| {
468 centroids
469 .iter()
470 .map(|c| euclidean_distance(v, c))
471 .fold(f32::MAX, f32::min)
472 })
473 .collect();
474
475 let total: f32 = distances.iter().map(|d| d * d).sum();
477 let mut threshold = rng.random_range(0.0..total);
478
479 for (idx, &dist) in distances.iter().enumerate() {
480 threshold -= dist * dist;
481 if threshold <= 0.0 {
482 centroids.push(vectors[idx].clone());
483 break;
484 }
485 }
486 }
487
488 for _iter in 0..max_iterations {
490 let assignments: Vec<usize> = vectors
492 .par_iter()
493 .map(|v| {
494 centroids
495 .iter()
496 .enumerate()
497 .map(|(idx, c)| (idx, euclidean_distance(v, c)))
498 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
499 .unwrap()
500 .0
501 })
502 .collect();
503
504 let mut new_centroids = vec![vec![0.0; dim]; k];
506 let mut counts = vec![0; k];
507
508 for (vec, &cluster_id) in vectors.iter().zip(&assignments) {
509 for (i, &val) in vec.iter().enumerate() {
510 new_centroids[cluster_id][i] += val;
511 }
512 counts[cluster_id] += 1;
513 }
514
515 for (centroid, count) in new_centroids.iter_mut().zip(&counts) {
517 if *count > 0 {
518 for val in centroid.iter_mut() {
519 *val /= *count as f32;
520 }
521 }
522 }
523
524 let mut total_movement = 0.0;
526 for (old, new) in centroids.iter().zip(&new_centroids) {
527 total_movement += euclidean_distance(old, new);
528 }
529
530 centroids = new_centroids;
531
532 if total_movement < 0.001 {
533 break;
534 }
535 }
536
537 Ok(centroids)
538}
539
540#[inline]
544fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
545 simd::euclidean_distance_simd(a, b)
546}
547
548#[inline]
552fn compute_distance(metric: &DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
553 simd::compute_distance_lower_is_better_simd(*metric, a, b)
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_ivf_pq_creation() {
563 let config = IvfPqConfig::default()
564 .with_nclusters(16)
565 .with_nsubvectors(8);
566
567 let index = IvfPqIndex::new(config);
568 assert_eq!(index.config.nclusters, 16);
569 assert_eq!(index.config.nsubvectors, 8);
570 }
571
572 #[test]
573 fn test_product_quantizer() {
574 let dim = 64;
575 let nsubvectors = 8;
576 let nbits = 8;
577
578 let pq = ProductQuantizer::new(dim, nsubvectors, nbits);
579 assert!(pq.is_ok());
580
581 let pq = pq.unwrap();
582 assert_eq!(pq.subvector_dim, 8);
583 assert_eq!(pq.ncentroids, 256);
584 }
585
586 #[test]
587 fn test_kmeans_basic() {
588 let vectors = vec![
589 vec![1.0, 0.0],
590 vec![1.1, 0.1],
591 vec![0.0, 1.0],
592 vec![0.1, 1.1],
593 ];
594
595 let centroids = kmeans(&vectors, 2, 10);
596 assert!(centroids.is_ok());
597
598 let centroids = centroids.unwrap();
599 assert_eq!(centroids.len(), 2);
600 }
601
602 #[test]
603 fn test_ivf_pq_build_and_search() {
604 let mut vectors = HashMap::new();
606 for i in 0..300 {
607 let vec: Vec<f32> = (0..64).map(|j| (i + j) as f32 / 300.0).collect();
608 vectors.insert(format!("doc{}", i), vec);
609 }
610
611 let config = IvfPqConfig::default()
613 .with_nclusters(8)
614 .with_nsubvectors(8)
615 .with_nbits(4) .with_nprobe(2);
617
618 let mut index = IvfPqIndex::new(config);
619 let build_result = index.build(&vectors);
620 if let Err(e) = &build_result {
621 panic!("Build failed: {}", e);
622 }
623
624 let query = vectors.get("doc150").unwrap().clone();
626 let results = index.search(&query, 5);
627 assert!(results.is_ok());
628
629 let results = results.unwrap();
630 assert_eq!(results.len(), 5);
631
632 assert!(results[0].entity_id.starts_with("doc"));
634 }
635
636 #[test]
637 fn test_ivf_pq_nprobe_effect() {
638 let mut vectors = HashMap::new();
640 for i in 0..300 {
641 let vec: Vec<f32> = (0..64).map(|j| (i + j) as f32 / 300.0).collect();
642 vectors.insert(format!("doc{}", i), vec);
643 }
644
645 let config1 = IvfPqConfig::default()
647 .with_nclusters(4)
648 .with_nsubvectors(8)
649 .with_nbits(4) .with_nprobe(1);
651
652 let mut index1 = IvfPqIndex::new(config1);
653 assert!(index1.build(&vectors).is_ok());
654
655 let config2 = IvfPqConfig::default()
657 .with_nclusters(4)
658 .with_nsubvectors(8)
659 .with_nbits(4) .with_nprobe(4);
661
662 let mut index2 = IvfPqIndex::new(config2);
663 assert!(index2.build(&vectors).is_ok());
664
665 let query = vectors.get("doc150").unwrap().clone();
667 let results1 = index1.search(&query, 5).unwrap();
668 let results2 = index2.search(&query, 5).unwrap();
669
670 assert_eq!(results1.len(), 5);
672 assert_eq!(results2.len(), 5);
673
674 assert!(results1[0].score >= 0.0);
676 assert!(results2[0].score >= 0.0);
677 }
678
679 #[test]
680 fn test_ivf_pq_stats() {
681 let mut vectors = HashMap::new();
682 for i in 0..300 {
683 let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 300.0).collect();
684 vectors.insert(format!("doc{}", i), vec);
685 }
686
687 let config = IvfPqConfig::default()
688 .with_nclusters(10)
689 .with_nsubvectors(16)
690 .with_nbits(4); let mut index = IvfPqIndex::new(config);
693 assert!(index.build(&vectors).is_ok());
694
695 let stats = index.stats();
696 assert_eq!(stats.nclusters, 10);
697 assert_eq!(stats.nvectors, 300);
698 assert_eq!(stats.dimension, 128);
699 assert!(stats.avg_list_size > 0.0);
700 assert!(stats.memory_bytes > 0);
701 assert!(stats.compression_ratio > 1.0); }
703
704 #[test]
705 fn test_ivf_pq_compression_ratio() {
706 let mut vectors = HashMap::new();
708 for i in 0..200 {
709 let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 200.0).collect();
710 vectors.insert(format!("doc{}", i), vec);
711 }
712
713 let config = IvfPqConfig {
714 nclusters: 8,
715 nsubvectors: 8, nbits: 4, max_kmeans_iterations: 20, ..IvfPqConfig::default()
719 };
720
721 let mut index = IvfPqIndex::new(config);
722 assert!(index.build(&vectors).is_ok());
723
724 let stats = index.stats();
725
726 let original_size = 200 * 128 * 4;
729 assert!(stats.memory_bytes < original_size);
730
731 assert!(stats.compression_ratio > 1.0);
733
734 println!(
735 "Compression: {:.2}x (original: {} bytes, compressed: {} bytes)",
736 stats.compression_ratio, original_size, stats.memory_bytes
737 );
738 }
739
740 #[test]
741 #[ignore]
742 fn test_ivf_pq_compression_ratio_full() {
743 let mut vectors = HashMap::new();
746 for i in 0..500 {
747 let vec: Vec<f32> = (0..768).map(|j| (i + j) as f32 / 500.0).collect();
748 vectors.insert(format!("doc{}", i), vec);
749 }
750
751 let config = IvfPqConfig::default()
752 .with_nclusters(16)
753 .with_nsubvectors(64)
754 .with_nbits(6); let mut index = IvfPqIndex::new(config);
757 assert!(index.build(&vectors).is_ok());
758
759 let stats = index.stats();
760
761 let original_size = 500 * 768 * 4;
764 assert!(stats.memory_bytes < original_size);
765
766 assert!(stats.compression_ratio > 1.0);
768
769 println!(
770 "Compression: {:.2}x (original: {} bytes, compressed: {} bytes)",
771 stats.compression_ratio, original_size, stats.memory_bytes
772 );
773 }
774
775 #[test]
776 fn test_ivf_pq_empty_vectors_error() {
777 let vectors = HashMap::new();
778 let config = IvfPqConfig::default();
779 let mut index = IvfPqIndex::new(config);
780
781 let result = index.build(&vectors);
782 assert!(result.is_err());
783 assert!(result
784 .unwrap_err()
785 .to_string()
786 .contains("Cannot build index with empty vector collection"));
787 }
788
789 #[test]
790 fn test_ivf_pq_search_before_build_error() {
791 let config = IvfPqConfig::default();
792 let index = IvfPqIndex::new(config);
793
794 let query = vec![0.1; 64];
795 let result = index.search(&query, 10);
796
797 assert!(result.is_err());
798 assert!(result.unwrap_err().to_string().contains("Index not built"));
799 }
800
801 #[test]
802 fn test_ivf_pq_invalid_dimension_error() {
803 let _config = IvfPqConfig::default().with_nsubvectors(8);
804
805 let pq = ProductQuantizer::new(65, 8, 8);
807 assert!(pq.is_err());
808 assert!(pq.unwrap_err().to_string().contains("must be divisible by"));
809 }
810
811 #[test]
812 fn test_ivf_pq_different_metrics() {
813 let mut vectors = HashMap::new();
814 for i in 0..300 {
815 let vec: Vec<f32> = (0..64).map(|j| (i + j) as f32 / 300.0).collect();
816 vectors.insert(format!("doc{}", i), vec);
817 }
818
819 let query = vectors.get("doc150").unwrap().clone();
820
821 let metrics = vec![
823 DistanceMetric::Cosine,
824 DistanceMetric::Euclidean,
825 DistanceMetric::DotProduct,
826 DistanceMetric::Manhattan,
827 ];
828
829 for metric in metrics {
830 let config = IvfPqConfig::default()
831 .with_nclusters(4)
832 .with_nsubvectors(8)
833 .with_nbits(4) .with_metric(metric);
835
836 let mut index = IvfPqIndex::new(config);
837 assert!(index.build(&vectors).is_ok());
838
839 let results = index.search(&query, 3);
840 assert!(results.is_ok());
841
842 let results = results.unwrap();
843 assert_eq!(results.len(), 3);
844 }
845 }
846
847 #[test]
848 fn test_product_quantizer_encode_decode() {
849 let dim = 64;
850 let nsubvectors = 8;
851 let nbits = 4; let mut pq = ProductQuantizer::new(dim, nsubvectors, nbits).unwrap();
854
855 let mut train_vectors = Vec::new();
857 for i in 0..100 {
858 let vec: Vec<f32> = (0..dim).map(|j| (i + j) as f32 / 100.0).collect();
859 train_vectors.push(vec);
860 }
861
862 let train_result = pq.train(&train_vectors, 20);
864 if let Err(e) = &train_result {
865 panic!("PQ training failed: {}", e);
866 }
867
868 let test_vector: Vec<f32> = (0..dim).map(|i| i as f32 / 64.0).collect();
870 let codes = pq.encode(&test_vector);
871
872 assert_eq!(codes.len(), nsubvectors);
874
875 for &code in &codes {
877 assert!((code as usize) < pq.ncentroids);
878 }
879
880 let distance = pq.asymmetric_distance(&test_vector, &codes);
882 assert!(distance >= 0.0);
883 }
884
885 #[test]
886 fn test_kmeans_convergence() {
887 let mut vectors = Vec::new();
889
890 for i in 0..20 {
892 vectors.push(vec![1.0 + (i as f32) * 0.01, 1.0 + (i as f32) * 0.01]);
893 }
894
895 for i in 0..20 {
897 vectors.push(vec![10.0 + (i as f32) * 0.01, 10.0 + (i as f32) * 0.01]);
898 }
899
900 let centroids = kmeans(&vectors, 2, 50).unwrap();
901 assert_eq!(centroids.len(), 2);
902
903 let mut has_low_centroid = false;
905 let mut has_high_centroid = false;
906
907 for centroid in ¢roids {
908 if centroid[0] < 5.0 {
909 has_low_centroid = true;
910 assert!(centroid[0] > 0.5 && centroid[0] < 1.5);
911 } else {
912 has_high_centroid = true;
913 assert!(centroid[0] > 9.5 && centroid[0] < 10.5);
914 }
915 }
916
917 assert!(has_low_centroid);
918 assert!(has_high_centroid);
919 }
920
921 #[test]
922 fn test_kmeans_error_cases() {
923 let empty_vectors: Vec<Vec<f32>> = vec![];
925 let result = kmeans(&empty_vectors, 2, 10);
926 assert!(result.is_err());
927
928 let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
930 let result = kmeans(&vectors, 5, 10);
931 assert!(result.is_err());
932 assert!(result.unwrap_err().to_string().contains("exceeds"));
933 }
934}