1use anyhow::{anyhow, bail, Result};
49use serde::{Deserialize, Serialize};
50use std::cmp::Reverse;
51use std::collections::BinaryHeap;
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct PqConfig {
60 pub dimension: usize,
62 pub num_sub_vectors: usize,
64 pub num_centroids: usize,
66 pub training_iterations: usize,
68 pub num_probes: usize,
70}
71
72impl Default for PqConfig {
73 fn default() -> Self {
74 Self {
75 dimension: 128,
76 num_sub_vectors: 8,
77 num_centroids: 256,
78 training_iterations: 20,
79 num_probes: 0,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
90struct SubCodebook {
91 centroids: Vec<Vec<f32>>,
93 sub_dim: usize,
95}
96
97impl SubCodebook {
98 fn new(sub_dim: usize, num_centroids: usize) -> Self {
99 Self {
100 centroids: vec![vec![0.0; sub_dim]; num_centroids],
101 sub_dim,
102 }
103 }
104
105 fn encode(&self, sub_vec: &[f32]) -> u16 {
107 let mut best_idx = 0u16;
108 let mut best_dist = f32::MAX;
109 for (i, centroid) in self.centroids.iter().enumerate() {
110 let dist = l2_sq(sub_vec, centroid);
111 if dist < best_dist {
112 best_dist = dist;
113 best_idx = i as u16;
114 }
115 }
116 best_idx
117 }
118
119 fn decode(&self, code: u16) -> &[f32] {
121 &self.centroids[code as usize]
122 }
123
124 fn build_distance_table(&self, query_sub: &[f32]) -> Vec<f32> {
126 self.centroids.iter().map(|c| l2_sq(query_sub, c)).collect()
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
136struct PqCode {
137 id: u64,
138 codes: Vec<u16>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct ProductQuantizationIndex {
148 config: PqConfig,
149 codebooks: Vec<SubCodebook>,
150 entries: Vec<PqCode>,
151 trained: bool,
152 sub_dim: usize,
153}
154
155impl ProductQuantizationIndex {
156 pub fn new(config: PqConfig) -> Result<Self> {
158 if config.dimension == 0 {
159 bail!("dimension must be > 0");
160 }
161 if config.num_sub_vectors == 0 {
162 bail!("num_sub_vectors must be > 0");
163 }
164 if config.dimension % config.num_sub_vectors != 0 {
165 bail!(
166 "dimension ({}) must be divisible by num_sub_vectors ({})",
167 config.dimension,
168 config.num_sub_vectors
169 );
170 }
171 if config.num_centroids == 0 || config.num_centroids > 65536 {
172 bail!("num_centroids must be in 1..=65536");
173 }
174 let sub_dim = config.dimension / config.num_sub_vectors;
175 let codebooks = (0..config.num_sub_vectors)
176 .map(|_| SubCodebook::new(sub_dim, config.num_centroids))
177 .collect();
178 Ok(Self {
179 config,
180 codebooks,
181 entries: Vec::new(),
182 trained: false,
183 sub_dim,
184 })
185 }
186
187 pub fn train(&mut self, training_data: &[Vec<f32>]) -> Result<()> {
189 if training_data.is_empty() {
190 bail!("training data is empty");
191 }
192 for (i, v) in training_data.iter().enumerate() {
193 if v.len() != self.config.dimension {
194 bail!(
195 "training vector {i} has dimension {} but expected {}",
196 v.len(),
197 self.config.dimension
198 );
199 }
200 }
201
202 for m in 0..self.config.num_sub_vectors {
203 let start = m * self.sub_dim;
204 let end = start + self.sub_dim;
205
206 let sub_vectors: Vec<Vec<f32>> = training_data
208 .iter()
209 .map(|v| v[start..end].to_vec())
210 .collect();
211
212 let centroids = kmeans(
214 &sub_vectors,
215 self.config.num_centroids,
216 self.config.training_iterations,
217 self.sub_dim,
218 );
219 self.codebooks[m].centroids = centroids;
220 }
221
222 self.trained = true;
223 Ok(())
224 }
225
226 pub fn is_trained(&self) -> bool {
228 self.trained
229 }
230
231 pub fn add(&mut self, id: u64, vector: &[f32]) -> Result<()> {
233 if !self.trained {
234 bail!("index must be trained before adding vectors");
235 }
236 if vector.len() != self.config.dimension {
237 bail!(
238 "vector dimension {} != expected {}",
239 vector.len(),
240 self.config.dimension
241 );
242 }
243
244 let mut codes = Vec::with_capacity(self.config.num_sub_vectors);
245 for m in 0..self.config.num_sub_vectors {
246 let start = m * self.sub_dim;
247 let end = start + self.sub_dim;
248 let code = self.codebooks[m].encode(&vector[start..end]);
249 codes.push(code);
250 }
251
252 self.entries.push(PqCode { id, codes });
253 Ok(())
254 }
255
256 pub fn len(&self) -> usize {
258 self.entries.len()
259 }
260
261 pub fn is_empty(&self) -> bool {
263 self.entries.is_empty()
264 }
265
266 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
269 if !self.trained {
270 bail!("index must be trained before searching");
271 }
272 if query.len() != self.config.dimension {
273 bail!(
274 "query dimension {} != expected {}",
275 query.len(),
276 self.config.dimension
277 );
278 }
279 if k == 0 {
280 return Ok(Vec::new());
281 }
282
283 let distance_tables: Vec<Vec<f32>> = (0..self.config.num_sub_vectors)
285 .map(|m| {
286 let start = m * self.sub_dim;
287 let end = start + self.sub_dim;
288 self.codebooks[m].build_distance_table(&query[start..end])
289 })
290 .collect();
291
292 let mut heap: BinaryHeap<Reverse<(OrderedF32, u64)>> = BinaryHeap::new();
294 for entry in &self.entries {
295 let mut dist = 0.0f32;
296 for (m, code) in entry.codes.iter().enumerate() {
297 dist += distance_tables[m][*code as usize];
298 }
299 heap.push(Reverse((OrderedF32(dist), entry.id)));
300 }
301
302 let mut results = Vec::with_capacity(k.min(heap.len()));
303 for _ in 0..k {
304 if let Some(Reverse((OrderedF32(d), id))) = heap.pop() {
305 results.push((id, d));
306 } else {
307 break;
308 }
309 }
310 Ok(results)
311 }
312
313 pub fn reconstruct(&self, id: u64) -> Result<Vec<f32>> {
315 let entry = self
316 .entries
317 .iter()
318 .find(|e| e.id == id)
319 .ok_or_else(|| anyhow!("id {id} not found in index"))?;
320
321 let mut vector = Vec::with_capacity(self.config.dimension);
322 for (m, code) in entry.codes.iter().enumerate() {
323 vector.extend_from_slice(self.codebooks[m].decode(*code));
324 }
325 Ok(vector)
326 }
327
328 pub fn clear(&mut self) {
330 self.entries.clear();
331 }
332
333 pub fn config(&self) -> &PqConfig {
335 &self.config
336 }
337
338 pub fn compression_ratio(&self) -> f64 {
340 if self.entries.is_empty() {
341 return 0.0;
342 }
343 let original_bytes = self.config.dimension * 4; let encoded_bytes = self.config.num_sub_vectors * 2; original_bytes as f64 / encoded_bytes as f64
346 }
347}
348
349fn kmeans(data: &[Vec<f32>], k: usize, iterations: usize, dim: usize) -> Vec<Vec<f32>> {
354 let actual_k = k.min(data.len());
355 let mut centroids: Vec<Vec<f32>> = data.iter().take(actual_k).cloned().collect();
357 while centroids.len() < k {
359 centroids.push(vec![0.0; dim]);
360 }
361
362 for _ in 0..iterations {
363 let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
365 for (idx, point) in data.iter().enumerate() {
366 let mut best = 0;
367 let mut best_dist = f32::MAX;
368 for (c, centroid) in centroids.iter().enumerate() {
369 let d = l2_sq(point, centroid);
370 if d < best_dist {
371 best_dist = d;
372 best = c;
373 }
374 }
375 assignments[best].push(idx);
376 }
377
378 for (c, assigned) in assignments.iter().enumerate() {
380 if assigned.is_empty() {
381 continue;
382 }
383 let mut new_centroid = vec![0.0f32; dim];
384 for &idx in assigned {
385 for (d, val) in data[idx].iter().enumerate() {
386 new_centroid[d] += val;
387 }
388 }
389 let count = assigned.len() as f32;
390 for val in &mut new_centroid {
391 *val /= count;
392 }
393 centroids[c] = new_centroid;
394 }
395 }
396
397 centroids
398}
399
400fn l2_sq(a: &[f32], b: &[f32]) -> f32 {
405 a.iter()
406 .zip(b.iter())
407 .map(|(x, y)| {
408 let d = x - y;
409 d * d
410 })
411 .sum()
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
416struct OrderedF32(f32);
417
418impl Eq for OrderedF32 {}
419
420impl PartialOrd for OrderedF32 {
421 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
422 Some(self.cmp(other))
423 }
424}
425
426impl Ord for OrderedF32 {
427 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
428 self.0
429 .partial_cmp(&other.0)
430 .unwrap_or(std::cmp::Ordering::Equal)
431 }
432}
433
434#[cfg(test)]
439mod tests {
440 use super::*;
441
442 fn default_config(dim: usize, m: usize, k: usize) -> PqConfig {
443 PqConfig {
444 dimension: dim,
445 num_sub_vectors: m,
446 num_centroids: k,
447 training_iterations: 5,
448 num_probes: 0,
449 }
450 }
451
452 fn make_training_data(n: usize, dim: usize) -> Vec<Vec<f32>> {
453 (0..n)
454 .map(|i| (0..dim).map(|d| (i * dim + d) as f32 * 0.1).collect())
455 .collect()
456 }
457
458 fn trained_index(dim: usize, m: usize, k: usize) -> ProductQuantizationIndex {
459 let config = default_config(dim, m, k);
460 let mut idx = ProductQuantizationIndex::new(config).expect("new");
461 let data = make_training_data(k.max(4), dim);
462 idx.train(&data).expect("train");
463 idx
464 }
465
466 #[test]
469 fn test_new_valid_config() {
470 let idx = ProductQuantizationIndex::new(default_config(8, 4, 4));
471 assert!(idx.is_ok());
472 }
473
474 #[test]
475 fn test_new_zero_dimension() {
476 let config = PqConfig {
477 dimension: 0,
478 ..Default::default()
479 };
480 assert!(ProductQuantizationIndex::new(config).is_err());
481 }
482
483 #[test]
484 fn test_new_zero_sub_vectors() {
485 let config = PqConfig {
486 num_sub_vectors: 0,
487 ..Default::default()
488 };
489 assert!(ProductQuantizationIndex::new(config).is_err());
490 }
491
492 #[test]
493 fn test_new_indivisible_dimension() {
494 let config = PqConfig {
495 dimension: 7,
496 num_sub_vectors: 4,
497 ..Default::default()
498 };
499 assert!(ProductQuantizationIndex::new(config).is_err());
500 }
501
502 #[test]
503 fn test_new_zero_centroids() {
504 let config = PqConfig {
505 num_centroids: 0,
506 ..Default::default()
507 };
508 assert!(ProductQuantizationIndex::new(config).is_err());
509 }
510
511 #[test]
514 fn test_train_sets_trained_flag() {
515 let mut idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
516 assert!(!idx.is_trained());
517 let data = make_training_data(10, 8);
518 idx.train(&data).expect("train");
519 assert!(idx.is_trained());
520 }
521
522 #[test]
523 fn test_train_empty_data_fails() {
524 let mut idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
525 assert!(idx.train(&[]).is_err());
526 }
527
528 #[test]
529 fn test_train_wrong_dimension_fails() {
530 let mut idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
531 let data = vec![vec![1.0, 2.0]]; assert!(idx.train(&data).is_err());
533 }
534
535 #[test]
538 fn test_add_before_training_fails() {
539 let mut idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
540 assert!(idx.add(0, &[1.0; 8]).is_err());
541 }
542
543 #[test]
544 fn test_add_wrong_dimension_fails() {
545 let mut idx = trained_index(8, 4, 4);
546 assert!(idx.add(0, &[1.0; 4]).is_err());
547 }
548
549 #[test]
550 fn test_add_and_len() {
551 let mut idx = trained_index(8, 4, 4);
552 assert!(idx.is_empty());
553 idx.add(0, &[1.0; 8]).expect("add");
554 assert_eq!(idx.len(), 1);
555 idx.add(1, &[2.0; 8]).expect("add");
556 assert_eq!(idx.len(), 2);
557 }
558
559 #[test]
562 fn test_search_before_training_fails() {
563 let idx = ProductQuantizationIndex::new(default_config(8, 4, 4)).expect("new");
564 assert!(idx.search(&[1.0; 8], 1).is_err());
565 }
566
567 #[test]
568 fn test_search_wrong_dimension_fails() {
569 let idx = trained_index(8, 4, 4);
570 assert!(idx.search(&[1.0; 4], 1).is_err());
571 }
572
573 #[test]
574 fn test_search_k_zero_returns_empty() {
575 let idx = trained_index(8, 4, 4);
576 let results = idx.search(&[1.0; 8], 0).expect("search");
577 assert!(results.is_empty());
578 }
579
580 #[test]
581 fn test_search_empty_index_returns_empty() {
582 let idx = trained_index(8, 4, 4);
583 let results = idx.search(&[1.0; 8], 5).expect("search");
584 assert!(results.is_empty());
585 }
586
587 #[test]
588 fn test_search_finds_nearest() {
589 let mut idx = trained_index(8, 4, 4);
590 let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
591 let v2 = vec![100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700.0, 800.0];
592 idx.add(10, &v1).expect("add");
593 idx.add(20, &v2).expect("add");
594
595 let query = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
596 let results = idx.search(&query, 1).expect("search");
597 assert_eq!(results.len(), 1);
598 assert_eq!(results[0].0, 10);
599 }
600
601 #[test]
602 fn test_search_returns_sorted_by_distance() {
603 let mut idx = trained_index(8, 4, 4);
604 let v1 = vec![1.0; 8];
605 let v2 = vec![2.0; 8];
606 let v3 = vec![10.0; 8];
607 idx.add(1, &v1).expect("add");
608 idx.add(2, &v2).expect("add");
609 idx.add(3, &v3).expect("add");
610
611 let results = idx.search(&[1.0; 8], 3).expect("search");
612 assert_eq!(results.len(), 3);
613 assert!(results[0].1 <= results[1].1);
615 assert!(results[1].1 <= results[2].1);
616 }
617
618 #[test]
619 fn test_search_k_larger_than_index() {
620 let mut idx = trained_index(8, 4, 4);
621 idx.add(1, &[1.0; 8]).expect("add");
622 let results = idx.search(&[1.0; 8], 100).expect("search");
623 assert_eq!(results.len(), 1);
624 }
625
626 #[test]
629 fn test_reconstruct_existing_id() {
630 let mut idx = trained_index(8, 4, 4);
631 let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
632 idx.add(42, &v).expect("add");
633 let reconstructed = idx.reconstruct(42).expect("reconstruct");
634 assert_eq!(reconstructed.len(), 8);
635 }
636
637 #[test]
638 fn test_reconstruct_missing_id() {
639 let idx = trained_index(8, 4, 4);
640 assert!(idx.reconstruct(999).is_err());
641 }
642
643 #[test]
646 fn test_clear() {
647 let mut idx = trained_index(8, 4, 4);
648 idx.add(1, &[1.0; 8]).expect("add");
649 assert_eq!(idx.len(), 1);
650 idx.clear();
651 assert!(idx.is_empty());
652 assert!(idx.is_trained());
654 }
655
656 #[test]
659 fn test_compression_ratio_empty() {
660 let idx = trained_index(8, 4, 4);
661 assert_eq!(idx.compression_ratio(), 0.0);
662 }
663
664 #[test]
665 fn test_compression_ratio_non_empty() {
666 let mut idx = trained_index(8, 4, 4);
667 idx.add(0, &[1.0; 8]).expect("add");
668 let ratio = idx.compression_ratio();
669 assert!((ratio - 4.0).abs() < 1e-6);
671 }
672
673 #[test]
676 fn test_config_accessor() {
677 let idx = ProductQuantizationIndex::new(default_config(16, 4, 8)).expect("new");
678 assert_eq!(idx.config().dimension, 16);
679 assert_eq!(idx.config().num_sub_vectors, 4);
680 }
681
682 #[test]
685 fn test_default_config() {
686 let config = PqConfig::default();
687 assert_eq!(config.dimension, 128);
688 assert_eq!(config.num_sub_vectors, 8);
689 assert_eq!(config.num_centroids, 256);
690 }
691
692 #[test]
695 fn test_kmeans_basic() {
696 let data = vec![
697 vec![1.0, 0.0],
698 vec![1.1, 0.0],
699 vec![0.0, 1.0],
700 vec![0.0, 1.1],
701 ];
702 let centroids = kmeans(&data, 2, 10, 2);
703 assert_eq!(centroids.len(), 2);
704 }
705
706 #[test]
707 fn test_kmeans_more_k_than_data() {
708 let data = vec![vec![1.0], vec![2.0]];
709 let centroids = kmeans(&data, 5, 3, 1);
710 assert_eq!(centroids.len(), 5);
711 }
712
713 #[test]
716 fn test_l2_sq_identical() {
717 assert_eq!(l2_sq(&[1.0, 2.0], &[1.0, 2.0]), 0.0);
718 }
719
720 #[test]
721 fn test_l2_sq_known() {
722 let dist = l2_sq(&[3.0, 4.0], &[1.0, 1.0]);
724 assert!((dist - 13.0).abs() < 1e-6);
725 }
726
727 #[test]
730 fn test_ordered_f32_ordering() {
731 let a = OrderedF32(1.0);
732 let b = OrderedF32(2.0);
733 assert!(a < b);
734 }
735
736 #[test]
739 fn test_multi_add_and_search() {
740 let mut idx = trained_index(8, 4, 4);
741 for i in 0..20_u64 {
742 let v: Vec<f32> = (0..8).map(|d| (i * 8 + d) as f32).collect();
743 idx.add(i, &v).expect("add");
744 }
745 assert_eq!(idx.len(), 20);
746 let results = idx.search(&[0.0; 8], 5).expect("search");
747 assert_eq!(results.len(), 5);
748 }
749
750 #[test]
751 fn test_retrain_resets_codebooks() {
752 let mut idx = trained_index(8, 4, 4);
753 idx.add(0, &[1.0; 8]).expect("add");
754 let data2 = make_training_data(10, 8);
755 idx.train(&data2).expect("retrain");
756 assert_eq!(idx.len(), 1);
758 }
759
760 #[test]
763 fn test_single_dimension_subvectors() {
764 let config = default_config(4, 4, 2);
765 let mut idx = ProductQuantizationIndex::new(config).expect("new");
766 let data = vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]];
767 idx.train(&data).expect("train");
768 idx.add(0, &[1.0, 2.0, 3.0, 4.0]).expect("add");
769 let results = idx.search(&[1.0, 2.0, 3.0, 4.0], 1).expect("search");
770 assert_eq!(results.len(), 1);
771 }
772
773 #[test]
774 fn test_single_centroid_perfect_encode() {
775 let config = default_config(4, 2, 1);
776 let mut idx = ProductQuantizationIndex::new(config).expect("new");
777 let data = vec![vec![1.0, 2.0, 3.0, 4.0]];
778 idx.train(&data).expect("train");
779 idx.add(0, &[1.0, 2.0, 3.0, 4.0]).expect("add");
780 let recon = idx.reconstruct(0).expect("reconstruct");
781 assert_eq!(recon.len(), 4);
783 }
784
785 #[test]
786 fn test_large_dimension() {
787 let config = default_config(64, 8, 4);
788 let mut idx = ProductQuantizationIndex::new(config).expect("new");
789 let data = make_training_data(10, 64);
790 idx.train(&data).expect("train");
791 idx.add(0, &vec![0.5; 64]).expect("add");
792 let results = idx.search(&[0.5; 64], 1).expect("search");
793 assert_eq!(results.len(), 1);
794 }
795}