1use serde::{Deserialize, Serialize};
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
59pub enum QuantizationType {
60 #[default]
62 None,
63 Scalar,
65 Binary,
67 Product {
69 num_subvectors: usize,
71 },
72}
73
74impl QuantizationType {
75 #[must_use]
80 pub fn compression_ratio(&self, dimensions: usize) -> usize {
81 match self {
82 Self::None => 1,
83 Self::Scalar => 4, Self::Binary => 32, Self::Product { num_subvectors } => {
86 let m = (*num_subvectors).max(1);
90 (dimensions * 4) / m
91 }
92 }
93 }
94
95 #[must_use]
97 pub fn name(&self) -> &'static str {
98 match self {
99 Self::None => "none",
100 Self::Scalar => "scalar",
101 Self::Binary => "binary",
102 Self::Product { .. } => "product",
103 }
104 }
105
106 #[must_use]
108 pub fn from_str(s: &str) -> Option<Self> {
109 match s.to_lowercase().as_str() {
110 "none" | "full" | "f32" => Some(Self::None),
111 "scalar" | "sq" | "u8" | "int8" => Some(Self::Scalar),
112 "binary" | "bin" | "bit" | "1bit" => Some(Self::Binary),
113 "product" | "pq" => Some(Self::Product { num_subvectors: 8 }),
114 s if s.starts_with("pq") => {
115 s[2..]
117 .parse()
118 .ok()
119 .map(|n| Self::Product { num_subvectors: n })
120 }
121 _ => None,
122 }
123 }
124
125 #[must_use]
127 pub const fn requires_training(&self) -> bool {
128 matches!(self, Self::Scalar | Self::Product { .. })
129 }
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct ScalarQuantizer {
169 min: Vec<f32>,
171 scale: Vec<f32>,
173 inv_scale: Vec<f32>,
175 dimensions: usize,
177}
178
179impl ScalarQuantizer {
180 #[must_use]
193 pub fn train(vectors: &[&[f32]]) -> Self {
194 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
195
196 let dimensions = vectors[0].len();
197 assert!(
198 vectors.iter().all(|v| v.len() == dimensions),
199 "All training vectors must have the same dimensions"
200 );
201
202 let mut min = vec![f32::INFINITY; dimensions];
204 let mut max = vec![f32::NEG_INFINITY; dimensions];
205
206 for vec in vectors {
207 for (i, &v) in vec.iter().enumerate() {
208 min[i] = min[i].min(v);
209 max[i] = max[i].max(v);
210 }
211 }
212
213 let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
215 .iter()
216 .zip(&max)
217 .map(|(&mn, &mx)| {
218 let range = mx - mn;
219 if range.abs() < f32::EPSILON {
220 (1.0, 1.0)
222 } else {
223 (255.0 / range, range / 255.0)
224 }
225 })
226 .unzip();
227
228 Self {
229 min,
230 scale,
231 inv_scale,
232 dimensions,
233 }
234 }
235
236 #[must_use]
238 pub fn with_ranges(min: Vec<f32>, max: Vec<f32>) -> Self {
239 let dimensions = min.len();
240 assert_eq!(min.len(), max.len(), "Min and max must have same length");
241
242 let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
243 .iter()
244 .zip(&max)
245 .map(|(&mn, &mx)| {
246 let range = mx - mn;
247 if range.abs() < f32::EPSILON {
248 (1.0, 1.0)
249 } else {
250 (255.0 / range, range / 255.0)
251 }
252 })
253 .unzip();
254
255 Self {
256 min,
257 scale,
258 inv_scale,
259 dimensions,
260 }
261 }
262
263 #[must_use]
265 pub fn dimensions(&self) -> usize {
266 self.dimensions
267 }
268
269 #[must_use]
271 pub fn min_values(&self) -> &[f32] {
272 &self.min
273 }
274
275 #[must_use]
279 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
280 debug_assert_eq!(
281 vector.len(),
282 self.dimensions,
283 "Vector dimension mismatch: expected {}, got {}",
284 self.dimensions,
285 vector.len()
286 );
287
288 vector
289 .iter()
290 .enumerate()
291 .map(|(i, &v)| {
292 let normalized = (v - self.min[i]) * self.scale[i];
293 normalized.clamp(0.0, 255.0) as u8
294 })
295 .collect()
296 }
297
298 #[must_use]
300 pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
301 vectors.iter().map(|v| self.quantize(v)).collect()
302 }
303
304 #[must_use]
306 pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
307 debug_assert_eq!(quantized.len(), self.dimensions);
308
309 quantized
310 .iter()
311 .enumerate()
312 .map(|(i, &q)| self.min[i] + (q as f32) * self.inv_scale[i])
313 .collect()
314 }
315
316 #[must_use]
321 pub fn distance_squared_u8(&self, a: &[u8], b: &[u8]) -> f32 {
322 debug_assert_eq!(a.len(), self.dimensions);
323 debug_assert_eq!(b.len(), self.dimensions);
324
325 let mut sum = 0.0f32;
327 for i in 0..a.len() {
328 let diff = (a[i] as f32) - (b[i] as f32);
329 sum += diff * diff * self.inv_scale[i] * self.inv_scale[i];
330 }
331 sum
332 }
333
334 #[must_use]
336 #[inline]
337 pub fn distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
338 self.distance_squared_u8(a, b).sqrt()
339 }
340
341 #[must_use]
345 pub fn cosine_distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
346 debug_assert_eq!(a.len(), self.dimensions);
347 debug_assert_eq!(b.len(), self.dimensions);
348
349 let mut dot = 0.0f32;
350 let mut norm_a = 0.0f32;
351 let mut norm_b = 0.0f32;
352
353 for i in 0..a.len() {
354 let va = self.min[i] + (a[i] as f32) * self.inv_scale[i];
356 let vb = self.min[i] + (b[i] as f32) * self.inv_scale[i];
357
358 dot += va * vb;
359 norm_a += va * va;
360 norm_b += vb * vb;
361 }
362
363 let denom = (norm_a * norm_b).sqrt();
364 if denom < f32::EPSILON {
365 1.0 } else {
367 1.0 - (dot / denom)
368 }
369 }
370
371 #[must_use]
375 pub fn asymmetric_distance_squared(&self, query: &[f32], quantized: &[u8]) -> f32 {
376 debug_assert_eq!(query.len(), self.dimensions);
377 debug_assert_eq!(quantized.len(), self.dimensions);
378
379 let mut sum = 0.0f32;
380 for i in 0..query.len() {
381 let dequant = self.min[i] + (quantized[i] as f32) * self.inv_scale[i];
383 let diff = query[i] - dequant;
384 sum += diff * diff;
385 }
386 sum
387 }
388
389 #[must_use]
391 #[inline]
392 pub fn asymmetric_distance(&self, query: &[f32], quantized: &[u8]) -> f32 {
393 self.asymmetric_distance_squared(query, quantized).sqrt()
394 }
395}
396
397pub struct BinaryQuantizer;
422
423impl BinaryQuantizer {
424 #[must_use]
429 pub fn quantize(vector: &[f32]) -> Vec<u64> {
430 let num_words = (vector.len() + 63) / 64;
431 let mut result = vec![0u64; num_words];
432
433 for (i, &v) in vector.iter().enumerate() {
434 if v >= 0.0 {
435 result[i / 64] |= 1u64 << (i % 64);
436 }
437 }
438
439 result
440 }
441
442 #[must_use]
444 pub fn quantize_batch(vectors: &[&[f32]]) -> Vec<Vec<u64>> {
445 vectors.iter().map(|v| Self::quantize(v)).collect()
446 }
447
448 #[must_use]
452 pub fn hamming_distance(a: &[u64], b: &[u64]) -> u32 {
453 debug_assert_eq!(a.len(), b.len(), "Binary vectors must have same length");
454
455 a.iter().zip(b).map(|(&x, &y)| (x ^ y).count_ones()).sum()
456 }
457
458 #[must_use]
462 pub fn hamming_distance_normalized(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
463 let hamming = Self::hamming_distance(a, b);
464 hamming as f32 / dimensions as f32
465 }
466
467 #[must_use]
472 pub fn approximate_euclidean(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
473 let hamming = Self::hamming_distance(a, b);
474 (2.0 * hamming as f32 / dimensions as f32).sqrt()
476 }
477
478 #[must_use]
480 pub const fn words_needed(dimensions: usize) -> usize {
481 (dimensions + 63) / 64
482 }
483
484 #[must_use]
486 pub const fn bytes_needed(dimensions: usize) -> usize {
487 Self::words_needed(dimensions) * 8
488 }
489}
490
491#[derive(Debug, Clone, Serialize, Deserialize)]
532pub struct ProductQuantizer {
533 num_subvectors: usize,
535 num_centroids: usize,
537 subvector_dim: usize,
539 dimensions: usize,
541 centroids: Vec<f32>,
543}
544
545impl ProductQuantizer {
546 #[must_use]
560 pub fn train(
561 vectors: &[&[f32]],
562 num_subvectors: usize,
563 num_centroids: usize,
564 iterations: usize,
565 ) -> Self {
566 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
567 assert!(
568 num_centroids <= 256,
569 "num_centroids must be <= 256 for u8 codes"
570 );
571 assert!(num_subvectors > 0, "num_subvectors must be > 0");
572
573 let dimensions = vectors[0].len();
574 assert!(
575 dimensions.is_multiple_of(num_subvectors),
576 "dimensions ({dimensions}) must be divisible by num_subvectors ({num_subvectors})"
577 );
578 assert!(
579 vectors.iter().all(|v| v.len() == dimensions),
580 "All training vectors must have the same dimensions"
581 );
582
583 let subvector_dim = dimensions / num_subvectors;
584
585 let mut centroids = Vec::with_capacity(num_subvectors * num_centroids * subvector_dim);
587
588 for m in 0..num_subvectors {
589 let subvectors: Vec<Vec<f32>> = vectors
591 .iter()
592 .map(|v| {
593 let start = m * subvector_dim;
594 let end = start + subvector_dim;
595 v[start..end].to_vec()
596 })
597 .collect();
598
599 let partition_centroids =
601 Self::kmeans(&subvectors, num_centroids, subvector_dim, iterations);
602
603 centroids.extend(partition_centroids);
604 }
605
606 Self {
607 num_subvectors,
608 num_centroids,
609 subvector_dim,
610 dimensions,
611 centroids,
612 }
613 }
614
615 fn kmeans(vectors: &[Vec<f32>], k: usize, dims: usize, iterations: usize) -> Vec<f32> {
617 let n = vectors.len();
618
619 let actual_k = k.min(n);
621 let mut centroids: Vec<f32> = if actual_k == n {
622 vectors.iter().flat_map(|v| v.iter().copied()).collect()
623 } else {
624 let step = n / actual_k;
626 (0..actual_k)
627 .flat_map(|i| vectors[i * step].iter().copied())
628 .collect()
629 };
630
631 if actual_k < k {
633 centroids.resize(k * dims, 0.0);
634 }
635
636 let mut assignments = vec![0usize; n];
637 let mut counts = vec![0usize; k];
638
639 for _ in 0..iterations {
640 for (i, vec) in vectors.iter().enumerate() {
642 let mut best_dist = f32::INFINITY;
643 let mut best_k = 0;
644
645 for j in 0..k {
646 let centroid_start = j * dims;
647 let dist: f32 = vec
648 .iter()
649 .enumerate()
650 .map(|(d, &v)| {
651 let diff = v - centroids[centroid_start + d];
652 diff * diff
653 })
654 .sum();
655
656 if dist < best_dist {
657 best_dist = dist;
658 best_k = j;
659 }
660 }
661
662 assignments[i] = best_k;
663 }
664
665 centroids.fill(0.0);
667 counts.fill(0);
668
669 for (i, vec) in vectors.iter().enumerate() {
670 let k_idx = assignments[i];
671 let centroid_start = k_idx * dims;
672 counts[k_idx] += 1;
673
674 for (d, &v) in vec.iter().enumerate() {
675 centroids[centroid_start + d] += v;
676 }
677 }
678
679 for j in 0..k {
681 if counts[j] > 0 {
682 let centroid_start = j * dims;
683 let count = counts[j] as f32;
684 for d in 0..dims {
685 centroids[centroid_start + d] /= count;
686 }
687 }
688 }
689 }
690
691 centroids
692 }
693
694 #[must_use]
696 pub fn with_centroids(
697 num_subvectors: usize,
698 num_centroids: usize,
699 dimensions: usize,
700 centroids: Vec<f32>,
701 ) -> Self {
702 let subvector_dim = dimensions / num_subvectors;
703 assert_eq!(
704 centroids.len(),
705 num_subvectors * num_centroids * subvector_dim,
706 "Invalid centroid count"
707 );
708
709 Self {
710 num_subvectors,
711 num_centroids,
712 subvector_dim,
713 dimensions,
714 centroids,
715 }
716 }
717
718 #[must_use]
720 pub fn num_subvectors(&self) -> usize {
721 self.num_subvectors
722 }
723
724 #[must_use]
726 pub fn num_centroids(&self) -> usize {
727 self.num_centroids
728 }
729
730 #[must_use]
732 pub fn dimensions(&self) -> usize {
733 self.dimensions
734 }
735
736 #[must_use]
738 pub fn subvector_dim(&self) -> usize {
739 self.subvector_dim
740 }
741
742 #[must_use]
744 pub fn code_size(&self) -> usize {
745 self.num_subvectors }
747
748 #[must_use]
750 pub fn compression_ratio(&self) -> usize {
751 (self.dimensions * 4) / self.num_subvectors
754 }
755
756 #[must_use]
760 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
761 debug_assert_eq!(
762 vector.len(),
763 self.dimensions,
764 "Vector dimension mismatch: expected {}, got {}",
765 self.dimensions,
766 vector.len()
767 );
768
769 let mut codes = Vec::with_capacity(self.num_subvectors);
770
771 for m in 0..self.num_subvectors {
772 let subvec_start = m * self.subvector_dim;
773 let subvec = &vector[subvec_start..subvec_start + self.subvector_dim];
774
775 let mut best_dist = f32::INFINITY;
777 let mut best_k = 0u8;
778
779 for k in 0..self.num_centroids {
780 let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
781 let dist: f32 = subvec
782 .iter()
783 .enumerate()
784 .map(|(d, &v)| {
785 let diff = v - self.centroids[centroid_start + d];
786 diff * diff
787 })
788 .sum();
789
790 if dist < best_dist {
791 best_dist = dist;
792 best_k = k as u8;
793 }
794 }
795
796 codes.push(best_k);
797 }
798
799 codes
800 }
801
802 #[must_use]
804 pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
805 vectors.iter().map(|v| self.quantize(v)).collect()
806 }
807
808 #[must_use]
814 pub fn build_distance_table(&self, query: &[f32]) -> Vec<f32> {
815 debug_assert_eq!(query.len(), self.dimensions);
816
817 let mut table = Vec::with_capacity(self.num_subvectors * self.num_centroids);
818
819 for m in 0..self.num_subvectors {
820 let query_start = m * self.subvector_dim;
821 let query_subvec = &query[query_start..query_start + self.subvector_dim];
822
823 for k in 0..self.num_centroids {
824 let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
825
826 let dist: f32 = query_subvec
827 .iter()
828 .enumerate()
829 .map(|(d, &v)| {
830 let diff = v - self.centroids[centroid_start + d];
831 diff * diff
832 })
833 .sum();
834
835 table.push(dist);
836 }
837 }
838
839 table
840 }
841
842 #[must_use]
846 #[inline]
847 pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
848 debug_assert_eq!(codes.len(), self.num_subvectors);
849 debug_assert_eq!(table.len(), self.num_subvectors * self.num_centroids);
850
851 codes
852 .iter()
853 .enumerate()
854 .map(|(m, &code)| table[m * self.num_centroids + code as usize])
855 .sum()
856 }
857
858 #[must_use]
863 pub fn asymmetric_distance_squared(&self, query: &[f32], codes: &[u8]) -> f32 {
864 let table = self.build_distance_table(query);
865 self.distance_with_table(&table, codes)
866 }
867
868 #[must_use]
870 #[inline]
871 pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
872 self.asymmetric_distance_squared(query, codes).sqrt()
873 }
874
875 #[must_use]
879 pub fn reconstruct(&self, codes: &[u8]) -> Vec<f32> {
880 debug_assert_eq!(codes.len(), self.num_subvectors);
881
882 let mut result = Vec::with_capacity(self.dimensions);
883
884 for (m, &code) in codes.iter().enumerate() {
885 let centroid_start = (m * self.num_centroids + code as usize) * self.subvector_dim;
886 result.extend_from_slice(
887 &self.centroids[centroid_start..centroid_start + self.subvector_dim],
888 );
889 }
890
891 result
892 }
893
894 #[must_use]
896 pub fn get_partition_centroids(&self, partition: usize) -> Vec<&[f32]> {
897 assert!(partition < self.num_subvectors);
898
899 (0..self.num_centroids)
900 .map(|k| {
901 let start = (partition * self.num_centroids + k) * self.subvector_dim;
902 &self.centroids[start..start + self.subvector_dim]
903 })
904 .collect()
905 }
906}
907
908#[cfg(target_arch = "x86_64")]
917#[must_use]
918pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
919 a.iter()
921 .zip(b)
922 .map(|(&x, &y)| {
923 let xor = x ^ y;
924 #[allow(unsafe_code)]
927 unsafe {
928 std::arch::x86_64::_popcnt64(xor as i64) as u32
929 }
930 })
931 .sum()
932}
933
934#[cfg(not(target_arch = "x86_64"))]
936#[must_use]
937pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
938 BinaryQuantizer::hamming_distance(a, b)
939}
940
941#[cfg(test)]
946mod tests {
947 use super::*;
948
949 #[test]
950 fn test_quantization_type_compression_ratio() {
951 let dims = 384;
953 assert_eq!(QuantizationType::None.compression_ratio(dims), 1);
954 assert_eq!(QuantizationType::Scalar.compression_ratio(dims), 4);
955 assert_eq!(QuantizationType::Binary.compression_ratio(dims), 32);
956
957 let pq8 = QuantizationType::Product { num_subvectors: 8 };
959 assert_eq!(pq8.compression_ratio(dims), 192);
960
961 let pq16 = QuantizationType::Product { num_subvectors: 16 };
963 assert_eq!(pq16.compression_ratio(dims), 96);
964 }
965
966 #[test]
967 fn test_quantization_type_from_str() {
968 assert_eq!(
969 QuantizationType::from_str("none"),
970 Some(QuantizationType::None)
971 );
972 assert_eq!(
973 QuantizationType::from_str("scalar"),
974 Some(QuantizationType::Scalar)
975 );
976 assert_eq!(
977 QuantizationType::from_str("SQ"),
978 Some(QuantizationType::Scalar)
979 );
980 assert_eq!(
981 QuantizationType::from_str("binary"),
982 Some(QuantizationType::Binary)
983 );
984 assert_eq!(
985 QuantizationType::from_str("bit"),
986 Some(QuantizationType::Binary)
987 );
988 assert_eq!(QuantizationType::from_str("invalid"), None);
989 }
990
991 #[test]
996 fn test_scalar_quantizer_train() {
997 let vectors = [
998 vec![0.0f32, 0.5, 1.0],
999 vec![0.2, 0.3, 0.8],
1000 vec![0.1, 0.6, 0.9],
1001 ];
1002 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1003
1004 let quantizer = ScalarQuantizer::train(&refs);
1005
1006 assert_eq!(quantizer.dimensions(), 3);
1007 assert_eq!(quantizer.min_values()[0], 0.0);
1008 assert_eq!(quantizer.min_values()[1], 0.3);
1009 assert_eq!(quantizer.min_values()[2], 0.8);
1010 }
1011
1012 #[test]
1013 fn test_scalar_quantizer_quantize() {
1014 let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1015
1016 let q_min = quantizer.quantize(&[0.0, 0.0]);
1018 assert_eq!(q_min, vec![0, 0]);
1019
1020 let q_max = quantizer.quantize(&[1.0, 1.0]);
1022 assert_eq!(q_max, vec![255, 255]);
1023
1024 let q_mid = quantizer.quantize(&[0.5, 0.5]);
1026 assert!(q_mid[0] >= 126 && q_mid[0] <= 128);
1027 }
1028
1029 #[test]
1030 fn test_scalar_quantizer_dequantize() {
1031 let quantizer = ScalarQuantizer::with_ranges(vec![0.0], vec![1.0]);
1032
1033 let original = [0.5f32];
1034 let quantized = quantizer.quantize(&original);
1035 let dequantized = quantizer.dequantize(&quantized);
1036
1037 assert!((original[0] - dequantized[0]).abs() < 0.01);
1039 }
1040
1041 #[test]
1042 fn test_scalar_quantizer_distance() {
1043 let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1044
1045 let a = quantizer.quantize(&[0.0, 0.0]);
1046 let b = quantizer.quantize(&[1.0, 0.0]);
1047
1048 let dist = quantizer.distance_u8(&a, &b);
1049 assert!((dist - 1.0).abs() < 0.1);
1051 }
1052
1053 #[test]
1054 fn test_scalar_quantizer_asymmetric_distance() {
1055 let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1056
1057 let query = [0.0f32, 0.0];
1058 let stored = quantizer.quantize(&[1.0, 0.0]);
1059
1060 let dist = quantizer.asymmetric_distance(&query, &stored);
1061 assert!((dist - 1.0).abs() < 0.1);
1062 }
1063
1064 #[test]
1065 fn test_scalar_quantizer_cosine_distance() {
1066 let quantizer = ScalarQuantizer::with_ranges(vec![-1.0, -1.0], vec![1.0, 1.0]);
1067
1068 let a = quantizer.quantize(&[1.0, 0.0]);
1070 let b = quantizer.quantize(&[0.0, 1.0]);
1071
1072 let dist = quantizer.cosine_distance_u8(&a, &b);
1073 assert!((dist - 1.0).abs() < 0.1);
1075 }
1076
1077 #[test]
1078 #[should_panic(expected = "Cannot train on empty vector set")]
1079 fn test_scalar_quantizer_empty_training() {
1080 let vectors: Vec<&[f32]> = vec![];
1081 let _ = ScalarQuantizer::train(&vectors);
1082 }
1083
1084 #[test]
1089 fn test_binary_quantizer_quantize() {
1090 let v = vec![0.5f32, -0.3, 0.0, 0.8];
1091 let bits = BinaryQuantizer::quantize(&v);
1092
1093 assert_eq!(bits.len(), 1); assert_eq!(bits[0] & 0xF, 0b1101);
1098 }
1099
1100 #[test]
1101 fn test_binary_quantizer_hamming_distance() {
1102 let v1 = vec![1.0f32, 1.0, 1.0, 1.0]; let v2 = vec![1.0f32, -1.0, 1.0, -1.0]; let bits1 = BinaryQuantizer::quantize(&v1);
1106 let bits2 = BinaryQuantizer::quantize(&v2);
1107
1108 let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1109 assert_eq!(dist, 2); }
1111
1112 #[test]
1113 fn test_binary_quantizer_identical_vectors() {
1114 let v = vec![0.1f32, -0.2, 0.3, -0.4, 0.5];
1115 let bits = BinaryQuantizer::quantize(&v);
1116
1117 let dist = BinaryQuantizer::hamming_distance(&bits, &bits);
1118 assert_eq!(dist, 0);
1119 }
1120
1121 #[test]
1122 fn test_binary_quantizer_opposite_vectors() {
1123 let v1 = vec![1.0f32; 64];
1124 let v2 = vec![-1.0f32; 64];
1125
1126 let bits1 = BinaryQuantizer::quantize(&v1);
1127 let bits2 = BinaryQuantizer::quantize(&v2);
1128
1129 let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1130 assert_eq!(dist, 64); }
1132
1133 #[test]
1134 fn test_binary_quantizer_large_vector() {
1135 let v: Vec<f32> = (0..1000)
1136 .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
1137 .collect();
1138 let bits = BinaryQuantizer::quantize(&v);
1139
1140 assert_eq!(bits.len(), 16);
1142 }
1143
1144 #[test]
1145 fn test_binary_quantizer_normalized_distance() {
1146 let v1 = vec![1.0f32; 100];
1147 let v2 = vec![-1.0f32; 100];
1148
1149 let bits1 = BinaryQuantizer::quantize(&v1);
1150 let bits2 = BinaryQuantizer::quantize(&v2);
1151
1152 let norm_dist = BinaryQuantizer::hamming_distance_normalized(&bits1, &bits2, 100);
1153 assert!((norm_dist - 1.0).abs() < 0.01); }
1155
1156 #[test]
1157 fn test_binary_quantizer_words_needed() {
1158 assert_eq!(BinaryQuantizer::words_needed(1), 1);
1159 assert_eq!(BinaryQuantizer::words_needed(64), 1);
1160 assert_eq!(BinaryQuantizer::words_needed(65), 2);
1161 assert_eq!(BinaryQuantizer::words_needed(128), 2);
1162 assert_eq!(BinaryQuantizer::words_needed(1536), 24); }
1164
1165 #[test]
1166 fn test_binary_quantizer_bytes_needed() {
1167 assert_eq!(BinaryQuantizer::bytes_needed(64), 8);
1169 assert_eq!(BinaryQuantizer::bytes_needed(128), 16);
1170 assert_eq!(BinaryQuantizer::bytes_needed(1536), 192); }
1172
1173 #[test]
1178 fn test_hamming_distance_simd() {
1179 let a = vec![0xFFFF_FFFF_FFFF_FFFFu64, 0x0000_0000_0000_0000];
1180 let b = vec![0x0000_0000_0000_0000u64, 0xFFFF_FFFF_FFFF_FFFF];
1181
1182 let dist = hamming_distance_simd(&a, &b);
1183 assert_eq!(dist, 128); }
1185
1186 #[test]
1191 fn test_quantization_type_product_from_str() {
1192 assert_eq!(
1194 QuantizationType::from_str("pq"),
1195 Some(QuantizationType::Product { num_subvectors: 8 })
1196 );
1197 assert_eq!(
1198 QuantizationType::from_str("product"),
1199 Some(QuantizationType::Product { num_subvectors: 8 })
1200 );
1201
1202 assert_eq!(
1204 QuantizationType::from_str("pq8"),
1205 Some(QuantizationType::Product { num_subvectors: 8 })
1206 );
1207 assert_eq!(
1208 QuantizationType::from_str("pq16"),
1209 Some(QuantizationType::Product { num_subvectors: 16 })
1210 );
1211 assert_eq!(
1212 QuantizationType::from_str("pq32"),
1213 Some(QuantizationType::Product { num_subvectors: 32 })
1214 );
1215 }
1216
1217 #[test]
1218 fn test_quantization_type_requires_training() {
1219 assert!(!QuantizationType::None.requires_training());
1220 assert!(QuantizationType::Scalar.requires_training());
1221 assert!(!QuantizationType::Binary.requires_training());
1222 assert!(QuantizationType::Product { num_subvectors: 8 }.requires_training());
1223 }
1224
1225 #[test]
1226 fn test_product_quantizer_train() {
1227 let vectors: Vec<Vec<f32>> = (0..100)
1229 .map(|i| (0..16).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1230 .collect();
1231 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1232
1233 let pq = ProductQuantizer::train(&refs, 4, 8, 5);
1235
1236 assert_eq!(pq.num_subvectors(), 4);
1237 assert_eq!(pq.num_centroids(), 8);
1238 assert_eq!(pq.dimensions(), 16);
1239 assert_eq!(pq.subvector_dim(), 4);
1240 assert_eq!(pq.code_size(), 4);
1241 }
1242
1243 #[test]
1244 fn test_product_quantizer_quantize() {
1245 let vectors: Vec<Vec<f32>> = (0..50)
1246 .map(|i| (0..8).map(|j| ((i * j) as f32 * 0.1).cos()).collect())
1247 .collect();
1248 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1249
1250 let pq = ProductQuantizer::train(&refs, 2, 16, 3);
1251
1252 let codes = pq.quantize(&vectors[0]);
1254 assert_eq!(codes.len(), 2);
1255
1256 for &code in &codes {
1258 assert!(code < 16);
1259 }
1260 }
1261
1262 #[test]
1263 fn test_product_quantizer_reconstruct() {
1264 let vectors: Vec<Vec<f32>> = (0..50)
1265 .map(|i| (0..12).map(|j| (i + j) as f32 * 0.05).collect())
1266 .collect();
1267 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1268
1269 let pq = ProductQuantizer::train(&refs, 3, 8, 5);
1270
1271 let original = &vectors[10];
1273 let codes = pq.quantize(original);
1274 let reconstructed = pq.reconstruct(&codes);
1275
1276 assert_eq!(reconstructed.len(), 12);
1277
1278 let error: f32 = original
1280 .iter()
1281 .zip(&reconstructed)
1282 .map(|(a, b)| (a - b).powi(2))
1283 .sum::<f32>()
1284 .sqrt();
1285
1286 assert!(error < 2.0, "Reconstruction error too high: {error}");
1288 }
1289
1290 #[test]
1291 fn test_product_quantizer_asymmetric_distance() {
1292 let vectors: Vec<Vec<f32>> = (0..100)
1293 .map(|i| (0..32).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1294 .collect();
1295 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1296
1297 let pq = ProductQuantizer::train(&refs, 8, 32, 5);
1298
1299 let query = &vectors[0];
1301 let codes = pq.quantize(query);
1302 let self_dist = pq.asymmetric_distance(query, &codes);
1303 assert!(self_dist < 1.0, "Self-distance too high: {self_dist}");
1304
1305 let other_codes = pq.quantize(&vectors[50]);
1307 let other_dist = pq.asymmetric_distance(query, &other_codes);
1308 assert!(other_dist > self_dist, "Other vector should be farther");
1309 }
1310
1311 #[test]
1312 fn test_product_quantizer_distance_table() {
1313 let vectors: Vec<Vec<f32>> = (0..50)
1314 .map(|i| (0..16).map(|j| (i + j) as f32 * 0.02).collect())
1315 .collect();
1316 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1317
1318 let pq = ProductQuantizer::train(&refs, 4, 8, 3);
1319
1320 let query = &vectors[0];
1321 let table = pq.build_distance_table(query);
1322
1323 assert_eq!(table.len(), 4 * 8);
1325
1326 let codes = pq.quantize(&vectors[5]);
1328 let dist_direct = pq.asymmetric_distance_squared(query, &codes);
1329 let dist_table = pq.distance_with_table(&table, &codes);
1330
1331 assert!((dist_direct - dist_table).abs() < 0.001);
1332 }
1333
1334 #[test]
1335 fn test_product_quantizer_batch() {
1336 let vectors: Vec<Vec<f32>> = (0..20)
1337 .map(|i| (0..8).map(|j| (i + j) as f32).collect())
1338 .collect();
1339 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1340
1341 let pq = ProductQuantizer::train(&refs, 2, 4, 2);
1342
1343 let batch_codes = pq.quantize_batch(&refs[0..5]);
1344 assert_eq!(batch_codes.len(), 5);
1345
1346 for codes in &batch_codes {
1347 assert_eq!(codes.len(), 2);
1348 }
1349 }
1350
1351 #[test]
1352 fn test_product_quantizer_compression_ratio() {
1353 let vectors: Vec<Vec<f32>> = (0..50)
1354 .map(|i| (0..384).map(|j| ((i * j) as f32).sin()).collect())
1355 .collect();
1356 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1357
1358 let pq8 = ProductQuantizer::train(&refs, 8, 256, 3);
1360 assert_eq!(pq8.compression_ratio(), 192); let pq48 = ProductQuantizer::train(&refs, 48, 256, 3);
1364 assert_eq!(pq48.compression_ratio(), 32); }
1366
1367 #[test]
1368 #[should_panic(expected = "dimensions (15) must be divisible by num_subvectors (4)")]
1369 fn test_product_quantizer_invalid_dimensions() {
1370 let vectors: Vec<Vec<f32>> = (0..10)
1371 .map(|i| (0..15).map(|j| (i + j) as f32).collect())
1372 .collect();
1373 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1374
1375 let _ = ProductQuantizer::train(&refs, 4, 8, 3);
1377 }
1378
1379 #[test]
1380 fn test_product_quantizer_get_partition_centroids() {
1381 let vectors: Vec<Vec<f32>> = (0..30)
1382 .map(|i| (0..8).map(|j| (i + j) as f32 * 0.1).collect())
1383 .collect();
1384 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1385
1386 let pq = ProductQuantizer::train(&refs, 2, 4, 3);
1387
1388 let centroids = pq.get_partition_centroids(0);
1390 assert_eq!(centroids.len(), 4); assert_eq!(centroids[0].len(), 4); }
1393}