1use common::{DistanceMetric, Vector};
13use parking_lot::RwLock;
14use rand::seq::SliceRandom;
15use std::collections::HashMap;
16
17use crate::pq::{PQConfig, ProductQuantizer};
18
19#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
21pub struct IvfPqConfig {
22 pub n_clusters: usize,
24 pub n_probe: usize,
26 pub pq_subquantizers: usize,
28 pub pq_centroids: usize,
30 pub ivf_iterations: usize,
32 pub pq_iterations: usize,
34 pub metric: DistanceMetric,
36}
37
38impl Default for IvfPqConfig {
39 fn default() -> Self {
40 Self {
41 n_clusters: 256,
42 n_probe: 8,
43 pq_subquantizers: 8,
44 pq_centroids: 256,
45 ivf_iterations: 20,
46 pq_iterations: 10,
47 metric: DistanceMetric::Euclidean,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct IvfPqSearchResult {
55 pub id: String,
56 pub score: f32,
57 pub cluster_id: usize,
58}
59
60#[derive(Debug, Clone)]
62struct PqEntry {
63 id: String,
64 codes: Vec<u8>,
66}
67
68pub struct IvfPqIndex {
70 config: IvfPqConfig,
71 dimension: Option<usize>,
72 centroids: Vec<Vec<f32>>,
74 pq: Option<ProductQuantizer>,
76 inverted_lists: Vec<RwLock<Vec<PqEntry>>>,
78 trained: bool,
80}
81
82impl IvfPqIndex {
83 pub fn new(config: IvfPqConfig) -> Self {
85 Self {
86 config,
87 dimension: None,
88 centroids: Vec::new(),
89 pq: None,
90 inverted_lists: Vec::new(),
91 trained: false,
92 }
93 }
94
95 pub fn is_trained(&self) -> bool {
97 self.trained
98 }
99
100 pub fn dimension(&self) -> Option<usize> {
102 self.dimension
103 }
104
105 pub fn stats(&self) -> IvfPqStats {
107 let mut list_sizes = Vec::with_capacity(self.inverted_lists.len());
108 let mut total_vectors = 0usize;
109
110 for list in &self.inverted_lists {
111 let size = list.read().len();
112 list_sizes.push(size);
113 total_vectors += size;
114 }
115
116 let avg_list_size = if list_sizes.is_empty() {
117 0.0
118 } else {
119 total_vectors as f64 / list_sizes.len() as f64
120 };
121
122 let max_list_size = list_sizes.iter().copied().max().unwrap_or(0);
123 let min_list_size = list_sizes.iter().copied().min().unwrap_or(0);
124
125 let centroid_memory = self.centroids.len() * self.dimension.unwrap_or(0) * 4;
127 let pq_memory = self
128 .pq
129 .as_ref()
130 .map(|pq| {
131 pq.config.num_subquantizers
132 * pq.config.num_centroids
133 * (self.dimension.unwrap_or(0) / pq.config.num_subquantizers)
134 * 4
135 })
136 .unwrap_or(0);
137 let codes_memory = total_vectors * self.config.pq_subquantizers;
138
139 IvfPqStats {
140 n_clusters: self.centroids.len(),
141 total_vectors,
142 avg_list_size,
143 max_list_size,
144 min_list_size,
145 trained: self.trained,
146 dimension: self.dimension,
147 memory_bytes: centroid_memory + pq_memory + codes_memory,
148 }
149 }
150
151 pub fn train(&mut self, vectors: &[Vector]) -> Result<(), String> {
158 if vectors.is_empty() {
159 return Err("Cannot train on empty vector set".to_string());
160 }
161
162 let dim = vectors[0].values.len();
163 self.dimension = Some(dim);
164
165 for v in vectors {
167 if v.values.len() != dim {
168 return Err(format!(
169 "Dimension mismatch: expected {}, got {}",
170 dim,
171 v.values.len()
172 ));
173 }
174 }
175
176 let n_clusters = self.config.n_clusters.min(vectors.len());
178 self.centroids = self.kmeans_train(vectors, n_clusters)?;
179
180 self.inverted_lists = (0..n_clusters).map(|_| RwLock::new(Vec::new())).collect();
182
183 let mut residuals = Vec::with_capacity(vectors.len());
185 for v in vectors {
186 let (cluster_id, _) = self.find_nearest_centroid(&v.values);
187 let residual = self.compute_residual(&v.values, cluster_id);
188 residuals.push(Vector {
189 id: v.id.clone(),
190 values: residual,
191 metadata: None,
192 ttl_seconds: None,
193 expires_at: None,
194 });
195 }
196
197 let pq_config = PQConfig {
199 num_subquantizers: self.config.pq_subquantizers,
200 num_centroids: self.config.pq_centroids,
201 kmeans_iterations: self.config.pq_iterations,
202 distance_metric: self.config.metric,
203 };
204
205 let mut pq = ProductQuantizer::new(pq_config, dim)?;
206 pq.train(&residuals)?;
207 self.pq = Some(pq);
208
209 self.trained = true;
210 Ok(())
211 }
212
213 pub fn add(&self, vectors: &[Vector]) -> Result<usize, String> {
215 if !self.trained {
216 return Err("Index must be trained before adding vectors".to_string());
217 }
218
219 let pq = self.pq.as_ref().ok_or("PQ not initialized")?;
220 let dim = self.dimension.ok_or("Dimension not set")?;
221
222 let mut added = 0;
223 for v in vectors {
224 if v.values.len() != dim {
225 continue;
226 }
227
228 let (cluster_id, _) = self.find_nearest_centroid(&v.values);
230
231 let residual = self.compute_residual(&v.values, cluster_id);
233
234 let codes = pq.encode(&residual)?;
236
237 let entry = PqEntry {
239 id: v.id.clone(),
240 codes,
241 };
242 self.inverted_lists[cluster_id].write().push(entry);
243 added += 1;
244 }
245
246 Ok(added)
247 }
248
249 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<IvfPqSearchResult>, String> {
251 if !self.trained {
252 return Err("Index must be trained before searching".to_string());
253 }
254
255 let pq = self.pq.as_ref().ok_or("PQ not initialized")?;
256 let dim = self.dimension.ok_or("Dimension not set")?;
257
258 if query.len() != dim {
259 return Err(format!(
260 "Query dimension {} doesn't match index dimension {}",
261 query.len(),
262 dim
263 ));
264 }
265
266 let n_probe = self.config.n_probe.min(self.centroids.len());
268 let nearest_clusters = self.find_nearest_centroids(query, n_probe);
269
270 let mut candidates: Vec<IvfPqSearchResult> = Vec::new();
272
273 for (cluster_id, _) in nearest_clusters {
274 let query_residual = self.compute_residual(query, cluster_id);
276
277 let distance_table = pq.compute_distance_table(&query_residual)?;
279
280 let list = self.inverted_lists[cluster_id].read();
282 for entry in list.iter() {
283 let score = pq.compute_distance_adc(&distance_table, &entry.codes);
286
287 candidates.push(IvfPqSearchResult {
288 id: entry.id.clone(),
289 score, cluster_id,
291 });
292 }
293 }
294
295 candidates.sort_by(|a, b| {
297 b.score
298 .partial_cmp(&a.score)
299 .unwrap_or(std::cmp::Ordering::Equal)
300 });
301 candidates.truncate(k);
302
303 Ok(candidates)
304 }
305
306 fn find_nearest_centroid(&self, vector: &[f32]) -> (usize, f32) {
308 let mut best_idx = 0;
309 let mut best_dist = f32::MAX;
310
311 for (idx, centroid) in self.centroids.iter().enumerate() {
312 let dist = euclidean_distance(vector, centroid);
313 if dist < best_dist {
314 best_dist = dist;
315 best_idx = idx;
316 }
317 }
318
319 (best_idx, best_dist)
320 }
321
322 fn find_nearest_centroids(&self, vector: &[f32], n: usize) -> Vec<(usize, f32)> {
324 let mut distances: Vec<(usize, f32)> = self
325 .centroids
326 .iter()
327 .enumerate()
328 .map(|(idx, centroid)| (idx, euclidean_distance(vector, centroid)))
329 .collect();
330
331 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
332 distances.truncate(n);
333 distances
334 }
335
336 fn compute_residual(&self, vector: &[f32], cluster_id: usize) -> Vec<f32> {
338 let centroid = &self.centroids[cluster_id];
339 vector
340 .iter()
341 .zip(centroid.iter())
342 .map(|(v, c)| v - c)
343 .collect()
344 }
345
346 fn kmeans_train(&self, vectors: &[Vector], k: usize) -> Result<Vec<Vec<f32>>, String> {
348 if vectors.is_empty() || k == 0 {
349 return Err("Invalid input for k-means".to_string());
350 }
351
352 let dim = vectors[0].values.len();
353 let mut rng = rand::thread_rng();
354
355 let mut indices: Vec<usize> = (0..vectors.len()).collect();
357 indices.shuffle(&mut rng);
358 let mut centroids: Vec<Vec<f32>> = indices
359 .iter()
360 .take(k)
361 .map(|&i| vectors[i].values.clone())
362 .collect();
363
364 while centroids.len() < k {
366 centroids.push(vec![0.0; dim]);
367 }
368
369 for _ in 0..self.config.ivf_iterations {
371 let mut assignments: HashMap<usize, Vec<usize>> = HashMap::new();
373 for cluster_id in 0..k {
374 assignments.insert(cluster_id, Vec::new());
375 }
376
377 for (vec_idx, v) in vectors.iter().enumerate() {
378 let mut best_cluster = 0;
379 let mut best_dist = f32::MAX;
380
381 for (cluster_id, centroid) in centroids.iter().enumerate() {
382 let dist = euclidean_distance(&v.values, centroid);
383 if dist < best_dist {
384 best_dist = dist;
385 best_cluster = cluster_id;
386 }
387 }
388
389 if let Some(members) = assignments.get_mut(&best_cluster) {
390 members.push(vec_idx);
391 }
392 }
393
394 let mut converged = true;
396 for (cluster_id, member_indices) in &assignments {
397 if member_indices.is_empty() {
398 continue;
399 }
400
401 let mut new_centroid = vec![0.0; dim];
402 for &idx in member_indices {
403 for (j, val) in vectors[idx].values.iter().enumerate() {
404 new_centroid[j] += val;
405 }
406 }
407 for val in &mut new_centroid {
408 *val /= member_indices.len() as f32;
409 }
410
411 let diff = euclidean_distance(¢roids[*cluster_id], &new_centroid);
413 if diff > 1e-4 {
414 converged = false;
415 }
416
417 centroids[*cluster_id] = new_centroid;
418 }
419
420 if converged {
421 break;
422 }
423 }
424
425 Ok(centroids)
426 }
427}
428
429#[derive(Debug, Clone)]
431pub struct IvfPqStats {
432 pub n_clusters: usize,
433 pub total_vectors: usize,
434 pub avg_list_size: f64,
435 pub max_list_size: usize,
436 pub min_list_size: usize,
437 pub trained: bool,
438 pub dimension: Option<usize>,
439 pub memory_bytes: usize,
440}
441
442fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
444 a.iter()
445 .zip(b.iter())
446 .map(|(x, y)| (x - y).powi(2))
447 .sum::<f32>()
448 .sqrt()
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 fn create_test_vectors(n: usize, dim: usize) -> Vec<Vector> {
456 use rand::Rng;
457 use rand::SeedableRng;
458
459 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
461
462 (0..n)
463 .map(|i| {
464 let values: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
465 Vector {
466 id: format!("v{}", i),
467 values,
468 metadata: None,
469 ttl_seconds: None,
470 expires_at: None,
471 }
472 })
473 .collect()
474 }
475
476 #[test]
477 fn test_ivfpq_creation() {
478 let config = IvfPqConfig::default();
479 let index = IvfPqIndex::new(config);
480
481 assert!(!index.is_trained());
482 assert_eq!(index.dimension(), None);
483 }
484
485 #[test]
486 fn test_ivfpq_training() {
487 let config = IvfPqConfig {
488 n_clusters: 4,
489 n_probe: 2,
490 pq_subquantizers: 2,
491 pq_centroids: 8,
492 ivf_iterations: 5,
493 pq_iterations: 5,
494 metric: DistanceMetric::Euclidean,
495 };
496
497 let mut index = IvfPqIndex::new(config);
498 let vectors = create_test_vectors(100, 16);
499
500 let result = index.train(&vectors);
501 assert!(result.is_ok(), "Training failed: {:?}", result.err());
502 assert!(index.is_trained());
503 assert_eq!(index.dimension(), Some(16));
504 }
505
506 #[test]
507 fn test_ivfpq_add_and_search() {
508 let config = IvfPqConfig {
509 n_clusters: 4,
510 n_probe: 4, pq_subquantizers: 2,
512 pq_centroids: 8,
513 ivf_iterations: 5,
514 pq_iterations: 5,
515 metric: DistanceMetric::Euclidean,
516 };
517
518 let mut index = IvfPqIndex::new(config);
519 let vectors = create_test_vectors(100, 16);
520
521 index.train(&vectors).unwrap();
523
524 let added = index.add(&vectors).unwrap();
526 assert_eq!(added, 100);
527
528 let query = &vectors[0].values;
530 let results = index.search(query, 10).unwrap();
531
532 assert!(!results.is_empty(), "Results should not be empty");
533
534 let found_self = results.iter().any(|r| r.id == "v0");
536 assert!(
537 found_self,
538 "Query vector should be found in results. Got: {:?}",
539 results.iter().map(|r| &r.id).collect::<Vec<_>>()
540 );
541 }
542
543 #[test]
544 fn test_ivfpq_stats() {
545 let config = IvfPqConfig {
546 n_clusters: 4,
547 n_probe: 2,
548 pq_subquantizers: 2,
549 pq_centroids: 8,
550 ivf_iterations: 5,
551 pq_iterations: 5,
552 metric: DistanceMetric::Euclidean,
553 };
554
555 let mut index = IvfPqIndex::new(config);
556 let vectors = create_test_vectors(100, 16);
557
558 index.train(&vectors).unwrap();
559 index.add(&vectors).unwrap();
560
561 let stats = index.stats();
562 assert_eq!(stats.n_clusters, 4);
563 assert_eq!(stats.total_vectors, 100);
564 assert!(stats.trained);
565 assert_eq!(stats.dimension, Some(16));
566 assert!(stats.memory_bytes > 0);
567 }
568
569 #[test]
570 fn test_ivfpq_search_quality() {
571 let config = IvfPqConfig {
572 n_clusters: 8,
573 n_probe: 8, pq_subquantizers: 4,
575 pq_centroids: 16,
576 ivf_iterations: 10,
577 pq_iterations: 10,
578 metric: DistanceMetric::Euclidean,
579 };
580
581 let mut index = IvfPqIndex::new(config);
582 let vectors = create_test_vectors(200, 32);
583
584 index.train(&vectors).unwrap();
585 index.add(&vectors).unwrap();
586
587 let mut total_recall = 0.0;
589 let test_queries = 10;
590
591 for i in 0..test_queries {
592 let query = &vectors[i * 10].values;
593 let results = index.search(query, 20).unwrap();
594
595 let expected_id = format!("v{}", i * 10);
597 if results.iter().any(|r| r.id == expected_id) {
598 total_recall += 1.0;
599 }
600 }
601
602 let recall = total_recall / test_queries as f32;
603 assert!(
604 recall >= 0.5,
605 "Recall should be at least 50%, got {}%",
606 recall * 100.0
607 );
608 }
609
610 #[test]
611 fn test_ivfpq_empty_search() {
612 let config = IvfPqConfig {
613 n_clusters: 4,
614 n_probe: 2,
615 pq_subquantizers: 2,
616 pq_centroids: 8,
617 ivf_iterations: 5,
618 pq_iterations: 5,
619 metric: DistanceMetric::Euclidean,
620 };
621
622 let mut index = IvfPqIndex::new(config);
623 let vectors = create_test_vectors(50, 16);
624
625 index.train(&vectors).unwrap();
626 let query = &vectors[0].values;
629 let results = index.search(query, 5).unwrap();
630
631 assert!(results.is_empty());
632 }
633
634 #[test]
635 fn test_ivfpq_untrained_error() {
636 let index = IvfPqIndex::new(IvfPqConfig::default());
637
638 let result = index.search(&[0.0; 128], 5);
639 assert!(result.is_err());
640 assert!(result.unwrap_err().contains("trained"));
641 }
642}