1use crate::distance::DistanceMetric;
7use crate::error::VectorError;
8
9use super::config::PQConfig;
10use super::training::{KMeans, KMeansConfig};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct PQCode {
18 codes: Vec<u8>,
20 bits_per_code: u8,
22}
23
24impl PQCode {
25 #[must_use]
32 pub fn new(codes: Vec<u8>, bits_per_code: u8) -> Self {
33 Self { codes, bits_per_code }
34 }
35
36 #[must_use]
38 pub fn num_segments(&self) -> usize {
39 self.codes.len()
40 }
41
42 #[must_use]
44 pub fn get(&self, segment: usize) -> Option<u8> {
45 self.codes.get(segment).copied()
46 }
47
48 #[must_use]
50 pub fn as_slice(&self) -> &[u8] {
51 &self.codes
52 }
53
54 #[must_use]
56 pub fn to_bytes(&self) -> Vec<u8> {
57 if self.bits_per_code == 8 {
60 self.codes.clone()
61 } else {
62 self.pack_codes()
64 }
65 }
66
67 #[must_use]
75 pub fn from_bytes(bytes: &[u8], num_segments: usize, bits_per_code: u8) -> Self {
76 if bits_per_code == 8 {
77 Self { codes: bytes[..num_segments].to_vec(), bits_per_code }
78 } else {
79 Self::unpack_codes(bytes, num_segments, bits_per_code)
80 }
81 }
82
83 fn pack_codes(&self) -> Vec<u8> {
85 let total_bits = self.codes.len() * self.bits_per_code as usize;
86 let num_bytes = total_bits.div_ceil(8);
87 let mut bytes = vec![0u8; num_bytes];
88
89 let mut bit_pos = 0usize;
90 for &code in &self.codes {
91 let byte_idx = bit_pos / 8;
92 let bit_offset = bit_pos % 8;
93
94 bytes[byte_idx] |= code << bit_offset;
96
97 if bit_offset + self.bits_per_code as usize > 8 && byte_idx + 1 < bytes.len() {
99 bytes[byte_idx + 1] |= code >> (8 - bit_offset);
100 }
101
102 bit_pos += self.bits_per_code as usize;
103 }
104
105 bytes
106 }
107
108 fn unpack_codes(bytes: &[u8], num_segments: usize, bits_per_code: u8) -> Self {
110 let mask = (1u8 << bits_per_code) - 1;
111 let mut codes = Vec::with_capacity(num_segments);
112
113 let mut bit_pos = 0usize;
114 for _ in 0..num_segments {
115 let byte_idx = bit_pos / 8;
116 let bit_offset = bit_pos % 8;
117
118 let code = if bit_offset + bits_per_code as usize <= 8 {
119 (bytes[byte_idx] >> bit_offset) & mask
120 } else {
121 let low = bytes[byte_idx] >> bit_offset;
122 let high = if byte_idx + 1 < bytes.len() {
123 bytes[byte_idx + 1] << (8 - bit_offset)
124 } else {
125 0
126 };
127 (low | high) & mask
128 };
129
130 codes.push(code);
131 bit_pos += bits_per_code as usize;
132 }
133
134 Self { codes, bits_per_code }
135 }
136}
137
138#[derive(Debug, Clone)]
143pub struct ProductQuantizer {
144 config: PQConfig,
146 codebooks: Vec<Vec<Vec<f32>>>,
148}
149
150impl ProductQuantizer {
151 pub fn train(config: &PQConfig, training_data: &[&[f32]]) -> Result<Self, VectorError> {
165 config.validate()?;
166
167 if training_data.is_empty() {
168 return Err(VectorError::Encoding("cannot train PQ on empty data".to_string()));
169 }
170
171 for (i, v) in training_data.iter().enumerate() {
173 if v.len() != config.dimension {
174 return Err(VectorError::DimensionMismatch {
175 expected: config.dimension,
176 actual: v.len(),
177 });
178 }
179 if i > 1000 {
180 break; }
182 }
183
184 let subspace_dim = config.subspace_dimension();
185 let mut codebooks = Vec::with_capacity(config.num_segments);
186
187 for segment in 0..config.num_segments {
189 let start = segment * subspace_dim;
190 let end = start + subspace_dim;
191
192 let subvectors: Vec<Vec<f32>> =
194 training_data.iter().map(|v| v[start..end].to_vec()).collect();
195
196 let subvector_refs: Vec<&[f32]> = subvectors.iter().map(|v| v.as_slice()).collect();
197
198 let kmeans_config = KMeansConfig::new(config.num_centroids)
200 .with_max_iterations(config.training_iterations)
201 .with_seed(config.seed.map(|s| s + segment as u64).unwrap_or(segment as u64));
202
203 let kmeans = KMeans::train(&subvector_refs, &kmeans_config, config.distance_metric)?;
204 codebooks.push(kmeans.centroids);
205 }
206
207 Ok(Self { config: config.clone(), codebooks })
208 }
209
210 pub fn from_codebooks(
221 config: &PQConfig,
222 codebooks: Vec<Vec<Vec<f32>>>,
223 ) -> Result<Self, VectorError> {
224 config.validate()?;
225
226 if codebooks.len() != config.num_segments {
227 return Err(VectorError::Encoding(format!(
228 "expected {} codebooks, got {}",
229 config.num_segments,
230 codebooks.len()
231 )));
232 }
233
234 let subspace_dim = config.subspace_dimension();
235 for (i, codebook) in codebooks.iter().enumerate() {
236 if codebook.len() != config.num_centroids {
237 return Err(VectorError::Encoding(format!(
238 "codebook {} has {} centroids, expected {}",
239 i,
240 codebook.len(),
241 config.num_centroids
242 )));
243 }
244 for centroid in codebook {
245 if centroid.len() != subspace_dim {
246 return Err(VectorError::DimensionMismatch {
247 expected: subspace_dim,
248 actual: centroid.len(),
249 });
250 }
251 }
252 }
253
254 Ok(Self { config: config.clone(), codebooks })
255 }
256
257 #[must_use]
259 pub fn config(&self) -> &PQConfig {
260 &self.config
261 }
262
263 #[must_use]
265 pub fn codebooks(&self) -> &[Vec<Vec<f32>>] {
266 &self.codebooks
267 }
268
269 #[must_use]
279 #[allow(clippy::cast_possible_truncation)]
280 pub fn encode(&self, vector: &[f32]) -> PQCode {
281 debug_assert_eq!(vector.len(), self.config.dimension);
282
283 let subspace_dim = self.config.subspace_dimension();
284 let mut codes = Vec::with_capacity(self.config.num_segments);
285
286 for (segment, codebook) in self.codebooks.iter().enumerate() {
287 let start = segment * subspace_dim;
288 let end = start + subspace_dim;
289 let subvector = &vector[start..end];
290
291 let mut min_dist = f32::MAX;
293 let mut min_idx = 0u8;
294
295 for (idx, centroid) in codebook.iter().enumerate() {
296 let dist = self.subspace_distance(subvector, centroid);
297 if dist < min_dist {
298 min_dist = dist;
299 min_idx = idx as u8;
300 }
301 }
302
303 codes.push(min_idx);
304 }
305
306 PQCode::new(codes, self.config.bits_per_code() as u8)
307 }
308
309 #[must_use]
314 pub fn decode(&self, code: &PQCode) -> Vec<f32> {
315 let mut vector = Vec::with_capacity(self.config.dimension);
316
317 for (segment, &idx) in code.as_slice().iter().enumerate() {
318 let centroid = &self.codebooks[segment][idx as usize];
319 vector.extend_from_slice(centroid);
320 }
321
322 vector
323 }
324
325 #[must_use]
336 pub fn compute_distance_table(&self, query: &[f32]) -> DistanceTable {
337 debug_assert_eq!(query.len(), self.config.dimension);
338
339 let subspace_dim = self.config.subspace_dimension();
340 let mut table = Vec::with_capacity(self.config.num_segments);
341
342 for (segment, codebook) in self.codebooks.iter().enumerate() {
343 let start = segment * subspace_dim;
344 let end = start + subspace_dim;
345 let subvector = &query[start..end];
346
347 let mut segment_distances = Vec::with_capacity(codebook.len());
348 for centroid in codebook {
349 let dist = self.subspace_distance(subvector, centroid);
350 segment_distances.push(dist);
351 }
352
353 table.push(segment_distances);
354 }
355
356 DistanceTable { table, metric: self.config.distance_metric }
357 }
358
359 #[must_use]
369 #[inline]
370 pub fn asymmetric_distance(&self, table: &DistanceTable, code: &PQCode) -> f32 {
371 let mut total = 0.0f32;
372
373 for (segment, &idx) in code.as_slice().iter().enumerate() {
374 total += table.table[segment][idx as usize];
375 }
376
377 match self.config.distance_metric {
380 DistanceMetric::Euclidean => total.sqrt(),
381 _ => total,
382 }
383 }
384
385 #[must_use]
390 #[inline]
391 pub fn asymmetric_distance_squared(&self, table: &DistanceTable, code: &PQCode) -> f32 {
392 let mut total = 0.0f32;
393
394 for (segment, &idx) in code.as_slice().iter().enumerate() {
395 total += table.table[segment][idx as usize];
396 }
397
398 total
399 }
400
401 #[must_use]
406 pub fn symmetric_distance(&self, code_a: &PQCode, code_b: &PQCode) -> f32 {
407 let mut total = 0.0f32;
408
409 for segment in 0..self.config.num_segments {
410 let idx_a = code_a.as_slice()[segment] as usize;
411 let idx_b = code_b.as_slice()[segment] as usize;
412
413 let centroid_a = &self.codebooks[segment][idx_a];
414 let centroid_b = &self.codebooks[segment][idx_b];
415
416 total += self.subspace_distance(centroid_a, centroid_b);
417 }
418
419 match self.config.distance_metric {
420 DistanceMetric::Euclidean => total.sqrt(),
421 _ => total,
422 }
423 }
424
425 #[inline]
427 fn subspace_distance(&self, a: &[f32], b: &[f32]) -> f32 {
428 match self.config.distance_metric {
429 DistanceMetric::Euclidean => {
430 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
432 }
433 DistanceMetric::Cosine => {
434 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
435 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
436 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
437 if norm_a == 0.0 || norm_b == 0.0 {
438 1.0
439 } else {
440 1.0 - (dot / (norm_a * norm_b))
441 }
442 }
443 DistanceMetric::DotProduct => {
444 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
445 -dot
446 }
447 DistanceMetric::Manhattan => a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum(),
448 DistanceMetric::Chebyshev => {
449 a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).fold(0.0f32, f32::max)
450 }
451 }
452 }
453
454 #[must_use]
456 pub fn to_bytes(&self) -> Vec<u8> {
457 let mut bytes = Vec::new();
458
459 bytes.push(1u8);
461
462 bytes.extend_from_slice(&(self.config.dimension as u32).to_le_bytes());
464 bytes.extend_from_slice(&(self.config.num_segments as u32).to_le_bytes());
465 bytes.extend_from_slice(&(self.config.num_centroids as u32).to_le_bytes());
466 bytes.push(match self.config.distance_metric {
467 DistanceMetric::Euclidean => 0,
468 DistanceMetric::Cosine => 1,
469 DistanceMetric::DotProduct => 2,
470 DistanceMetric::Manhattan => 3,
471 DistanceMetric::Chebyshev => 4,
472 });
473
474 for codebook in &self.codebooks {
476 for centroid in codebook {
477 for &val in centroid {
478 bytes.extend_from_slice(&val.to_le_bytes());
479 }
480 }
481 }
482
483 bytes
484 }
485
486 pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
492 if bytes.len() < 14 {
493 return Err(VectorError::Encoding("PQ bytes too short".to_string()));
494 }
495
496 let version = bytes[0];
497 if version != 1 {
498 return Err(VectorError::Encoding(format!("unsupported PQ version: {}", version)));
499 }
500
501 let dimension = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
502 let num_segments = u32::from_le_bytes([bytes[5], bytes[6], bytes[7], bytes[8]]) as usize;
503 let num_centroids =
504 u32::from_le_bytes([bytes[9], bytes[10], bytes[11], bytes[12]]) as usize;
505 let distance_metric = match bytes[13] {
506 0 => DistanceMetric::Euclidean,
507 1 => DistanceMetric::Cosine,
508 2 => DistanceMetric::DotProduct,
509 3 => DistanceMetric::Manhattan,
510 4 => DistanceMetric::Chebyshev,
511 m => return Err(VectorError::Encoding(format!("unknown metric: {}", m))),
512 };
513
514 let config = PQConfig::new(dimension, num_segments)
515 .with_num_centroids(num_centroids)
516 .with_distance_metric(distance_metric);
517
518 let subspace_dim = dimension / num_segments;
519 let codebook_size = num_centroids * subspace_dim * 4; let expected_size = 14 + num_segments * codebook_size;
521
522 if bytes.len() < expected_size {
523 return Err(VectorError::Encoding(format!(
524 "PQ bytes too short: expected {}, got {}",
525 expected_size,
526 bytes.len()
527 )));
528 }
529
530 let mut offset = 14;
531 let mut codebooks = Vec::with_capacity(num_segments);
532
533 for _ in 0..num_segments {
534 let mut codebook = Vec::with_capacity(num_centroids);
535 for _ in 0..num_centroids {
536 let mut centroid = Vec::with_capacity(subspace_dim);
537 for _ in 0..subspace_dim {
538 let val = f32::from_le_bytes([
539 bytes[offset],
540 bytes[offset + 1],
541 bytes[offset + 2],
542 bytes[offset + 3],
543 ]);
544 centroid.push(val);
545 offset += 4;
546 }
547 codebook.push(centroid);
548 }
549 codebooks.push(codebook);
550 }
551
552 Self::from_codebooks(&config, codebooks)
553 }
554}
555
556#[derive(Debug, Clone)]
560pub struct DistanceTable {
561 table: Vec<Vec<f32>>,
563 metric: DistanceMetric,
565}
566
567impl DistanceTable {
568 #[must_use]
570 #[inline]
571 pub fn get(&self, segment: usize, centroid_idx: usize) -> f32 {
572 self.table[segment][centroid_idx]
573 }
574
575 #[must_use]
577 pub fn num_segments(&self) -> usize {
578 self.table.len()
579 }
580
581 #[must_use]
583 pub fn num_centroids(&self) -> usize {
584 self.table.first().map_or(0, Vec::len)
585 }
586
587 #[must_use]
589 pub fn metric(&self) -> DistanceMetric {
590 self.metric
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597
598 fn generate_random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
599 let mut rng_state = seed;
600 (0..n)
601 .map(|_| {
602 (0..dim)
603 .map(|_| {
604 rng_state ^= rng_state << 13;
605 rng_state ^= rng_state >> 7;
606 rng_state ^= rng_state << 17;
607 (rng_state as f64 / u64::MAX as f64) as f32 * 2.0 - 1.0
608 })
609 .collect()
610 })
611 .collect()
612 }
613
614 #[test]
615 fn test_pq_code_roundtrip() {
616 let code = PQCode::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 8);
617 let bytes = code.to_bytes();
618 let restored = PQCode::from_bytes(&bytes, 8, 8);
619 assert_eq!(code, restored);
620 }
621
622 #[test]
623 fn test_pq_code_4bit_roundtrip() {
624 let code = PQCode::new(vec![1, 15, 8, 3], 4);
625 let bytes = code.to_bytes();
626 let restored = PQCode::from_bytes(&bytes, 4, 4);
627 assert_eq!(code, restored);
628 }
629
630 #[test]
631 fn test_pq_train_and_encode() {
632 let training_data = generate_random_vectors(100, 32, 42);
634 let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
635
636 let config = PQConfig::new(32, 4).with_num_centroids(16).with_seed(42);
637
638 let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
639
640 let vector = generate_random_vectors(1, 32, 123)[0].clone();
642 let code = pq.encode(&vector);
643
644 assert_eq!(code.num_segments(), 4);
645 for i in 0..4 {
646 assert!(code.get(i).unwrap() < 16);
647 }
648 }
649
650 #[test]
651 fn test_pq_decode() {
652 let training_data = generate_random_vectors(100, 32, 42);
653 let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
654
655 let config = PQConfig::new(32, 4).with_num_centroids(16).with_seed(42);
656 let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
657
658 let vector = generate_random_vectors(1, 32, 123)[0].clone();
659 let code = pq.encode(&vector);
660 let decoded = pq.decode(&code);
661
662 assert_eq!(decoded.len(), 32);
663 }
664
665 #[test]
666 fn test_asymmetric_distance() {
667 let training_data = generate_random_vectors(200, 64, 42);
668 let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
669
670 let config = PQConfig::new(64, 8).with_num_centroids(32).with_seed(42);
671 let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
672
673 let db_vectors = generate_random_vectors(50, 64, 100);
675 let codes: Vec<PQCode> = db_vectors.iter().map(|v| pq.encode(v)).collect();
676
677 let query = generate_random_vectors(1, 64, 200)[0].clone();
679 let table = pq.compute_distance_table(&query);
680
681 let approx_dists: Vec<f32> =
683 codes.iter().map(|c| pq.asymmetric_distance(&table, c)).collect();
684
685 for d in &approx_dists {
687 assert!(*d >= 0.0, "distance should be non-negative: {}", d);
688 }
689 }
690
691 #[test]
692 fn test_symmetric_distance() {
693 let training_data = generate_random_vectors(100, 32, 42);
694 let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
695
696 let config = PQConfig::new(32, 4).with_num_centroids(16).with_seed(42);
697 let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
698
699 let v1 = generate_random_vectors(1, 32, 100)[0].clone();
700 let v2 = generate_random_vectors(1, 32, 200)[0].clone();
701
702 let code1 = pq.encode(&v1);
703 let code2 = pq.encode(&v2);
704
705 let dist = pq.symmetric_distance(&code1, &code2);
706 assert!(dist >= 0.0);
707
708 let self_dist = pq.symmetric_distance(&code1, &code1);
710 assert!(self_dist < 1e-6);
711 }
712
713 #[test]
714 fn test_pq_serialization() {
715 let training_data = generate_random_vectors(100, 32, 42);
716 let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
717
718 let config = PQConfig::new(32, 4).with_num_centroids(16).with_seed(42);
719 let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
720
721 let bytes = pq.to_bytes();
722 let restored = ProductQuantizer::from_bytes(&bytes).unwrap();
723
724 assert_eq!(pq.config().dimension, restored.config().dimension);
726 assert_eq!(pq.config().num_segments, restored.config().num_segments);
727 assert_eq!(pq.config().num_centroids, restored.config().num_centroids);
728
729 for (seg, (orig, rest)) in
731 pq.codebooks().iter().zip(restored.codebooks().iter()).enumerate()
732 {
733 for (cent, (o, r)) in orig.iter().zip(rest.iter()).enumerate() {
734 for (dim, (&ov, &rv)) in o.iter().zip(r.iter()).enumerate() {
735 assert!(
736 (ov - rv).abs() < 1e-6,
737 "mismatch at seg={}, cent={}, dim={}: {} vs {}",
738 seg,
739 cent,
740 dim,
741 ov,
742 rv
743 );
744 }
745 }
746 }
747 }
748
749 #[test]
750 fn test_distance_approximation_quality() {
751 let training_data = generate_random_vectors(500, 64, 42);
753 let training_refs: Vec<&[f32]> = training_data.iter().map(|v| v.as_slice()).collect();
754
755 let config = PQConfig::new(64, 8).with_num_centroids(256).with_seed(42);
756 let pq = ProductQuantizer::train(&config, &training_refs).unwrap();
757
758 let query = generate_random_vectors(1, 64, 100)[0].clone();
760 let database = generate_random_vectors(100, 64, 200);
761
762 let table = pq.compute_distance_table(&query);
763
764 let mut correlations = Vec::new();
766 for db_vec in &database {
767 let true_dist: f32 =
768 query.iter().zip(db_vec.iter()).map(|(a, b)| (a - b) * (a - b)).sum::<f32>().sqrt();
769
770 let code = pq.encode(db_vec);
771 let approx_dist = pq.asymmetric_distance(&table, &code);
772
773 if true_dist > 0.1 {
775 let rel_error = (approx_dist - true_dist).abs() / true_dist;
776 correlations.push(rel_error);
777 }
778 }
779
780 let avg_error: f32 = correlations.iter().sum::<f32>() / correlations.len() as f32;
782 assert!(avg_error < 0.5, "average relative error too high: {}", avg_error);
783 }
784}