1#[derive(Debug, Clone)]
6pub struct IvfPqConfig {
7 pub nlist: usize,
9 pub m: usize,
11 pub k_per_sub: usize,
13 pub nprobe: usize,
15 pub dimension: usize,
17}
18
19impl IvfPqConfig {
20 pub fn validate(&self) -> Result<(), IvfPqError> {
22 if self.m == 0 {
23 return Err(IvfPqError::InvalidConfig("m must be > 0".to_string()));
24 }
25 if self.dimension == 0 {
26 return Err(IvfPqError::InvalidConfig(
27 "dimension must be > 0".to_string(),
28 ));
29 }
30 if self.dimension % self.m != 0 {
31 return Err(IvfPqError::InvalidConfig(format!(
32 "dimension ({}) must be divisible by m ({})",
33 self.dimension, self.m
34 )));
35 }
36 if self.nlist == 0 {
37 return Err(IvfPqError::InvalidConfig("nlist must be > 0".to_string()));
38 }
39 if self.nprobe == 0 {
40 return Err(IvfPqError::InvalidConfig("nprobe must be > 0".to_string()));
41 }
42 if self.nprobe > self.nlist {
43 return Err(IvfPqError::InvalidConfig(format!(
44 "nprobe ({}) must be <= nlist ({})",
45 self.nprobe, self.nlist
46 )));
47 }
48 if self.k_per_sub == 0 {
49 return Err(IvfPqError::InvalidConfig(
50 "k_per_sub must be > 0".to_string(),
51 ));
52 }
53 Ok(())
54 }
55}
56
57#[derive(Debug)]
59pub enum IvfPqError {
60 DimensionMismatch { expected: usize, got: usize },
61 NotTrained,
62 InvalidConfig(String),
63 InsufficientData(String),
64}
65
66impl std::fmt::Display for IvfPqError {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 match self {
69 IvfPqError::DimensionMismatch { expected, got } => {
70 write!(f, "Dimension mismatch: expected {expected}, got {got}")
71 }
72 IvfPqError::NotTrained => write!(f, "Index has not been trained yet"),
73 IvfPqError::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
74 IvfPqError::InsufficientData(msg) => write!(f, "Insufficient data: {msg}"),
75 }
76 }
77}
78
79impl std::error::Error for IvfPqError {}
80
81pub struct IvfPqIndex {
83 config: IvfPqConfig,
84 coarse_centroids: Vec<Vec<f64>>,
86 inverted_lists: Vec<Vec<(u64, Vec<u8>)>>,
88 pq_codebook: Vec<Vec<Vec<f64>>>,
90 is_trained: bool,
91 next_id: u64,
92}
93
94impl IvfPqIndex {
95 pub fn new(config: IvfPqConfig) -> Result<Self, IvfPqError> {
97 config.validate()?;
98 let nlist = config.nlist;
99 Ok(Self {
100 config,
101 coarse_centroids: Vec::new(),
102 inverted_lists: vec![Vec::new(); nlist],
103 pq_codebook: Vec::new(),
104 is_trained: false,
105 next_id: 0,
106 })
107 }
108
109 pub fn train(&mut self, vectors: &[Vec<f64>]) -> Result<(), IvfPqError> {
111 if vectors.is_empty() {
112 return Err(IvfPqError::InsufficientData(
113 "Need at least 1 vector to train".to_string(),
114 ));
115 }
116 let n = vectors.len();
117 let dim = self.config.dimension;
118
119 for v in vectors.iter() {
121 if v.len() != dim {
122 return Err(IvfPqError::DimensionMismatch {
123 expected: dim,
124 got: v.len(),
125 });
126 }
127 }
128
129 let nlist = self.config.nlist.min(n); let m = self.config.m;
131 let k_per_sub = self.config.k_per_sub;
132 let sub_dim = dim / m;
133
134 self.coarse_centroids = Self::kmeans(vectors, nlist, dim, 10);
136
137 let residuals: Vec<Vec<f64>> = vectors
140 .iter()
141 .map(|v| {
142 let nearest = self.find_nearest_centroid_trained(v);
143 let centroid = &self.coarse_centroids[nearest];
144 v.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect()
145 })
146 .collect();
147
148 let mut pq_codebook = Vec::with_capacity(m);
150 for sub_idx in 0..m {
151 let start = sub_idx * sub_dim;
152 let end = start + sub_dim;
153 let sub_data: Vec<Vec<f64>> =
154 residuals.iter().map(|r| r[start..end].to_vec()).collect();
155 let k = k_per_sub.min(sub_data.len());
156 let centroids = Self::kmeans(&sub_data, k, sub_dim, 5);
157 pq_codebook.push(centroids);
158 }
159 self.pq_codebook = pq_codebook;
160 self.is_trained = true;
161
162 let actual_nlist = self.coarse_centroids.len();
164 self.inverted_lists = vec![Vec::new(); actual_nlist];
165 Ok(())
166 }
167
168 pub fn add(&mut self, vector: &[f64]) -> Result<u64, IvfPqError> {
170 if !self.is_trained {
171 return Err(IvfPqError::NotTrained);
172 }
173 let dim = self.config.dimension;
174 if vector.len() != dim {
175 return Err(IvfPqError::DimensionMismatch {
176 expected: dim,
177 got: vector.len(),
178 });
179 }
180 let cluster_idx = self.find_nearest_centroid(vector);
181 let centroid = &self.coarse_centroids[cluster_idx];
182 let residual: Vec<f64> = vector
183 .iter()
184 .zip(centroid.iter())
185 .map(|(a, b)| a - b)
186 .collect();
187 let codes = self.encode_residual(&residual);
188 let id = self.next_id;
189 self.next_id += 1;
190 self.inverted_lists[cluster_idx].push((id, codes));
191 Ok(id)
192 }
193
194 pub fn add_batch(&mut self, vectors: &[Vec<f64>]) -> Result<Vec<u64>, IvfPqError> {
196 vectors.iter().map(|v| self.add(v)).collect()
197 }
198
199 pub fn search(&self, query: &[f64], k: usize) -> Result<Vec<(u64, f64)>, IvfPqError> {
203 if !self.is_trained {
204 return Err(IvfPqError::NotTrained);
205 }
206 let dim = self.config.dimension;
207 if query.len() != dim {
208 return Err(IvfPqError::DimensionMismatch {
209 expected: dim,
210 got: query.len(),
211 });
212 }
213
214 let nprobe = self.config.nprobe.min(self.coarse_centroids.len());
216 let mut centroid_dists: Vec<(usize, f64)> = self
217 .coarse_centroids
218 .iter()
219 .enumerate()
220 .map(|(i, c)| (i, Self::l2_distance(query, c)))
221 .collect();
222 centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
223
224 let sub_dim = dim / self.config.m;
226 let m = self.config.m;
227
228 let mut candidates: Vec<(u64, f64)> = Vec::new();
229
230 for &(cluster_idx, _) in centroid_dists.iter().take(nprobe) {
231 let centroid = &self.coarse_centroids[cluster_idx];
232 let residual: Vec<f64> = query
233 .iter()
234 .zip(centroid.iter())
235 .map(|(a, b)| a - b)
236 .collect();
237
238 let dist_tables: Vec<Vec<f64>> = (0..m)
240 .map(|sub_idx| {
241 let start = sub_idx * sub_dim;
242 let q_sub = &residual[start..start + sub_dim];
243 self.pq_codebook[sub_idx]
244 .iter()
245 .map(|code_centroid| Self::l2_distance(q_sub, code_centroid))
246 .collect()
247 })
248 .collect();
249
250 for &(id, ref codes) in &self.inverted_lists[cluster_idx] {
251 let dist: f64 = codes
253 .iter()
254 .enumerate()
255 .map(|(sub_idx, &code)| {
256 let code_idx = code as usize;
257 dist_tables[sub_idx]
258 .get(code_idx)
259 .copied()
260 .unwrap_or(f64::MAX)
261 })
262 .sum();
263 candidates.push((id, dist));
264 }
265 }
266
267 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
269 candidates.truncate(k);
270 Ok(candidates)
271 }
272
273 pub fn size(&self) -> usize {
275 self.inverted_lists.iter().map(|l| l.len()).sum()
276 }
277
278 pub fn is_trained(&self) -> bool {
280 self.is_trained
281 }
282
283 pub fn find_nearest_centroid(&self, vector: &[f64]) -> usize {
285 self.find_nearest_centroid_trained(vector)
286 }
287
288 fn find_nearest_centroid_trained(&self, vector: &[f64]) -> usize {
289 let mut best_idx = 0;
290 let mut best_dist = f64::MAX;
291 for (i, centroid) in self.coarse_centroids.iter().enumerate() {
292 let d = Self::l2_distance(vector, centroid);
293 if d < best_dist {
294 best_dist = d;
295 best_idx = i;
296 }
297 }
298 best_idx
299 }
300
301 pub fn encode_residual(&self, residual: &[f64]) -> Vec<u8> {
303 let sub_dim = self.config.dimension / self.config.m;
304 let m = self.config.m;
305 let mut codes = Vec::with_capacity(m);
306 for sub_idx in 0..m {
307 let start = sub_idx * sub_dim;
308 let sub = &residual[start..start + sub_dim];
309 let mut best_code = 0u8;
311 let mut best_dist = f64::MAX;
312 for (code_idx, centroid) in self.pq_codebook[sub_idx].iter().enumerate() {
313 let d = Self::l2_distance(sub, centroid);
314 if d < best_dist {
315 best_dist = d;
316 best_code = (code_idx & 0xFF) as u8;
317 }
318 }
319 codes.push(best_code);
320 }
321 codes
322 }
323
324 pub fn l2_distance(a: &[f64], b: &[f64]) -> f64 {
326 a.iter()
327 .zip(b.iter())
328 .map(|(x, y)| (x - y).powi(2))
329 .sum::<f64>()
330 }
331
332 pub fn kmeans(data: &[Vec<f64>], k: usize, dim: usize, iters: usize) -> Vec<Vec<f64>> {
334 if data.is_empty() || k == 0 {
335 return Vec::new();
336 }
337 let k = k.min(data.len());
338
339 let mut centroids: Vec<Vec<f64>> =
341 (0..k).map(|i| data[i * data.len() / k].clone()).collect();
342
343 for _ in 0..iters {
344 let mut clusters: Vec<Vec<usize>> = vec![Vec::new(); k];
346 for (idx, point) in data.iter().enumerate() {
347 let best = centroids
348 .iter()
349 .enumerate()
350 .map(|(ci, c)| (ci, Self::l2_distance(point, c)))
351 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
352 .map(|(ci, _)| ci)
353 .unwrap_or(0);
354 clusters[best].push(idx);
355 }
356
357 let mut new_centroids = Vec::with_capacity(k);
359 for (ci, members) in clusters.iter().enumerate() {
360 if members.is_empty() {
361 new_centroids.push(centroids[ci].clone());
363 } else {
364 let mut centroid = vec![0.0f64; dim];
365 for &idx in members {
366 for (d, val) in centroid.iter_mut().zip(data[idx].iter()) {
367 *d += val;
368 }
369 }
370 let count = members.len() as f64;
371 for d in centroid.iter_mut() {
372 *d /= count;
373 }
374 new_centroids.push(centroid);
375 }
376 }
377 centroids = new_centroids;
378 }
379 centroids
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 fn make_config(dim: usize, nlist: usize, m: usize, k: usize, nprobe: usize) -> IvfPqConfig {
388 IvfPqConfig {
389 nlist,
390 m,
391 k_per_sub: k,
392 nprobe,
393 dimension: dim,
394 }
395 }
396
397 fn make_random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f64>> {
398 let mut state = seed;
400 (0..n)
401 .map(|_| {
402 (0..dim)
403 .map(|_| {
404 state = state
405 .wrapping_mul(6364136223846793005)
406 .wrapping_add(1442695040888963407);
407 ((state >> 33) as f64) / (u32::MAX as f64) * 2.0 - 1.0
408 })
409 .collect()
410 })
411 .collect()
412 }
413
414 #[test]
417 fn test_new_valid_config() {
418 let config = make_config(8, 4, 2, 4, 2);
419 assert!(IvfPqIndex::new(config).is_ok());
420 }
421
422 #[test]
423 fn test_new_m_not_divides_dimension() {
424 let config = make_config(7, 4, 3, 4, 2); assert!(matches!(
426 IvfPqIndex::new(config),
427 Err(IvfPqError::InvalidConfig(_))
428 ));
429 }
430
431 #[test]
432 fn test_new_m_zero() {
433 let config = make_config(8, 4, 0, 4, 2);
434 assert!(matches!(
435 IvfPqIndex::new(config),
436 Err(IvfPqError::InvalidConfig(_))
437 ));
438 }
439
440 #[test]
441 fn test_new_nlist_zero() {
442 let config = make_config(8, 0, 2, 4, 2);
443 assert!(matches!(
444 IvfPqIndex::new(config),
445 Err(IvfPqError::InvalidConfig(_))
446 ));
447 }
448
449 #[test]
450 fn test_new_nprobe_gt_nlist() {
451 let config = make_config(8, 2, 2, 4, 5); assert!(matches!(
453 IvfPqIndex::new(config),
454 Err(IvfPqError::InvalidConfig(_))
455 ));
456 }
457
458 #[test]
459 fn test_new_dimension_zero() {
460 let config = make_config(0, 4, 0, 4, 2);
461 assert!(matches!(
462 IvfPqIndex::new(config),
463 Err(IvfPqError::InvalidConfig(_))
464 ));
465 }
466
467 #[test]
470 fn test_not_trained_initially() {
471 let config = make_config(8, 4, 2, 4, 2);
472 let index = IvfPqIndex::new(config).unwrap();
473 assert!(!index.is_trained());
474 }
475
476 #[test]
477 fn test_train_basic() {
478 let config = make_config(8, 4, 2, 4, 2);
479 let mut index = IvfPqIndex::new(config).unwrap();
480 let vectors = make_random_vectors(20, 8, 42);
481 index.train(&vectors).unwrap();
482 assert!(index.is_trained());
483 }
484
485 #[test]
486 fn test_train_too_few_vectors() {
487 let config = make_config(8, 4, 2, 4, 2);
488 let mut index = IvfPqIndex::new(config).unwrap();
489 let result = index.train(&[]);
491 assert!(matches!(result, Err(IvfPqError::InsufficientData(_))));
492 }
493
494 #[test]
495 fn test_train_dimension_mismatch() {
496 let config = make_config(8, 4, 2, 4, 2);
497 let mut index = IvfPqIndex::new(config).unwrap();
498 let vectors = vec![vec![1.0, 2.0, 3.0]]; let result = index.train(&vectors);
500 assert!(matches!(result, Err(IvfPqError::DimensionMismatch { .. })));
501 }
502
503 #[test]
506 fn test_add_before_training_error() {
507 let config = make_config(8, 4, 2, 4, 2);
508 let mut index = IvfPqIndex::new(config).unwrap();
509 let v = vec![0.0; 8];
510 let result = index.add(&v);
511 assert!(matches!(result, Err(IvfPqError::NotTrained)));
512 }
513
514 #[test]
515 fn test_add_after_training() {
516 let config = make_config(8, 4, 2, 4, 2);
517 let mut index = IvfPqIndex::new(config).unwrap();
518 let vectors = make_random_vectors(20, 8, 1);
519 index.train(&vectors).unwrap();
520 let id = index.add(&vectors[0]).unwrap();
521 assert_eq!(id, 0);
522 assert_eq!(index.size(), 1);
523 }
524
525 #[test]
526 fn test_add_dimension_mismatch() {
527 let config = make_config(8, 4, 2, 4, 2);
528 let mut index = IvfPqIndex::new(config).unwrap();
529 let vectors = make_random_vectors(20, 8, 2);
530 index.train(&vectors).unwrap();
531 let bad_v = vec![1.0, 2.0]; let result = index.add(&bad_v);
533 assert!(matches!(result, Err(IvfPqError::DimensionMismatch { .. })));
534 }
535
536 #[test]
539 fn test_add_batch() {
540 let config = make_config(8, 4, 2, 4, 2);
541 let mut index = IvfPqIndex::new(config).unwrap();
542 let train_data = make_random_vectors(20, 8, 3);
543 index.train(&train_data).unwrap();
544 let add_data = make_random_vectors(5, 8, 4);
545 let ids = index.add_batch(&add_data).unwrap();
546 assert_eq!(ids.len(), 5);
547 assert_eq!(index.size(), 5);
548 }
549
550 #[test]
553 fn test_size_starts_at_zero() {
554 let config = make_config(8, 4, 2, 4, 2);
555 let mut index = IvfPqIndex::new(config).unwrap();
556 let vectors = make_random_vectors(20, 8, 5);
557 index.train(&vectors).unwrap();
558 assert_eq!(index.size(), 0);
559 }
560
561 #[test]
562 fn test_size_after_adding() {
563 let config = make_config(8, 4, 2, 4, 2);
564 let mut index = IvfPqIndex::new(config).unwrap();
565 let vectors = make_random_vectors(20, 8, 6);
566 index.train(&vectors).unwrap();
567 for v in &vectors {
568 index.add(v).unwrap();
569 }
570 assert_eq!(index.size(), 20);
571 }
572
573 #[test]
576 fn test_search_before_training_error() {
577 let config = make_config(8, 4, 2, 4, 2);
578 let index = IvfPqIndex::new(config).unwrap();
579 let q = vec![0.0; 8];
580 let result = index.search(&q, 5);
581 assert!(matches!(result, Err(IvfPqError::NotTrained)));
582 }
583
584 #[test]
585 fn test_search_empty_index() {
586 let config = make_config(8, 4, 2, 4, 2);
587 let mut index = IvfPqIndex::new(config).unwrap();
588 let vectors = make_random_vectors(20, 8, 7);
589 index.train(&vectors).unwrap();
590 let q = vec![0.0; 8];
591 let results = index.search(&q, 5).unwrap();
592 assert!(results.is_empty());
593 }
594
595 #[test]
596 fn test_search_returns_k_results() {
597 let config = make_config(8, 4, 2, 4, 2);
598 let mut index = IvfPqIndex::new(config).unwrap();
599 let vectors = make_random_vectors(50, 8, 8);
600 index.train(&vectors).unwrap();
601 for v in &vectors {
602 index.add(v).unwrap();
603 }
604 let q = vec![0.0; 8];
605 let results = index.search(&q, 10).unwrap();
606 assert!(results.len() <= 10);
607 assert!(!results.is_empty());
608 }
609
610 #[test]
611 fn test_search_sorted_by_distance() {
612 let config = make_config(8, 4, 2, 4, 2);
613 let mut index = IvfPqIndex::new(config).unwrap();
614 let vectors = make_random_vectors(30, 8, 9);
615 index.train(&vectors).unwrap();
616 for v in &vectors {
617 index.add(v).unwrap();
618 }
619 let q = vec![0.0; 8];
620 let results = index.search(&q, 10).unwrap();
621 for i in 1..results.len() {
622 assert!(
623 results[i - 1].1 <= results[i].1,
624 "Results not sorted: {} > {}",
625 results[i - 1].1,
626 results[i].1
627 );
628 }
629 }
630
631 #[test]
632 fn test_search_dimension_mismatch() {
633 let config = make_config(8, 4, 2, 4, 2);
634 let mut index = IvfPqIndex::new(config).unwrap();
635 let vectors = make_random_vectors(20, 8, 10);
636 index.train(&vectors).unwrap();
637 let bad_q = vec![1.0, 2.0]; let result = index.search(&bad_q, 5);
639 assert!(matches!(result, Err(IvfPqError::DimensionMismatch { .. })));
640 }
641
642 #[test]
645 fn test_l2_distance_zero() {
646 let a = vec![1.0, 2.0, 3.0];
647 assert!(IvfPqIndex::l2_distance(&a, &a) < 1e-10);
648 }
649
650 #[test]
651 fn test_l2_distance_known() {
652 let a = vec![0.0, 0.0];
653 let b = vec![3.0, 4.0];
654 let d = IvfPqIndex::l2_distance(&a, &b);
655 assert!((d - 25.0).abs() < 1e-10); }
657
658 #[test]
661 fn test_error_display() {
662 let e = IvfPqError::DimensionMismatch {
663 expected: 8,
664 got: 4,
665 };
666 assert!(format!("{e}").contains("8"));
667 let e2 = IvfPqError::NotTrained;
668 assert!(!format!("{e2}").is_empty());
669 let e3 = IvfPqError::InvalidConfig("m".to_string());
670 assert!(format!("{e3}").contains("m"));
671 let e4 = IvfPqError::InsufficientData("need more".to_string());
672 assert!(format!("{e4}").contains("need more"));
673 }
674
675 #[test]
678 fn test_config_validation_valid() {
679 let config = make_config(8, 4, 2, 4, 2);
680 assert!(config.validate().is_ok());
681 }
682
683 #[test]
684 fn test_config_validation_k_per_sub_zero() {
685 let config = IvfPqConfig {
686 nlist: 4,
687 m: 2,
688 k_per_sub: 0,
689 nprobe: 2,
690 dimension: 8,
691 };
692 assert!(matches!(
693 config.validate(),
694 Err(IvfPqError::InvalidConfig(_))
695 ));
696 }
697}