1use serde::{Deserialize, Serialize};
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
50pub enum QuantizationType {
51 #[default]
53 None,
54 Scalar,
56 Binary,
58 Product {
60 num_subvectors: usize,
62 },
63}
64
65impl QuantizationType {
66 #[must_use]
71 pub fn compression_ratio(&self, dimensions: usize) -> usize {
72 match self {
73 Self::None => 1,
74 Self::Scalar => 4, Self::Binary => 32, Self::Product { num_subvectors } => {
77 let m = (*num_subvectors).max(1);
81 (dimensions * 4) / m
82 }
83 }
84 }
85
86 #[must_use]
88 pub fn name(&self) -> &'static str {
89 match self {
90 Self::None => "none",
91 Self::Scalar => "scalar",
92 Self::Binary => "binary",
93 Self::Product { .. } => "product",
94 }
95 }
96
97 #[must_use]
99 pub fn from_str(s: &str) -> Option<Self> {
100 match s.to_lowercase().as_str() {
101 "none" | "full" | "f32" => Some(Self::None),
102 "scalar" | "sq" | "u8" | "int8" => Some(Self::Scalar),
103 "binary" | "bin" | "bit" | "1bit" => Some(Self::Binary),
104 "product" | "pq" => Some(Self::Product { num_subvectors: 8 }),
105 s if s.starts_with("pq") => {
106 s[2..]
108 .parse()
109 .ok()
110 .map(|n| Self::Product { num_subvectors: n })
111 }
112 _ => None,
113 }
114 }
115
116 #[must_use]
118 pub const fn requires_training(&self) -> bool {
119 matches!(self, Self::Scalar | Self::Product { .. })
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ScalarQuantizer {
160 min: Vec<f32>,
162 scale: Vec<f32>,
164 inv_scale: Vec<f32>,
166 dimensions: usize,
168}
169
170impl ScalarQuantizer {
171 #[must_use]
184 pub fn train(vectors: &[&[f32]]) -> Self {
185 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
186
187 let dimensions = vectors[0].len();
188 assert!(
189 vectors.iter().all(|v| v.len() == dimensions),
190 "All training vectors must have the same dimensions"
191 );
192
193 let mut min = vec![f32::INFINITY; dimensions];
195 let mut max = vec![f32::NEG_INFINITY; dimensions];
196
197 for vec in vectors {
198 for (i, &v) in vec.iter().enumerate() {
199 min[i] = min[i].min(v);
200 max[i] = max[i].max(v);
201 }
202 }
203
204 let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
206 .iter()
207 .zip(&max)
208 .map(|(&mn, &mx)| {
209 let range = mx - mn;
210 if range.abs() < f32::EPSILON {
211 (1.0, 1.0)
213 } else {
214 (255.0 / range, range / 255.0)
215 }
216 })
217 .unzip();
218
219 Self {
220 min,
221 scale,
222 inv_scale,
223 dimensions,
224 }
225 }
226
227 #[must_use]
229 pub fn with_ranges(min: Vec<f32>, max: Vec<f32>) -> Self {
230 let dimensions = min.len();
231 assert_eq!(min.len(), max.len(), "Min and max must have same length");
232
233 let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
234 .iter()
235 .zip(&max)
236 .map(|(&mn, &mx)| {
237 let range = mx - mn;
238 if range.abs() < f32::EPSILON {
239 (1.0, 1.0)
240 } else {
241 (255.0 / range, range / 255.0)
242 }
243 })
244 .unzip();
245
246 Self {
247 min,
248 scale,
249 inv_scale,
250 dimensions,
251 }
252 }
253
254 #[must_use]
256 pub fn dimensions(&self) -> usize {
257 self.dimensions
258 }
259
260 #[must_use]
262 pub fn min_values(&self) -> &[f32] {
263 &self.min
264 }
265
266 #[must_use]
270 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
271 debug_assert_eq!(
272 vector.len(),
273 self.dimensions,
274 "Vector dimension mismatch: expected {}, got {}",
275 self.dimensions,
276 vector.len()
277 );
278
279 vector
280 .iter()
281 .enumerate()
282 .map(|(i, &v)| {
283 let normalized = (v - self.min[i]) * self.scale[i];
284 normalized.clamp(0.0, 255.0) as u8
285 })
286 .collect()
287 }
288
289 #[must_use]
291 pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
292 vectors.iter().map(|v| self.quantize(v)).collect()
293 }
294
295 #[must_use]
297 pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
298 debug_assert_eq!(quantized.len(), self.dimensions);
299
300 quantized
301 .iter()
302 .enumerate()
303 .map(|(i, &q)| self.min[i] + (q as f32) * self.inv_scale[i])
304 .collect()
305 }
306
307 #[must_use]
312 pub fn distance_squared_u8(&self, a: &[u8], b: &[u8]) -> f32 {
313 debug_assert_eq!(a.len(), self.dimensions);
314 debug_assert_eq!(b.len(), self.dimensions);
315
316 let mut sum = 0.0f32;
318 for i in 0..a.len() {
319 let diff = (a[i] as f32) - (b[i] as f32);
320 sum += diff * diff * self.inv_scale[i] * self.inv_scale[i];
321 }
322 sum
323 }
324
325 #[must_use]
327 #[inline]
328 pub fn distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
329 self.distance_squared_u8(a, b).sqrt()
330 }
331
332 #[must_use]
336 pub fn cosine_distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
337 debug_assert_eq!(a.len(), self.dimensions);
338 debug_assert_eq!(b.len(), self.dimensions);
339
340 let mut dot = 0.0f32;
341 let mut norm_a = 0.0f32;
342 let mut norm_b = 0.0f32;
343
344 for i in 0..a.len() {
345 let va = self.min[i] + (a[i] as f32) * self.inv_scale[i];
347 let vb = self.min[i] + (b[i] as f32) * self.inv_scale[i];
348
349 dot += va * vb;
350 norm_a += va * va;
351 norm_b += vb * vb;
352 }
353
354 let denom = (norm_a * norm_b).sqrt();
355 if denom < f32::EPSILON {
356 1.0 } else {
358 1.0 - (dot / denom)
359 }
360 }
361
362 #[must_use]
366 pub fn asymmetric_distance_squared(&self, query: &[f32], quantized: &[u8]) -> f32 {
367 debug_assert_eq!(query.len(), self.dimensions);
368 debug_assert_eq!(quantized.len(), self.dimensions);
369
370 let mut sum = 0.0f32;
371 for i in 0..query.len() {
372 let dequant = self.min[i] + (quantized[i] as f32) * self.inv_scale[i];
374 let diff = query[i] - dequant;
375 sum += diff * diff;
376 }
377 sum
378 }
379
380 #[must_use]
382 #[inline]
383 pub fn asymmetric_distance(&self, query: &[f32], quantized: &[u8]) -> f32 {
384 self.asymmetric_distance_squared(query, quantized).sqrt()
385 }
386}
387
388pub struct BinaryQuantizer;
413
414impl BinaryQuantizer {
415 #[must_use]
420 pub fn quantize(vector: &[f32]) -> Vec<u64> {
421 let num_words = (vector.len() + 63) / 64;
422 let mut result = vec![0u64; num_words];
423
424 for (i, &v) in vector.iter().enumerate() {
425 if v >= 0.0 {
426 result[i / 64] |= 1u64 << (i % 64);
427 }
428 }
429
430 result
431 }
432
433 #[must_use]
435 pub fn quantize_batch(vectors: &[&[f32]]) -> Vec<Vec<u64>> {
436 vectors.iter().map(|v| Self::quantize(v)).collect()
437 }
438
439 #[must_use]
443 pub fn hamming_distance(a: &[u64], b: &[u64]) -> u32 {
444 debug_assert_eq!(a.len(), b.len(), "Binary vectors must have same length");
445
446 a.iter().zip(b).map(|(&x, &y)| (x ^ y).count_ones()).sum()
447 }
448
449 #[must_use]
453 pub fn hamming_distance_normalized(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
454 let hamming = Self::hamming_distance(a, b);
455 hamming as f32 / dimensions as f32
456 }
457
458 #[must_use]
463 pub fn approximate_euclidean(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
464 let hamming = Self::hamming_distance(a, b);
465 (2.0 * hamming as f32 / dimensions as f32).sqrt()
467 }
468
469 #[must_use]
471 pub const fn words_needed(dimensions: usize) -> usize {
472 (dimensions + 63) / 64
473 }
474
475 #[must_use]
477 pub const fn bytes_needed(dimensions: usize) -> usize {
478 Self::words_needed(dimensions) * 8
479 }
480}
481
482#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct ProductQuantizer {
524 num_subvectors: usize,
526 num_centroids: usize,
528 subvector_dim: usize,
530 dimensions: usize,
532 centroids: Vec<f32>,
534}
535
536impl ProductQuantizer {
537 #[must_use]
551 pub fn train(
552 vectors: &[&[f32]],
553 num_subvectors: usize,
554 num_centroids: usize,
555 iterations: usize,
556 ) -> Self {
557 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
558 assert!(
559 num_centroids <= 256,
560 "num_centroids must be <= 256 for u8 codes"
561 );
562 assert!(num_subvectors > 0, "num_subvectors must be > 0");
563
564 let dimensions = vectors[0].len();
565 assert!(
566 dimensions.is_multiple_of(num_subvectors),
567 "dimensions ({dimensions}) must be divisible by num_subvectors ({num_subvectors})"
568 );
569 assert!(
570 vectors.iter().all(|v| v.len() == dimensions),
571 "All training vectors must have the same dimensions"
572 );
573
574 let subvector_dim = dimensions / num_subvectors;
575
576 let mut centroids = Vec::with_capacity(num_subvectors * num_centroids * subvector_dim);
578
579 for m in 0..num_subvectors {
580 let subvectors: Vec<Vec<f32>> = vectors
582 .iter()
583 .map(|v| {
584 let start = m * subvector_dim;
585 let end = start + subvector_dim;
586 v[start..end].to_vec()
587 })
588 .collect();
589
590 let partition_centroids =
592 Self::kmeans(&subvectors, num_centroids, subvector_dim, iterations);
593
594 centroids.extend(partition_centroids);
595 }
596
597 Self {
598 num_subvectors,
599 num_centroids,
600 subvector_dim,
601 dimensions,
602 centroids,
603 }
604 }
605
606 fn kmeans(vectors: &[Vec<f32>], k: usize, dims: usize, iterations: usize) -> Vec<f32> {
608 let n = vectors.len();
609
610 let actual_k = k.min(n);
612 let mut centroids: Vec<f32> = if actual_k == n {
613 vectors.iter().flat_map(|v| v.iter().copied()).collect()
614 } else {
615 let step = n / actual_k;
617 (0..actual_k)
618 .flat_map(|i| vectors[i * step].iter().copied())
619 .collect()
620 };
621
622 if actual_k < k {
624 centroids.resize(k * dims, 0.0);
625 }
626
627 let mut assignments = vec![0usize; n];
628 let mut counts = vec![0usize; k];
629
630 for _ in 0..iterations {
631 for (i, vec) in vectors.iter().enumerate() {
633 let mut best_dist = f32::INFINITY;
634 let mut best_k = 0;
635
636 for j in 0..k {
637 let centroid_start = j * dims;
638 let dist: f32 = vec
639 .iter()
640 .enumerate()
641 .map(|(d, &v)| {
642 let diff = v - centroids[centroid_start + d];
643 diff * diff
644 })
645 .sum();
646
647 if dist < best_dist {
648 best_dist = dist;
649 best_k = j;
650 }
651 }
652
653 assignments[i] = best_k;
654 }
655
656 centroids.fill(0.0);
658 counts.fill(0);
659
660 for (i, vec) in vectors.iter().enumerate() {
661 let k_idx = assignments[i];
662 let centroid_start = k_idx * dims;
663 counts[k_idx] += 1;
664
665 for (d, &v) in vec.iter().enumerate() {
666 centroids[centroid_start + d] += v;
667 }
668 }
669
670 for j in 0..k {
672 if counts[j] > 0 {
673 let centroid_start = j * dims;
674 let count = counts[j] as f32;
675 for d in 0..dims {
676 centroids[centroid_start + d] /= count;
677 }
678 }
679 }
680 }
681
682 centroids
683 }
684
685 #[must_use]
687 pub fn with_centroids(
688 num_subvectors: usize,
689 num_centroids: usize,
690 dimensions: usize,
691 centroids: Vec<f32>,
692 ) -> Self {
693 let subvector_dim = dimensions / num_subvectors;
694 assert_eq!(
695 centroids.len(),
696 num_subvectors * num_centroids * subvector_dim,
697 "Invalid centroid count"
698 );
699
700 Self {
701 num_subvectors,
702 num_centroids,
703 subvector_dim,
704 dimensions,
705 centroids,
706 }
707 }
708
709 #[must_use]
711 pub fn num_subvectors(&self) -> usize {
712 self.num_subvectors
713 }
714
715 #[must_use]
717 pub fn num_centroids(&self) -> usize {
718 self.num_centroids
719 }
720
721 #[must_use]
723 pub fn dimensions(&self) -> usize {
724 self.dimensions
725 }
726
727 #[must_use]
729 pub fn subvector_dim(&self) -> usize {
730 self.subvector_dim
731 }
732
733 #[must_use]
735 pub fn code_size(&self) -> usize {
736 self.num_subvectors }
738
739 #[must_use]
741 pub fn compression_ratio(&self) -> usize {
742 (self.dimensions * 4) / self.num_subvectors
745 }
746
747 #[must_use]
751 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
752 debug_assert_eq!(
753 vector.len(),
754 self.dimensions,
755 "Vector dimension mismatch: expected {}, got {}",
756 self.dimensions,
757 vector.len()
758 );
759
760 let mut codes = Vec::with_capacity(self.num_subvectors);
761
762 for m in 0..self.num_subvectors {
763 let subvec_start = m * self.subvector_dim;
764 let subvec = &vector[subvec_start..subvec_start + self.subvector_dim];
765
766 let mut best_dist = f32::INFINITY;
768 let mut best_k = 0u8;
769
770 for k in 0..self.num_centroids {
771 let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
772 let dist: f32 = subvec
773 .iter()
774 .enumerate()
775 .map(|(d, &v)| {
776 let diff = v - self.centroids[centroid_start + d];
777 diff * diff
778 })
779 .sum();
780
781 if dist < best_dist {
782 best_dist = dist;
783 best_k = k as u8;
784 }
785 }
786
787 codes.push(best_k);
788 }
789
790 codes
791 }
792
793 #[must_use]
795 pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
796 vectors.iter().map(|v| self.quantize(v)).collect()
797 }
798
799 #[must_use]
805 pub fn build_distance_table(&self, query: &[f32]) -> Vec<f32> {
806 debug_assert_eq!(query.len(), self.dimensions);
807
808 let mut table = Vec::with_capacity(self.num_subvectors * self.num_centroids);
809
810 for m in 0..self.num_subvectors {
811 let query_start = m * self.subvector_dim;
812 let query_subvec = &query[query_start..query_start + self.subvector_dim];
813
814 for k in 0..self.num_centroids {
815 let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
816
817 let dist: f32 = query_subvec
818 .iter()
819 .enumerate()
820 .map(|(d, &v)| {
821 let diff = v - self.centroids[centroid_start + d];
822 diff * diff
823 })
824 .sum();
825
826 table.push(dist);
827 }
828 }
829
830 table
831 }
832
833 #[must_use]
837 #[inline]
838 pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
839 debug_assert_eq!(codes.len(), self.num_subvectors);
840 debug_assert_eq!(table.len(), self.num_subvectors * self.num_centroids);
841
842 codes
843 .iter()
844 .enumerate()
845 .map(|(m, &code)| table[m * self.num_centroids + code as usize])
846 .sum()
847 }
848
849 #[must_use]
854 pub fn asymmetric_distance_squared(&self, query: &[f32], codes: &[u8]) -> f32 {
855 let table = self.build_distance_table(query);
856 self.distance_with_table(&table, codes)
857 }
858
859 #[must_use]
861 #[inline]
862 pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
863 self.asymmetric_distance_squared(query, codes).sqrt()
864 }
865
866 #[must_use]
870 pub fn reconstruct(&self, codes: &[u8]) -> Vec<f32> {
871 debug_assert_eq!(codes.len(), self.num_subvectors);
872
873 let mut result = Vec::with_capacity(self.dimensions);
874
875 for (m, &code) in codes.iter().enumerate() {
876 let centroid_start = (m * self.num_centroids + code as usize) * self.subvector_dim;
877 result.extend_from_slice(
878 &self.centroids[centroid_start..centroid_start + self.subvector_dim],
879 );
880 }
881
882 result
883 }
884
885 #[must_use]
887 pub fn get_partition_centroids(&self, partition: usize) -> Vec<&[f32]> {
888 assert!(partition < self.num_subvectors);
889
890 (0..self.num_centroids)
891 .map(|k| {
892 let start = (partition * self.num_centroids + k) * self.subvector_dim;
893 &self.centroids[start..start + self.subvector_dim]
894 })
895 .collect()
896 }
897}
898
899#[cfg(target_arch = "x86_64")]
908#[must_use]
909pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
910 a.iter()
912 .zip(b)
913 .map(|(&x, &y)| {
914 let xor = x ^ y;
915 #[allow(unsafe_code)]
918 unsafe {
919 std::arch::x86_64::_popcnt64(xor as i64) as u32
920 }
921 })
922 .sum()
923}
924
925#[cfg(not(target_arch = "x86_64"))]
927#[must_use]
928pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
929 BinaryQuantizer::hamming_distance(a, b)
930}
931
932#[cfg(test)]
937mod tests {
938 use super::*;
939
940 #[test]
941 fn test_quantization_type_compression_ratio() {
942 let dims = 384;
944 assert_eq!(QuantizationType::None.compression_ratio(dims), 1);
945 assert_eq!(QuantizationType::Scalar.compression_ratio(dims), 4);
946 assert_eq!(QuantizationType::Binary.compression_ratio(dims), 32);
947
948 let pq8 = QuantizationType::Product { num_subvectors: 8 };
950 assert_eq!(pq8.compression_ratio(dims), 192);
951
952 let pq16 = QuantizationType::Product { num_subvectors: 16 };
954 assert_eq!(pq16.compression_ratio(dims), 96);
955 }
956
957 #[test]
958 fn test_quantization_type_from_str() {
959 assert_eq!(
960 QuantizationType::from_str("none"),
961 Some(QuantizationType::None)
962 );
963 assert_eq!(
964 QuantizationType::from_str("scalar"),
965 Some(QuantizationType::Scalar)
966 );
967 assert_eq!(
968 QuantizationType::from_str("SQ"),
969 Some(QuantizationType::Scalar)
970 );
971 assert_eq!(
972 QuantizationType::from_str("binary"),
973 Some(QuantizationType::Binary)
974 );
975 assert_eq!(
976 QuantizationType::from_str("bit"),
977 Some(QuantizationType::Binary)
978 );
979 assert_eq!(QuantizationType::from_str("invalid"), None);
980 }
981
982 #[test]
987 fn test_scalar_quantizer_train() {
988 let vectors = [
989 vec![0.0f32, 0.5, 1.0],
990 vec![0.2, 0.3, 0.8],
991 vec![0.1, 0.6, 0.9],
992 ];
993 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
994
995 let quantizer = ScalarQuantizer::train(&refs);
996
997 assert_eq!(quantizer.dimensions(), 3);
998 assert_eq!(quantizer.min_values()[0], 0.0);
999 assert_eq!(quantizer.min_values()[1], 0.3);
1000 assert_eq!(quantizer.min_values()[2], 0.8);
1001 }
1002
1003 #[test]
1004 fn test_scalar_quantizer_quantize() {
1005 let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1006
1007 let q_min = quantizer.quantize(&[0.0, 0.0]);
1009 assert_eq!(q_min, vec![0, 0]);
1010
1011 let q_max = quantizer.quantize(&[1.0, 1.0]);
1013 assert_eq!(q_max, vec![255, 255]);
1014
1015 let q_mid = quantizer.quantize(&[0.5, 0.5]);
1017 assert!(q_mid[0] >= 126 && q_mid[0] <= 128);
1018 }
1019
1020 #[test]
1021 fn test_scalar_quantizer_dequantize() {
1022 let quantizer = ScalarQuantizer::with_ranges(vec![0.0], vec![1.0]);
1023
1024 let original = [0.5f32];
1025 let quantized = quantizer.quantize(&original);
1026 let dequantized = quantizer.dequantize(&quantized);
1027
1028 assert!((original[0] - dequantized[0]).abs() < 0.01);
1030 }
1031
1032 #[test]
1033 fn test_scalar_quantizer_distance() {
1034 let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1035
1036 let a = quantizer.quantize(&[0.0, 0.0]);
1037 let b = quantizer.quantize(&[1.0, 0.0]);
1038
1039 let dist = quantizer.distance_u8(&a, &b);
1040 assert!((dist - 1.0).abs() < 0.1);
1042 }
1043
1044 #[test]
1045 fn test_scalar_quantizer_asymmetric_distance() {
1046 let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1047
1048 let query = [0.0f32, 0.0];
1049 let stored = quantizer.quantize(&[1.0, 0.0]);
1050
1051 let dist = quantizer.asymmetric_distance(&query, &stored);
1052 assert!((dist - 1.0).abs() < 0.1);
1053 }
1054
1055 #[test]
1056 fn test_scalar_quantizer_cosine_distance() {
1057 let quantizer = ScalarQuantizer::with_ranges(vec![-1.0, -1.0], vec![1.0, 1.0]);
1058
1059 let a = quantizer.quantize(&[1.0, 0.0]);
1061 let b = quantizer.quantize(&[0.0, 1.0]);
1062
1063 let dist = quantizer.cosine_distance_u8(&a, &b);
1064 assert!((dist - 1.0).abs() < 0.1);
1066 }
1067
1068 #[test]
1069 #[should_panic(expected = "Cannot train on empty vector set")]
1070 fn test_scalar_quantizer_empty_training() {
1071 let vectors: Vec<&[f32]> = vec![];
1072 let _ = ScalarQuantizer::train(&vectors);
1073 }
1074
1075 #[test]
1080 fn test_binary_quantizer_quantize() {
1081 let v = vec![0.5f32, -0.3, 0.0, 0.8];
1082 let bits = BinaryQuantizer::quantize(&v);
1083
1084 assert_eq!(bits.len(), 1); assert_eq!(bits[0] & 0xF, 0b1101);
1089 }
1090
1091 #[test]
1092 fn test_binary_quantizer_hamming_distance() {
1093 let v1 = vec![1.0f32, 1.0, 1.0, 1.0]; let v2 = vec![1.0f32, -1.0, 1.0, -1.0]; let bits1 = BinaryQuantizer::quantize(&v1);
1097 let bits2 = BinaryQuantizer::quantize(&v2);
1098
1099 let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1100 assert_eq!(dist, 2); }
1102
1103 #[test]
1104 fn test_binary_quantizer_identical_vectors() {
1105 let v = vec![0.1f32, -0.2, 0.3, -0.4, 0.5];
1106 let bits = BinaryQuantizer::quantize(&v);
1107
1108 let dist = BinaryQuantizer::hamming_distance(&bits, &bits);
1109 assert_eq!(dist, 0);
1110 }
1111
1112 #[test]
1113 fn test_binary_quantizer_opposite_vectors() {
1114 let v1 = vec![1.0f32; 64];
1115 let v2 = vec![-1.0f32; 64];
1116
1117 let bits1 = BinaryQuantizer::quantize(&v1);
1118 let bits2 = BinaryQuantizer::quantize(&v2);
1119
1120 let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1121 assert_eq!(dist, 64); }
1123
1124 #[test]
1125 fn test_binary_quantizer_large_vector() {
1126 let v: Vec<f32> = (0..1000)
1127 .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
1128 .collect();
1129 let bits = BinaryQuantizer::quantize(&v);
1130
1131 assert_eq!(bits.len(), 16);
1133 }
1134
1135 #[test]
1136 fn test_binary_quantizer_normalized_distance() {
1137 let v1 = vec![1.0f32; 100];
1138 let v2 = vec![-1.0f32; 100];
1139
1140 let bits1 = BinaryQuantizer::quantize(&v1);
1141 let bits2 = BinaryQuantizer::quantize(&v2);
1142
1143 let norm_dist = BinaryQuantizer::hamming_distance_normalized(&bits1, &bits2, 100);
1144 assert!((norm_dist - 1.0).abs() < 0.01); }
1146
1147 #[test]
1148 fn test_binary_quantizer_words_needed() {
1149 assert_eq!(BinaryQuantizer::words_needed(1), 1);
1150 assert_eq!(BinaryQuantizer::words_needed(64), 1);
1151 assert_eq!(BinaryQuantizer::words_needed(65), 2);
1152 assert_eq!(BinaryQuantizer::words_needed(128), 2);
1153 assert_eq!(BinaryQuantizer::words_needed(1536), 24); }
1155
1156 #[test]
1157 fn test_binary_quantizer_bytes_needed() {
1158 assert_eq!(BinaryQuantizer::bytes_needed(64), 8);
1160 assert_eq!(BinaryQuantizer::bytes_needed(128), 16);
1161 assert_eq!(BinaryQuantizer::bytes_needed(1536), 192); }
1163
1164 #[test]
1169 fn test_hamming_distance_simd() {
1170 let a = vec![0xFFFF_FFFF_FFFF_FFFFu64, 0x0000_0000_0000_0000];
1171 let b = vec![0x0000_0000_0000_0000u64, 0xFFFF_FFFF_FFFF_FFFF];
1172
1173 let dist = hamming_distance_simd(&a, &b);
1174 assert_eq!(dist, 128); }
1176
1177 #[test]
1182 fn test_quantization_type_product_from_str() {
1183 assert_eq!(
1185 QuantizationType::from_str("pq"),
1186 Some(QuantizationType::Product { num_subvectors: 8 })
1187 );
1188 assert_eq!(
1189 QuantizationType::from_str("product"),
1190 Some(QuantizationType::Product { num_subvectors: 8 })
1191 );
1192
1193 assert_eq!(
1195 QuantizationType::from_str("pq8"),
1196 Some(QuantizationType::Product { num_subvectors: 8 })
1197 );
1198 assert_eq!(
1199 QuantizationType::from_str("pq16"),
1200 Some(QuantizationType::Product { num_subvectors: 16 })
1201 );
1202 assert_eq!(
1203 QuantizationType::from_str("pq32"),
1204 Some(QuantizationType::Product { num_subvectors: 32 })
1205 );
1206 }
1207
1208 #[test]
1209 fn test_quantization_type_requires_training() {
1210 assert!(!QuantizationType::None.requires_training());
1211 assert!(QuantizationType::Scalar.requires_training());
1212 assert!(!QuantizationType::Binary.requires_training());
1213 assert!(QuantizationType::Product { num_subvectors: 8 }.requires_training());
1214 }
1215
1216 #[test]
1217 fn test_product_quantizer_train() {
1218 let vectors: Vec<Vec<f32>> = (0..100)
1220 .map(|i| (0..16).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1221 .collect();
1222 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1223
1224 let pq = ProductQuantizer::train(&refs, 4, 8, 5);
1226
1227 assert_eq!(pq.num_subvectors(), 4);
1228 assert_eq!(pq.num_centroids(), 8);
1229 assert_eq!(pq.dimensions(), 16);
1230 assert_eq!(pq.subvector_dim(), 4);
1231 assert_eq!(pq.code_size(), 4);
1232 }
1233
1234 #[test]
1235 fn test_product_quantizer_quantize() {
1236 let vectors: Vec<Vec<f32>> = (0..50)
1237 .map(|i| (0..8).map(|j| ((i * j) as f32 * 0.1).cos()).collect())
1238 .collect();
1239 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1240
1241 let pq = ProductQuantizer::train(&refs, 2, 16, 3);
1242
1243 let codes = pq.quantize(&vectors[0]);
1245 assert_eq!(codes.len(), 2);
1246
1247 for &code in &codes {
1249 assert!(code < 16);
1250 }
1251 }
1252
1253 #[test]
1254 fn test_product_quantizer_reconstruct() {
1255 let vectors: Vec<Vec<f32>> = (0..50)
1256 .map(|i| (0..12).map(|j| (i + j) as f32 * 0.05).collect())
1257 .collect();
1258 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1259
1260 let pq = ProductQuantizer::train(&refs, 3, 8, 5);
1261
1262 let original = &vectors[10];
1264 let codes = pq.quantize(original);
1265 let reconstructed = pq.reconstruct(&codes);
1266
1267 assert_eq!(reconstructed.len(), 12);
1268
1269 let error: f32 = original
1271 .iter()
1272 .zip(&reconstructed)
1273 .map(|(a, b)| (a - b).powi(2))
1274 .sum::<f32>()
1275 .sqrt();
1276
1277 assert!(error < 2.0, "Reconstruction error too high: {error}");
1279 }
1280
1281 #[test]
1282 fn test_product_quantizer_asymmetric_distance() {
1283 let vectors: Vec<Vec<f32>> = (0..100)
1284 .map(|i| (0..32).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1285 .collect();
1286 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1287
1288 let pq = ProductQuantizer::train(&refs, 8, 32, 5);
1289
1290 let query = &vectors[0];
1292 let codes = pq.quantize(query);
1293 let self_dist = pq.asymmetric_distance(query, &codes);
1294 assert!(self_dist < 1.0, "Self-distance too high: {self_dist}");
1295
1296 let other_codes = pq.quantize(&vectors[50]);
1298 let other_dist = pq.asymmetric_distance(query, &other_codes);
1299 assert!(other_dist > self_dist, "Other vector should be farther");
1300 }
1301
1302 #[test]
1303 fn test_product_quantizer_distance_table() {
1304 let vectors: Vec<Vec<f32>> = (0..50)
1305 .map(|i| (0..16).map(|j| (i + j) as f32 * 0.02).collect())
1306 .collect();
1307 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1308
1309 let pq = ProductQuantizer::train(&refs, 4, 8, 3);
1310
1311 let query = &vectors[0];
1312 let table = pq.build_distance_table(query);
1313
1314 assert_eq!(table.len(), 4 * 8);
1316
1317 let codes = pq.quantize(&vectors[5]);
1319 let dist_direct = pq.asymmetric_distance_squared(query, &codes);
1320 let dist_table = pq.distance_with_table(&table, &codes);
1321
1322 assert!((dist_direct - dist_table).abs() < 0.001);
1323 }
1324
1325 #[test]
1326 fn test_product_quantizer_batch() {
1327 let vectors: Vec<Vec<f32>> = (0..20)
1328 .map(|i| (0..8).map(|j| (i + j) as f32).collect())
1329 .collect();
1330 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1331
1332 let pq = ProductQuantizer::train(&refs, 2, 4, 2);
1333
1334 let batch_codes = pq.quantize_batch(&refs[0..5]);
1335 assert_eq!(batch_codes.len(), 5);
1336
1337 for codes in &batch_codes {
1338 assert_eq!(codes.len(), 2);
1339 }
1340 }
1341
1342 #[test]
1343 fn test_product_quantizer_compression_ratio() {
1344 let vectors: Vec<Vec<f32>> = (0..50)
1345 .map(|i| (0..384).map(|j| ((i * j) as f32).sin()).collect())
1346 .collect();
1347 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1348
1349 let pq8 = ProductQuantizer::train(&refs, 8, 256, 3);
1351 assert_eq!(pq8.compression_ratio(), 192); let pq48 = ProductQuantizer::train(&refs, 48, 256, 3);
1355 assert_eq!(pq48.compression_ratio(), 32); }
1357
1358 #[test]
1359 #[should_panic(expected = "dimensions (15) must be divisible by num_subvectors (4)")]
1360 fn test_product_quantizer_invalid_dimensions() {
1361 let vectors: Vec<Vec<f32>> = (0..10)
1362 .map(|i| (0..15).map(|j| (i + j) as f32).collect())
1363 .collect();
1364 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1365
1366 let _ = ProductQuantizer::train(&refs, 4, 8, 3);
1368 }
1369
1370 #[test]
1371 fn test_product_quantizer_get_partition_centroids() {
1372 let vectors: Vec<Vec<f32>> = (0..30)
1373 .map(|i| (0..8).map(|j| (i + j) as f32 * 0.1).collect())
1374 .collect();
1375 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1376
1377 let pq = ProductQuantizer::train(&refs, 2, 4, 3);
1378
1379 let centroids = pq.get_partition_centroids(0);
1381 assert_eq!(centroids.len(), 4); assert_eq!(centroids[0].len(), 4); }
1384}