1#[derive(Debug, Clone)]
6pub struct IvfPqConfig {
7 pub nlist: usize,
9 pub m: usize,
11 pub k_per_sub: usize,
13 pub nprobe: usize,
15 pub dimension: usize,
17}
18
19impl IvfPqConfig {
20 pub fn validate(&self) -> Result<(), IvfPqError> {
22 if self.m == 0 {
23 return Err(IvfPqError::InvalidConfig("m must be > 0".to_string()));
24 }
25 if self.dimension == 0 {
26 return Err(IvfPqError::InvalidConfig(
27 "dimension must be > 0".to_string(),
28 ));
29 }
30 if self.dimension % self.m != 0 {
31 return Err(IvfPqError::InvalidConfig(format!(
32 "dimension ({}) must be divisible by m ({})",
33 self.dimension, self.m
34 )));
35 }
36 if self.nlist == 0 {
37 return Err(IvfPqError::InvalidConfig("nlist must be > 0".to_string()));
38 }
39 if self.nprobe == 0 {
40 return Err(IvfPqError::InvalidConfig("nprobe must be > 0".to_string()));
41 }
42 if self.nprobe > self.nlist {
43 return Err(IvfPqError::InvalidConfig(format!(
44 "nprobe ({}) must be <= nlist ({})",
45 self.nprobe, self.nlist
46 )));
47 }
48 if self.k_per_sub == 0 {
49 return Err(IvfPqError::InvalidConfig(
50 "k_per_sub must be > 0".to_string(),
51 ));
52 }
53 Ok(())
54 }
55}
56
57#[derive(Debug)]
59pub enum IvfPqError {
60 DimensionMismatch { expected: usize, got: usize },
61 NotTrained,
62 InvalidConfig(String),
63 InsufficientData(String),
64}
65
66impl std::fmt::Display for IvfPqError {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 match self {
69 IvfPqError::DimensionMismatch { expected, got } => {
70 write!(f, "Dimension mismatch: expected {expected}, got {got}")
71 }
72 IvfPqError::NotTrained => write!(f, "Index has not been trained yet"),
73 IvfPqError::InvalidConfig(msg) => write!(f, "Invalid config: {msg}"),
74 IvfPqError::InsufficientData(msg) => write!(f, "Insufficient data: {msg}"),
75 }
76 }
77}
78
79impl std::error::Error for IvfPqError {}
80
81pub struct IvfPqIndex {
83 config: IvfPqConfig,
84 coarse_centroids: Vec<Vec<f64>>,
86 inverted_lists: Vec<Vec<(u64, Vec<u8>)>>,
88 pq_codebook: Vec<Vec<Vec<f64>>>,
90 is_trained: bool,
91 next_id: u64,
92}
93
94impl IvfPqIndex {
95 pub fn new(config: IvfPqConfig) -> Result<Self, IvfPqError> {
97 config.validate()?;
98 let nlist = config.nlist;
99 Ok(Self {
100 config,
101 coarse_centroids: Vec::new(),
102 inverted_lists: vec![Vec::new(); nlist],
103 pq_codebook: Vec::new(),
104 is_trained: false,
105 next_id: 0,
106 })
107 }
108
109 pub fn train(&mut self, vectors: &[Vec<f64>]) -> Result<(), IvfPqError> {
111 if vectors.is_empty() {
112 return Err(IvfPqError::InsufficientData(
113 "Need at least 1 vector to train".to_string(),
114 ));
115 }
116 let n = vectors.len();
117 let dim = self.config.dimension;
118
119 for v in vectors.iter() {
121 if v.len() != dim {
122 return Err(IvfPqError::DimensionMismatch {
123 expected: dim,
124 got: v.len(),
125 });
126 }
127 }
128
129 let nlist = self.config.nlist.min(n); let m = self.config.m;
131 let k_per_sub = self.config.k_per_sub;
132 let sub_dim = dim / m;
133
134 self.coarse_centroids = Self::kmeans(vectors, nlist, dim, 10);
136
137 let residuals: Vec<Vec<f64>> = vectors
140 .iter()
141 .map(|v| {
142 let nearest = self.find_nearest_centroid_trained(v);
143 let centroid = &self.coarse_centroids[nearest];
144 v.iter().zip(centroid.iter()).map(|(a, b)| a - b).collect()
145 })
146 .collect();
147
148 let mut pq_codebook = Vec::with_capacity(m);
150 for sub_idx in 0..m {
151 let start = sub_idx * sub_dim;
152 let end = start + sub_dim;
153 let sub_data: Vec<Vec<f64>> =
154 residuals.iter().map(|r| r[start..end].to_vec()).collect();
155 let k = k_per_sub.min(sub_data.len());
156 let centroids = Self::kmeans(&sub_data, k, sub_dim, 5);
157 pq_codebook.push(centroids);
158 }
159 self.pq_codebook = pq_codebook;
160 self.is_trained = true;
161
162 let actual_nlist = self.coarse_centroids.len();
164 self.inverted_lists = vec![Vec::new(); actual_nlist];
165 Ok(())
166 }
167
168 pub fn add(&mut self, vector: &[f64]) -> Result<u64, IvfPqError> {
170 if !self.is_trained {
171 return Err(IvfPqError::NotTrained);
172 }
173 let dim = self.config.dimension;
174 if vector.len() != dim {
175 return Err(IvfPqError::DimensionMismatch {
176 expected: dim,
177 got: vector.len(),
178 });
179 }
180 let cluster_idx = self.find_nearest_centroid(vector);
181 let centroid = &self.coarse_centroids[cluster_idx];
182 let residual: Vec<f64> = vector
183 .iter()
184 .zip(centroid.iter())
185 .map(|(a, b)| a - b)
186 .collect();
187 let codes = self.encode_residual(&residual);
188 let id = self.next_id;
189 self.next_id += 1;
190 self.inverted_lists[cluster_idx].push((id, codes));
191 Ok(id)
192 }
193
194 pub fn add_batch(&mut self, vectors: &[Vec<f64>]) -> Result<Vec<u64>, IvfPqError> {
196 vectors.iter().map(|v| self.add(v)).collect()
197 }
198
199 pub fn search(&self, query: &[f64], k: usize) -> Result<Vec<(u64, f64)>, IvfPqError> {
203 if !self.is_trained {
204 return Err(IvfPqError::NotTrained);
205 }
206 let dim = self.config.dimension;
207 if query.len() != dim {
208 return Err(IvfPqError::DimensionMismatch {
209 expected: dim,
210 got: query.len(),
211 });
212 }
213
214 let nprobe = self.config.nprobe.min(self.coarse_centroids.len());
216 let mut centroid_dists: Vec<(usize, f64)> = self
217 .coarse_centroids
218 .iter()
219 .enumerate()
220 .map(|(i, c)| (i, Self::l2_distance(query, c)))
221 .collect();
222 centroid_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
223
224 let sub_dim = dim / self.config.m;
226 let m = self.config.m;
227
228 let mut candidates: Vec<(u64, f64)> = Vec::new();
229
230 for &(cluster_idx, _) in centroid_dists.iter().take(nprobe) {
231 let centroid = &self.coarse_centroids[cluster_idx];
232 let residual: Vec<f64> = query
233 .iter()
234 .zip(centroid.iter())
235 .map(|(a, b)| a - b)
236 .collect();
237
238 let dist_tables: Vec<Vec<f64>> = (0..m)
240 .map(|sub_idx| {
241 let start = sub_idx * sub_dim;
242 let q_sub = &residual[start..start + sub_dim];
243 self.pq_codebook[sub_idx]
244 .iter()
245 .map(|code_centroid| Self::l2_distance(q_sub, code_centroid))
246 .collect()
247 })
248 .collect();
249
250 for &(id, ref codes) in &self.inverted_lists[cluster_idx] {
251 let dist: f64 = codes
253 .iter()
254 .enumerate()
255 .map(|(sub_idx, &code)| {
256 let code_idx = code as usize;
257 dist_tables[sub_idx]
258 .get(code_idx)
259 .copied()
260 .unwrap_or(f64::MAX)
261 })
262 .sum();
263 candidates.push((id, dist));
264 }
265 }
266
267 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
269 candidates.truncate(k);
270 Ok(candidates)
271 }
272
273 pub fn size(&self) -> usize {
275 self.inverted_lists.iter().map(|l| l.len()).sum()
276 }
277
278 pub fn is_trained(&self) -> bool {
280 self.is_trained
281 }
282
283 pub fn find_nearest_centroid(&self, vector: &[f64]) -> usize {
285 self.find_nearest_centroid_trained(vector)
286 }
287
288 fn find_nearest_centroid_trained(&self, vector: &[f64]) -> usize {
289 let mut best_idx = 0;
290 let mut best_dist = f64::MAX;
291 for (i, centroid) in self.coarse_centroids.iter().enumerate() {
292 let d = Self::l2_distance(vector, centroid);
293 if d < best_dist {
294 best_dist = d;
295 best_idx = i;
296 }
297 }
298 best_idx
299 }
300
301 pub fn encode_residual(&self, residual: &[f64]) -> Vec<u8> {
303 let sub_dim = self.config.dimension / self.config.m;
304 let m = self.config.m;
305 let mut codes = Vec::with_capacity(m);
306 for sub_idx in 0..m {
307 let start = sub_idx * sub_dim;
308 let sub = &residual[start..start + sub_dim];
309 let mut best_code = 0u8;
311 let mut best_dist = f64::MAX;
312 for (code_idx, centroid) in self.pq_codebook[sub_idx].iter().enumerate() {
313 let d = Self::l2_distance(sub, centroid);
314 if d < best_dist {
315 best_dist = d;
316 best_code = (code_idx & 0xFF) as u8;
317 }
318 }
319 codes.push(best_code);
320 }
321 codes
322 }
323
324 pub fn l2_distance(a: &[f64], b: &[f64]) -> f64 {
326 a.iter()
327 .zip(b.iter())
328 .map(|(x, y)| (x - y).powi(2))
329 .sum::<f64>()
330 }
331
332 pub fn kmeans(data: &[Vec<f64>], k: usize, dim: usize, iters: usize) -> Vec<Vec<f64>> {
334 if data.is_empty() || k == 0 {
335 return Vec::new();
336 }
337 let k = k.min(data.len());
338
339 let mut centroids: Vec<Vec<f64>> =
341 (0..k).map(|i| data[i * data.len() / k].clone()).collect();
342
343 for _ in 0..iters {
344 let mut clusters: Vec<Vec<usize>> = vec![Vec::new(); k];
346 for (idx, point) in data.iter().enumerate() {
347 let best = centroids
348 .iter()
349 .enumerate()
350 .map(|(ci, c)| (ci, Self::l2_distance(point, c)))
351 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
352 .map(|(ci, _)| ci)
353 .unwrap_or(0);
354 clusters[best].push(idx);
355 }
356
357 let mut new_centroids = Vec::with_capacity(k);
359 for (ci, members) in clusters.iter().enumerate() {
360 if members.is_empty() {
361 new_centroids.push(centroids[ci].clone());
363 } else {
364 let mut centroid = vec![0.0f64; dim];
365 for &idx in members {
366 for (d, val) in centroid.iter_mut().zip(data[idx].iter()) {
367 *d += val;
368 }
369 }
370 let count = members.len() as f64;
371 for d in centroid.iter_mut() {
372 *d /= count;
373 }
374 new_centroids.push(centroid);
375 }
376 }
377 centroids = new_centroids;
378 }
379 centroids
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
386 use super::*;
387
388 fn make_config(dim: usize, nlist: usize, m: usize, k: usize, nprobe: usize) -> IvfPqConfig {
389 IvfPqConfig {
390 nlist,
391 m,
392 k_per_sub: k,
393 nprobe,
394 dimension: dim,
395 }
396 }
397
398 fn make_random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f64>> {
399 let mut state = seed;
401 (0..n)
402 .map(|_| {
403 (0..dim)
404 .map(|_| {
405 state = state
406 .wrapping_mul(6364136223846793005)
407 .wrapping_add(1442695040888963407);
408 ((state >> 33) as f64) / (u32::MAX as f64) * 2.0 - 1.0
409 })
410 .collect()
411 })
412 .collect()
413 }
414
415 #[test]
418 fn test_new_valid_config() {
419 let config = make_config(8, 4, 2, 4, 2);
420 assert!(IvfPqIndex::new(config).is_ok());
421 }
422
423 #[test]
424 fn test_new_m_not_divides_dimension() {
425 let config = make_config(7, 4, 3, 4, 2); assert!(matches!(
427 IvfPqIndex::new(config),
428 Err(IvfPqError::InvalidConfig(_))
429 ));
430 }
431
432 #[test]
433 fn test_new_m_zero() {
434 let config = make_config(8, 4, 0, 4, 2);
435 assert!(matches!(
436 IvfPqIndex::new(config),
437 Err(IvfPqError::InvalidConfig(_))
438 ));
439 }
440
441 #[test]
442 fn test_new_nlist_zero() {
443 let config = make_config(8, 0, 2, 4, 2);
444 assert!(matches!(
445 IvfPqIndex::new(config),
446 Err(IvfPqError::InvalidConfig(_))
447 ));
448 }
449
450 #[test]
451 fn test_new_nprobe_gt_nlist() {
452 let config = make_config(8, 2, 2, 4, 5); assert!(matches!(
454 IvfPqIndex::new(config),
455 Err(IvfPqError::InvalidConfig(_))
456 ));
457 }
458
459 #[test]
460 fn test_new_dimension_zero() {
461 let config = make_config(0, 4, 0, 4, 2);
462 assert!(matches!(
463 IvfPqIndex::new(config),
464 Err(IvfPqError::InvalidConfig(_))
465 ));
466 }
467
468 #[test]
471 fn test_not_trained_initially() -> Result<()> {
472 let config = make_config(8, 4, 2, 4, 2);
473 let index = IvfPqIndex::new(config)?;
474 assert!(!index.is_trained());
475 Ok(())
476 }
477
478 #[test]
479 fn test_train_basic() -> Result<()> {
480 let config = make_config(8, 4, 2, 4, 2);
481 let mut index = IvfPqIndex::new(config)?;
482 let vectors = make_random_vectors(20, 8, 42);
483 index.train(&vectors)?;
484 assert!(index.is_trained());
485 Ok(())
486 }
487
488 #[test]
489 fn test_train_too_few_vectors() -> Result<()> {
490 let config = make_config(8, 4, 2, 4, 2);
491 let mut index = IvfPqIndex::new(config)?;
492 let result = index.train(&[]);
494 assert!(matches!(result, Err(IvfPqError::InsufficientData(_))));
495 Ok(())
496 }
497
498 #[test]
499 fn test_train_dimension_mismatch() -> Result<()> {
500 let config = make_config(8, 4, 2, 4, 2);
501 let mut index = IvfPqIndex::new(config)?;
502 let vectors = vec![vec![1.0, 2.0, 3.0]]; let result = index.train(&vectors);
504 assert!(matches!(result, Err(IvfPqError::DimensionMismatch { .. })));
505 Ok(())
506 }
507
508 #[test]
511 fn test_add_before_training_error() -> Result<()> {
512 let config = make_config(8, 4, 2, 4, 2);
513 let mut index = IvfPqIndex::new(config)?;
514 let v = vec![0.0; 8];
515 let result = index.add(&v);
516 assert!(matches!(result, Err(IvfPqError::NotTrained)));
517 Ok(())
518 }
519
520 #[test]
521 fn test_add_after_training() -> Result<()> {
522 let config = make_config(8, 4, 2, 4, 2);
523 let mut index = IvfPqIndex::new(config)?;
524 let vectors = make_random_vectors(20, 8, 1);
525 index.train(&vectors)?;
526 let id = index.add(&vectors[0])?;
527 assert_eq!(id, 0);
528 assert_eq!(index.size(), 1);
529 Ok(())
530 }
531
532 #[test]
533 fn test_add_dimension_mismatch() -> Result<()> {
534 let config = make_config(8, 4, 2, 4, 2);
535 let mut index = IvfPqIndex::new(config)?;
536 let vectors = make_random_vectors(20, 8, 2);
537 index.train(&vectors)?;
538 let bad_v = vec![1.0, 2.0]; let result = index.add(&bad_v);
540 assert!(matches!(result, Err(IvfPqError::DimensionMismatch { .. })));
541 Ok(())
542 }
543
544 #[test]
547 fn test_add_batch() -> Result<()> {
548 let config = make_config(8, 4, 2, 4, 2);
549 let mut index = IvfPqIndex::new(config)?;
550 let train_data = make_random_vectors(20, 8, 3);
551 index.train(&train_data)?;
552 let add_data = make_random_vectors(5, 8, 4);
553 let ids = index.add_batch(&add_data)?;
554 assert_eq!(ids.len(), 5);
555 assert_eq!(index.size(), 5);
556 Ok(())
557 }
558
559 #[test]
562 fn test_size_starts_at_zero() -> Result<()> {
563 let config = make_config(8, 4, 2, 4, 2);
564 let mut index = IvfPqIndex::new(config)?;
565 let vectors = make_random_vectors(20, 8, 5);
566 index.train(&vectors)?;
567 assert_eq!(index.size(), 0);
568 Ok(())
569 }
570
571 #[test]
572 fn test_size_after_adding() -> Result<()> {
573 let config = make_config(8, 4, 2, 4, 2);
574 let mut index = IvfPqIndex::new(config)?;
575 let vectors = make_random_vectors(20, 8, 6);
576 index.train(&vectors)?;
577 for v in &vectors {
578 index.add(v)?;
579 }
580 assert_eq!(index.size(), 20);
581 Ok(())
582 }
583
584 #[test]
587 fn test_search_before_training_error() -> Result<()> {
588 let config = make_config(8, 4, 2, 4, 2);
589 let index = IvfPqIndex::new(config)?;
590 let q = vec![0.0; 8];
591 let result = index.search(&q, 5);
592 assert!(matches!(result, Err(IvfPqError::NotTrained)));
593 Ok(())
594 }
595
596 #[test]
597 fn test_search_empty_index() -> Result<()> {
598 let config = make_config(8, 4, 2, 4, 2);
599 let mut index = IvfPqIndex::new(config)?;
600 let vectors = make_random_vectors(20, 8, 7);
601 index.train(&vectors)?;
602 let q = vec![0.0; 8];
603 let results = index.search(&q, 5)?;
604 assert!(results.is_empty());
605 Ok(())
606 }
607
608 #[test]
609 fn test_search_returns_k_results() -> Result<()> {
610 let config = make_config(8, 4, 2, 4, 2);
611 let mut index = IvfPqIndex::new(config)?;
612 let vectors = make_random_vectors(50, 8, 8);
613 index.train(&vectors)?;
614 for v in &vectors {
615 index.add(v)?;
616 }
617 let q = vec![0.0; 8];
618 let results = index.search(&q, 10)?;
619 assert!(results.len() <= 10);
620 assert!(!results.is_empty());
621 Ok(())
622 }
623
624 #[test]
625 fn test_search_sorted_by_distance() -> Result<()> {
626 let config = make_config(8, 4, 2, 4, 2);
627 let mut index = IvfPqIndex::new(config)?;
628 let vectors = make_random_vectors(30, 8, 9);
629 index.train(&vectors)?;
630 for v in &vectors {
631 index.add(v)?;
632 }
633 let q = vec![0.0; 8];
634 let results = index.search(&q, 10)?;
635 for i in 1..results.len() {
636 assert!(
637 results[i - 1].1 <= results[i].1,
638 "Results not sorted: {} > {}",
639 results[i - 1].1,
640 results[i].1
641 );
642 }
643 Ok(())
644 }
645
646 #[test]
647 fn test_search_dimension_mismatch() -> Result<()> {
648 let config = make_config(8, 4, 2, 4, 2);
649 let mut index = IvfPqIndex::new(config)?;
650 let vectors = make_random_vectors(20, 8, 10);
651 index.train(&vectors)?;
652 let bad_q = vec![1.0, 2.0]; let result = index.search(&bad_q, 5);
654 assert!(matches!(result, Err(IvfPqError::DimensionMismatch { .. })));
655 Ok(())
656 }
657
658 #[test]
661 fn test_l2_distance_zero() {
662 let a = vec![1.0, 2.0, 3.0];
663 assert!(IvfPqIndex::l2_distance(&a, &a) < 1e-10);
664 }
665
666 #[test]
667 fn test_l2_distance_known() {
668 let a = vec![0.0, 0.0];
669 let b = vec![3.0, 4.0];
670 let d = IvfPqIndex::l2_distance(&a, &b);
671 assert!((d - 25.0).abs() < 1e-10); }
673
674 #[test]
677 fn test_error_display() {
678 let e = IvfPqError::DimensionMismatch {
679 expected: 8,
680 got: 4,
681 };
682 assert!(format!("{e}").contains("8"));
683 let e2 = IvfPqError::NotTrained;
684 assert!(!format!("{e2}").is_empty());
685 let e3 = IvfPqError::InvalidConfig("m".to_string());
686 assert!(format!("{e3}").contains("m"));
687 let e4 = IvfPqError::InsufficientData("need more".to_string());
688 assert!(format!("{e4}").contains("need more"));
689 }
690
691 #[test]
694 fn test_config_validation_valid() {
695 let config = make_config(8, 4, 2, 4, 2);
696 assert!(config.validate().is_ok());
697 }
698
699 #[test]
700 fn test_config_validation_k_per_sub_zero() {
701 let config = IvfPqConfig {
702 nlist: 4,
703 m: 2,
704 k_per_sub: 0,
705 nprobe: 2,
706 dimension: 8,
707 };
708 assert!(matches!(
709 config.validate(),
710 Err(IvfPqError::InvalidConfig(_))
711 ));
712 }
713}