1use pulp::Simd;
35use serde::{Deserialize, Serialize};
36
37pub trait Quantizer: Send + Sync {
39 type Quantized: Clone + Send + Sync;
41
42 fn quantize(&self, vector: &[f32]) -> Self::Quantized;
44
45 fn dequantize(&self, quantized: &Self::Quantized) -> Vec<f32>;
47
48 fn distance_quantized(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32;
50
51 fn distance_asymmetric(&self, query: &[f32], quantized: &Self::Quantized) -> f32;
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ScalarQuantizationParams {
62 pub min: f32,
64 pub max: f32,
66 pub scale: f32,
68}
69
70impl ScalarQuantizationParams {
71 pub fn new(min: f32, max: f32) -> Self {
73 let range = max - min;
74 let scale = if range > 0.0 { range / 255.0 } else { 1.0 };
75 Self { min, max, scale }
76 }
77
78 #[inline]
80 pub fn quantize_value(&self, value: f32) -> u8 {
81 let normalized = (value - self.min) / self.scale;
82 normalized.clamp(0.0, 255.0) as u8
83 }
84
85 #[inline]
87 pub fn dequantize_value(&self, quantized: u8) -> f32 {
88 (quantized as f32) * self.scale + self.min
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ScalarQuantizer {
120 params: Vec<ScalarQuantizationParams>,
122 dim: usize,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ScalarQuantizedVector {
129 pub data: Vec<u8>,
131}
132
133impl ScalarQuantizer {
134 pub fn fit(training_vectors: &[Vec<f32>]) -> Self {
142 assert!(
143 !training_vectors.is_empty(),
144 "Need at least one training vector"
145 );
146
147 let dim = training_vectors[0].len();
148 let mut mins = vec![f32::INFINITY; dim];
149 let mut maxs = vec![f32::NEG_INFINITY; dim];
150
151 for vector in training_vectors {
152 assert_eq!(vector.len(), dim, "Inconsistent vector dimensions");
153 for (i, &val) in vector.iter().enumerate() {
154 mins[i] = mins[i].min(val);
155 maxs[i] = maxs[i].max(val);
156 }
157 }
158
159 let params: Vec<_> = mins
160 .iter()
161 .zip(maxs.iter())
162 .map(|(&min, &max)| ScalarQuantizationParams::new(min, max))
163 .collect();
164
165 Self { params, dim }
166 }
167
168 pub fn with_bounds(dim: usize, min: f32, max: f32) -> Self {
172 let params = vec![ScalarQuantizationParams::new(min, max); dim];
173 Self { params, dim }
174 }
175
176 pub fn for_normalized(dim: usize) -> Self {
178 Self::with_bounds(dim, -1.0, 1.0)
179 }
180
181 pub fn dim(&self) -> usize {
183 self.dim
184 }
185
186 pub fn params(&self) -> &[ScalarQuantizationParams] {
188 &self.params
189 }
190}
191
192impl Quantizer for ScalarQuantizer {
193 type Quantized = ScalarQuantizedVector;
194
195 fn quantize(&self, vector: &[f32]) -> Self::Quantized {
196 debug_assert_eq!(vector.len(), self.dim);
197
198 let data: Vec<u8> = vector
199 .iter()
200 .zip(self.params.iter())
201 .map(|(&val, param)| param.quantize_value(val))
202 .collect();
203
204 ScalarQuantizedVector { data }
205 }
206
207 fn dequantize(&self, quantized: &Self::Quantized) -> Vec<f32> {
208 quantized
209 .data
210 .iter()
211 .zip(self.params.iter())
212 .map(|(&val, param)| param.dequantize_value(val))
213 .collect()
214 }
215
216 fn distance_quantized(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32 {
217 sq8_l2_distance_simd(&a.data, &b.data)
219 }
220
221 fn distance_asymmetric(&self, query: &[f32], quantized: &Self::Quantized) -> f32 {
222 sq8_asymmetric_l2_distance_simd(query, &quantized.data, &self.params)
224 }
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct BinaryQuantizer {
257 dim: usize,
258 byte_len: usize,
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct BinaryQuantizedVector {
265 pub data: Vec<u8>,
267}
268
269impl BinaryQuantizer {
270 pub fn new(dim: usize) -> Self {
272 let byte_len = (dim + 7) / 8; Self { dim, byte_len }
274 }
275
276 pub fn dim(&self) -> usize {
278 self.dim
279 }
280
281 pub fn byte_len(&self) -> usize {
283 self.byte_len
284 }
285}
286
287impl Quantizer for BinaryQuantizer {
288 type Quantized = BinaryQuantizedVector;
289
290 fn quantize(&self, vector: &[f32]) -> Self::Quantized {
291 debug_assert_eq!(vector.len(), self.dim);
292
293 let mut data = vec![0u8; self.byte_len];
294
295 for (i, &val) in vector.iter().enumerate() {
296 if val >= 0.0 {
297 let byte_idx = i / 8;
298 let bit_idx = i % 8;
299 data[byte_idx] |= 1 << bit_idx;
300 }
301 }
302
303 BinaryQuantizedVector { data }
304 }
305
306 fn dequantize(&self, quantized: &Self::Quantized) -> Vec<f32> {
307 let mut result = vec![0.0f32; self.dim];
309
310 for i in 0..self.dim {
311 let byte_idx = i / 8;
312 let bit_idx = i % 8;
313 let bit = (quantized.data[byte_idx] >> bit_idx) & 1;
314 result[i] = if bit == 1 { 1.0 } else { -1.0 };
315 }
316
317 result
318 }
319
320 fn distance_quantized(&self, a: &Self::Quantized, b: &Self::Quantized) -> f32 {
321 hamming_distance_simd(&a.data, &b.data) as f32
323 }
324
325 fn distance_asymmetric(&self, query: &[f32], quantized: &Self::Quantized) -> f32 {
326 let mut mismatches = 0u32;
328
329 for (i, &val) in query.iter().enumerate() {
330 let byte_idx = i / 8;
331 let bit_idx = i % 8;
332 let quantized_bit = (quantized.data[byte_idx] >> bit_idx) & 1;
333 let query_bit = if val >= 0.0 { 1 } else { 0 };
334
335 if quantized_bit != query_bit {
336 mismatches += 1;
337 }
338 }
339
340 mismatches as f32
341 }
342}
343
344#[inline]
350pub fn sq8_l2_distance_simd(a: &[u8], b: &[u8]) -> f32 {
351 debug_assert_eq!(a.len(), b.len());
352
353 let simd = pulp::Arch::new();
354 simd.dispatch(|| sq8_l2_distance_impl(simd, a, b))
355}
356
357#[inline(always)]
359fn sq8_l2_distance_impl(simd: pulp::Arch, a: &[u8], b: &[u8]) -> f32 {
360 struct Sq8L2<'a> {
361 a: &'a [u8],
362 b: &'a [u8],
363 }
364
365 impl pulp::WithSimd for Sq8L2<'_> {
366 type Output = f32;
367
368 #[inline(always)]
369 fn with_simd<S: Simd>(self, _simd: S) -> Self::Output {
370 let mut sum_sq: u32 = 0;
374
375 let mut chunks = self.a.chunks_exact(4).zip(self.b.chunks_exact(4));
376 for (a_chunk, b_chunk) in &mut chunks {
377 let d0 = (a_chunk[0] as i32) - (b_chunk[0] as i32);
378 let d1 = (a_chunk[1] as i32) - (b_chunk[1] as i32);
379 let d2 = (a_chunk[2] as i32) - (b_chunk[2] as i32);
380 let d3 = (a_chunk[3] as i32) - (b_chunk[3] as i32);
381 sum_sq += (d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3) as u32;
382 }
383
384 let rem_start = self.a.len() - self.a.len() % 4;
386 for i in rem_start..self.a.len() {
387 let diff = (self.a[i] as i32) - (self.b[i] as i32);
388 sum_sq += (diff * diff) as u32;
389 }
390
391 (sum_sq as f32).sqrt()
392 }
393 }
394
395 simd.dispatch(Sq8L2 { a, b })
396}
397
398#[inline]
400pub fn sq8_asymmetric_l2_distance_simd(
401 query: &[f32],
402 quantized: &[u8],
403 params: &[ScalarQuantizationParams],
404) -> f32 {
405 debug_assert_eq!(query.len(), quantized.len());
406 debug_assert_eq!(query.len(), params.len());
407
408 let simd = pulp::Arch::new();
409 simd.dispatch(|| sq8_asymmetric_l2_impl(simd, query, quantized, params))
410}
411
412#[inline(always)]
414fn sq8_asymmetric_l2_impl(
415 simd: pulp::Arch,
416 query: &[f32],
417 quantized: &[u8],
418 params: &[ScalarQuantizationParams],
419) -> f32 {
420 struct AsymL2<'a> {
421 query: &'a [f32],
422 quantized: &'a [u8],
423 params: &'a [ScalarQuantizationParams],
424 }
425
426 impl pulp::WithSimd for AsymL2<'_> {
427 type Output = f32;
428
429 #[inline(always)]
430 fn with_simd<S: Simd>(self, _simd: S) -> Self::Output {
431 let mut sum_sq: f32 = 0.0;
435 let n = self.query.len();
436
437 let mut i = 0;
438 while i + 4 <= n {
439 let d0 = self.query[i] - self.params[i].dequantize_value(self.quantized[i]);
440 let d1 =
441 self.query[i + 1] - self.params[i + 1].dequantize_value(self.quantized[i + 1]);
442 let d2 =
443 self.query[i + 2] - self.params[i + 2].dequantize_value(self.quantized[i + 2]);
444 let d3 =
445 self.query[i + 3] - self.params[i + 3].dequantize_value(self.quantized[i + 3]);
446 sum_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
447 i += 4;
448 }
449
450 while i < n {
452 let dequantized = self.params[i].dequantize_value(self.quantized[i]);
453 let diff = self.query[i] - dequantized;
454 sum_sq += diff * diff;
455 i += 1;
456 }
457
458 sum_sq.sqrt()
459 }
460 }
461
462 simd.dispatch(AsymL2 {
463 query,
464 quantized,
465 params,
466 })
467}
468
469#[inline]
471pub fn hamming_distance_simd(a: &[u8], b: &[u8]) -> u32 {
472 debug_assert_eq!(a.len(), b.len());
473
474 let simd = pulp::Arch::new();
475 simd.dispatch(|| hamming_distance_impl(simd, a, b))
476}
477
478#[inline(always)]
480fn hamming_distance_impl(simd: pulp::Arch, a: &[u8], b: &[u8]) -> u32 {
481 struct Hamming<'a> {
482 a: &'a [u8],
483 b: &'a [u8],
484 }
485
486 impl pulp::WithSimd for Hamming<'_> {
487 type Output = u32;
488
489 #[inline(always)]
490 fn with_simd<S: Simd>(self, _simd: S) -> Self::Output {
491 let mut distance = 0u32;
493
494 let chunks = self.a.len() / 8;
496 for i in 0..chunks {
497 let offset = i * 8;
498 let a_u64 = u64::from_le_bytes([
499 self.a[offset],
500 self.a[offset + 1],
501 self.a[offset + 2],
502 self.a[offset + 3],
503 self.a[offset + 4],
504 self.a[offset + 5],
505 self.a[offset + 6],
506 self.a[offset + 7],
507 ]);
508 let b_u64 = u64::from_le_bytes([
509 self.b[offset],
510 self.b[offset + 1],
511 self.b[offset + 2],
512 self.b[offset + 3],
513 self.b[offset + 4],
514 self.b[offset + 5],
515 self.b[offset + 6],
516 self.b[offset + 7],
517 ]);
518 distance += (a_u64 ^ b_u64).count_ones();
519 }
520
521 for i in (chunks * 8)..self.a.len() {
523 distance += (self.a[i] ^ self.b[i]).count_ones();
524 }
525
526 distance
527 }
528 }
529
530 simd.dispatch(Hamming { a, b })
531}
532
533#[inline]
539pub fn binary_dot_product(query: &[f32], quantized: &BinaryQuantizedVector, dim: usize) -> f32 {
540 let mut sum = 0.0f32;
541
542 for i in 0..dim {
543 let byte_idx = i / 8;
544 let bit_idx = i % 8;
545 let bit = ((quantized.data[byte_idx] >> bit_idx) & 1) as f32;
546 let sign = bit * 2.0 - 1.0;
548 sum += query[i] * sign;
549 }
550
551 sum
552}
553
554#[derive(Debug, Clone, Serialize, Deserialize)]
560pub struct ProductQuantizerConfig {
561 pub dim: usize,
563 pub num_subvectors: usize,
565 pub bits_per_subvector: usize,
567}
568
569impl ProductQuantizerConfig {
570 pub fn default_for_dim(dim: usize) -> Self {
574 let num_subvectors = 8.min(dim);
575 Self {
576 dim,
577 num_subvectors,
578 bits_per_subvector: 8,
579 }
580 }
581
582 pub fn subvector_dim(&self) -> usize {
584 self.dim / self.num_subvectors
585 }
586
587 pub fn num_centroids(&self) -> usize {
589 1 << self.bits_per_subvector
590 }
591
592 pub fn compressed_size(&self) -> usize {
594 self.num_subvectors * ((self.bits_per_subvector + 7) / 8)
595 }
596}
597
598#[cfg(test)]
606mod tests {
607 use super::*;
608
609 const EPSILON: f32 = 1e-5;
610
611 #[test]
616 fn test_scalar_quantizer_fit() {
617 let vectors = vec![
618 vec![0.0, 0.5, 1.0],
619 vec![0.2, 0.3, 0.8],
620 vec![0.1, 0.6, 0.9],
621 ];
622
623 let sq = ScalarQuantizer::fit(&vectors);
624 assert_eq!(sq.dim(), 3);
625
626 assert!((sq.params[0].min - 0.0).abs() < EPSILON);
628 assert!((sq.params[0].max - 0.2).abs() < EPSILON);
629 assert!((sq.params[2].min - 0.8).abs() < EPSILON);
630 assert!((sq.params[2].max - 1.0).abs() < EPSILON);
631 }
632
633 #[test]
634 fn test_scalar_quantizer_roundtrip() {
635 let vectors = vec![vec![-1.0, 0.0, 1.0], vec![-0.5, 0.5, 0.5]];
636
637 let sq = ScalarQuantizer::fit(&vectors);
638 let original = vec![-0.7, 0.3, 0.8];
639 let quantized = sq.quantize(&original);
640 let reconstructed = sq.dequantize(&quantized);
641
642 for (o, r) in original.iter().zip(reconstructed.iter()) {
644 assert!((o - r).abs() < 0.02, "orig={}, recon={}", o, r);
645 }
646 }
647
648 #[test]
649 fn test_scalar_quantizer_for_normalized() {
650 let sq = ScalarQuantizer::for_normalized(384);
651 assert_eq!(sq.dim(), 384);
652
653 let vector: Vec<f32> = (0..384).map(|i| (i as f32 / 192.0) - 1.0).collect();
655 let quantized = sq.quantize(&vector);
656 let reconstructed = sq.dequantize(&quantized);
657
658 let max_error: f32 = vector
660 .iter()
661 .zip(reconstructed.iter())
662 .map(|(a, b)| (a - b).abs())
663 .fold(0.0f32, |a, b| a.max(b));
664
665 assert!(max_error < 0.01, "Max error: {}", max_error);
666 }
667
668 #[test]
669 fn test_sq8_distance_quantized() {
670 let sq = ScalarQuantizer::for_normalized(4);
671
672 let a = vec![1.0, 0.0, -1.0, 0.5];
673 let b = vec![1.0, 0.0, -1.0, 0.5];
674
675 let qa = sq.quantize(&a);
676 let qb = sq.quantize(&b);
677
678 let dist = sq.distance_quantized(&qa, &qb);
679 assert!(dist < 1.0, "Same vectors should have near-zero distance");
680 }
681
682 #[test]
683 fn test_sq8_distance_different() {
684 let sq = ScalarQuantizer::for_normalized(4);
685
686 let a = vec![1.0, 1.0, 1.0, 1.0];
687 let b = vec![-1.0, -1.0, -1.0, -1.0];
688
689 let qa = sq.quantize(&a);
690 let qb = sq.quantize(&b);
691
692 let dist = sq.distance_quantized(&qa, &qb);
693 assert!(dist > 100.0, "Opposite vectors should have large distance");
694 }
695
696 #[test]
701 fn test_binary_quantizer_basic() {
702 let bq = BinaryQuantizer::new(8);
703 assert_eq!(bq.dim(), 8);
704 assert_eq!(bq.byte_len(), 1);
705 }
706
707 #[test]
708 fn test_binary_quantizer_byte_len() {
709 assert_eq!(BinaryQuantizer::new(1).byte_len(), 1);
711 assert_eq!(BinaryQuantizer::new(8).byte_len(), 1);
712 assert_eq!(BinaryQuantizer::new(9).byte_len(), 2);
713 assert_eq!(BinaryQuantizer::new(16).byte_len(), 2);
714 assert_eq!(BinaryQuantizer::new(384).byte_len(), 48);
715 }
716
717 #[test]
718 fn test_binary_quantizer_all_positive() {
719 let bq = BinaryQuantizer::new(8);
720 let vector = vec![0.5, 0.3, 0.1, 0.9, 0.2, 0.4, 0.6, 0.8];
721 let quantized = bq.quantize(&vector);
722
723 assert_eq!(quantized.data[0], 0xFF);
725 }
726
727 #[test]
728 fn test_binary_quantizer_all_negative() {
729 let bq = BinaryQuantizer::new(8);
730 let vector = vec![-0.5, -0.3, -0.1, -0.9, -0.2, -0.4, -0.6, -0.8];
731 let quantized = bq.quantize(&vector);
732
733 assert_eq!(quantized.data[0], 0x00);
735 }
736
737 #[test]
738 fn test_binary_quantizer_mixed() {
739 let bq = BinaryQuantizer::new(8);
740 let vector = vec![0.5, -0.3, 0.1, -0.9, 0.2, -0.4, 0.6, -0.8];
741 let quantized = bq.quantize(&vector);
743 assert_eq!(quantized.data[0], 0b01010101);
744 }
745
746 #[test]
747 fn test_binary_hamming_distance() {
748 let bq = BinaryQuantizer::new(8);
749
750 let a = vec![1.0; 8]; let b = vec![-1.0; 8]; let qa = bq.quantize(&a);
754 let qb = bq.quantize(&b);
755
756 let dist = bq.distance_quantized(&qa, &qb);
757 assert_eq!(dist, 8.0); }
759
760 #[test]
761 fn test_binary_hamming_same() {
762 let bq = BinaryQuantizer::new(16);
763
764 let a = vec![
765 0.5, -0.3, 0.1, -0.9, 0.2, -0.4, 0.6, -0.8, 0.5, -0.3, 0.1, -0.9, 0.2, -0.4, 0.6, -0.8,
766 ];
767
768 let qa = bq.quantize(&a);
769 let qb = bq.quantize(&a);
770
771 let dist = bq.distance_quantized(&qa, &qb);
772 assert_eq!(dist, 0.0); }
774
775 #[test]
776 fn test_binary_dequantize() {
777 let bq = BinaryQuantizer::new(4);
778 let vector = vec![0.5, -0.3, 0.1, -0.9];
779 let quantized = bq.quantize(&vector);
780 let dequantized = bq.dequantize(&quantized);
781
782 assert_eq!(dequantized, vec![1.0, -1.0, 1.0, -1.0]);
784 }
785
786 #[test]
787 fn test_binary_large_dimension() {
788 let bq = BinaryQuantizer::new(384);
789 let vector: Vec<f32> = (0..384)
790 .map(|i| if i % 2 == 0 { 0.5 } else { -0.5 })
791 .collect();
792
793 let quantized = bq.quantize(&vector);
794 assert_eq!(quantized.data.len(), 48);
795
796 let dequantized = bq.dequantize(&quantized);
797 for (i, &val) in dequantized.iter().enumerate() {
798 let expected = if i % 2 == 0 { 1.0 } else { -1.0 };
799 assert_eq!(val, expected);
800 }
801 }
802
803 #[test]
808 fn test_hamming_distance_simd_basic() {
809 let a = vec![0b11110000u8, 0b10101010];
810 let b = vec![0b00001111u8, 0b10101010];
811
812 let dist = hamming_distance_simd(&a, &b);
813 assert_eq!(dist, 8);
816 }
817
818 #[test]
819 fn test_hamming_distance_simd_same() {
820 let a = vec![0xFF, 0x00, 0xAB, 0xCD];
821 let b = a.clone();
822
823 let dist = hamming_distance_simd(&a, &b);
824 assert_eq!(dist, 0);
825 }
826
827 #[test]
828 fn test_sq8_l2_distance_simd_basic() {
829 let a = vec![0u8, 50, 100, 150, 200, 250];
830 let b = vec![0u8, 50, 100, 150, 200, 250];
831
832 let dist = sq8_l2_distance_simd(&a, &b);
833 assert!(dist < EPSILON);
834 }
835
836 #[test]
837 fn test_sq8_l2_distance_simd_different() {
838 let a = vec![0u8, 0, 0, 0];
839 let b = vec![255u8, 255, 255, 255];
840
841 let dist = sq8_l2_distance_simd(&a, &b);
842 assert!((dist - 510.0).abs() < 1.0);
844 }
845
846 #[test]
851 fn test_pq_config_defaults() {
852 let config = ProductQuantizerConfig::default_for_dim(384);
853
854 assert_eq!(config.dim, 384);
855 assert_eq!(config.num_subvectors, 8);
856 assert_eq!(config.bits_per_subvector, 8);
857 assert_eq!(config.subvector_dim(), 48);
858 assert_eq!(config.num_centroids(), 256);
859 assert_eq!(config.compressed_size(), 8); }
861
862 #[test]
867 fn test_sq8_recall_approximation() {
868 use rand::SeedableRng;
870 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
871
872 let dim = 128;
873 let num_vectors = 100;
874
875 let vectors: Vec<Vec<f32>> = (0..num_vectors)
876 .map(|_| {
877 (0..dim)
878 .map(|_| rand::Rng::gen_range(&mut rng, -1.0..1.0))
879 .collect()
880 })
881 .collect();
882
883 let sq = ScalarQuantizer::fit(&vectors);
884 let quantized: Vec<_> = vectors.iter().map(|v| sq.quantize(v)).collect();
885
886 let query_idx = 42;
888 let query = &vectors[query_idx];
889 let query_q = &quantized[query_idx];
890
891 let mut exact_distances: Vec<(usize, f32)> = vectors
893 .iter()
894 .enumerate()
895 .filter(|(i, _)| *i != query_idx)
896 .map(|(i, v)| {
897 let dist: f32 = query
898 .iter()
899 .zip(v.iter())
900 .map(|(a, b)| (a - b).powi(2))
901 .sum::<f32>()
902 .sqrt();
903 (i, dist)
904 })
905 .collect();
906
907 let mut quantized_distances: Vec<(usize, f32)> = quantized
908 .iter()
909 .enumerate()
910 .filter(|(i, _)| *i != query_idx)
911 .map(|(i, q)| (i, sq.distance_quantized(query_q, q)))
912 .collect();
913
914 exact_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
915 quantized_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
916
917 let exact_top10: std::collections::HashSet<_> =
919 exact_distances[..10].iter().map(|(i, _)| *i).collect();
920 let quantized_top10: std::collections::HashSet<_> =
921 quantized_distances[..10].iter().map(|(i, _)| *i).collect();
922
923 let recall = exact_top10.intersection(&quantized_top10).count();
924 assert!(recall >= 7, "Recall@10: {}/10", recall);
926 }
927
928 #[test]
929 fn test_binary_recall_approximation() {
930 use rand::SeedableRng;
931 let mut rng = rand::rngs::StdRng::seed_from_u64(123);
932
933 let dim = 128;
934 let num_vectors = 100;
935
936 let vectors: Vec<Vec<f32>> = (0..num_vectors)
937 .map(|_| {
938 (0..dim)
939 .map(|_| rand::Rng::gen_range(&mut rng, -1.0..1.0))
940 .collect()
941 })
942 .collect();
943
944 let bq = BinaryQuantizer::new(dim);
945 let quantized: Vec<_> = vectors.iter().map(|v| bq.quantize(v)).collect();
946
947 let query_idx = 42;
948 let query = &vectors[query_idx];
949 let query_q = &quantized[query_idx];
950
951 let mut exact_distances: Vec<(usize, f32)> = vectors
953 .iter()
954 .enumerate()
955 .filter(|(i, _)| *i != query_idx)
956 .map(|(i, v)| {
957 let dot: f32 = query.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
958 let norm_q: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
959 let norm_v: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
960 let cosine = dot / (norm_q * norm_v);
961 (i, 1.0 - cosine) })
963 .collect();
964
965 let mut quantized_distances: Vec<(usize, f32)> = quantized
966 .iter()
967 .enumerate()
968 .filter(|(i, _)| *i != query_idx)
969 .map(|(i, q)| (i, bq.distance_quantized(query_q, q)))
970 .collect();
971
972 exact_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
973 quantized_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
974
975 let exact_top10: std::collections::HashSet<_> =
977 exact_distances[..10].iter().map(|(i, _)| *i).collect();
978 let quantized_top10: std::collections::HashSet<_> =
979 quantized_distances[..10].iter().map(|(i, _)| *i).collect();
980
981 let recall = exact_top10.intersection(&quantized_top10).count();
982 assert!(recall >= 5, "Binary recall@10: {}/10", recall);
983 }
984}