1use crate::{Vector, VectorIndex};
8use anyhow::{anyhow, Result};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, PartialEq)]
13pub struct PQConfig {
14 pub n_subquantizers: usize,
16 pub n_centroids: usize,
18 pub n_bits: usize,
20 pub max_iterations: usize,
22 pub convergence_threshold: f32,
24 pub seed: Option<u64>,
26 pub enable_residual_quantization: bool,
28 pub residual_levels: usize,
30 pub enable_multi_codebook: bool,
32 pub num_codebooks: usize,
34 pub enable_symmetric_distance: bool,
36}
37
38impl Default for PQConfig {
39 fn default() -> Self {
40 Self {
41 n_subquantizers: 8,
42 n_centroids: 256,
43 n_bits: 8, max_iterations: 50,
45 convergence_threshold: 1e-4,
46 seed: None,
47 enable_residual_quantization: false,
48 residual_levels: 2,
49 enable_multi_codebook: false,
50 num_codebooks: 2,
51 enable_symmetric_distance: false,
52 }
53 }
54}
55
56impl PQConfig {
57 pub fn with_bits(n_subquantizers: usize, n_bits: usize) -> Self {
59 Self {
60 n_subquantizers,
61 n_centroids: 1 << n_bits, n_bits,
63 max_iterations: 50,
64 convergence_threshold: 1e-4,
65 seed: None,
66 enable_residual_quantization: false,
67 residual_levels: 2,
68 enable_multi_codebook: false,
69 num_codebooks: 2,
70 enable_symmetric_distance: false,
71 }
72 }
73
74 pub fn with_residual_quantization(
76 n_subquantizers: usize,
77 n_bits: usize,
78 residual_levels: usize,
79 ) -> Self {
80 Self {
81 n_subquantizers,
82 n_centroids: 1 << n_bits,
83 n_bits,
84 enable_residual_quantization: true,
85 residual_levels,
86 ..Default::default()
87 }
88 }
89
90 pub fn with_multi_codebook(
92 n_subquantizers: usize,
93 n_bits: usize,
94 num_codebooks: usize,
95 ) -> Self {
96 Self {
97 n_subquantizers,
98 n_centroids: 1 << n_bits,
99 n_bits,
100 enable_multi_codebook: true,
101 num_codebooks,
102 ..Default::default()
103 }
104 }
105
106 pub fn enhanced(n_subquantizers: usize, n_bits: usize) -> Self {
108 Self {
109 n_subquantizers,
110 n_centroids: 1 << n_bits,
111 n_bits,
112 enable_residual_quantization: true,
113 residual_levels: 2,
114 enable_multi_codebook: true,
115 num_codebooks: 2,
116 enable_symmetric_distance: true,
117 ..Default::default()
118 }
119 }
120
121 pub fn validate(&self) -> Result<()> {
123 if self.n_centroids != (1 << self.n_bits) {
124 return Err(anyhow!(
125 "n_centroids {} doesn't match 2^n_bits ({})",
126 self.n_centroids,
127 1 << self.n_bits
128 ));
129 }
130 if self.n_subquantizers == 0 {
131 return Err(anyhow!("n_subquantizers must be greater than 0"));
132 }
133 if self.n_bits == 0 || self.n_bits > 16 {
134 return Err(anyhow!("n_bits must be between 1 and 16"));
135 }
136 if self.enable_residual_quantization && self.residual_levels == 0 {
137 return Err(anyhow!(
138 "residual_levels must be greater than 0 when residual quantization is enabled"
139 ));
140 }
141 if self.enable_multi_codebook && self.num_codebooks < 2 {
142 return Err(anyhow!(
143 "num_codebooks must be at least 2 when multi-codebook quantization is enabled"
144 ));
145 }
146 Ok(())
147 }
148}
149
150#[derive(Debug, Clone)]
152struct SubQuantizer {
153 start_dim: usize,
155 end_dim: usize,
157 centroids: Vec<Vec<f32>>,
159}
160
161impl SubQuantizer {
162 fn new(start_dim: usize, end_dim: usize, n_centroids: usize) -> Self {
163 Self {
164 start_dim,
165 end_dim,
166 centroids: Vec::with_capacity(n_centroids),
167 }
168 }
169
170 fn extract_subvector(&self, vector: &[f32]) -> Vec<f32> {
172 vector[self.start_dim..self.end_dim].to_vec()
173 }
174
175 fn train(&mut self, subvectors: &[Vec<f32>], config: &PQConfig) -> Result<()> {
177 if subvectors.is_empty() {
178 return Err(anyhow!("Cannot train subquantizer with empty data"));
179 }
180
181 let dims = subvectors[0].len();
182
183 self.centroids = self.initialize_centroids_kmeans_plus_plus(subvectors, config)?;
185
186 let mut iteration = 0;
188 let mut prev_error = f32::INFINITY;
189
190 while iteration < config.max_iterations {
191 let mut clusters: Vec<Vec<&Vec<f32>>> = vec![Vec::new(); config.n_centroids];
193
194 for subvector in subvectors {
195 let nearest_idx = self.find_nearest_centroid(subvector)?;
196 clusters[nearest_idx].push(subvector);
197 }
198
199 let mut total_error = 0.0;
201 for (i, cluster) in clusters.iter().enumerate() {
202 if !cluster.is_empty() {
203 let new_centroid = self.compute_centroid(cluster, dims);
204 total_error += self.euclidean_distance(&self.centroids[i], &new_centroid);
205 self.centroids[i] = new_centroid;
206 }
207 }
208
209 if (prev_error - total_error).abs() < config.convergence_threshold {
211 break;
212 }
213
214 prev_error = total_error;
215 iteration += 1;
216 }
217
218 Ok(())
219 }
220
221 fn initialize_centroids_kmeans_plus_plus(
223 &self,
224 subvectors: &[Vec<f32>],
225 config: &PQConfig,
226 ) -> Result<Vec<Vec<f32>>> {
227 use std::collections::hash_map::DefaultHasher;
228 use std::hash::{Hash, Hasher};
229
230 let mut hasher = DefaultHasher::new();
231 config.seed.unwrap_or(42).hash(&mut hasher);
232 let mut rng_state = hasher.finish();
233
234 let mut centroids = Vec::with_capacity(config.n_centroids);
235
236 let first_idx = (rng_state as usize) % subvectors.len();
238 centroids.push(subvectors[first_idx].clone());
239
240 while centroids.len() < config.n_centroids {
242 let mut distances = Vec::with_capacity(subvectors.len());
243 let mut sum_distances = 0.0;
244
245 for subvector in subvectors {
247 let min_dist = centroids
248 .iter()
249 .map(|c| self.euclidean_distance(subvector, c))
250 .fold(f32::INFINITY, |a, b| a.min(b));
251
252 distances.push(min_dist * min_dist);
253 sum_distances += min_dist * min_dist;
254 }
255
256 rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
258 let threshold = (rng_state as f32 / u64::MAX as f32) * sum_distances;
259
260 let mut cumulative = 0.0;
261 for (i, &dist) in distances.iter().enumerate() {
262 cumulative += dist;
263 if cumulative >= threshold {
264 centroids.push(subvectors[i].clone());
265 break;
266 }
267 }
268 }
269
270 Ok(centroids)
271 }
272
273 fn compute_centroid(&self, cluster: &[&Vec<f32>], dims: usize) -> Vec<f32> {
275 if cluster.is_empty() {
276 return vec![0.0; dims];
277 }
278
279 let mut sum = vec![0.0; dims];
280 for vector in cluster {
281 for (i, &val) in vector.iter().enumerate() {
282 sum[i] += val;
283 }
284 }
285
286 let count = cluster.len() as f32;
287 for val in &mut sum {
288 *val /= count;
289 }
290
291 sum
292 }
293
294 fn find_nearest_centroid(&self, subvector: &[f32]) -> Result<usize> {
296 if self.centroids.is_empty() {
297 return Err(anyhow!("No centroids available"));
298 }
299
300 let mut min_distance = f32::INFINITY;
301 let mut nearest_idx = 0;
302
303 for (i, centroid) in self.centroids.iter().enumerate() {
304 let distance = self.euclidean_distance(subvector, centroid);
305 if distance < min_distance {
306 min_distance = distance;
307 nearest_idx = i;
308 }
309 }
310
311 Ok(nearest_idx)
312 }
313
314 fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
316 a.iter()
317 .zip(b.iter())
318 .map(|(x, y)| (x - y).powi(2))
319 .sum::<f32>()
320 .sqrt()
321 }
322
323 fn encode(&self, subvector: &[f32]) -> Result<u8> {
325 if self.centroids.len() > 256 {
326 return Err(anyhow!("Too many centroids for u8 encoding"));
327 }
328
329 let idx = self.find_nearest_centroid(subvector)?;
330 Ok(idx as u8)
331 }
332
333 fn decode(&self, code: u8) -> Result<Vec<f32>> {
335 let idx = code as usize;
336 if idx >= self.centroids.len() {
337 return Err(anyhow!("Invalid code: {}", code));
338 }
339 Ok(self.centroids[idx].clone())
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct EnhancedCodes {
346 pub primary: Vec<u8>,
348 pub residual: Vec<Vec<u8>>,
350 pub multi_codebook: Vec<Vec<u8>>,
352}
353
354#[derive(Debug, Clone)]
356pub struct PQIndex {
357 config: PQConfig,
358 subquantizers: Vec<SubQuantizer>,
360 residual_quantizers: Vec<Vec<SubQuantizer>>,
362 multi_codebook_quantizers: Vec<Vec<SubQuantizer>>,
364 codes: Vec<(String, Vec<u8>)>,
366 residual_codes: Vec<Vec<(String, Vec<u8>)>>,
368 multi_codebook_codes: Vec<Vec<(String, Vec<u8>)>>,
370 distance_tables: Option<Vec<Vec<Vec<f32>>>>,
372 uri_to_id: HashMap<String, usize>,
374 dimensions: Option<usize>,
376 is_trained: bool,
378}
379
380impl PQIndex {
381 pub fn new(config: PQConfig) -> Self {
383 Self {
384 residual_quantizers: vec![Vec::new(); config.residual_levels],
385 multi_codebook_quantizers: vec![Vec::new(); config.num_codebooks],
386 residual_codes: vec![Vec::new(); config.residual_levels],
387 multi_codebook_codes: vec![Vec::new(); config.num_codebooks],
388 distance_tables: None,
389 config,
390 subquantizers: Vec::new(),
391 codes: Vec::new(),
392 uri_to_id: HashMap::new(),
393 dimensions: None,
394 is_trained: false,
395 }
396 }
397
398 pub fn train(&mut self, training_vectors: &[Vector]) -> Result<()> {
400 if training_vectors.is_empty() {
401 return Err(anyhow!("Cannot train PQ with empty training set"));
402 }
403
404 let dims = training_vectors[0].dimensions;
406 if !training_vectors.iter().all(|v| v.dimensions == dims) {
407 return Err(anyhow!(
408 "All training vectors must have the same dimensions"
409 ));
410 }
411
412 if dims % self.config.n_subquantizers != 0 {
413 return Err(anyhow!(
414 "Vector dimensions {} must be divisible by n_subquantizers {}",
415 dims,
416 self.config.n_subquantizers
417 ));
418 }
419
420 self.dimensions = Some(dims);
421 let subdim = dims / self.config.n_subquantizers;
422
423 self.subquantizers.clear();
425 for i in 0..self.config.n_subquantizers {
426 let start = i * subdim;
427 let end = start + subdim;
428 self.subquantizers
429 .push(SubQuantizer::new(start, end, self.config.n_centroids));
430 }
431
432 let training_data: Vec<Vec<f32>> = training_vectors.iter().map(|v| v.as_f32()).collect();
434
435 for sq in self.subquantizers.iter_mut() {
437 let subvectors: Vec<Vec<f32>> = training_data
439 .iter()
440 .map(|v| sq.extract_subvector(v))
441 .collect();
442
443 sq.train(&subvectors, &self.config)?;
444 }
445
446 if self.config.enable_residual_quantization {
448 self.train_residual_quantizers(&training_data)?;
449 }
450
451 if self.config.enable_multi_codebook {
453 self.train_multi_codebook_quantizers(&training_data)?;
454 }
455
456 if self.config.enable_symmetric_distance {
458 self.build_distance_tables()?;
459 }
460
461 self.is_trained = true;
462 Ok(())
463 }
464
465 fn train_residual_quantizers(&mut self, training_data: &[Vec<f32>]) -> Result<()> {
467 let subdim = self.dimensions.unwrap() / self.config.n_subquantizers;
468
469 let mut current_residuals = training_data.to_vec();
471
472 for level in 0..self.config.residual_levels {
473 if level == 0 {
475 for (i, vector) in training_data.iter().enumerate() {
477 let primary_codes = self.encode_primary_vector(vector)?;
478 let reconstructed = self.decode_primary_codes(&primary_codes)?;
479
480 let residual: Vec<f32> = vector
482 .iter()
483 .zip(reconstructed.iter())
484 .map(|(a, b)| a - b)
485 .collect();
486 current_residuals[i] = residual;
487 }
488 } else {
489 for (i, residual) in current_residuals.clone().iter().enumerate() {
491 let residual_codes = self.encode_residual_vector(residual, level - 1)?;
492 let reconstructed_residual =
493 self.decode_residual_codes(&residual_codes, level - 1)?;
494
495 let new_residual: Vec<f32> = residual
496 .iter()
497 .zip(reconstructed_residual.iter())
498 .map(|(a, b)| a - b)
499 .collect();
500 current_residuals[i] = new_residual;
501 }
502 }
503
504 self.residual_quantizers[level].clear();
506 for i in 0..self.config.n_subquantizers {
507 let start = i * subdim;
508 let end = start + subdim;
509 self.residual_quantizers[level].push(SubQuantizer::new(
510 start,
511 end,
512 self.config.n_centroids,
513 ));
514 }
515
516 for sq in self.residual_quantizers[level].iter_mut() {
518 let subvectors: Vec<Vec<f32>> = current_residuals
519 .iter()
520 .map(|v| sq.extract_subvector(v))
521 .collect();
522
523 sq.train(&subvectors, &self.config)?;
524 }
525 }
526
527 Ok(())
528 }
529
530 fn train_multi_codebook_quantizers(&mut self, training_data: &[Vec<f32>]) -> Result<()> {
532 let subdim = self.dimensions.unwrap() / self.config.n_subquantizers;
533
534 for codebook_idx in 0..self.config.num_codebooks {
535 self.multi_codebook_quantizers[codebook_idx].clear();
537 for i in 0..self.config.n_subquantizers {
538 let start = i * subdim;
539 let end = start + subdim;
540 self.multi_codebook_quantizers[codebook_idx].push(SubQuantizer::new(
541 start,
542 end,
543 self.config.n_centroids,
544 ));
545 }
546
547 let mut modified_config = self.config.clone();
549 modified_config.seed = self.config.seed.map(|s| s + codebook_idx as u64);
550
551 for sq in self.multi_codebook_quantizers[codebook_idx].iter_mut() {
553 let subvectors: Vec<Vec<f32>> = training_data
554 .iter()
555 .map(|v| sq.extract_subvector(v))
556 .collect();
557
558 sq.train(&subvectors, &modified_config)?;
559 }
560 }
561
562 Ok(())
563 }
564
565 fn build_distance_tables(&mut self) -> Result<()> {
567 let mut tables = Vec::new();
568
569 for sq_idx in 0..self.config.n_subquantizers {
570 let sq = &self.subquantizers[sq_idx];
571 let mut sq_table = Vec::new();
572
573 for i in 0..sq.centroids.len() {
575 let mut centroid_distances = Vec::new();
576 for j in 0..sq.centroids.len() {
577 let distance = sq.euclidean_distance(&sq.centroids[i], &sq.centroids[j]);
578 centroid_distances.push(distance);
579 }
580 sq_table.push(centroid_distances);
581 }
582 tables.push(sq_table);
583 }
584
585 self.distance_tables = Some(tables);
586 Ok(())
587 }
588
589 fn encode_primary_vector(&self, vector: &[f32]) -> Result<Vec<u8>> {
591 let mut codes = Vec::with_capacity(self.subquantizers.len());
592
593 for sq in &self.subquantizers {
594 let subvec = sq.extract_subvector(vector);
595 let code = sq.encode(&subvec)?;
596 codes.push(code);
597 }
598
599 Ok(codes)
600 }
601
602 fn decode_primary_codes(&self, codes: &[u8]) -> Result<Vec<f32>> {
604 let mut reconstructed = Vec::new();
605
606 for (sq, &code) in self.subquantizers.iter().zip(codes.iter()) {
607 let subvec = sq.decode(code)?;
608 reconstructed.extend(subvec);
609 }
610
611 Ok(reconstructed)
612 }
613
614 fn encode_residual_vector(&self, vector: &[f32], level: usize) -> Result<Vec<u8>> {
616 if level >= self.residual_quantizers.len() {
617 return Err(anyhow!("Invalid residual level: {}", level));
618 }
619
620 let mut codes = Vec::with_capacity(self.residual_quantizers[level].len());
621
622 for sq in &self.residual_quantizers[level] {
623 let subvec = sq.extract_subvector(vector);
624 let code = sq.encode(&subvec)?;
625 codes.push(code);
626 }
627
628 Ok(codes)
629 }
630
631 fn decode_residual_codes(&self, codes: &[u8], level: usize) -> Result<Vec<f32>> {
633 if level >= self.residual_quantizers.len() {
634 return Err(anyhow!("Invalid residual level: {}", level));
635 }
636
637 let mut reconstructed = Vec::new();
638
639 for (sq, &code) in self.residual_quantizers[level].iter().zip(codes.iter()) {
640 let subvec = sq.decode(code)?;
641 reconstructed.extend(subvec);
642 }
643
644 Ok(reconstructed)
645 }
646
647 fn encode_vector(&self, vector: &Vector) -> Result<Vec<u8>> {
649 if !self.is_trained {
650 return Err(anyhow!("PQ index must be trained before encoding"));
651 }
652
653 let vector_f32 = vector.as_f32();
654 let mut codes = Vec::with_capacity(self.subquantizers.len());
655
656 for sq in &self.subquantizers {
657 let subvec = sq.extract_subvector(&vector_f32);
658 let code = sq.encode(&subvec)?;
659 codes.push(code);
660 }
661
662 Ok(codes)
663 }
664
665 fn decode_codes(&self, codes: &[u8]) -> Result<Vector> {
667 if codes.len() != self.subquantizers.len() {
668 return Err(anyhow!("Invalid code length"));
669 }
670
671 let mut reconstructed = Vec::new();
672
673 for (sq, &code) in self.subquantizers.iter().zip(codes.iter()) {
674 let subvec = sq.decode(code)?;
675 reconstructed.extend(subvec);
676 }
677
678 Ok(Vector::new(reconstructed))
679 }
680
681 pub fn encode(&self, vector: &Vector) -> Result<Vec<u8>> {
683 self.encode_vector(vector)
684 }
685
686 pub fn decode(&self, codes: &[u8]) -> Result<Vector> {
688 self.decode_codes(codes)
689 }
690
691 pub fn reconstruct(&self, vector: &Vector) -> Result<Vector> {
693 let codes = self.encode_vector(vector)?;
694 self.decode_codes(&codes)
695 }
696
697 fn asymmetric_distance(&self, query: &Vector, codes: &[u8]) -> Result<f32> {
699 let query_f32 = query.as_f32();
700 let mut total_distance = 0.0;
701
702 for (sq, &code) in self.subquantizers.iter().zip(codes.iter()) {
703 let query_subvec = sq.extract_subvector(&query_f32);
704 let centroid = &sq.centroids[code as usize];
705
706 let dist: f32 = query_subvec
708 .iter()
709 .zip(centroid.iter())
710 .map(|(a, b)| (a - b).powi(2))
711 .sum();
712
713 total_distance += dist;
714 }
715
716 Ok(total_distance.sqrt())
717 }
718
719 fn encode_vector_enhanced(&self, vector: &Vector) -> Result<EnhancedCodes> {
721 if !self.is_trained {
722 return Err(anyhow!("PQ index must be trained before encoding"));
723 }
724
725 let vector_f32 = vector.as_f32();
726
727 let primary_codes = self.encode_primary_vector(&vector_f32)?;
729
730 let mut residual_codes = Vec::new();
732 if self.config.enable_residual_quantization {
733 let mut current_residual = vector_f32.clone();
734
735 let primary_reconstructed = self.decode_primary_codes(&primary_codes)?;
737 current_residual = current_residual
738 .iter()
739 .zip(primary_reconstructed.iter())
740 .map(|(a, b)| a - b)
741 .collect();
742
743 for level in 0..self.config.residual_levels {
745 let level_codes = self.encode_residual_vector(¤t_residual, level)?;
746 residual_codes.push(level_codes.clone());
747
748 if level < self.config.residual_levels - 1 {
750 let level_reconstructed = self.decode_residual_codes(&level_codes, level)?;
751 current_residual = current_residual
752 .iter()
753 .zip(level_reconstructed.iter())
754 .map(|(a, b)| a - b)
755 .collect();
756 }
757 }
758 }
759
760 let mut multi_codebook_codes = Vec::new();
762 if self.config.enable_multi_codebook {
763 for codebook_idx in 0..self.config.num_codebooks {
764 let mut codes =
765 Vec::with_capacity(self.multi_codebook_quantizers[codebook_idx].len());
766
767 for sq in &self.multi_codebook_quantizers[codebook_idx] {
768 let subvec = sq.extract_subvector(&vector_f32);
769 let code = sq.encode(&subvec)?;
770 codes.push(code);
771 }
772 multi_codebook_codes.push(codes);
773 }
774 }
775
776 Ok(EnhancedCodes {
777 primary: primary_codes,
778 residual: residual_codes,
779 multi_codebook: multi_codebook_codes,
780 })
781 }
782
783 fn symmetric_distance(&self, codes1: &[u8], codes2: &[u8]) -> Result<f32> {
785 if !self.config.enable_symmetric_distance {
786 return Err(anyhow!("Symmetric distance computation not enabled"));
787 }
788
789 let distance_tables = self
790 .distance_tables
791 .as_ref()
792 .ok_or_else(|| anyhow!("Distance tables not built"))?;
793
794 if codes1.len() != codes2.len() || codes1.len() != self.config.n_subquantizers {
795 return Err(anyhow!("Invalid code lengths for symmetric distance"));
796 }
797
798 let mut total_distance = 0.0;
799
800 for (sq_idx, (&code1, &code2)) in codes1.iter().zip(codes2.iter()).enumerate() {
801 let distance = distance_tables[sq_idx][code1 as usize][code2 as usize];
802 total_distance += distance * distance; }
804
805 Ok(total_distance.sqrt())
806 }
807
808 fn enhanced_distance(&self, query: &Vector, enhanced_codes: &EnhancedCodes) -> Result<f32> {
810 let mut total_distance = self.asymmetric_distance(query, &enhanced_codes.primary)?;
812
813 if self.config.enable_residual_quantization && !enhanced_codes.residual.is_empty() {
815 let query_f32 = query.as_f32();
816 let mut current_residual = query_f32.clone();
817
818 let primary_reconstructed = self.decode_primary_codes(&enhanced_codes.primary)?;
820 current_residual = current_residual
821 .iter()
822 .zip(primary_reconstructed.iter())
823 .map(|(a, b)| a - b)
824 .collect();
825
826 for (level, residual_codes) in enhanced_codes.residual.iter().enumerate() {
828 let mut residual_distance = 0.0;
829
830 for (sq, &code) in self.residual_quantizers[level]
831 .iter()
832 .zip(residual_codes.iter())
833 {
834 let query_subvec = sq.extract_subvector(¤t_residual);
835 let centroid = &sq.centroids[code as usize];
836
837 let dist: f32 = query_subvec
838 .iter()
839 .zip(centroid.iter())
840 .map(|(a, b)| (a - b).powi(2))
841 .sum();
842
843 residual_distance += dist;
844 }
845
846 total_distance += residual_distance.sqrt() * 0.5; if level < enhanced_codes.residual.len() - 1 {
850 let level_reconstructed = self.decode_residual_codes(residual_codes, level)?;
851 current_residual = current_residual
852 .iter()
853 .zip(level_reconstructed.iter())
854 .map(|(a, b)| a - b)
855 .collect();
856 }
857 }
858 }
859
860 if self.config.enable_multi_codebook && !enhanced_codes.multi_codebook.is_empty() {
862 let mut min_codebook_distance = f32::INFINITY;
863
864 for codes in &enhanced_codes.multi_codebook {
865 let codebook_distance = self.asymmetric_distance(query, codes)?;
866 min_codebook_distance = min_codebook_distance.min(codebook_distance);
867 }
868
869 total_distance = total_distance.min(min_codebook_distance);
871 }
872
873 Ok(total_distance)
874 }
875
876 pub fn compression_ratio(&self) -> f32 {
878 if let Some(dims) = self.dimensions {
879 (dims as f32 * 4.0) / (self.config.n_subquantizers as f32)
882 } else {
883 0.0
884 }
885 }
886
887 pub fn stats(&self) -> PQStats {
889 PQStats {
890 n_vectors: self.codes.len(),
891 n_subquantizers: self.config.n_subquantizers,
892 n_centroids: self.config.n_centroids,
893 is_trained: self.is_trained,
894 dimensions: self.dimensions,
895 compression_ratio: self.compression_ratio(),
896 memory_usage_bytes: self.estimate_memory_usage(),
897 }
898 }
899
900 fn estimate_memory_usage(&self) -> usize {
902 let codebook_size = self
903 .subquantizers
904 .iter()
905 .map(|sq| sq.centroids.len() * (sq.end_dim - sq.start_dim) * 4)
906 .sum::<usize>();
907
908 let codes_size = self.codes.len() * self.config.n_subquantizers;
909
910 codebook_size + codes_size
911 }
912
913 pub fn is_trained(&self) -> bool {
915 self.is_trained
916 }
917
918 pub fn compute_distance(&self, query: &Vector, codes: &[u8]) -> Result<f32> {
920 self.asymmetric_distance(query, codes)
921 }
922
923 pub fn decode_vector(&self, codes: &[u8]) -> Result<Vector> {
925 self.decode_codes(codes)
926 }
927}
928
929impl VectorIndex for PQIndex {
930 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
931 if !self.is_trained {
932 return Err(anyhow!("PQ index must be trained before inserting vectors"));
933 }
934
935 if let Some(dims) = self.dimensions {
937 if vector.dimensions != dims {
938 return Err(anyhow!(
939 "Vector dimensions {} don't match index dimensions {}",
940 vector.dimensions,
941 dims
942 ));
943 }
944 }
945
946 let codes = self.encode_vector(&vector)?;
948
949 let id = self.codes.len();
951 self.uri_to_id.insert(uri.clone(), id);
952 self.codes.push((uri, codes));
953
954 Ok(())
955 }
956
957 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
958 if !self.is_trained {
959 return Err(anyhow!("PQ index must be trained before searching"));
960 }
961
962 let mut distances: Vec<(String, f32)> = self
964 .codes
965 .iter()
966 .map(|(uri, codes)| {
967 let dist = self
968 .asymmetric_distance(query, codes)
969 .unwrap_or(f32::INFINITY);
970 (uri.clone(), dist)
971 })
972 .collect();
973
974 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
976 distances.truncate(k);
977
978 Ok(distances
980 .into_iter()
981 .map(|(uri, dist)| (uri, 1.0 / (1.0 + dist)))
982 .collect())
983 }
984
985 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
986 if !self.is_trained {
987 return Err(anyhow!("PQ index must be trained before searching"));
988 }
989
990 let mut results = Vec::new();
991
992 for (uri, codes) in &self.codes {
993 let dist = self.asymmetric_distance(query, codes)?;
994 let similarity = 1.0 / (1.0 + dist);
995
996 if similarity >= threshold {
997 results.push((uri.clone(), similarity));
998 }
999 }
1000
1001 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1003
1004 Ok(results)
1005 }
1006
1007 fn get_vector(&self, _uri: &str) -> Option<&Vector> {
1008 None
1011 }
1012}
1013
1014impl PQIndex {
1015 pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1017 self.search_knn(query, k)
1018 }
1019}
1020
1021#[derive(Debug, Clone)]
1023pub struct PQStats {
1024 pub n_vectors: usize,
1025 pub n_subquantizers: usize,
1026 pub n_centroids: usize,
1027 pub is_trained: bool,
1028 pub dimensions: Option<usize>,
1029 pub compression_ratio: f32,
1030 pub memory_usage_bytes: usize,
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035 use super::*;
1036
1037 #[test]
1038 fn test_pq_basic() {
1039 let config = PQConfig {
1040 n_subquantizers: 2,
1041 n_centroids: 4,
1042 ..Default::default()
1043 };
1044
1045 let mut index = PQIndex::new(config);
1046
1047 let training_vectors = vec![
1049 Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1050 Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1051 Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1052 Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1053 Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1054 Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1055 ];
1056
1057 index.train(&training_vectors).unwrap();
1059 assert!(index.is_trained);
1060
1061 for (i, vec) in training_vectors.iter().enumerate() {
1063 index.insert(format!("vec{i}"), vec.clone()).unwrap();
1064 }
1065
1066 let query = Vector::new(vec![0.9, 0.1, 0.1, 0.9]);
1068 let results = index.search_knn(&query, 3).unwrap();
1069
1070 assert!(!results.is_empty());
1071 assert!(results.len() <= 3);
1072 }
1073
1074 #[test]
1075 fn test_pq_compression() {
1076 let config = PQConfig {
1077 n_subquantizers: 4,
1078 n_centroids: 16,
1079 ..Default::default()
1080 };
1081
1082 let mut index = PQIndex::new(config);
1083
1084 let dims = 128;
1086 let training_vectors: Vec<Vector> = (0..100)
1087 .map(|i| {
1088 let values: Vec<f32> = (0..dims).map(|j| ((i + j) as f32).sin()).collect();
1089 Vector::new(values)
1090 })
1091 .collect();
1092
1093 index.train(&training_vectors).unwrap();
1095
1096 let compression_ratio = index.compression_ratio();
1097 assert_eq!(compression_ratio, 128.0); let stats = index.stats();
1100 assert_eq!(stats.n_subquantizers, 4);
1101 assert_eq!(stats.n_centroids, 16);
1102 assert_eq!(stats.dimensions, Some(128));
1103 }
1104
1105 #[test]
1106 fn test_pq_reconstruction() {
1107 let config = PQConfig {
1108 n_subquantizers: 2,
1109 n_centroids: 8,
1110 ..Default::default()
1111 };
1112
1113 let mut index = PQIndex::new(config);
1114
1115 let training_vectors = vec![
1117 Vector::new(vec![1.0, 0.0]),
1118 Vector::new(vec![0.0, 1.0]),
1119 Vector::new(vec![-1.0, 0.0]),
1120 Vector::new(vec![0.0, -1.0]),
1121 ];
1122
1123 index.train(&training_vectors).unwrap();
1124
1125 let original = Vector::new(vec![0.7, 0.7]);
1127 let codes = index.encode_vector(&original).unwrap();
1128 let reconstructed = index.decode_codes(&codes).unwrap();
1129
1130 let dist = original.euclidean_distance(&reconstructed).unwrap();
1132 assert!(dist < 1.0); }
1134
1135 #[test]
1136 fn test_pq_residual_quantization() {
1137 let config = PQConfig::with_residual_quantization(2, 3, 2); let mut index = PQIndex::new(config);
1139
1140 let training_vectors = vec![
1142 Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1143 Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1144 Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1145 Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1146 Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1147 Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1148 ];
1149
1150 index.train(&training_vectors).unwrap();
1152 assert!(index.is_trained());
1153 assert_eq!(index.residual_quantizers.len(), 2);
1154
1155 let test_vector = Vector::new(vec![0.7, 0.3, 0.3, 0.7]);
1157 let enhanced_codes = index.encode_vector_enhanced(&test_vector).unwrap();
1158
1159 assert!(!enhanced_codes.primary.is_empty());
1160 assert_eq!(enhanced_codes.residual.len(), 2);
1161 assert!(enhanced_codes.multi_codebook.is_empty()); }
1163
1164 #[test]
1165 fn test_pq_multi_codebook() {
1166 let config = PQConfig::with_multi_codebook(2, 3, 3); let mut index = PQIndex::new(config);
1168
1169 let training_vectors = vec![
1171 Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1172 Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1173 Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1174 Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1175 Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1176 Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1177 ];
1178
1179 index.train(&training_vectors).unwrap();
1181 assert!(index.is_trained());
1182 assert_eq!(index.multi_codebook_quantizers.len(), 3);
1183
1184 let test_vector = Vector::new(vec![0.7, 0.3, 0.3, 0.7]);
1186 let enhanced_codes = index.encode_vector_enhanced(&test_vector).unwrap();
1187
1188 assert!(!enhanced_codes.primary.is_empty());
1189 assert!(enhanced_codes.residual.is_empty()); assert_eq!(enhanced_codes.multi_codebook.len(), 3);
1191 }
1192
1193 #[test]
1194 fn test_pq_symmetric_distance() {
1195 let config = PQConfig {
1196 enable_symmetric_distance: true,
1197 n_subquantizers: 2,
1198 n_centroids: 4,
1199 ..Default::default()
1200 };
1201
1202 let mut index = PQIndex::new(config);
1203
1204 let training_vectors = vec![
1206 Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1207 Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1208 Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1209 Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1210 ];
1211
1212 index.train(&training_vectors).unwrap();
1214 assert!(index.distance_tables.is_some());
1215
1216 let codes1 = vec![0, 1];
1218 let codes2 = vec![1, 0];
1219 let distance = index.symmetric_distance(&codes1, &codes2).unwrap();
1220
1221 assert!(distance >= 0.0);
1222 assert!(distance.is_finite());
1223 }
1224
1225 #[test]
1226 fn test_pq_enhanced_features() {
1227 let config = PQConfig::enhanced(2, 3); let mut index = PQIndex::new(config);
1229
1230 let training_vectors = vec![
1232 Vector::new(vec![1.0, 0.0, 0.0, 1.0]),
1233 Vector::new(vec![0.0, 1.0, 1.0, 0.0]),
1234 Vector::new(vec![-1.0, 0.0, 0.0, -1.0]),
1235 Vector::new(vec![0.0, -1.0, -1.0, 0.0]),
1236 Vector::new(vec![0.5, 0.5, 0.5, 0.5]),
1237 Vector::new(vec![-0.5, -0.5, -0.5, -0.5]),
1238 ];
1239
1240 index.train(&training_vectors).unwrap();
1242 assert!(index.is_trained());
1243
1244 assert!(!index.residual_quantizers.is_empty());
1246 assert!(!index.multi_codebook_quantizers.is_empty());
1247 assert!(index.distance_tables.is_some());
1248
1249 let test_vector = Vector::new(vec![0.7, 0.3, 0.3, 0.7]);
1251 let enhanced_codes = index.encode_vector_enhanced(&test_vector).unwrap();
1252 let enhanced_distance = index
1253 .enhanced_distance(&test_vector, &enhanced_codes)
1254 .unwrap();
1255
1256 assert!(enhanced_distance >= 0.0);
1257 assert!(enhanced_distance.is_finite());
1258
1259 let basic_distance = index
1261 .asymmetric_distance(&test_vector, &enhanced_codes.primary)
1262 .unwrap();
1263 assert!(enhanced_distance <= basic_distance * 1.1); }
1265
1266 #[test]
1267 fn test_pq_config_validation() {
1268 let config = PQConfig::enhanced(4, 8);
1270 assert!(config.validate().is_ok());
1271
1272 let invalid_config = PQConfig {
1274 enable_residual_quantization: true,
1275 residual_levels: 0,
1276 ..Default::default()
1277 };
1278 assert!(invalid_config.validate().is_err());
1279
1280 let invalid_config = PQConfig {
1282 enable_multi_codebook: true,
1283 num_codebooks: 1,
1284 ..Default::default()
1285 };
1286 assert!(invalid_config.validate().is_err());
1287 }
1288}