1use serde::{Deserialize, Serialize};
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
59#[non_exhaustive]
60pub enum QuantizationType {
61 #[default]
63 None,
64 Scalar,
66 Binary,
68 Product {
70 num_subvectors: usize,
72 },
73}
74
75impl QuantizationType {
76 #[must_use]
81 pub fn compression_ratio(&self, dimensions: usize) -> usize {
82 match self {
83 Self::None => 1,
84 Self::Scalar => 4, Self::Binary => 32, Self::Product { num_subvectors } => {
87 let m = (*num_subvectors).max(1);
91 (dimensions * 4) / m
92 }
93 }
94 }
95
96 #[must_use]
98 pub fn name(&self) -> &'static str {
99 match self {
100 Self::None => "none",
101 Self::Scalar => "scalar",
102 Self::Binary => "binary",
103 Self::Product { .. } => "product",
104 }
105 }
106
107 #[must_use]
109 pub fn from_str(s: &str) -> Option<Self> {
110 match s.to_lowercase().as_str() {
111 "none" | "full" | "f32" => Some(Self::None),
112 "scalar" | "sq" | "u8" | "int8" => Some(Self::Scalar),
113 "binary" | "bin" | "bit" | "1bit" => Some(Self::Binary),
114 "product" | "pq" => Some(Self::Product { num_subvectors: 8 }),
115 s if s.starts_with("pq") => {
116 s[2..]
118 .parse()
119 .ok()
120 .map(|n| Self::Product { num_subvectors: n })
121 }
122 _ => None,
123 }
124 }
125
126 #[must_use]
128 pub const fn requires_training(&self) -> bool {
129 matches!(self, Self::Scalar | Self::Product { .. })
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ScalarQuantizer {
170 min: Vec<f32>,
172 scale: Vec<f32>,
174 inv_scale: Vec<f32>,
176 dimensions: usize,
178}
179
180impl ScalarQuantizer {
181 #[must_use]
194 pub fn train(vectors: &[&[f32]]) -> Self {
195 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
196
197 let dimensions = vectors[0].len();
198 assert!(
199 vectors.iter().all(|v| v.len() == dimensions),
200 "All training vectors must have the same dimensions"
201 );
202
203 let mut min = vec![f32::INFINITY; dimensions];
205 let mut max = vec![f32::NEG_INFINITY; dimensions];
206
207 for vec in vectors {
208 for (i, &v) in vec.iter().enumerate() {
209 min[i] = min[i].min(v);
210 max[i] = max[i].max(v);
211 }
212 }
213
214 let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
216 .iter()
217 .zip(&max)
218 .map(|(&mn, &mx)| {
219 let range = mx - mn;
220 if range.abs() < f32::EPSILON {
221 (1.0, 1.0)
223 } else {
224 (255.0 / range, range / 255.0)
225 }
226 })
227 .unzip();
228
229 Self {
230 min,
231 scale,
232 inv_scale,
233 dimensions,
234 }
235 }
236
237 #[must_use]
243 pub fn with_ranges(min: Vec<f32>, max: Vec<f32>) -> Self {
244 let dimensions = min.len();
245 assert_eq!(min.len(), max.len(), "Min and max must have same length");
246
247 let (scale, inv_scale): (Vec<f32>, Vec<f32>) = min
248 .iter()
249 .zip(&max)
250 .map(|(&mn, &mx)| {
251 let range = mx - mn;
252 if range.abs() < f32::EPSILON {
253 (1.0, 1.0)
254 } else {
255 (255.0 / range, range / 255.0)
256 }
257 })
258 .unzip();
259
260 Self {
261 min,
262 scale,
263 inv_scale,
264 dimensions,
265 }
266 }
267
268 #[must_use]
270 pub fn dimensions(&self) -> usize {
271 self.dimensions
272 }
273
274 #[must_use]
276 pub fn min_values(&self) -> &[f32] {
277 &self.min
278 }
279
280 #[must_use]
284 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
285 debug_assert_eq!(
286 vector.len(),
287 self.dimensions,
288 "Vector dimension mismatch: expected {}, got {}",
289 self.dimensions,
290 vector.len()
291 );
292
293 vector
294 .iter()
295 .enumerate()
296 .map(|(i, &v)| {
297 let normalized = (v - self.min[i]) * self.scale[i];
298 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
300 {
301 normalized.clamp(0.0, 255.0) as u8
302 }
303 })
304 .collect()
305 }
306
307 #[must_use]
309 pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
310 vectors.iter().map(|v| self.quantize(v)).collect()
311 }
312
313 #[must_use]
315 pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
316 debug_assert_eq!(quantized.len(), self.dimensions);
317
318 quantized
319 .iter()
320 .enumerate()
321 .map(|(i, &q)| self.min[i] + (q as f32) * self.inv_scale[i])
322 .collect()
323 }
324
325 #[must_use]
330 pub fn distance_squared_u8(&self, a: &[u8], b: &[u8]) -> f32 {
331 debug_assert_eq!(a.len(), self.dimensions);
332 debug_assert_eq!(b.len(), self.dimensions);
333
334 let mut sum = 0.0f32;
336 for i in 0..a.len() {
337 let diff = (a[i] as f32) - (b[i] as f32);
338 sum += diff * diff * self.inv_scale[i] * self.inv_scale[i];
339 }
340 sum
341 }
342
343 #[must_use]
345 #[inline]
346 pub fn distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
347 self.distance_squared_u8(a, b).sqrt()
348 }
349
350 #[must_use]
354 pub fn cosine_distance_u8(&self, a: &[u8], b: &[u8]) -> f32 {
355 debug_assert_eq!(a.len(), self.dimensions);
356 debug_assert_eq!(b.len(), self.dimensions);
357
358 let mut dot = 0.0f32;
359 let mut norm_a = 0.0f32;
360 let mut norm_b = 0.0f32;
361
362 for i in 0..a.len() {
363 let va = self.min[i] + (a[i] as f32) * self.inv_scale[i];
365 let vb = self.min[i] + (b[i] as f32) * self.inv_scale[i];
366
367 dot += va * vb;
368 norm_a += va * va;
369 norm_b += vb * vb;
370 }
371
372 let denom = (norm_a * norm_b).sqrt();
373 if denom < f32::EPSILON {
374 1.0 } else {
376 1.0 - (dot / denom)
377 }
378 }
379
380 #[must_use]
384 pub fn asymmetric_distance_squared(&self, query: &[f32], quantized: &[u8]) -> f32 {
385 debug_assert_eq!(query.len(), self.dimensions);
386 debug_assert_eq!(quantized.len(), self.dimensions);
387
388 let mut sum = 0.0f32;
389 for i in 0..query.len() {
390 let dequant = self.min[i] + (quantized[i] as f32) * self.inv_scale[i];
392 let diff = query[i] - dequant;
393 sum += diff * diff;
394 }
395 sum
396 }
397
398 #[must_use]
400 #[inline]
401 pub fn asymmetric_distance(&self, query: &[f32], quantized: &[u8]) -> f32 {
402 self.asymmetric_distance_squared(query, quantized).sqrt()
403 }
404}
405
406pub struct BinaryQuantizer;
431
432impl BinaryQuantizer {
433 #[must_use]
438 pub fn quantize(vector: &[f32]) -> Vec<u64> {
439 let num_words = (vector.len() + 63) / 64;
440 let mut result = vec![0u64; num_words];
441
442 for (i, &v) in vector.iter().enumerate() {
443 if v >= 0.0 {
444 result[i / 64] |= 1u64 << (i % 64);
445 }
446 }
447
448 result
449 }
450
451 #[must_use]
453 pub fn quantize_batch(vectors: &[&[f32]]) -> Vec<Vec<u64>> {
454 vectors.iter().map(|v| Self::quantize(v)).collect()
455 }
456
457 #[must_use]
461 pub fn hamming_distance(a: &[u64], b: &[u64]) -> u32 {
462 debug_assert_eq!(a.len(), b.len(), "Binary vectors must have same length");
463
464 a.iter().zip(b).map(|(&x, &y)| (x ^ y).count_ones()).sum()
465 }
466
467 #[must_use]
471 pub fn hamming_distance_normalized(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
472 let hamming = Self::hamming_distance(a, b);
473 hamming as f32 / dimensions as f32
474 }
475
476 #[must_use]
481 pub fn approximate_euclidean(a: &[u64], b: &[u64], dimensions: usize) -> f32 {
482 let hamming = Self::hamming_distance(a, b);
483 (2.0 * hamming as f32 / dimensions as f32).sqrt()
485 }
486
487 #[must_use]
489 pub const fn words_needed(dimensions: usize) -> usize {
490 (dimensions + 63) / 64
491 }
492
493 #[must_use]
495 pub const fn bytes_needed(dimensions: usize) -> usize {
496 Self::words_needed(dimensions) * 8
497 }
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
541pub struct ProductQuantizer {
542 num_subvectors: usize,
544 num_centroids: usize,
546 subvector_dim: usize,
548 dimensions: usize,
550 centroids: Vec<f32>,
552}
553
554impl ProductQuantizer {
555 #[must_use]
569 pub fn train(
570 vectors: &[&[f32]],
571 num_subvectors: usize,
572 num_centroids: usize,
573 iterations: usize,
574 ) -> Self {
575 assert!(!vectors.is_empty(), "Cannot train on empty vector set");
576 assert!(
577 num_centroids <= 256,
578 "num_centroids must be <= 256 for u8 codes"
579 );
580 assert!(num_subvectors > 0, "num_subvectors must be > 0");
581
582 let dimensions = vectors[0].len();
583 assert!(
584 dimensions.is_multiple_of(num_subvectors),
585 "dimensions ({dimensions}) must be divisible by num_subvectors ({num_subvectors})"
586 );
587 assert!(
588 vectors.iter().all(|v| v.len() == dimensions),
589 "All training vectors must have the same dimensions"
590 );
591
592 let subvector_dim = dimensions / num_subvectors;
593
594 let mut centroids = Vec::with_capacity(num_subvectors * num_centroids * subvector_dim);
596
597 for m in 0..num_subvectors {
598 let subvectors: Vec<Vec<f32>> = vectors
600 .iter()
601 .map(|v| {
602 let start = m * subvector_dim;
603 let end = start + subvector_dim;
604 v[start..end].to_vec()
605 })
606 .collect();
607
608 let partition_centroids =
610 Self::kmeans(&subvectors, num_centroids, subvector_dim, iterations);
611
612 centroids.extend(partition_centroids);
613 }
614
615 Self {
616 num_subvectors,
617 num_centroids,
618 subvector_dim,
619 dimensions,
620 centroids,
621 }
622 }
623
624 fn kmeans(vectors: &[Vec<f32>], k: usize, dims: usize, iterations: usize) -> Vec<f32> {
626 let n = vectors.len();
627
628 let actual_k = k.min(n);
630 let mut centroids: Vec<f32> = if actual_k == n {
631 vectors.iter().flat_map(|v| v.iter().copied()).collect()
632 } else {
633 let step = n / actual_k;
635 (0..actual_k)
636 .flat_map(|i| vectors[i * step].iter().copied())
637 .collect()
638 };
639
640 if actual_k < k {
642 centroids.resize(k * dims, 0.0);
643 }
644
645 let mut assignments = vec![0usize; n];
646 let mut counts = vec![0usize; k];
647
648 for _ in 0..iterations {
649 for (i, vec) in vectors.iter().enumerate() {
651 let mut best_dist = f32::INFINITY;
652 let mut best_k = 0;
653
654 for j in 0..k {
655 let centroid_start = j * dims;
656 let dist: f32 = vec
657 .iter()
658 .enumerate()
659 .map(|(d, &v)| {
660 let diff = v - centroids[centroid_start + d];
661 diff * diff
662 })
663 .sum();
664
665 if dist < best_dist {
666 best_dist = dist;
667 best_k = j;
668 }
669 }
670
671 assignments[i] = best_k;
672 }
673
674 centroids.fill(0.0);
676 counts.fill(0);
677
678 for (i, vec) in vectors.iter().enumerate() {
679 let k_idx = assignments[i];
680 let centroid_start = k_idx * dims;
681 counts[k_idx] += 1;
682
683 for (d, &v) in vec.iter().enumerate() {
684 centroids[centroid_start + d] += v;
685 }
686 }
687
688 for j in 0..k {
690 if counts[j] > 0 {
691 let centroid_start = j * dims;
692 let count = counts[j] as f32;
693 for d in 0..dims {
694 centroids[centroid_start + d] /= count;
695 }
696 }
697 }
698 }
699
700 centroids
701 }
702
703 #[must_use]
710 pub fn with_centroids(
711 num_subvectors: usize,
712 num_centroids: usize,
713 dimensions: usize,
714 centroids: Vec<f32>,
715 ) -> Self {
716 let subvector_dim = dimensions / num_subvectors;
717 assert_eq!(
718 centroids.len(),
719 num_subvectors * num_centroids * subvector_dim,
720 "Invalid centroid count"
721 );
722
723 Self {
724 num_subvectors,
725 num_centroids,
726 subvector_dim,
727 dimensions,
728 centroids,
729 }
730 }
731
732 #[must_use]
734 pub fn num_subvectors(&self) -> usize {
735 self.num_subvectors
736 }
737
738 #[must_use]
740 pub fn num_centroids(&self) -> usize {
741 self.num_centroids
742 }
743
744 #[must_use]
746 pub fn dimensions(&self) -> usize {
747 self.dimensions
748 }
749
750 #[must_use]
752 pub fn subvector_dim(&self) -> usize {
753 self.subvector_dim
754 }
755
756 #[must_use]
758 pub fn code_size(&self) -> usize {
759 self.num_subvectors }
761
762 #[must_use]
764 pub fn compression_ratio(&self) -> usize {
765 (self.dimensions * 4) / self.num_subvectors
768 }
769
770 #[must_use]
774 pub fn quantize(&self, vector: &[f32]) -> Vec<u8> {
775 debug_assert_eq!(
776 vector.len(),
777 self.dimensions,
778 "Vector dimension mismatch: expected {}, got {}",
779 self.dimensions,
780 vector.len()
781 );
782
783 let mut codes = Vec::with_capacity(self.num_subvectors);
784
785 for m in 0..self.num_subvectors {
786 let subvec_start = m * self.subvector_dim;
787 let subvec = &vector[subvec_start..subvec_start + self.subvector_dim];
788
789 let mut best_dist = f32::INFINITY;
791 let mut best_k = 0u8;
792
793 for k in 0..self.num_centroids {
794 let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
795 let dist: f32 = subvec
796 .iter()
797 .enumerate()
798 .map(|(d, &v)| {
799 let diff = v - self.centroids[centroid_start + d];
800 diff * diff
801 })
802 .sum();
803
804 if dist < best_dist {
805 best_dist = dist;
806 #[allow(clippy::cast_possible_truncation)]
808 {
809 best_k = k as u8;
810 }
811 }
812 }
813
814 codes.push(best_k);
815 }
816
817 codes
818 }
819
820 #[must_use]
822 pub fn quantize_batch(&self, vectors: &[&[f32]]) -> Vec<Vec<u8>> {
823 vectors.iter().map(|v| self.quantize(v)).collect()
824 }
825
826 #[must_use]
832 pub fn build_distance_table(&self, query: &[f32]) -> Vec<f32> {
833 debug_assert_eq!(query.len(), self.dimensions);
834
835 let mut table = Vec::with_capacity(self.num_subvectors * self.num_centroids);
836
837 for m in 0..self.num_subvectors {
838 let query_start = m * self.subvector_dim;
839 let query_subvec = &query[query_start..query_start + self.subvector_dim];
840
841 for k in 0..self.num_centroids {
842 let centroid_start = (m * self.num_centroids + k) * self.subvector_dim;
843
844 let dist: f32 = query_subvec
845 .iter()
846 .enumerate()
847 .map(|(d, &v)| {
848 let diff = v - self.centroids[centroid_start + d];
849 diff * diff
850 })
851 .sum();
852
853 table.push(dist);
854 }
855 }
856
857 table
858 }
859
860 #[must_use]
864 #[inline]
865 pub fn distance_with_table(&self, table: &[f32], codes: &[u8]) -> f32 {
866 debug_assert_eq!(codes.len(), self.num_subvectors);
867 debug_assert_eq!(table.len(), self.num_subvectors * self.num_centroids);
868
869 codes
870 .iter()
871 .enumerate()
872 .map(|(m, &code)| table[m * self.num_centroids + code as usize])
873 .sum()
874 }
875
876 #[must_use]
881 pub fn asymmetric_distance_squared(&self, query: &[f32], codes: &[u8]) -> f32 {
882 let table = self.build_distance_table(query);
883 self.distance_with_table(&table, codes)
884 }
885
886 #[must_use]
888 #[inline]
889 pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
890 self.asymmetric_distance_squared(query, codes).sqrt()
891 }
892
893 #[must_use]
897 pub fn reconstruct(&self, codes: &[u8]) -> Vec<f32> {
898 debug_assert_eq!(codes.len(), self.num_subvectors);
899
900 let mut result = Vec::with_capacity(self.dimensions);
901
902 for (m, &code) in codes.iter().enumerate() {
903 let centroid_start = (m * self.num_centroids + code as usize) * self.subvector_dim;
904 result.extend_from_slice(
905 &self.centroids[centroid_start..centroid_start + self.subvector_dim],
906 );
907 }
908
909 result
910 }
911
912 #[must_use]
918 pub fn get_partition_centroids(&self, partition: usize) -> Vec<&[f32]> {
919 assert!(partition < self.num_subvectors);
920
921 (0..self.num_centroids)
922 .map(|k| {
923 let start = (partition * self.num_centroids + k) * self.subvector_dim;
924 &self.centroids[start..start + self.subvector_dim]
925 })
926 .collect()
927 }
928}
929
930#[cfg(target_arch = "x86_64")]
939#[must_use]
940pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
941 a.iter()
943 .zip(b)
944 .map(|(&x, &y)| {
945 let xor = x ^ y;
946 #[allow(unsafe_code, clippy::cast_possible_wrap, clippy::cast_sign_loss)]
950 unsafe {
951 std::arch::x86_64::_popcnt64(xor as i64) as u32
952 }
953 })
954 .sum()
955}
956
957#[cfg(not(target_arch = "x86_64"))]
959#[must_use]
960pub fn hamming_distance_simd(a: &[u64], b: &[u64]) -> u32 {
961 BinaryQuantizer::hamming_distance(a, b)
962}
963
964#[cfg(test)]
969mod tests {
970 use super::*;
971
972 #[test]
973 fn test_quantization_type_compression_ratio() {
974 let dims = 384;
976 assert_eq!(QuantizationType::None.compression_ratio(dims), 1);
977 assert_eq!(QuantizationType::Scalar.compression_ratio(dims), 4);
978 assert_eq!(QuantizationType::Binary.compression_ratio(dims), 32);
979
980 let pq8 = QuantizationType::Product { num_subvectors: 8 };
982 assert_eq!(pq8.compression_ratio(dims), 192);
983
984 let pq16 = QuantizationType::Product { num_subvectors: 16 };
986 assert_eq!(pq16.compression_ratio(dims), 96);
987 }
988
989 #[test]
990 fn test_quantization_type_from_str() {
991 assert_eq!(
992 QuantizationType::from_str("none"),
993 Some(QuantizationType::None)
994 );
995 assert_eq!(
996 QuantizationType::from_str("scalar"),
997 Some(QuantizationType::Scalar)
998 );
999 assert_eq!(
1000 QuantizationType::from_str("SQ"),
1001 Some(QuantizationType::Scalar)
1002 );
1003 assert_eq!(
1004 QuantizationType::from_str("binary"),
1005 Some(QuantizationType::Binary)
1006 );
1007 assert_eq!(
1008 QuantizationType::from_str("bit"),
1009 Some(QuantizationType::Binary)
1010 );
1011 assert_eq!(QuantizationType::from_str("invalid"), None);
1012 }
1013
1014 #[test]
1019 fn test_scalar_quantizer_train() {
1020 let vectors = [
1021 vec![0.0f32, 0.5, 1.0],
1022 vec![0.2, 0.3, 0.8],
1023 vec![0.1, 0.6, 0.9],
1024 ];
1025 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1026
1027 let quantizer = ScalarQuantizer::train(&refs);
1028
1029 assert_eq!(quantizer.dimensions(), 3);
1030 assert_eq!(quantizer.min_values()[0], 0.0);
1031 assert_eq!(quantizer.min_values()[1], 0.3);
1032 assert_eq!(quantizer.min_values()[2], 0.8);
1033 }
1034
1035 #[test]
1036 fn test_scalar_quantizer_quantize() {
1037 let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1038
1039 let q_min = quantizer.quantize(&[0.0, 0.0]);
1041 assert_eq!(q_min, vec![0, 0]);
1042
1043 let q_max = quantizer.quantize(&[1.0, 1.0]);
1045 assert_eq!(q_max, vec![255, 255]);
1046
1047 let q_mid = quantizer.quantize(&[0.5, 0.5]);
1049 assert!(q_mid[0] >= 126 && q_mid[0] <= 128);
1050 }
1051
1052 #[test]
1053 fn test_scalar_quantizer_dequantize() {
1054 let quantizer = ScalarQuantizer::with_ranges(vec![0.0], vec![1.0]);
1055
1056 let original = [0.5f32];
1057 let quantized = quantizer.quantize(&original);
1058 let dequantized = quantizer.dequantize(&quantized);
1059
1060 assert!((original[0] - dequantized[0]).abs() < 0.01);
1062 }
1063
1064 #[test]
1065 fn test_scalar_quantizer_distance() {
1066 let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1067
1068 let a = quantizer.quantize(&[0.0, 0.0]);
1069 let b = quantizer.quantize(&[1.0, 0.0]);
1070
1071 let dist = quantizer.distance_u8(&a, &b);
1072 assert!((dist - 1.0).abs() < 0.1);
1074 }
1075
1076 #[test]
1077 fn test_scalar_quantizer_asymmetric_distance() {
1078 let quantizer = ScalarQuantizer::with_ranges(vec![0.0, 0.0], vec![1.0, 1.0]);
1079
1080 let query = [0.0f32, 0.0];
1081 let stored = quantizer.quantize(&[1.0, 0.0]);
1082
1083 let dist = quantizer.asymmetric_distance(&query, &stored);
1084 assert!((dist - 1.0).abs() < 0.1);
1085 }
1086
1087 #[test]
1088 fn test_scalar_quantizer_cosine_distance() {
1089 let quantizer = ScalarQuantizer::with_ranges(vec![-1.0, -1.0], vec![1.0, 1.0]);
1090
1091 let a = quantizer.quantize(&[1.0, 0.0]);
1093 let b = quantizer.quantize(&[0.0, 1.0]);
1094
1095 let dist = quantizer.cosine_distance_u8(&a, &b);
1096 assert!((dist - 1.0).abs() < 0.1);
1098 }
1099
1100 #[test]
1101 #[should_panic(expected = "Cannot train on empty vector set")]
1102 fn test_scalar_quantizer_empty_training() {
1103 let vectors: Vec<&[f32]> = vec![];
1104 let _ = ScalarQuantizer::train(&vectors);
1105 }
1106
1107 #[test]
1112 fn test_binary_quantizer_quantize() {
1113 let v = vec![0.5f32, -0.3, 0.0, 0.8];
1114 let bits = BinaryQuantizer::quantize(&v);
1115
1116 assert_eq!(bits.len(), 1); assert_eq!(bits[0] & 0xF, 0b1101);
1121 }
1122
1123 #[test]
1124 fn test_binary_quantizer_hamming_distance() {
1125 let v1 = vec![1.0f32, 1.0, 1.0, 1.0]; let v2 = vec![1.0f32, -1.0, 1.0, -1.0]; let bits1 = BinaryQuantizer::quantize(&v1);
1129 let bits2 = BinaryQuantizer::quantize(&v2);
1130
1131 let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1132 assert_eq!(dist, 2); }
1134
1135 #[test]
1136 fn test_binary_quantizer_identical_vectors() {
1137 let v = vec![0.1f32, -0.2, 0.3, -0.4, 0.5];
1138 let bits = BinaryQuantizer::quantize(&v);
1139
1140 let dist = BinaryQuantizer::hamming_distance(&bits, &bits);
1141 assert_eq!(dist, 0);
1142 }
1143
1144 #[test]
1145 fn test_binary_quantizer_opposite_vectors() {
1146 let v1 = vec![1.0f32; 64];
1147 let v2 = vec![-1.0f32; 64];
1148
1149 let bits1 = BinaryQuantizer::quantize(&v1);
1150 let bits2 = BinaryQuantizer::quantize(&v2);
1151
1152 let dist = BinaryQuantizer::hamming_distance(&bits1, &bits2);
1153 assert_eq!(dist, 64); }
1155
1156 #[test]
1157 fn test_binary_quantizer_large_vector() {
1158 let v: Vec<f32> = (0..1000)
1159 .map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
1160 .collect();
1161 let bits = BinaryQuantizer::quantize(&v);
1162
1163 assert_eq!(bits.len(), 16);
1165 }
1166
1167 #[test]
1168 fn test_binary_quantizer_normalized_distance() {
1169 let v1 = vec![1.0f32; 100];
1170 let v2 = vec![-1.0f32; 100];
1171
1172 let bits1 = BinaryQuantizer::quantize(&v1);
1173 let bits2 = BinaryQuantizer::quantize(&v2);
1174
1175 let norm_dist = BinaryQuantizer::hamming_distance_normalized(&bits1, &bits2, 100);
1176 assert!((norm_dist - 1.0).abs() < 0.01); }
1178
1179 #[test]
1180 fn test_binary_quantizer_words_needed() {
1181 assert_eq!(BinaryQuantizer::words_needed(1), 1);
1182 assert_eq!(BinaryQuantizer::words_needed(64), 1);
1183 assert_eq!(BinaryQuantizer::words_needed(65), 2);
1184 assert_eq!(BinaryQuantizer::words_needed(128), 2);
1185 assert_eq!(BinaryQuantizer::words_needed(1536), 24); }
1187
1188 #[test]
1189 fn test_binary_quantizer_bytes_needed() {
1190 assert_eq!(BinaryQuantizer::bytes_needed(64), 8);
1192 assert_eq!(BinaryQuantizer::bytes_needed(128), 16);
1193 assert_eq!(BinaryQuantizer::bytes_needed(1536), 192); }
1195
1196 #[test]
1201 fn test_hamming_distance_simd() {
1202 let a = vec![0xFFFF_FFFF_FFFF_FFFFu64, 0x0000_0000_0000_0000];
1203 let b = vec![0x0000_0000_0000_0000u64, 0xFFFF_FFFF_FFFF_FFFF];
1204
1205 let dist = hamming_distance_simd(&a, &b);
1206 assert_eq!(dist, 128); }
1208
1209 #[test]
1214 fn test_quantization_type_product_from_str() {
1215 assert_eq!(
1217 QuantizationType::from_str("pq"),
1218 Some(QuantizationType::Product { num_subvectors: 8 })
1219 );
1220 assert_eq!(
1221 QuantizationType::from_str("product"),
1222 Some(QuantizationType::Product { num_subvectors: 8 })
1223 );
1224
1225 assert_eq!(
1227 QuantizationType::from_str("pq8"),
1228 Some(QuantizationType::Product { num_subvectors: 8 })
1229 );
1230 assert_eq!(
1231 QuantizationType::from_str("pq16"),
1232 Some(QuantizationType::Product { num_subvectors: 16 })
1233 );
1234 assert_eq!(
1235 QuantizationType::from_str("pq32"),
1236 Some(QuantizationType::Product { num_subvectors: 32 })
1237 );
1238 }
1239
1240 #[test]
1241 fn test_quantization_type_requires_training() {
1242 assert!(!QuantizationType::None.requires_training());
1243 assert!(QuantizationType::Scalar.requires_training());
1244 assert!(!QuantizationType::Binary.requires_training());
1245 assert!(QuantizationType::Product { num_subvectors: 8 }.requires_training());
1246 }
1247
1248 #[test]
1249 fn test_product_quantizer_train() {
1250 let vectors: Vec<Vec<f32>> = (0..100)
1252 .map(|i| (0..16).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1253 .collect();
1254 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1255
1256 let pq = ProductQuantizer::train(&refs, 4, 8, 5);
1258
1259 assert_eq!(pq.num_subvectors(), 4);
1260 assert_eq!(pq.num_centroids(), 8);
1261 assert_eq!(pq.dimensions(), 16);
1262 assert_eq!(pq.subvector_dim(), 4);
1263 assert_eq!(pq.code_size(), 4);
1264 }
1265
1266 #[test]
1267 fn test_product_quantizer_quantize() {
1268 let vectors: Vec<Vec<f32>> = (0..50)
1269 .map(|i| (0..8).map(|j| ((i * j) as f32 * 0.1).cos()).collect())
1270 .collect();
1271 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1272
1273 let pq = ProductQuantizer::train(&refs, 2, 16, 3);
1274
1275 let codes = pq.quantize(&vectors[0]);
1277 assert_eq!(codes.len(), 2);
1278
1279 for &code in &codes {
1281 assert!(code < 16);
1282 }
1283 }
1284
1285 #[test]
1286 fn test_product_quantizer_reconstruct() {
1287 let vectors: Vec<Vec<f32>> = (0..50)
1288 .map(|i| (0..12).map(|j| (i + j) as f32 * 0.05).collect())
1289 .collect();
1290 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1291
1292 let pq = ProductQuantizer::train(&refs, 3, 8, 5);
1293
1294 let original = &vectors[10];
1296 let codes = pq.quantize(original);
1297 let reconstructed = pq.reconstruct(&codes);
1298
1299 assert_eq!(reconstructed.len(), 12);
1300
1301 let error: f32 = original
1303 .iter()
1304 .zip(&reconstructed)
1305 .map(|(a, b)| (a - b).powi(2))
1306 .sum::<f32>()
1307 .sqrt();
1308
1309 assert!(error < 2.0, "Reconstruction error too high: {error}");
1311 }
1312
1313 #[test]
1314 fn test_product_quantizer_asymmetric_distance() {
1315 let vectors: Vec<Vec<f32>> = (0..100)
1316 .map(|i| (0..32).map(|j| ((i * j) as f32 * 0.01).sin()).collect())
1317 .collect();
1318 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1319
1320 let pq = ProductQuantizer::train(&refs, 8, 32, 5);
1321
1322 let query = &vectors[0];
1324 let codes = pq.quantize(query);
1325 let self_dist = pq.asymmetric_distance(query, &codes);
1326 assert!(self_dist < 1.0, "Self-distance too high: {self_dist}");
1327
1328 let other_codes = pq.quantize(&vectors[50]);
1330 let other_dist = pq.asymmetric_distance(query, &other_codes);
1331 assert!(other_dist > self_dist, "Other vector should be farther");
1332 }
1333
1334 #[test]
1335 fn test_product_quantizer_distance_table() {
1336 let vectors: Vec<Vec<f32>> = (0..50)
1337 .map(|i| (0..16).map(|j| (i + j) as f32 * 0.02).collect())
1338 .collect();
1339 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1340
1341 let pq = ProductQuantizer::train(&refs, 4, 8, 3);
1342
1343 let query = &vectors[0];
1344 let table = pq.build_distance_table(query);
1345
1346 assert_eq!(table.len(), 4 * 8);
1348
1349 let codes = pq.quantize(&vectors[5]);
1351 let dist_direct = pq.asymmetric_distance_squared(query, &codes);
1352 let dist_table = pq.distance_with_table(&table, &codes);
1353
1354 assert!((dist_direct - dist_table).abs() < 0.001);
1355 }
1356
1357 #[test]
1358 fn test_product_quantizer_batch() {
1359 let vectors: Vec<Vec<f32>> = (0..20)
1360 .map(|i| (0..8).map(|j| (i + j) as f32).collect())
1361 .collect();
1362 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1363
1364 let pq = ProductQuantizer::train(&refs, 2, 4, 2);
1365
1366 let batch_codes = pq.quantize_batch(&refs[0..5]);
1367 assert_eq!(batch_codes.len(), 5);
1368
1369 for codes in &batch_codes {
1370 assert_eq!(codes.len(), 2);
1371 }
1372 }
1373
1374 #[test]
1375 fn test_product_quantizer_compression_ratio() {
1376 let vectors: Vec<Vec<f32>> = (0..50)
1377 .map(|i| (0..384).map(|j| ((i * j) as f32).sin()).collect())
1378 .collect();
1379 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1380
1381 let pq8 = ProductQuantizer::train(&refs, 8, 256, 3);
1383 assert_eq!(pq8.compression_ratio(), 192); let pq48 = ProductQuantizer::train(&refs, 48, 256, 3);
1387 assert_eq!(pq48.compression_ratio(), 32); }
1389
1390 #[test]
1391 #[should_panic(expected = "dimensions (15) must be divisible by num_subvectors (4)")]
1392 fn test_product_quantizer_invalid_dimensions() {
1393 let vectors: Vec<Vec<f32>> = (0..10)
1394 .map(|i| (0..15).map(|j| (i + j) as f32).collect())
1395 .collect();
1396 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1397
1398 let _ = ProductQuantizer::train(&refs, 4, 8, 3);
1400 }
1401
1402 #[test]
1403 fn test_product_quantizer_get_partition_centroids() {
1404 let vectors: Vec<Vec<f32>> = (0..30)
1405 .map(|i| (0..8).map(|j| (i + j) as f32 * 0.1).collect())
1406 .collect();
1407 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
1408
1409 let pq = ProductQuantizer::train(&refs, 2, 4, 3);
1410
1411 let centroids = pq.get_partition_centroids(0);
1413 assert_eq!(centroids.len(), 4); assert_eq!(centroids[0].len(), 4); }
1416}