1use std::collections::HashMap;
10
11use parking_lot::RwLock;
12use rand::seq::SliceRandom;
13use serde::{Deserialize, Serialize};
14
15use common::{DistanceMetric, Vector, VectorId};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct PQConfig {
20 pub num_subquantizers: usize,
23 pub num_centroids: usize,
25 pub kmeans_iterations: usize,
27 pub distance_metric: DistanceMetric,
29}
30
31impl Default for PQConfig {
32 fn default() -> Self {
33 Self {
34 num_subquantizers: 8,
35 num_centroids: 256,
36 kmeans_iterations: 20,
37 distance_metric: DistanceMetric::Euclidean,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ProductQuantizer {
45 pub config: PQConfig,
47 pub codebooks: Vec<Vec<Vec<f32>>>,
49 pub dimension: usize,
51 pub subvector_dim: usize,
53}
54
55impl ProductQuantizer {
56 pub fn new(config: PQConfig, dimension: usize) -> Result<Self, String> {
58 if !dimension.is_multiple_of(config.num_subquantizers) {
59 return Err(format!(
60 "Dimension {} not divisible by num_subquantizers {}",
61 dimension, config.num_subquantizers
62 ));
63 }
64
65 let subvector_dim = dimension / config.num_subquantizers;
66
67 Ok(Self {
68 config,
69 codebooks: Vec::new(),
70 dimension,
71 subvector_dim,
72 })
73 }
74
75 pub fn train(&mut self, vectors: &[Vector]) -> Result<(), String> {
77 if vectors.is_empty() {
78 return Err("Cannot train on empty vectors".to_string());
79 }
80
81 if vectors[0].values.len() != self.dimension {
82 return Err(format!(
83 "Vector dimension {} doesn't match expected {}",
84 vectors[0].values.len(),
85 self.dimension
86 ));
87 }
88
89 let m = self.config.num_subquantizers;
90 let k = self.config.num_centroids;
91 let d = self.subvector_dim;
92
93 self.codebooks = Vec::with_capacity(m);
94
95 for subspace_idx in 0..m {
97 let start = subspace_idx * d;
98 let end = start + d;
99
100 let subvectors: Vec<Vec<f32>> = vectors
102 .iter()
103 .map(|v| v.values[start..end].to_vec())
104 .collect();
105
106 let codebook = self.train_kmeans(&subvectors, k)?;
108 self.codebooks.push(codebook);
109 }
110
111 Ok(())
112 }
113
114 fn train_kmeans(&self, subvectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>, String> {
116 if subvectors.is_empty() {
117 return Err("Cannot train k-means on empty subvectors".to_string());
118 }
119 let actual_k = k.min(subvectors.len());
120 let dim = subvectors[0].len();
121
122 let mut centroids = self.kmeans_plus_plus(subvectors, actual_k);
124
125 for _ in 0..self.config.kmeans_iterations {
127 let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); actual_k];
129 for (i, subvec) in subvectors.iter().enumerate() {
130 let nearest = self.find_nearest_centroid(subvec, ¢roids);
131 assignments[nearest].push(i);
132 }
133
134 for (c_idx, assigned) in assignments.iter().enumerate() {
136 if assigned.is_empty() {
137 continue;
138 }
139
140 let mut new_centroid = vec![0.0f32; dim];
141 for &vec_idx in assigned {
142 for (j, &val) in subvectors[vec_idx].iter().enumerate() {
143 new_centroid[j] += val;
144 }
145 }
146
147 let count = assigned.len() as f32;
148 for val in &mut new_centroid {
149 *val /= count;
150 }
151
152 centroids[c_idx] = new_centroid;
153 }
154 }
155
156 Ok(centroids)
157 }
158
159 fn kmeans_plus_plus(&self, subvectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
161 let mut rng = rand::thread_rng();
162 let mut centroids = Vec::with_capacity(k);
163
164 if let Some(first) = subvectors.choose(&mut rng) {
166 centroids.push(first.clone());
167 } else {
168 return centroids;
169 }
170
171 for _ in 1..k {
173 let distances: Vec<f32> = subvectors
174 .iter()
175 .map(|v| {
176 centroids
177 .iter()
178 .map(|c| self.squared_distance(v, c))
179 .fold(f32::MAX, f32::min)
180 })
181 .collect();
182
183 let total: f32 = distances.iter().sum();
184 if total == 0.0 {
185 break;
186 }
187
188 let threshold: f32 = rand::random::<f32>() * total;
189 let mut cumsum = 0.0;
190
191 for (i, &d) in distances.iter().enumerate() {
192 cumsum += d;
193 if cumsum >= threshold {
194 centroids.push(subvectors[i].clone());
195 break;
196 }
197 }
198 }
199
200 centroids
201 }
202
203 fn find_nearest_centroid(&self, subvec: &[f32], centroids: &[Vec<f32>]) -> usize {
205 let mut best_idx = 0;
206 let mut best_dist = f32::MAX;
207
208 for (i, centroid) in centroids.iter().enumerate() {
209 let dist = self.squared_distance(subvec, centroid);
210 if dist < best_dist {
211 best_dist = dist;
212 best_idx = i;
213 }
214 }
215
216 best_idx
217 }
218
219 #[inline]
221 fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 {
222 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
223 }
224
225 pub fn is_trained(&self) -> bool {
227 !self.codebooks.is_empty()
228 }
229
230 pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>, String> {
232 if !self.is_trained() {
233 return Err("Quantizer not trained".to_string());
234 }
235
236 if vector.len() != self.dimension {
237 return Err(format!(
238 "Vector dimension {} doesn't match expected {}",
239 vector.len(),
240 self.dimension
241 ));
242 }
243
244 let m = self.config.num_subquantizers;
245 let d = self.subvector_dim;
246 let mut codes = Vec::with_capacity(m);
247
248 for subspace_idx in 0..m {
249 let start = subspace_idx * d;
250 let end = start + d;
251 let subvec = &vector[start..end];
252
253 let nearest = self.find_nearest_centroid(subvec, &self.codebooks[subspace_idx]);
254 codes.push(nearest as u8);
255 }
256
257 Ok(codes)
258 }
259
260 pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>, String> {
262 if !self.is_trained() {
263 return Err("Quantizer not trained".to_string());
264 }
265
266 if codes.len() != self.config.num_subquantizers {
267 return Err(format!(
268 "Code length {} doesn't match num_subquantizers {}",
269 codes.len(),
270 self.config.num_subquantizers
271 ));
272 }
273
274 let mut vector = Vec::with_capacity(self.dimension);
275
276 for (subspace_idx, &code) in codes.iter().enumerate() {
277 let centroid = &self.codebooks[subspace_idx][code as usize];
278 vector.extend_from_slice(centroid);
279 }
280
281 Ok(vector)
282 }
283
284 pub fn compute_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>, String> {
287 if !self.is_trained() {
288 return Err("Quantizer not trained".to_string());
289 }
290
291 if query.len() != self.dimension {
292 return Err(format!(
293 "Query dimension {} doesn't match expected {}",
294 query.len(),
295 self.dimension
296 ));
297 }
298
299 let m = self.config.num_subquantizers;
300 let k = self.config.num_centroids;
301 let d = self.subvector_dim;
302
303 let mut table = Vec::with_capacity(m);
304
305 for subspace_idx in 0..m {
306 let start = subspace_idx * d;
307 let end = start + d;
308 let query_subvec = &query[start..end];
309
310 let mut distances = Vec::with_capacity(k);
311 for centroid in &self.codebooks[subspace_idx] {
312 let dist = match self.config.distance_metric {
313 DistanceMetric::Euclidean => {
314 -self.squared_distance(query_subvec, centroid).sqrt()
315 }
316 DistanceMetric::Cosine => self.cosine_sim(query_subvec, centroid),
317 DistanceMetric::DotProduct => self.dot_product(query_subvec, centroid),
318 };
319 distances.push(dist);
320 }
321
322 table.push(distances);
323 }
324
325 Ok(table)
326 }
327
328 #[inline]
330 pub fn compute_distance_adc(&self, table: &[Vec<f32>], codes: &[u8]) -> f32 {
331 let mut total = 0.0f32;
332 for (subspace_idx, &code) in codes.iter().enumerate() {
333 total += table[subspace_idx][code as usize];
334 }
335 total
336 }
337
338 #[inline]
339 fn cosine_sim(&self, a: &[f32], b: &[f32]) -> f32 {
340 let mut dot = 0.0f32;
341 let mut norm_a = 0.0f32;
342 let mut norm_b = 0.0f32;
343
344 for (x, y) in a.iter().zip(b.iter()) {
345 dot += x * y;
346 norm_a += x * x;
347 norm_b += y * y;
348 }
349
350 let norm_a = norm_a.sqrt();
351 let norm_b = norm_b.sqrt();
352
353 if norm_a == 0.0 || norm_b == 0.0 {
354 0.0
355 } else {
356 dot / (norm_a * norm_b)
357 }
358 }
359
360 #[inline]
361 fn dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
362 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
363 }
364}
365
366pub struct PQIndex {
368 quantizer: RwLock<ProductQuantizer>,
370 encoded_vectors: RwLock<HashMap<VectorId, Vec<u8>>>,
372 original_vectors: RwLock<HashMap<VectorId, Vector>>,
374 store_originals: bool,
376}
377
378#[derive(Debug, Clone)]
380pub struct PQSearchResult {
381 pub id: VectorId,
382 pub score: f32,
383 pub vector: Option<Vector>,
384}
385
386impl PQIndex {
387 pub fn new(config: PQConfig, dimension: usize, store_originals: bool) -> Result<Self, String> {
389 let quantizer = ProductQuantizer::new(config, dimension)?;
390
391 Ok(Self {
392 quantizer: RwLock::new(quantizer),
393 encoded_vectors: RwLock::new(HashMap::new()),
394 original_vectors: RwLock::new(HashMap::new()),
395 store_originals,
396 })
397 }
398
399 pub fn train(&self, vectors: &[Vector]) -> Result<(), String> {
401 let mut quantizer = self.quantizer.write();
402 quantizer.train(vectors)
403 }
404
405 pub fn is_trained(&self) -> bool {
407 self.quantizer.read().is_trained()
408 }
409
410 pub fn add(&self, vectors: Vec<Vector>) -> Result<usize, String> {
412 let quantizer = self.quantizer.read();
413 if !quantizer.is_trained() {
414 return Err("Index not trained".to_string());
415 }
416
417 let mut encoded = self.encoded_vectors.write();
418 let mut originals = self.original_vectors.write();
419 let mut count = 0;
420
421 for vector in vectors {
422 let codes = quantizer.encode(&vector.values)?;
423 encoded.insert(vector.id.clone(), codes);
424
425 if self.store_originals {
426 originals.insert(vector.id.clone(), vector);
427 }
428
429 count += 1;
430 }
431
432 Ok(count)
433 }
434
435 pub fn remove(&self, ids: &[VectorId]) -> usize {
437 let mut encoded = self.encoded_vectors.write();
438 let mut originals = self.original_vectors.write();
439 let mut count = 0;
440
441 for id in ids {
442 if encoded.remove(id).is_some() {
443 count += 1;
444 }
445 originals.remove(id);
446 }
447
448 count
449 }
450
451 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<PQSearchResult>, String> {
453 let quantizer = self.quantizer.read();
454 if !quantizer.is_trained() {
455 return Err("Index not trained".to_string());
456 }
457
458 let table = quantizer.compute_distance_table(query)?;
460
461 let encoded = self.encoded_vectors.read();
462 let originals = self.original_vectors.read();
463
464 let mut results: Vec<PQSearchResult> = encoded
466 .iter()
467 .map(|(id, codes)| {
468 let score = quantizer.compute_distance_adc(&table, codes);
469 let vector = originals.get(id).cloned();
470
471 PQSearchResult {
472 id: id.clone(),
473 score,
474 vector,
475 }
476 })
477 .collect();
478
479 results.sort_by(|a, b| {
481 b.score
482 .partial_cmp(&a.score)
483 .unwrap_or(std::cmp::Ordering::Equal)
484 });
485 results.truncate(k);
486
487 Ok(results)
488 }
489
490 pub fn len(&self) -> usize {
492 self.encoded_vectors.read().len()
493 }
494
495 pub fn is_empty(&self) -> bool {
497 self.encoded_vectors.read().is_empty()
498 }
499
500 pub fn compression_ratio(&self) -> f32 {
502 let quantizer = self.quantizer.read();
503 let original_size = quantizer.dimension * 4; let compressed_size = quantizer.config.num_subquantizers; original_size as f32 / compressed_size as f32
506 }
507
508 pub fn decode(&self, id: &VectorId) -> Result<Vec<f32>, String> {
510 let quantizer = self.quantizer.read();
511 let encoded = self.encoded_vectors.read();
512
513 let codes = encoded
514 .get(id)
515 .ok_or_else(|| format!("Vector {} not found", id))?;
516
517 quantizer.decode(codes)
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 fn test_vectors(n: usize, dim: usize) -> Vec<Vector> {
526 (0..n)
527 .map(|i| Vector {
528 id: format!("v{}", i),
529 values: (0..dim).map(|j| ((i + j) as f32 * 0.1).sin()).collect(),
530 metadata: None,
531 ttl_seconds: None,
532 expires_at: None,
533 })
534 .collect()
535 }
536
537 #[test]
538 fn test_pq_config_validation() {
539 let config = PQConfig {
540 num_subquantizers: 8,
541 ..Default::default()
542 };
543
544 assert!(ProductQuantizer::new(config.clone(), 64).is_ok());
546
547 assert!(ProductQuantizer::new(config, 65).is_err());
549 }
550
551 #[test]
552 fn test_pq_train() {
553 let config = PQConfig {
554 num_subquantizers: 4,
555 num_centroids: 16,
556 kmeans_iterations: 10,
557 ..Default::default()
558 };
559
560 let mut pq = ProductQuantizer::new(config, 32).unwrap();
561 let vectors = test_vectors(100, 32);
562
563 assert!(!pq.is_trained());
564 pq.train(&vectors).unwrap();
565 assert!(pq.is_trained());
566
567 assert_eq!(pq.codebooks.len(), 4);
569 assert_eq!(pq.codebooks[0].len(), 16);
570 assert_eq!(pq.codebooks[0][0].len(), 8); }
572
573 #[test]
574 fn test_pq_encode_decode() {
575 let config = PQConfig {
576 num_subquantizers: 4,
577 num_centroids: 16,
578 ..Default::default()
579 };
580
581 let mut pq = ProductQuantizer::new(config, 32).unwrap();
582 let vectors = test_vectors(100, 32);
583 pq.train(&vectors).unwrap();
584
585 let original = &vectors[0].values;
587 let codes = pq.encode(original).unwrap();
588
589 assert_eq!(codes.len(), 4);
590
591 let decoded = pq.decode(&codes).unwrap();
593 assert_eq!(decoded.len(), 32);
594
595 let error: f32 = original
598 .iter()
599 .zip(decoded.iter())
600 .map(|(a, b)| (a - b).powi(2))
601 .sum::<f32>()
602 .sqrt();
603
604 assert!(error < 5.0, "Quantization error too high: {}", error);
606 }
607
608 #[test]
609 fn test_pq_distance_table() {
610 let config = PQConfig {
611 num_subquantizers: 4,
612 num_centroids: 16,
613 ..Default::default()
614 };
615
616 let mut pq = ProductQuantizer::new(config, 32).unwrap();
617 let vectors = test_vectors(100, 32);
618 pq.train(&vectors).unwrap();
619
620 let query = &vectors[0].values;
621 let table = pq.compute_distance_table(query).unwrap();
622
623 assert_eq!(table.len(), 4);
624 assert_eq!(table[0].len(), 16);
625 }
626
627 #[test]
628 fn test_pq_adc() {
629 let config = PQConfig {
630 num_subquantizers: 4,
631 num_centroids: 16,
632 ..Default::default()
633 };
634
635 let mut pq = ProductQuantizer::new(config, 32).unwrap();
636 let vectors = test_vectors(100, 32);
637 pq.train(&vectors).unwrap();
638
639 let query = &vectors[50].values;
640 let table = pq.compute_distance_table(query).unwrap();
641
642 let codes = pq.encode(query).unwrap();
644 let dist = pq.compute_distance_adc(&table, &codes);
645
646 assert!(
650 dist > -3.0,
651 "Self-distance should be relatively small, got {}",
652 dist
653 );
654 }
655
656 #[test]
657 fn test_pq_index_basic() {
658 let config = PQConfig {
659 num_subquantizers: 4,
660 num_centroids: 16,
661 ..Default::default()
662 };
663
664 let index = PQIndex::new(config, 32, true).unwrap();
665 let vectors = test_vectors(100, 32);
666
667 index.train(&vectors).unwrap();
668 assert!(index.is_trained());
669
670 let added = index.add(vectors.clone()).unwrap();
671 assert_eq!(added, 100);
672 assert_eq!(index.len(), 100);
673 }
674
675 #[test]
676 fn test_pq_index_search() {
677 let config = PQConfig {
678 num_subquantizers: 4,
679 num_centroids: 32,
680 kmeans_iterations: 15,
681 distance_metric: DistanceMetric::Euclidean,
682 };
683
684 let index = PQIndex::new(config, 32, true).unwrap();
685 let vectors = test_vectors(200, 32);
686
687 index.train(&vectors).unwrap();
688 index.add(vectors.clone()).unwrap();
689
690 let query = &vectors[100].values;
692 let results = index.search(query, 10).unwrap();
693
694 assert!(!results.is_empty());
695 assert!(results.len() <= 10);
696
697 for i in 1..results.len() {
699 assert!(results[i - 1].score >= results[i].score);
700 }
701
702 let found = results.iter().any(|r| r.id == "v100");
704 assert!(found, "Query vector not found in top results");
705 }
706
707 #[test]
708 fn test_pq_index_remove() {
709 let config = PQConfig {
710 num_subquantizers: 4,
711 num_centroids: 16,
712 ..Default::default()
713 };
714
715 let index = PQIndex::new(config, 32, false).unwrap();
716 let vectors = test_vectors(50, 32);
717
718 index.train(&vectors).unwrap();
719 index.add(vectors).unwrap();
720
721 assert_eq!(index.len(), 50);
722
723 let removed = index.remove(&["v0".to_string(), "v1".to_string()]);
724 assert_eq!(removed, 2);
725 assert_eq!(index.len(), 48);
726 }
727
728 #[test]
729 fn test_pq_compression_ratio() {
730 let config = PQConfig {
731 num_subquantizers: 8,
732 num_centroids: 256,
733 ..Default::default()
734 };
735
736 let index = PQIndex::new(config, 128, false).unwrap();
737
738 let ratio = index.compression_ratio();
742 assert!((ratio - 64.0).abs() < 0.1);
743 }
744
745 #[test]
746 fn test_pq_decode_from_index() {
747 let config = PQConfig {
748 num_subquantizers: 4,
749 num_centroids: 16,
750 ..Default::default()
751 };
752
753 let index = PQIndex::new(config, 32, false).unwrap();
754 let vectors = test_vectors(50, 32);
755
756 index.train(&vectors).unwrap();
757 index.add(vectors).unwrap();
758
759 let decoded = index.decode(&"v10".to_string()).unwrap();
761 assert_eq!(decoded.len(), 32);
762
763 assert!(index.decode(&"nonexistent".to_string()).is_err());
765 }
766}