1use serde::{Deserialize, Serialize};
21use smallvec::SmallVec;
22use std::fmt;
23
24#[cfg(target_arch = "aarch64")]
25use std::arch::aarch64::{vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vsubq_f32};
26#[cfg(target_arch = "x86_64")]
27#[allow(clippy::wildcard_imports)]
28use std::arch::x86_64::*;
29
30const MAX_CODES: usize = 16;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35pub enum QuantizationBits {
36 Bits1,
38 Bits2,
40 Bits3,
42 Bits4,
44 Bits5,
46 Bits7,
48 Bits8,
50}
51
52impl QuantizationBits {
53 #[must_use]
55 pub fn to_u8(self) -> u8 {
56 match self {
57 QuantizationBits::Bits1 => 1,
58 QuantizationBits::Bits2 => 2,
59 QuantizationBits::Bits3 => 3,
60 QuantizationBits::Bits4 => 4,
61 QuantizationBits::Bits5 => 5,
62 QuantizationBits::Bits7 => 7,
63 QuantizationBits::Bits8 => 8,
64 }
65 }
66
67 #[must_use]
69 pub fn levels(self) -> usize {
70 1 << self.to_u8()
71 }
72
73 #[must_use]
75 pub fn compression_ratio(self) -> f32 {
76 32.0 / self.to_u8() as f32
77 }
78
79 #[must_use]
81 pub fn values_per_byte(self) -> usize {
82 8 / self.to_u8() as usize
83 }
84}
85
86impl fmt::Display for QuantizationBits {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 write!(f, "{}-bit", self.to_u8())
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct RaBitQParams {
95 pub bits_per_dim: QuantizationBits,
97
98 pub num_rescale_factors: usize,
103
104 pub rescale_range: (f32, f32),
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct TrainedParams {
116 pub mins: Vec<f32>,
118 pub maxs: Vec<f32>,
120 pub dimensions: usize,
122}
123
124impl TrainedParams {
125 pub fn train(vectors: &[&[f32]]) -> Result<Self, &'static str> {
136 Self::train_with_percentiles(vectors, 0.01, 0.99)
137 }
138
139 pub fn train_with_percentiles(
149 vectors: &[&[f32]],
150 lower_percentile: f32,
151 upper_percentile: f32,
152 ) -> Result<Self, &'static str> {
153 if vectors.is_empty() {
154 return Err("Need at least one vector to train");
155 }
156 let dimensions = vectors[0].len();
157 if !vectors.iter().all(|v| v.len() == dimensions) {
158 return Err("All vectors must have same dimensions");
159 }
160
161 let n = vectors.len();
162 let lower_idx = ((n as f32 * lower_percentile) as usize).min(n - 1);
163 let upper_idx = ((n as f32 * upper_percentile) as usize).min(n - 1);
164
165 let mut mins = Vec::with_capacity(dimensions);
166 let mut maxs = Vec::with_capacity(dimensions);
167
168 let mut dim_values: Vec<f32> = Vec::with_capacity(n);
170 for d in 0..dimensions {
171 dim_values.clear();
172 for v in vectors {
173 dim_values.push(v[d]);
174 }
175 dim_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
176
177 let min_val = dim_values[lower_idx];
178 let max_val = dim_values[upper_idx];
179
180 let range = max_val - min_val;
182 if range < 1e-7 {
183 mins.push(min_val - 0.5);
184 maxs.push(max_val + 0.5);
185 } else {
186 mins.push(min_val);
187 maxs.push(max_val);
188 }
189 }
190
191 Ok(Self {
192 mins,
193 maxs,
194 dimensions,
195 })
196 }
197
198 #[inline]
200 #[must_use]
201 pub fn quantize_value(&self, value: f32, dim: usize, levels: usize) -> u8 {
202 let min = self.mins[dim];
203 let max = self.maxs[dim];
204 let range = max - min;
205
206 let normalized = (value - min) / range;
208 let level = (normalized * (levels - 1) as f32).round();
209 level.clamp(0.0, (levels - 1) as f32) as u8
210 }
211
212 #[inline]
214 #[must_use]
215 pub fn dequantize_value(&self, code: u8, dim: usize, levels: usize) -> f32 {
216 let min = self.mins[dim];
217 let max = self.maxs[dim];
218 let range = max - min;
219
220 (code as f32 / (levels - 1) as f32) * range + min
222 }
223}
224
225impl Default for RaBitQParams {
226 fn default() -> Self {
227 Self {
228 bits_per_dim: QuantizationBits::Bits4, num_rescale_factors: 12, rescale_range: (0.5, 2.0), }
232 }
233}
234
235impl RaBitQParams {
236 #[must_use]
238 pub fn bits2() -> Self {
239 Self {
240 bits_per_dim: QuantizationBits::Bits2,
241 ..Default::default()
242 }
243 }
244
245 #[must_use]
247 pub fn bits4() -> Self {
248 Self {
249 bits_per_dim: QuantizationBits::Bits4,
250 ..Default::default()
251 }
252 }
253
254 #[must_use]
256 pub fn bits8() -> Self {
257 Self {
258 bits_per_dim: QuantizationBits::Bits8,
259 num_rescale_factors: 16, rescale_range: (0.7, 1.5), }
262 }
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct QuantizedVector {
273 pub data: Vec<u8>,
281
282 pub scale: f32,
287
288 pub bits: u8,
290
291 pub dimensions: usize,
293}
294
295impl QuantizedVector {
296 #[must_use]
298 pub fn new(data: Vec<u8>, scale: f32, bits: u8, dimensions: usize) -> Self {
299 Self {
300 data,
301 scale,
302 bits,
303 dimensions,
304 }
305 }
306
307 #[must_use]
309 pub fn memory_bytes(&self) -> usize {
310 std::mem::size_of::<Self>() + self.data.len()
311 }
312
313 #[must_use]
315 pub fn compression_ratio(&self) -> f32 {
316 let original_bytes = self.dimensions * 4; let compressed_bytes = self.data.len() + 4 + 1; original_bytes as f32 / compressed_bytes as f32
319 }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct RaBitQ {
344 params: RaBitQParams,
345 trained: Option<TrainedParams>,
348}
349
350#[derive(Debug, Clone)]
378pub struct ADCTable {
379 table: Vec<SmallVec<[f32; MAX_CODES]>>,
383
384 bits: u8,
386
387 dimensions: usize,
389}
390
391impl ADCTable {
392 #[must_use]
403 pub fn new_trained(query: &[f32], trained: &TrainedParams, params: &RaBitQParams) -> Self {
404 let bits = params.bits_per_dim.to_u8();
405 let num_codes = params.bits_per_dim.levels();
406 let dimensions = query.len();
407
408 let mut table = Vec::with_capacity(dimensions);
409
410 for (d, &q_value) in query.iter().enumerate() {
412 let mut dim_table = SmallVec::new();
413
414 for code in 0..num_codes {
415 let reconstructed = trained.dequantize_value(code as u8, d, num_codes);
417
418 let diff = q_value - reconstructed;
420 dim_table.push(diff * diff);
421 }
422
423 table.push(dim_table);
424 }
425
426 Self {
427 table,
428 bits,
429 dimensions,
430 }
431 }
432
433 #[must_use]
455 pub fn new(query: &[f32], scale: f32, params: &RaBitQParams) -> Self {
456 let bits = params.bits_per_dim.to_u8();
457 let num_codes = params.bits_per_dim.levels();
458 let dimensions = query.len();
459
460 let mut table = Vec::with_capacity(dimensions);
461
462 let levels = num_codes as f32;
464 let dequant_factor = 1.0 / ((levels - 1.0) * scale);
465
466 for &q_value in query {
468 let mut dim_table = SmallVec::new();
469
470 for code in 0..num_codes {
471 let reconstructed = (code as f32) * dequant_factor;
473
474 let diff = q_value - reconstructed;
476 dim_table.push(diff * diff);
477 }
478
479 table.push(dim_table);
480 }
481
482 Self {
483 table,
484 bits,
485 dimensions,
486 }
487 }
488
489 #[inline]
508 #[must_use]
509 pub fn distance_squared(&self, data: &[u8]) -> f32 {
510 match self.bits {
511 4 => self.distance_squared_4bit(data),
512 2 => self.distance_squared_2bit(data),
513 8 => self.distance_squared_8bit(data),
514 _ => self.distance_squared_generic(data),
515 }
516 }
517
518 #[inline]
522 #[must_use]
523 pub fn distance(&self, data: &[u8]) -> f32 {
524 self.distance_squared_simd(data).sqrt()
525 }
526
527 #[inline]
534 fn distance_squared_4bit(&self, data: &[u8]) -> f32 {
535 let mut sum = 0.0f32;
536 let num_pairs = self.dimensions / 2;
537
538 for i in 0..num_pairs {
540 if i >= data.len() {
541 break;
542 }
543
544 let byte = unsafe { *data.get_unchecked(i) };
546 let code_hi = (byte >> 4) as usize; let code_lo = (byte & 0x0F) as usize; sum += unsafe {
555 *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
556 + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo)
557 };
558 }
559
560 if self.dimensions % 2 == 1 && num_pairs < data.len() {
562 let byte = unsafe { *data.get_unchecked(num_pairs) };
564 let code_hi = (byte >> 4) as usize; sum += unsafe {
569 *self
570 .table
571 .get_unchecked(self.dimensions - 1)
572 .get_unchecked(code_hi)
573 };
574 }
575
576 sum
577 }
578
579 #[inline]
586 fn distance_squared_2bit(&self, data: &[u8]) -> f32 {
587 let mut sum = 0.0f32;
588 let num_quads = self.dimensions / 4;
589
590 for i in 0..num_quads {
592 if i >= data.len() {
593 break;
594 }
595
596 let byte = unsafe { *data.get_unchecked(i) };
598
599 sum += unsafe {
604 *self
605 .table
606 .get_unchecked(i * 4)
607 .get_unchecked((byte & 0b11) as usize)
608 + *self
609 .table
610 .get_unchecked(i * 4 + 1)
611 .get_unchecked(((byte >> 2) & 0b11) as usize)
612 + *self
613 .table
614 .get_unchecked(i * 4 + 2)
615 .get_unchecked(((byte >> 4) & 0b11) as usize)
616 + *self
617 .table
618 .get_unchecked(i * 4 + 3)
619 .get_unchecked(((byte >> 6) & 0b11) as usize)
620 };
621 }
622
623 let remaining = self.dimensions % 4;
625 if remaining > 0 && num_quads < data.len() {
626 let byte = unsafe { *data.get_unchecked(num_quads) };
628 for j in 0..remaining {
629 let code = ((byte >> (j * 2)) & 0b11) as usize; sum += unsafe {
634 *self
635 .table
636 .get_unchecked(num_quads * 4 + j)
637 .get_unchecked(code)
638 };
639 }
640 }
641
642 sum
643 }
644
645 #[inline]
652 fn distance_squared_8bit(&self, data: &[u8]) -> f32 {
653 let mut sum = 0.0f32;
654
655 for (i, &byte) in data.iter().enumerate().take(self.dimensions) {
656 sum += unsafe { *self.table.get_unchecked(i).get_unchecked(byte as usize) };
661 }
662
663 sum
664 }
665
666 #[inline]
668 fn distance_squared_generic(&self, data: &[u8]) -> f32 {
669 let mut sum = 0.0f32;
671
672 for (i, dim_table) in self.table.iter().enumerate() {
673 if i >= data.len() {
674 break;
675 }
676 let code = data[i] as usize;
677 if let Some(&dist) = dim_table.get(code) {
678 sum += dist;
679 }
680 }
681
682 sum
683 }
684
685 #[inline]
690 #[must_use]
691 pub fn distance_squared_simd(&self, data: &[u8]) -> f32 {
692 match self.bits {
693 4 => {
694 #[cfg(target_arch = "x86_64")]
695 {
696 if is_x86_feature_detected!("avx2") {
697 unsafe { self.distance_squared_4bit_avx2(data) }
698 } else {
699 self.distance_squared_4bit(data)
701 }
702 }
703 #[cfg(target_arch = "aarch64")]
704 {
705 unsafe { self.distance_squared_4bit_neon(data) }
707 }
708 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
709 {
710 self.distance_squared_4bit(data)
712 }
713 }
714 2 => {
715 self.distance_squared_2bit(data)
717 }
718 8 => {
719 self.distance_squared_8bit(data)
721 }
722 _ => self.distance_squared_generic(data),
723 }
724 }
725
726 #[cfg(target_arch = "x86_64")]
728 #[target_feature(enable = "avx2")]
729 #[target_feature(enable = "fma")]
730 unsafe fn distance_squared_4bit_avx2(&self, data: &[u8]) -> f32 {
731 let mut sum = _mm256_setzero_ps();
732 let num_pairs = self.dimensions / 2;
733
734 let chunks = num_pairs / 8;
736 for chunk_idx in 0..chunks {
737 let byte_idx = chunk_idx * 8;
738 if byte_idx + 8 > data.len() {
739 break;
740 }
741
742 let mut values = [0.0f32; 8];
744 for (i, value) in values.iter_mut().enumerate() {
745 let byte = *data.get_unchecked(byte_idx + i);
746 let code_hi = (byte >> 4) as usize;
747 let code_lo = (byte & 0x0F) as usize;
748
749 let dist_hi = *self
750 .table
751 .get_unchecked((byte_idx + i) * 2)
752 .get_unchecked(code_hi);
753 let dist_lo = *self
754 .table
755 .get_unchecked((byte_idx + i) * 2 + 1)
756 .get_unchecked(code_lo);
757 *value = dist_hi + dist_lo;
758 }
759
760 let vec = _mm256_loadu_ps(values.as_ptr());
761 sum = _mm256_add_ps(sum, vec);
762 }
763
764 let mut result = horizontal_sum_avx2(sum);
766
767 for i in (chunks * 8)..num_pairs {
769 if i >= data.len() {
770 break;
771 }
772 let byte = *data.get_unchecked(i);
773 let code_hi = (byte >> 4) as usize;
774 let code_lo = (byte & 0x0F) as usize;
775
776 result += *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
777 + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo);
778 }
779
780 if self.dimensions % 2 == 1 && num_pairs < data.len() {
782 let byte = *data.get_unchecked(num_pairs);
783 let code_hi = (byte >> 4) as usize;
784 result += *self
785 .table
786 .get_unchecked(self.dimensions - 1)
787 .get_unchecked(code_hi);
788 }
789
790 result
791 }
792
793 #[cfg(target_arch = "aarch64")]
795 unsafe fn distance_squared_4bit_neon(&self, data: &[u8]) -> f32 {
796 let mut sum = vdupq_n_f32(0.0);
797 let num_pairs = self.dimensions / 2;
798
799 let chunks = num_pairs / 4;
801 for chunk_idx in 0..chunks {
802 let byte_idx = chunk_idx * 4;
803 if byte_idx + 4 > data.len() {
804 break;
805 }
806
807 let mut values = [0.0f32; 4];
808 for (i, value) in values.iter_mut().enumerate() {
809 let byte = *data.get_unchecked(byte_idx + i);
810 let code_hi = (byte >> 4) as usize;
811 let code_lo = (byte & 0x0F) as usize;
812
813 let dist_hi = *self
814 .table
815 .get_unchecked((byte_idx + i) * 2)
816 .get_unchecked(code_hi);
817 let dist_lo = *self
818 .table
819 .get_unchecked((byte_idx + i) * 2 + 1)
820 .get_unchecked(code_lo);
821 *value = dist_hi + dist_lo;
822 }
823
824 let vec = vld1q_f32(values.as_ptr());
825 sum = vaddq_f32(sum, vec);
826 }
827
828 let mut result = vaddvq_f32(sum);
829
830 for i in (chunks * 4)..num_pairs {
832 if i >= data.len() {
833 break;
834 }
835 let byte = *data.get_unchecked(i);
836 let code_hi = (byte >> 4) as usize;
837 let code_lo = (byte & 0x0F) as usize;
838
839 result += *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
840 + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo);
841 }
842
843 if self.dimensions % 2 == 1 && num_pairs < data.len() {
845 let byte = *data.get_unchecked(num_pairs);
846 let code_hi = (byte >> 4) as usize;
847 result += *self
848 .table
849 .get_unchecked(self.dimensions - 1)
850 .get_unchecked(code_hi);
851 }
852
853 result
854 }
855
856 #[must_use]
858 pub fn bits(&self) -> u8 {
859 self.bits
860 }
861
862 #[must_use]
864 pub fn dimensions(&self) -> usize {
865 self.dimensions
866 }
867
868 #[must_use]
872 pub fn get(&self, dim: usize, code: usize) -> f32 {
873 self.table
874 .get(dim)
875 .and_then(|t| t.get(code))
876 .copied()
877 .unwrap_or(0.0)
878 }
879
880 #[must_use]
882 pub fn memory_bytes(&self) -> usize {
883 std::mem::size_of::<Self>()
884 + self.table.len() * std::mem::size_of::<SmallVec<[f32; MAX_CODES]>>()
885 + self.table.iter().map(|t| t.len() * 4).sum::<usize>()
886 }
887}
888
889impl RaBitQ {
890 #[must_use]
894 pub fn new(params: RaBitQParams) -> Self {
895 Self {
896 params,
897 trained: None,
898 }
899 }
900
901 #[must_use]
903 pub fn new_trained(params: RaBitQParams, trained: TrainedParams) -> Self {
904 Self {
905 params,
906 trained: Some(trained),
907 }
908 }
909
910 #[must_use]
912 pub fn default_4bit() -> Self {
913 Self::new(RaBitQParams::bits4())
914 }
915
916 #[must_use]
918 pub fn params(&self) -> &RaBitQParams {
919 &self.params
920 }
921
922 #[must_use]
924 pub fn is_trained(&self) -> bool {
925 self.trained.is_some()
926 }
927
928 #[must_use]
930 pub fn trained_params(&self) -> Option<&TrainedParams> {
931 self.trained.as_ref()
932 }
933
934 pub fn train(&mut self, vectors: &[&[f32]]) -> Result<(), &'static str> {
945 self.trained = Some(TrainedParams::train(vectors)?);
946 Ok(())
947 }
948
949 pub fn train_owned(&mut self, vectors: &[Vec<f32>]) -> Result<(), &'static str> {
954 let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
955 self.train(&refs)
956 }
957
958 #[must_use]
963 pub fn quantize(&self, vector: &[f32]) -> QuantizedVector {
964 if let Some(ref trained) = self.trained {
966 return self.quantize_trained(vector, trained);
967 }
968
969 let mut best_error = f32::MAX;
971 let mut best_quantized = Vec::new();
972 let mut best_scale = 1.0;
973
974 let scales = self.generate_scales();
976
977 for scale in scales {
979 let quantized = self.quantize_with_scale(vector, scale);
980 let error = self.compute_error(vector, &quantized, scale);
981
982 if error < best_error {
983 best_error = error;
984 best_quantized = quantized;
985 best_scale = scale;
986 }
987 }
988
989 QuantizedVector::new(
990 best_quantized,
991 best_scale,
992 self.params.bits_per_dim.to_u8(),
993 vector.len(),
994 )
995 }
996
997 fn quantize_trained(&self, vector: &[f32], trained: &TrainedParams) -> QuantizedVector {
1001 let bits = self.params.bits_per_dim.to_u8();
1002 let levels = self.params.bits_per_dim.levels();
1003
1004 let quantized: Vec<u8> = vector
1006 .iter()
1007 .enumerate()
1008 .map(|(d, &v)| trained.quantize_value(v, d, levels))
1009 .collect();
1010
1011 let packed = self.pack_quantized(&quantized, bits);
1013
1014 QuantizedVector::new(packed, 1.0, bits, vector.len())
1016 }
1017
1018 fn generate_scales(&self) -> Vec<f32> {
1023 let (min_scale, max_scale) = self.params.rescale_range;
1024 let n = self.params.num_rescale_factors;
1025
1026 if n == 1 {
1027 return vec![f32::midpoint(min_scale, max_scale)];
1028 }
1029
1030 let step = (max_scale - min_scale) / (n - 1) as f32;
1031 (0..n).map(|i| min_scale + i as f32 * step).collect()
1032 }
1033
1034 fn quantize_with_scale(&self, vector: &[f32], scale: f32) -> Vec<u8> {
1042 let bits = self.params.bits_per_dim.to_u8();
1043 let levels = self.params.bits_per_dim.levels() as f32;
1044 let max_level = (levels - 1.0) as u8;
1045
1046 let quantized: Vec<u8> = vector
1048 .iter()
1049 .map(|&v| {
1050 let scaled = v * scale;
1052 let level = (scaled * (levels - 1.0)).round();
1054 level.clamp(0.0, max_level as f32) as u8
1056 })
1057 .collect();
1058
1059 self.pack_quantized(&quantized, bits)
1061 }
1062
1063 #[allow(clippy::unused_self)]
1070 fn pack_quantized(&self, values: &[u8], bits: u8) -> Vec<u8> {
1071 match bits {
1072 2 => {
1073 let mut packed = Vec::with_capacity(values.len().div_ceil(4));
1075 for chunk in values.chunks(4) {
1076 let mut byte = 0u8;
1077 for (i, &val) in chunk.iter().enumerate() {
1078 byte |= (val & 0b11) << (i * 2);
1079 }
1080 packed.push(byte);
1081 }
1082 packed
1083 }
1084 4 => {
1085 let mut packed = Vec::with_capacity(values.len().div_ceil(2));
1087 for chunk in values.chunks(2) {
1088 let byte = if chunk.len() == 2 {
1089 (chunk[0] << 4) | (chunk[1] & 0x0F)
1090 } else {
1091 chunk[0] << 4
1092 };
1093 packed.push(byte);
1094 }
1095 packed
1096 }
1097 8 => {
1098 values.to_vec()
1100 }
1101 _ => {
1102 values.to_vec()
1108 }
1109 }
1110 }
1111
1112 #[must_use]
1114 pub fn unpack_quantized(&self, packed: &[u8], bits: u8, dimensions: usize) -> Vec<u8> {
1115 match bits {
1116 2 => {
1117 let mut values = Vec::with_capacity(dimensions);
1119 for &byte in packed {
1120 for i in 0..4 {
1121 if values.len() < dimensions {
1122 values.push((byte >> (i * 2)) & 0b11);
1123 }
1124 }
1125 }
1126 values
1127 }
1128 4 => {
1129 let mut values = Vec::with_capacity(dimensions);
1131 for &byte in packed {
1132 values.push(byte >> 4);
1133 if values.len() < dimensions {
1134 values.push(byte & 0x0F);
1135 }
1136 }
1137 values.truncate(dimensions);
1138 values
1139 }
1140 8 => {
1141 packed[..dimensions.min(packed.len())].to_vec()
1143 }
1144 _ => {
1145 packed[..dimensions.min(packed.len())].to_vec()
1147 }
1148 }
1149 }
1150
1151 fn compute_error(&self, original: &[f32], quantized: &[u8], scale: f32) -> f32 {
1155 let reconstructed = self.reconstruct(quantized, scale, original.len());
1156
1157 original
1158 .iter()
1159 .zip(reconstructed.iter())
1160 .map(|(o, r)| (o - r).powi(2))
1161 .sum()
1162 }
1163
1164 #[must_use]
1171 pub fn reconstruct(&self, quantized: &[u8], scale: f32, dimensions: usize) -> Vec<f32> {
1172 let bits = self.params.bits_per_dim.to_u8();
1173 let levels = self.params.bits_per_dim.levels() as f32;
1174
1175 let values = self.unpack_quantized(quantized, bits, dimensions);
1177
1178 values
1180 .iter()
1181 .map(|&q| {
1182 let denorm = q as f32 / (levels - 1.0);
1184 denorm / scale
1186 })
1187 .collect()
1188 }
1189
1190 #[must_use]
1195 pub fn distance_l2(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1196 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1197 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1198
1199 v1.iter()
1200 .zip(v2.iter())
1201 .map(|(a, b)| (a - b).powi(2))
1202 .sum::<f32>()
1203 .sqrt()
1204 }
1205
1206 #[must_use]
1210 pub fn distance_cosine(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1211 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1212 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1213
1214 let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
1215 let norm1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
1216 let norm2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
1217
1218 if norm1 < 1e-10 || norm2 < 1e-10 {
1219 return 1.0; }
1221
1222 let cosine_sim = dot / (norm1 * norm2);
1223 1.0 - cosine_sim
1224 }
1225
1226 #[must_use]
1228 pub fn distance_dot(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1229 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1230 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1231
1232 -v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum::<f32>()
1234 }
1235
1236 #[must_use]
1241 pub fn distance_approximate(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1242 let v1 = self.unpack_quantized(&qv1.data, qv1.bits, qv1.dimensions);
1244 let v2 = self.unpack_quantized(&qv2.data, qv2.bits, qv2.dimensions);
1245
1246 v1.iter()
1248 .zip(v2.iter())
1249 .map(|(a, b)| {
1250 let diff = (*a as i16 - *b as i16) as f32;
1251 diff * diff
1252 })
1253 .sum::<f32>()
1254 .sqrt()
1255 }
1256
1257 #[must_use]
1265 pub fn distance_asymmetric_l2(&self, query: &[f32], quantized: &QuantizedVector) -> f32 {
1266 if let Some(trained) = &self.trained {
1268 return self.distance_asymmetric_l2_trained(query, &quantized.data, trained);
1269 }
1270 self.distance_asymmetric_l2_raw(query, &quantized.data, quantized.scale, quantized.bits)
1272 }
1273
1274 #[must_use]
1280 #[inline]
1281 pub fn distance_asymmetric_l2_flat(&self, query: &[f32], data: &[u8], scale: f32) -> f32 {
1282 if let Some(trained) = &self.trained {
1283 return self.distance_asymmetric_l2_trained(query, data, trained);
1284 }
1285 self.distance_asymmetric_l2_raw(query, data, scale, self.params.bits_per_dim.to_u8())
1287 }
1288
1289 #[must_use]
1294 fn distance_asymmetric_l2_trained(
1295 &self,
1296 query: &[f32],
1297 data: &[u8],
1298 trained: &TrainedParams,
1299 ) -> f32 {
1300 let levels = self.params.bits_per_dim.levels() as f32;
1301 let bits = self.params.bits_per_dim.to_u8();
1302
1303 let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1305
1306 match bits {
1307 4 => {
1308 let num_pairs = query.len() / 2;
1309 if data.len() < query.len().div_ceil(2) {
1310 return f32::MAX;
1311 }
1312
1313 for i in 0..num_pairs {
1314 let byte = unsafe { *data.get_unchecked(i) };
1315 let d0 = i * 2;
1316 let d1 = i * 2 + 1;
1317
1318 let code0 = (byte >> 4) as f32;
1320 let code1 = (byte & 0x0F) as f32;
1321
1322 let range0 = trained.maxs[d0] - trained.mins[d0];
1323 let range1 = trained.maxs[d1] - trained.mins[d1];
1324
1325 buffer.push((code0 / (levels - 1.0)) * range0 + trained.mins[d0]);
1326 buffer.push((code1 / (levels - 1.0)) * range1 + trained.mins[d1]);
1327 }
1328
1329 if !query.len().is_multiple_of(2) {
1330 let byte = unsafe { *data.get_unchecked(num_pairs) };
1331 let d = num_pairs * 2;
1332 let code = (byte >> 4) as f32;
1333 let range = trained.maxs[d] - trained.mins[d];
1334 buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1335 }
1336 }
1337 2 => {
1338 let num_quads = query.len() / 4;
1339 if data.len() < query.len().div_ceil(4) {
1340 return f32::MAX;
1341 }
1342
1343 for i in 0..num_quads {
1344 let byte = unsafe { *data.get_unchecked(i) };
1345 for j in 0..4 {
1346 let d = i * 4 + j;
1347 let code = ((byte >> (j * 2)) & 0b11) as f32;
1348 let range = trained.maxs[d] - trained.mins[d];
1349 buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1350 }
1351 }
1352
1353 let remaining = query.len() % 4;
1354 if remaining > 0 {
1355 let byte = unsafe { *data.get_unchecked(num_quads) };
1356 for j in 0..remaining {
1357 let d = num_quads * 4 + j;
1358 let code = ((byte >> (j * 2)) & 0b11) as f32;
1359 let range = trained.maxs[d] - trained.mins[d];
1360 buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1361 }
1362 }
1363 }
1364 8 => {
1365 if data.len() < query.len() {
1366 return f32::MAX;
1367 }
1368 for (d, &byte) in data.iter().enumerate().take(query.len()) {
1369 let code = byte as f32;
1370 let range = trained.maxs[d] - trained.mins[d];
1371 buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1372 }
1373 }
1374 _ => {
1375 let unpacked = self.unpack_quantized(data, bits, query.len());
1377 for (d, &code) in unpacked.iter().enumerate().take(query.len()) {
1378 let range = trained.maxs[d] - trained.mins[d];
1379 buffer.push((code as f32 / (levels - 1.0)) * range + trained.mins[d]);
1380 }
1381 }
1382 }
1383
1384 simd_l2_distance(query, &buffer)
1385 }
1386
1387 #[must_use]
1403 pub fn build_adc_table(&self, query: &[f32]) -> Option<ADCTable> {
1404 self.trained
1405 .as_ref()
1406 .map(|trained| ADCTable::new_trained(query, trained, &self.params))
1407 }
1408
1409 #[must_use]
1415 pub fn build_adc_table_with_scale(&self, query: &[f32], scale: f32) -> ADCTable {
1416 ADCTable::new(query, scale, &self.params)
1417 }
1418
1419 #[must_use]
1423 pub fn distance_with_adc(&self, query: &[f32], quantized: &QuantizedVector) -> Option<f32> {
1424 let adc = self.build_adc_table(query)?;
1425 Some(adc.distance(&quantized.data))
1426 }
1427
1428 #[must_use]
1433 pub fn distance_asymmetric_l2_raw(
1434 &self,
1435 query: &[f32],
1436 data: &[u8],
1437 scale: f32,
1438 bits: u8,
1439 ) -> f32 {
1440 let levels = self.params.bits_per_dim.levels() as f32;
1441
1442 let factor = 1.0 / ((levels - 1.0) * scale);
1446
1447 match bits {
1448 4 => {
1449 let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1451
1452 let num_pairs = query.len() / 2;
1453
1454 if data.len() < query.len().div_ceil(2) {
1456 return f32::MAX;
1458 }
1459
1460 for i in 0..num_pairs {
1461 let byte = unsafe { *data.get_unchecked(i) };
1462 buffer.push((byte >> 4) as f32 * factor);
1463 buffer.push((byte & 0x0F) as f32 * factor);
1464 }
1465
1466 if !query.len().is_multiple_of(2) {
1467 let byte = unsafe { *data.get_unchecked(num_pairs) };
1468 buffer.push((byte >> 4) as f32 * factor);
1469 }
1470
1471 simd_l2_distance(query, &buffer)
1472 }
1473 2 => {
1474 let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1475 let num_quads = query.len() / 4;
1476
1477 if data.len() < query.len().div_ceil(4) {
1478 return f32::MAX;
1479 }
1480
1481 for i in 0..num_quads {
1482 let byte = unsafe { *data.get_unchecked(i) };
1483 buffer.push((byte & 0b11) as f32 * factor);
1484 buffer.push(((byte >> 2) & 0b11) as f32 * factor);
1485 buffer.push(((byte >> 4) & 0b11) as f32 * factor);
1486 buffer.push(((byte >> 6) & 0b11) as f32 * factor);
1487 }
1488
1489 let remaining = query.len() % 4;
1491 if remaining > 0 {
1492 let byte = unsafe { *data.get_unchecked(num_quads) };
1493 for i in 0..remaining {
1494 buffer.push(((byte >> (i * 2)) & 0b11) as f32 * factor);
1495 }
1496 }
1497
1498 simd_l2_distance(query, &buffer)
1499 }
1500 _ => {
1501 let unpacked = self.unpack_quantized(data, bits, query.len());
1505 let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1506
1507 for &q in &unpacked {
1508 buffer.push(q as f32 * factor);
1509 }
1510
1511 simd_l2_distance(query, &buffer)
1512 }
1513 }
1514 }
1515
1516 #[inline]
1524 #[must_use]
1525 pub fn distance_l2_simd(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1526 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1528 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1529
1530 simd_l2_distance(&v1, &v2)
1532 }
1533
1534 #[inline]
1536 #[must_use]
1537 pub fn distance_cosine_simd(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1538 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1539 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1540
1541 simd_cosine_distance(&v1, &v2)
1542 }
1543}
1544
1545#[inline]
1549fn simd_l2_distance(v1: &[f32], v2: &[f32]) -> f32 {
1550 #[cfg(target_arch = "x86_64")]
1551 {
1552 if is_x86_feature_detected!("avx2") {
1553 return unsafe { l2_distance_avx2(v1, v2) };
1554 } else if is_x86_feature_detected!("sse2") {
1555 return unsafe { l2_distance_sse2(v1, v2) };
1556 }
1557 l2_distance_scalar(v1, v2)
1559 }
1560
1561 #[cfg(target_arch = "aarch64")]
1562 {
1563 unsafe { l2_distance_neon(v1, v2) }
1565 }
1566
1567 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1568 {
1569 l2_distance_scalar(v1, v2)
1571 }
1572}
1573
1574#[inline]
1576fn simd_cosine_distance(v1: &[f32], v2: &[f32]) -> f32 {
1577 #[cfg(target_arch = "x86_64")]
1578 {
1579 if is_x86_feature_detected!("avx2") {
1580 return unsafe { cosine_distance_avx2(v1, v2) };
1581 } else if is_x86_feature_detected!("sse2") {
1582 return unsafe { cosine_distance_sse2(v1, v2) };
1583 }
1584 cosine_distance_scalar(v1, v2)
1586 }
1587
1588 #[cfg(target_arch = "aarch64")]
1589 {
1590 unsafe { cosine_distance_neon(v1, v2) }
1592 }
1593
1594 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1595 {
1596 cosine_distance_scalar(v1, v2)
1598 }
1599}
1600
1601#[inline]
1604#[allow(dead_code)] fn l2_distance_scalar(v1: &[f32], v2: &[f32]) -> f32 {
1606 v1.iter()
1607 .zip(v2.iter())
1608 .map(|(a, b)| (a - b).powi(2))
1609 .sum::<f32>()
1610 .sqrt()
1611}
1612
1613#[inline]
1614#[allow(dead_code)] fn cosine_distance_scalar(v1: &[f32], v2: &[f32]) -> f32 {
1616 let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
1617 let norm1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
1618 let norm2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
1619
1620 if norm1 < 1e-10 || norm2 < 1e-10 {
1621 return 1.0;
1622 }
1623
1624 let cosine_sim = dot / (norm1 * norm2);
1625 1.0 - cosine_sim
1626}
1627
1628#[cfg(target_arch = "x86_64")]
1631#[target_feature(enable = "avx2")]
1632#[target_feature(enable = "fma")]
1633unsafe fn l2_distance_avx2(v1: &[f32], v2: &[f32]) -> f32 {
1634 unsafe {
1635 let len = v1.len().min(v2.len());
1636 let mut sum = _mm256_setzero_ps();
1637
1638 let chunks = len / 8;
1639 for i in 0..chunks {
1640 let a = _mm256_loadu_ps(v1.as_ptr().add(i * 8));
1641 let b = _mm256_loadu_ps(v2.as_ptr().add(i * 8));
1642 let diff = _mm256_sub_ps(a, b);
1643 sum = _mm256_fmadd_ps(diff, diff, sum);
1644 }
1645
1646 let sum_high = _mm256_extractf128_ps(sum, 1);
1648 let sum_low = _mm256_castps256_ps128(sum);
1649 let sum128 = _mm_add_ps(sum_low, sum_high);
1650 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1651 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1652 let mut result = _mm_cvtss_f32(sum32);
1653
1654 for i in (chunks * 8)..len {
1656 let diff = v1[i] - v2[i];
1657 result += diff * diff;
1658 }
1659
1660 result.sqrt()
1661 }
1662}
1663
1664#[cfg(target_arch = "x86_64")]
1665#[target_feature(enable = "avx2")]
1666#[target_feature(enable = "fma")]
1667unsafe fn cosine_distance_avx2(v1: &[f32], v2: &[f32]) -> f32 {
1668 unsafe {
1669 let len = v1.len().min(v2.len());
1670 let mut dot_sum = _mm256_setzero_ps();
1671 let mut norm1_sum = _mm256_setzero_ps();
1672 let mut norm2_sum = _mm256_setzero_ps();
1673
1674 let chunks = len / 8;
1675 for i in 0..chunks {
1676 let a = _mm256_loadu_ps(v1.as_ptr().add(i * 8));
1677 let b = _mm256_loadu_ps(v2.as_ptr().add(i * 8));
1678 dot_sum = _mm256_fmadd_ps(a, b, dot_sum);
1679 norm1_sum = _mm256_fmadd_ps(a, a, norm1_sum);
1680 norm2_sum = _mm256_fmadd_ps(b, b, norm2_sum);
1681 }
1682
1683 let mut dot = horizontal_sum_avx2(dot_sum);
1685 let mut norm1 = horizontal_sum_avx2(norm1_sum);
1686 let mut norm2 = horizontal_sum_avx2(norm2_sum);
1687
1688 for i in (chunks * 8)..len {
1690 dot += v1[i] * v2[i];
1691 norm1 += v1[i] * v1[i];
1692 norm2 += v2[i] * v2[i];
1693 }
1694
1695 if norm1 < 1e-10 || norm2 < 1e-10 {
1696 return 1.0;
1697 }
1698
1699 let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1700 1.0 - cosine_sim
1701 }
1702}
1703
1704#[cfg(target_arch = "x86_64")]
1705#[inline]
1706unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
1707 unsafe {
1708 let sum_high = _mm256_extractf128_ps(v, 1);
1709 let sum_low = _mm256_castps256_ps128(v);
1710 let sum128 = _mm_add_ps(sum_low, sum_high);
1711 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1712 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1713 _mm_cvtss_f32(sum32)
1714 }
1715}
1716
1717#[cfg(target_arch = "x86_64")]
1720#[target_feature(enable = "sse2")]
1721unsafe fn l2_distance_sse2(v1: &[f32], v2: &[f32]) -> f32 {
1722 unsafe {
1723 let len = v1.len().min(v2.len());
1724 let mut sum = _mm_setzero_ps();
1725
1726 let chunks = len / 4;
1727 for i in 0..chunks {
1728 let a = _mm_loadu_ps(v1.as_ptr().add(i * 4));
1729 let b = _mm_loadu_ps(v2.as_ptr().add(i * 4));
1730 let diff = _mm_sub_ps(a, b);
1731 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
1732 }
1733
1734 let sum64 = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
1736 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1737 let mut result = _mm_cvtss_f32(sum32);
1738
1739 for i in (chunks * 4)..len {
1741 let diff = v1[i] - v2[i];
1742 result += diff * diff;
1743 }
1744
1745 result.sqrt()
1746 }
1747}
1748
1749#[cfg(target_arch = "x86_64")]
1750#[target_feature(enable = "sse2")]
1751unsafe fn cosine_distance_sse2(v1: &[f32], v2: &[f32]) -> f32 {
1752 unsafe {
1753 let len = v1.len().min(v2.len());
1754 let mut dot_sum = _mm_setzero_ps();
1755 let mut norm1_sum = _mm_setzero_ps();
1756 let mut norm2_sum = _mm_setzero_ps();
1757
1758 let chunks = len / 4;
1759 for i in 0..chunks {
1760 let a = _mm_loadu_ps(v1.as_ptr().add(i * 4));
1761 let b = _mm_loadu_ps(v2.as_ptr().add(i * 4));
1762 dot_sum = _mm_add_ps(dot_sum, _mm_mul_ps(a, b));
1763 norm1_sum = _mm_add_ps(norm1_sum, _mm_mul_ps(a, a));
1764 norm2_sum = _mm_add_ps(norm2_sum, _mm_mul_ps(b, b));
1765 }
1766
1767 let mut dot = horizontal_sum_sse2(dot_sum);
1769 let mut norm1 = horizontal_sum_sse2(norm1_sum);
1770 let mut norm2 = horizontal_sum_sse2(norm2_sum);
1771
1772 for i in (chunks * 4)..len {
1774 dot += v1[i] * v2[i];
1775 norm1 += v1[i] * v1[i];
1776 norm2 += v2[i] * v2[i];
1777 }
1778
1779 if norm1 < 1e-10 || norm2 < 1e-10 {
1780 return 1.0;
1781 }
1782
1783 let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1784 1.0 - cosine_sim
1785 }
1786}
1787
1788#[cfg(target_arch = "x86_64")]
1789#[inline]
1790unsafe fn horizontal_sum_sse2(v: __m128) -> f32 {
1791 unsafe {
1792 let sum64 = _mm_add_ps(v, _mm_movehl_ps(v, v));
1793 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1794 _mm_cvtss_f32(sum32)
1795 }
1796}
1797
1798#[cfg(target_arch = "aarch64")]
1801unsafe fn l2_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
1802 let len = v1.len().min(v2.len());
1803
1804 unsafe {
1806 let mut sum = vdupq_n_f32(0.0);
1807
1808 let chunks = len / 4;
1809 for i in 0..chunks {
1810 let a = vld1q_f32(v1.as_ptr().add(i * 4));
1811 let b = vld1q_f32(v2.as_ptr().add(i * 4));
1812 let diff = vsubq_f32(a, b);
1813 sum = vfmaq_f32(sum, diff, diff);
1814 }
1815
1816 let mut result = vaddvq_f32(sum);
1818
1819 for i in (chunks * 4)..len {
1821 let diff = v1[i] - v2[i];
1822 result += diff * diff;
1823 }
1824
1825 result.sqrt()
1826 }
1827}
1828
1829#[cfg(target_arch = "aarch64")]
1830unsafe fn cosine_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
1831 let len = v1.len().min(v2.len());
1832
1833 unsafe {
1835 let mut dot_sum = vdupq_n_f32(0.0);
1836 let mut norm1_sum = vdupq_n_f32(0.0);
1837 let mut norm2_sum = vdupq_n_f32(0.0);
1838
1839 let chunks = len / 4;
1840 for i in 0..chunks {
1841 let a = vld1q_f32(v1.as_ptr().add(i * 4));
1842 let b = vld1q_f32(v2.as_ptr().add(i * 4));
1843 dot_sum = vfmaq_f32(dot_sum, a, b);
1844 norm1_sum = vfmaq_f32(norm1_sum, a, a);
1845 norm2_sum = vfmaq_f32(norm2_sum, b, b);
1846 }
1847
1848 let mut dot = vaddvq_f32(dot_sum);
1850 let mut norm1 = vaddvq_f32(norm1_sum);
1851 let mut norm2 = vaddvq_f32(norm2_sum);
1852
1853 for i in (chunks * 4)..len {
1855 dot += v1[i] * v2[i];
1856 norm1 += v1[i] * v1[i];
1857 norm2 += v2[i] * v2[i];
1858 }
1859
1860 if norm1 < 1e-10 || norm2 < 1e-10 {
1861 return 1.0;
1862 }
1863
1864 let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1865 1.0 - cosine_sim
1866 }
1867}
1868
1869#[cfg(test)]
1870#[allow(clippy::float_cmp)]
1871mod tests {
1872 use super::*;
1873
1874 #[test]
1875 fn test_quantization_bits_conversion() {
1876 assert_eq!(QuantizationBits::Bits2.to_u8(), 2);
1877 assert_eq!(QuantizationBits::Bits4.to_u8(), 4);
1878 assert_eq!(QuantizationBits::Bits8.to_u8(), 8);
1879 }
1880
1881 #[test]
1882 fn test_quantization_bits_levels() {
1883 assert_eq!(QuantizationBits::Bits2.levels(), 4); assert_eq!(QuantizationBits::Bits4.levels(), 16); assert_eq!(QuantizationBits::Bits8.levels(), 256); }
1887
1888 #[test]
1889 fn test_quantization_bits_compression() {
1890 assert_eq!(QuantizationBits::Bits2.compression_ratio(), 16.0); assert_eq!(QuantizationBits::Bits4.compression_ratio(), 8.0); assert_eq!(QuantizationBits::Bits8.compression_ratio(), 4.0); }
1894
1895 #[test]
1896 fn test_quantization_bits_values_per_byte() {
1897 assert_eq!(QuantizationBits::Bits2.values_per_byte(), 4); assert_eq!(QuantizationBits::Bits4.values_per_byte(), 2); assert_eq!(QuantizationBits::Bits8.values_per_byte(), 1); }
1901
1902 #[test]
1903 fn test_default_params() {
1904 let params = RaBitQParams::default();
1905 assert_eq!(params.bits_per_dim, QuantizationBits::Bits4);
1906 assert_eq!(params.num_rescale_factors, 12);
1907 assert_eq!(params.rescale_range, (0.5, 2.0));
1908 }
1909
1910 #[test]
1911 fn test_preset_params() {
1912 let params2 = RaBitQParams::bits2();
1913 assert_eq!(params2.bits_per_dim, QuantizationBits::Bits2);
1914
1915 let params4 = RaBitQParams::bits4();
1916 assert_eq!(params4.bits_per_dim, QuantizationBits::Bits4);
1917
1918 let params8 = RaBitQParams::bits8();
1919 assert_eq!(params8.bits_per_dim, QuantizationBits::Bits8);
1920 assert_eq!(params8.num_rescale_factors, 16);
1921 }
1922
1923 #[test]
1924 fn test_quantized_vector_creation() {
1925 let data = vec![0u8, 128, 255];
1926 let qv = QuantizedVector::new(data.clone(), 1.5, 8, 3);
1927
1928 assert_eq!(qv.data, data);
1929 assert_eq!(qv.scale, 1.5);
1930 assert_eq!(qv.bits, 8);
1931 assert_eq!(qv.dimensions, 3);
1932 }
1933
1934 #[test]
1935 fn test_quantized_vector_memory() {
1936 let data = vec![0u8; 16]; let qv = QuantizedVector::new(data, 1.0, 4, 32);
1938
1939 let expected_min = 16; assert!(qv.memory_bytes() >= expected_min);
1942 }
1943
1944 #[test]
1945 fn test_quantized_vector_compression_ratio() {
1946 let data = vec![0u8; 64];
1948 let qv = QuantizedVector::new(data, 1.0, 4, 128);
1949
1950 let ratio = qv.compression_ratio();
1954 assert!(ratio > 7.0 && ratio < 8.0);
1955 }
1956
1957 #[test]
1958 fn test_create_quantizer() {
1959 let quantizer = RaBitQ::default_4bit();
1960 assert_eq!(quantizer.params().bits_per_dim, QuantizationBits::Bits4);
1961 }
1962
1963 #[test]
1966 fn test_generate_scales() {
1967 let quantizer = RaBitQ::new(RaBitQParams {
1968 bits_per_dim: QuantizationBits::Bits4,
1969 num_rescale_factors: 5,
1970 rescale_range: (0.5, 1.5),
1971 });
1972
1973 let scales = quantizer.generate_scales();
1974 assert_eq!(scales.len(), 5);
1975 assert_eq!(scales[0], 0.5);
1976 assert_eq!(scales[4], 1.5);
1977 assert!((scales[2] - 1.0).abs() < 0.01); }
1979
1980 #[test]
1981 fn test_generate_scales_single() {
1982 let quantizer = RaBitQ::new(RaBitQParams {
1983 bits_per_dim: QuantizationBits::Bits4,
1984 num_rescale_factors: 1,
1985 rescale_range: (0.5, 1.5),
1986 });
1987
1988 let scales = quantizer.generate_scales();
1989 assert_eq!(scales.len(), 1);
1990 assert_eq!(scales[0], 1.0); }
1992
1993 #[test]
1994 fn test_pack_unpack_2bit() {
1995 let quantizer = RaBitQ::new(RaBitQParams {
1996 bits_per_dim: QuantizationBits::Bits2,
1997 ..Default::default()
1998 });
1999
2000 let values = vec![0u8, 1, 2, 3, 0, 1, 2, 3];
2002 let packed = quantizer.pack_quantized(&values, 2);
2003 assert_eq!(packed.len(), 2); let unpacked = quantizer.unpack_quantized(&packed, 2, 8);
2006 assert_eq!(unpacked, values);
2007 }
2008
2009 #[test]
2010 fn test_pack_unpack_4bit() {
2011 let quantizer = RaBitQ::new(RaBitQParams {
2012 bits_per_dim: QuantizationBits::Bits4,
2013 ..Default::default()
2014 });
2015
2016 let values = vec![0u8, 1, 2, 3, 4, 5, 6, 7];
2018 let packed = quantizer.pack_quantized(&values, 4);
2019 assert_eq!(packed.len(), 4); let unpacked = quantizer.unpack_quantized(&packed, 4, 8);
2022 assert_eq!(unpacked, values);
2023 }
2024
2025 #[test]
2026 fn test_pack_unpack_8bit() {
2027 let quantizer = RaBitQ::new(RaBitQParams {
2028 bits_per_dim: QuantizationBits::Bits8,
2029 ..Default::default()
2030 });
2031
2032 let values = vec![0u8, 10, 20, 30, 40, 50, 60, 70];
2034 let packed = quantizer.pack_quantized(&values, 8);
2035 assert_eq!(packed.len(), 8); let unpacked = quantizer.unpack_quantized(&packed, 8, 8);
2038 assert_eq!(unpacked, values);
2039 }
2040
2041 #[test]
2042 fn test_quantize_simple_vector() {
2043 let quantizer = RaBitQ::new(RaBitQParams {
2044 bits_per_dim: QuantizationBits::Bits4,
2045 num_rescale_factors: 4,
2046 rescale_range: (0.5, 1.5),
2047 });
2048
2049 let vector = vec![0.0, 0.25, 0.5, 0.75, 1.0];
2051 let quantized = quantizer.quantize(&vector);
2052
2053 assert_eq!(quantized.dimensions, 5);
2055 assert_eq!(quantized.bits, 4);
2056 assert!(quantized.scale > 0.0);
2057
2058 assert!(quantized.data.len() <= 4);
2061 }
2062
2063 #[test]
2064 fn test_quantize_reconstruct_accuracy() {
2065 let quantizer = RaBitQ::new(RaBitQParams {
2066 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 8,
2068 rescale_range: (0.8, 1.2),
2069 });
2070
2071 let vector = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2073 let quantized = quantizer.quantize(&vector);
2074
2075 let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, vector.len());
2077
2078 for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
2080 let error = (orig - recon).abs();
2081 assert!(error < 0.1, "Error too large: {orig} vs {recon}");
2082 }
2083 }
2084
2085 #[test]
2086 fn test_quantize_uniform_vector() {
2087 let quantizer = RaBitQ::default_4bit();
2088
2089 let vector = vec![0.5; 10];
2091 let quantized = quantizer.quantize(&vector);
2092
2093 let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, vector.len());
2095
2096 let avg = reconstructed.iter().sum::<f32>() / reconstructed.len() as f32;
2098 for &val in &reconstructed {
2099 assert!((val - avg).abs() < 0.2);
2100 }
2101 }
2102
2103 #[test]
2104 fn test_compute_error() {
2105 let quantizer = RaBitQ::default_4bit();
2106
2107 let original = vec![0.1, 0.2, 0.3, 0.4];
2108 let quantized_vec = quantizer.quantize(&original);
2109
2110 let error = quantizer.compute_error(&original, &quantized_vec.data, quantized_vec.scale);
2112
2113 assert!(error >= 0.0);
2115 assert!(error.is_finite());
2116 }
2117
2118 #[test]
2119 fn test_quantize_different_bit_widths() {
2120 let test_vector = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2121
2122 let q2 = RaBitQ::new(RaBitQParams::bits2());
2124 let qv2 = q2.quantize(&test_vector);
2125 assert_eq!(qv2.bits, 2);
2126
2127 let q4 = RaBitQ::default_4bit();
2129 let qv4 = q4.quantize(&test_vector);
2130 assert_eq!(qv4.bits, 4);
2131
2132 let q8 = RaBitQ::new(RaBitQParams::bits8());
2134 let qv8 = q8.quantize(&test_vector);
2135 assert_eq!(qv8.bits, 8);
2136
2137 assert!(qv2.data.len() <= qv4.data.len());
2139 assert!(qv4.data.len() <= qv8.data.len());
2140 }
2141
2142 #[test]
2143 fn test_quantize_high_dimensional() {
2144 let quantizer = RaBitQ::default_4bit();
2145
2146 let vector: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2148 let quantized = quantizer.quantize(&vector);
2149
2150 assert_eq!(quantized.dimensions, 128);
2151 assert_eq!(quantized.bits, 4);
2152
2153 assert_eq!(quantized.data.len(), 64);
2155
2156 let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, 128);
2158 assert_eq!(reconstructed.len(), 128);
2159 }
2160
2161 #[test]
2164 fn test_distance_l2() {
2165 let quantizer = RaBitQ::new(RaBitQParams {
2166 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 8,
2168 rescale_range: (0.8, 1.2),
2169 });
2170
2171 let v1 = vec![0.0, 0.0, 0.0];
2172 let v2 = vec![1.0, 0.0, 0.0];
2173
2174 let qv1 = quantizer.quantize(&v1);
2175 let qv2 = quantizer.quantize(&v2);
2176
2177 let dist = quantizer.distance_l2(&qv1, &qv2);
2178
2179 assert!((dist - 1.0).abs() < 0.2, "Distance: {dist}");
2181 }
2182
2183 #[test]
2184 fn test_distance_l2_identical() {
2185 let quantizer = RaBitQ::default_4bit();
2186
2187 let v = vec![0.5, 0.3, 0.8, 0.2];
2188 let qv1 = quantizer.quantize(&v);
2189 let qv2 = quantizer.quantize(&v);
2190
2191 let dist = quantizer.distance_l2(&qv1, &qv2);
2192
2193 assert!(dist < 0.3, "Distance should be near zero, got: {dist}");
2195 }
2196
2197 #[test]
2198 fn test_distance_cosine() {
2199 let quantizer = RaBitQ::new(RaBitQParams {
2200 bits_per_dim: QuantizationBits::Bits8,
2201 num_rescale_factors: 8,
2202 rescale_range: (0.8, 1.2),
2203 });
2204
2205 let v1 = vec![1.0, 0.0, 0.0];
2207 let v2 = vec![0.0, 1.0, 0.0];
2208
2209 let qv1 = quantizer.quantize(&v1);
2210 let qv2 = quantizer.quantize(&v2);
2211
2212 let dist = quantizer.distance_cosine(&qv1, &qv2);
2213
2214 assert!((dist - 1.0).abs() < 0.3, "Distance: {dist}");
2216 }
2217
2218 #[test]
2219 fn test_distance_cosine_identical() {
2220 let quantizer = RaBitQ::default_4bit();
2221
2222 let v = vec![0.5, 0.3, 0.8];
2223 let qv1 = quantizer.quantize(&v);
2224 let qv2 = quantizer.quantize(&v);
2225
2226 let dist = quantizer.distance_cosine(&qv1, &qv2);
2227
2228 assert!(dist < 0.2, "Distance should be near zero, got: {dist}");
2230 }
2231
2232 #[test]
2233 fn test_distance_dot() {
2234 let quantizer = RaBitQ::new(RaBitQParams {
2235 bits_per_dim: QuantizationBits::Bits8,
2236 num_rescale_factors: 8,
2237 rescale_range: (0.8, 1.2),
2238 });
2239
2240 let v1 = vec![1.0, 0.0, 0.0];
2241 let v2 = vec![1.0, 0.0, 0.0];
2242
2243 let qv1 = quantizer.quantize(&v1);
2244 let qv2 = quantizer.quantize(&v2);
2245
2246 let dist = quantizer.distance_dot(&qv1, &qv2);
2247
2248 assert!((dist + 1.0).abs() < 0.3, "Distance: {dist}");
2250 }
2251
2252 #[test]
2253 fn test_distance_approximate() {
2254 let quantizer = RaBitQ::default_4bit();
2255
2256 let v1 = vec![0.0, 0.0, 0.0];
2257 let v2 = vec![0.5, 0.5, 0.5];
2258
2259 let qv1 = quantizer.quantize(&v1);
2260 let qv2 = quantizer.quantize(&v2);
2261
2262 let dist_approx = quantizer.distance_approximate(&qv1, &qv2);
2263 let dist_exact = quantizer.distance_l2(&qv1, &qv2);
2264
2265 assert!(dist_approx >= 0.0);
2267 assert!(dist_approx.is_finite());
2268
2269 let v3 = vec![1.0, 1.0, 1.0];
2272 let qv3 = quantizer.quantize(&v3);
2273
2274 let dist_approx2 = quantizer.distance_approximate(&qv1, &qv3);
2275 let dist_exact2 = quantizer.distance_l2(&qv1, &qv3);
2276
2277 if dist_exact2 > dist_exact {
2279 assert!(dist_approx2 > dist_approx * 0.5); }
2281 }
2282
2283 #[test]
2284 fn test_distance_correlation() {
2285 let quantizer = RaBitQ::new(RaBitQParams {
2286 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 12,
2288 rescale_range: (0.8, 1.2),
2289 });
2290
2291 let vectors = [
2293 vec![0.1, 0.2, 0.3],
2294 vec![0.4, 0.5, 0.6],
2295 vec![0.7, 0.8, 0.9],
2296 ];
2297
2298 let quantized: Vec<QuantizedVector> =
2300 vectors.iter().map(|v| quantizer.quantize(v)).collect();
2301
2302 let ground_truth_01 = vectors[0]
2304 .iter()
2305 .zip(vectors[1].iter())
2306 .map(|(a, b)| (a - b).powi(2))
2307 .sum::<f32>()
2308 .sqrt();
2309
2310 let ground_truth_02 = vectors[0]
2311 .iter()
2312 .zip(vectors[2].iter())
2313 .map(|(a, b)| (a - b).powi(2))
2314 .sum::<f32>()
2315 .sqrt();
2316
2317 let quantized_01 = quantizer.distance_l2(&quantized[0], &quantized[1]);
2319 let quantized_02 = quantizer.distance_l2(&quantized[0], &quantized[2]);
2320
2321 if ground_truth_02 > ground_truth_01 {
2323 assert!(
2324 quantized_02 > quantized_01 * 0.8,
2325 "Order not preserved: {quantized_01} vs {quantized_02}"
2326 );
2327 }
2328 }
2329
2330 #[test]
2331 fn test_distance_zero_vectors() {
2332 let quantizer = RaBitQ::default_4bit();
2333
2334 let v_zero = vec![0.0, 0.0, 0.0];
2335 let qv_zero = quantizer.quantize(&v_zero);
2336
2337 let dist = quantizer.distance_l2(&qv_zero, &qv_zero);
2339 assert!(dist < 0.1);
2340
2341 let dist_cosine = quantizer.distance_cosine(&qv_zero, &qv_zero);
2343 assert!(dist_cosine.is_finite());
2344 }
2345
2346 #[test]
2347 fn test_distance_high_dimensional() {
2348 let quantizer = RaBitQ::default_4bit();
2349
2350 let v1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2352 let v2: Vec<f32> = (0..128).map(|i| ((i + 10) as f32) / 128.0).collect();
2353
2354 let qv1 = quantizer.quantize(&v1);
2355 let qv2 = quantizer.quantize(&v2);
2356
2357 let dist_l2 = quantizer.distance_l2(&qv1, &qv2);
2359 let dist_cosine = quantizer.distance_cosine(&qv1, &qv2);
2360 let dist_dot = quantizer.distance_dot(&qv1, &qv2);
2361 let dist_approx = quantizer.distance_approximate(&qv1, &qv2);
2362
2363 assert!(dist_l2 > 0.0 && dist_l2.is_finite());
2364 assert!(dist_cosine >= 0.0 && dist_cosine.is_finite());
2365 assert!(dist_dot.is_finite());
2366 assert!(dist_approx > 0.0 && dist_approx.is_finite());
2367 }
2368
2369 #[test]
2370 fn test_distance_asymmetric_l2() {
2371 let quantizer = RaBitQ::default_4bit();
2372
2373 let query = vec![0.1, 0.2, 0.3, 0.4];
2374 let vector = vec![0.12, 0.22, 0.32, 0.42];
2376
2377 let quantized = quantizer.quantize(&vector);
2378
2379 let dist_sym = quantizer.distance_l2_simd(&quantized, &quantizer.quantize(&query));
2381
2382 let dist_asym = quantizer.distance_asymmetric_l2(&query, &quantized);
2384
2385 assert!(dist_asym >= 0.0);
2388 assert!((dist_asym - dist_sym).abs() < 0.2);
2389 }
2390
2391 #[test]
2394 fn test_simd_l2_matches_scalar() {
2395 let quantizer = RaBitQ::new(RaBitQParams {
2396 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 8,
2398 rescale_range: (0.8, 1.2),
2399 });
2400
2401 let v1 = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2402 let v2 = vec![0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
2403
2404 let qv1 = quantizer.quantize(&v1);
2405 let qv2 = quantizer.quantize(&v2);
2406
2407 let dist_scalar = quantizer.distance_l2(&qv1, &qv2);
2408 let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2409
2410 let diff = (dist_scalar - dist_simd).abs();
2412 assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2413 }
2414
2415 #[test]
2416 fn test_simd_cosine_matches_scalar() {
2417 let quantizer = RaBitQ::new(RaBitQParams {
2418 bits_per_dim: QuantizationBits::Bits8,
2419 num_rescale_factors: 8,
2420 rescale_range: (0.8, 1.2),
2421 });
2422
2423 let v1 = vec![1.0, 0.0, 0.0];
2424 let v2 = vec![0.0, 1.0, 0.0];
2425
2426 let qv1 = quantizer.quantize(&v1);
2427 let qv2 = quantizer.quantize(&v2);
2428
2429 let dist_scalar = quantizer.distance_cosine(&qv1, &qv2);
2430 let dist_simd = quantizer.distance_cosine_simd(&qv1, &qv2);
2431
2432 let diff = (dist_scalar - dist_simd).abs();
2434 assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2435 }
2436
2437 #[test]
2438 fn test_simd_high_dimensional() {
2439 let quantizer = RaBitQ::default_4bit();
2440
2441 let v1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2443 let v2: Vec<f32> = (0..128).map(|i| ((i + 1) as f32) / 128.0).collect();
2444
2445 let qv1 = quantizer.quantize(&v1);
2446 let qv2 = quantizer.quantize(&v2);
2447
2448 let dist_scalar = quantizer.distance_l2(&qv1, &qv2);
2449 let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2450
2451 let diff = (dist_scalar - dist_simd).abs();
2453 assert!(
2454 diff < 0.1,
2455 "High-D SIMD vs scalar: {dist_simd} vs {dist_scalar}"
2456 );
2457 }
2458
2459 #[test]
2460 fn test_simd_scalar_fallback() {
2461 let quantizer = RaBitQ::default_4bit();
2462
2463 let v1 = vec![0.1, 0.2, 0.3];
2465 let v2 = vec![0.4, 0.5, 0.6];
2466
2467 let qv1 = quantizer.quantize(&v1);
2468 let qv2 = quantizer.quantize(&v2);
2469
2470 let dist_l2 = quantizer.distance_l2_simd(&qv1, &qv2);
2472 let dist_cosine = quantizer.distance_cosine_simd(&qv1, &qv2);
2473
2474 assert!(dist_l2.is_finite());
2475 assert!(dist_cosine.is_finite());
2476 }
2477
2478 #[test]
2479 fn test_simd_performance_improvement() {
2480 let quantizer = RaBitQ::default_4bit();
2481
2482 let v1: Vec<f32> = (0..1536).map(|i| (i as f32) / 1536.0).collect();
2484 let v2: Vec<f32> = (0..1536).map(|i| ((i + 10) as f32) / 1536.0).collect();
2485
2486 let qv1 = quantizer.quantize(&v1);
2487 let qv2 = quantizer.quantize(&v2);
2488
2489 let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2491 assert!(dist_simd > 0.0 && dist_simd.is_finite());
2492
2493 }
2495
2496 #[test]
2497 fn test_scalar_distance_functions() {
2498 let v1 = vec![0.0, 0.0, 0.0];
2500 let v2 = vec![1.0, 0.0, 0.0];
2501
2502 let dist = l2_distance_scalar(&v1, &v2);
2503 assert!((dist - 1.0).abs() < 0.001);
2504
2505 let v1 = vec![1.0, 0.0, 0.0];
2506 let v2 = vec![0.0, 1.0, 0.0];
2507
2508 let dist = cosine_distance_scalar(&v1, &v2);
2509 assert!((dist - 1.0).abs() < 0.001);
2510 }
2511
2512 #[test]
2515 fn test_adc_table_creation() {
2516 let quantizer = RaBitQ::default_4bit();
2517 let query = vec![0.1, 0.2, 0.3, 0.4];
2518 let scale = 1.0;
2519
2520 let adc = quantizer.build_adc_table_with_scale(&query, scale);
2521
2522 assert_eq!(adc.dimensions, 4);
2524 assert_eq!(adc.bits, 4);
2525 assert_eq!(adc.table.len(), 4);
2526
2527 for dim_table in &adc.table {
2529 assert_eq!(dim_table.len(), 16);
2530 }
2531 }
2532
2533 #[test]
2534 fn test_adc_table_2bit() {
2535 let quantizer = RaBitQ::new(RaBitQParams::bits2());
2536 let query = vec![0.1, 0.2, 0.3, 0.4];
2537 let scale = 1.0;
2538
2539 let adc = quantizer.build_adc_table_with_scale(&query, scale);
2540
2541 for dim_table in &adc.table {
2543 assert_eq!(dim_table.len(), 4);
2544 }
2545 }
2546
2547 #[test]
2548 fn test_adc_distance_matches_asymmetric() {
2549 let quantizer = RaBitQ::default_4bit();
2550
2551 let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2553 let vector = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
2554
2555 let quantized = quantizer.quantize(&vector);
2557
2558 let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2560
2561 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2563 let dist_adc = adc.distance(&quantized.data);
2564
2565 let diff = (dist_asymmetric - dist_adc).abs();
2568 assert!(
2569 diff < 0.1,
2570 "ADC vs asymmetric: {dist_adc} vs {dist_asymmetric}, diff: {diff}"
2571 );
2572 }
2573
2574 #[test]
2575 fn test_adc_distance_accuracy() {
2576 let quantizer = RaBitQ::new(RaBitQParams {
2577 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 16,
2579 rescale_range: (0.8, 1.2),
2580 });
2581
2582 let query = vec![0.1, 0.2, 0.3, 0.4];
2583 let vector = vec![0.1, 0.2, 0.3, 0.4]; let quantized = quantizer.quantize(&vector);
2586
2587 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2589
2590 let dist = adc.distance(&quantized.data);
2592 assert!(dist < 0.2, "Distance should be near zero, got: {dist}");
2593 }
2594
2595 #[test]
2596 fn test_adc_distance_ordering() {
2597 let quantizer = RaBitQ::default_4bit();
2598
2599 let query = vec![0.5, 0.5, 0.5, 0.5];
2600 let v1 = vec![0.5, 0.5, 0.5, 0.5]; let v2 = vec![0.6, 0.6, 0.6, 0.6]; let v3 = vec![0.9, 0.9, 0.9, 0.9]; let qv1 = quantizer.quantize(&v1);
2605 let qv2 = quantizer.quantize(&v2);
2606 let qv3 = quantizer.quantize(&v3);
2607
2608 let adc1 = quantizer.build_adc_table_with_scale(&query, qv1.scale);
2610 let adc2 = quantizer.build_adc_table_with_scale(&query, qv2.scale);
2611 let adc3 = quantizer.build_adc_table_with_scale(&query, qv3.scale);
2612
2613 let dist1 = adc1.distance(&qv1.data);
2614 let dist2 = adc2.distance(&qv2.data);
2615 let dist3 = adc3.distance(&qv3.data);
2616
2617 assert!(
2619 dist1 < dist2,
2620 "v1 should be closer than v2: {dist1} vs {dist2}"
2621 );
2622 assert!(
2623 dist2 < dist3,
2624 "v2 should be closer than v3: {dist2} vs {dist3}"
2625 );
2626 }
2627
2628 #[test]
2629 fn test_adc_high_dimensional() {
2630 let quantizer = RaBitQ::default_4bit();
2631
2632 let query: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2634 let vector: Vec<f32> = (0..128).map(|i| ((i + 5) as f32) / 128.0).collect();
2635
2636 let quantized = quantizer.quantize(&vector);
2637
2638 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2640
2641 let dist = adc.distance(&quantized.data);
2643 assert!(dist > 0.0 && dist.is_finite());
2644 }
2645
2646 #[test]
2647 fn test_adc_batch_search() {
2648 let quantizer = RaBitQ::default_4bit();
2649
2650 let query = vec![0.5, 0.5, 0.5, 0.5];
2651 let candidates = [
2652 vec![0.5, 0.5, 0.5, 0.5],
2653 vec![0.6, 0.6, 0.6, 0.6],
2654 vec![0.4, 0.4, 0.4, 0.4],
2655 vec![0.7, 0.7, 0.7, 0.7],
2656 ];
2657
2658 let quantized: Vec<QuantizedVector> =
2660 candidates.iter().map(|v| quantizer.quantize(v)).collect();
2661
2662 let mut results: Vec<(usize, f32)> = quantized
2664 .iter()
2665 .enumerate()
2666 .map(|(i, qv)| {
2667 let adc = quantizer.build_adc_table_with_scale(&query, qv.scale);
2668 (i, adc.distance(&qv.data))
2669 })
2670 .collect();
2671
2672 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
2674
2675 assert_eq!(results[0].0, 0, "Results: {results:?}");
2677 }
2678
2679 #[test]
2680 fn test_adc_distance_squared() {
2681 let quantizer = RaBitQ::default_4bit();
2682
2683 let query = vec![0.0, 0.0, 0.0];
2684 let vector = vec![1.0, 0.0, 0.0];
2685
2686 let quantized = quantizer.quantize(&vector);
2687 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2688
2689 let dist_squared = adc.distance_squared(&quantized.data);
2690 let dist = adc.distance(&quantized.data);
2691
2692 let diff = (dist_squared - dist * dist).abs();
2694 assert!(
2695 diff < 0.01,
2696 "distance_squared != dist^2: {} vs {}",
2697 dist_squared,
2698 dist * dist
2699 );
2700 }
2701
2702 #[test]
2703 fn test_adc_simd_matches_scalar() {
2704 let quantizer = RaBitQ::default_4bit();
2705
2706 let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2707 let vector = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
2708
2709 let quantized = quantizer.quantize(&vector);
2710 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2711
2712 let dist_scalar = adc.distance_squared(&quantized.data);
2713 let dist_simd = adc.distance_squared_simd(&quantized.data);
2714
2715 let diff = (dist_scalar - dist_simd).abs();
2717 assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2718 }
2719
2720 #[test]
2721 fn test_adc_simd_high_dimensional() {
2722 let quantizer = RaBitQ::default_4bit();
2723
2724 let query: Vec<f32> = (0..1536).map(|i| (i as f32) / 1536.0).collect();
2726 let vector: Vec<f32> = (0..1536).map(|i| ((i + 10) as f32) / 1536.0).collect();
2727
2728 let quantized = quantizer.quantize(&vector);
2729 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2730
2731 let dist_simd = adc.distance_squared_simd(&quantized.data);
2733 assert!(dist_simd > 0.0 && dist_simd.is_finite());
2734 }
2735
2736 #[test]
2737 fn test_adc_memory_usage() {
2738 let quantizer = RaBitQ::default_4bit();
2739
2740 let query: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2741 let adc = quantizer.build_adc_table_with_scale(&query, 1.0);
2742
2743 let memory = adc.memory_bytes();
2744
2745 let expected_min = 128 * 16 * 4;
2747 assert!(
2748 memory >= expected_min,
2749 "Memory {memory} should be at least {expected_min}"
2750 );
2751 }
2752
2753 #[test]
2754 fn test_adc_different_scales() {
2755 let quantizer = RaBitQ::default_4bit();
2756
2757 let query = vec![0.5, 0.5, 0.5, 0.5];
2758 let vector = vec![0.6, 0.6, 0.6, 0.6];
2759
2760 let quantized = quantizer.quantize(&vector);
2761
2762 let adc1 = quantizer.build_adc_table_with_scale(&query, 0.5);
2764 let adc2 = quantizer.build_adc_table_with_scale(&query, 1.0);
2765 let adc3 = quantizer.build_adc_table_with_scale(&query, 2.0);
2766
2767 let dist1 = adc1.distance(&quantized.data);
2769 let dist2 = adc2.distance(&quantized.data);
2770 let dist3 = adc3.distance(&quantized.data);
2771
2772 assert!(dist1.is_finite());
2774 assert!(dist2.is_finite());
2775 assert!(dist3.is_finite());
2776 }
2777
2778 #[test]
2779 fn test_adc_edge_cases() {
2780 let quantizer = RaBitQ::default_4bit();
2781
2782 let query = vec![0.5];
2784 let vector = vec![0.6];
2785 let quantized = quantizer.quantize(&vector);
2786 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2787 let dist = adc.distance(&quantized.data);
2788 assert!(dist.is_finite());
2789
2790 let query = vec![0.0, 0.0, 0.0, 0.0];
2792 let vector = vec![0.0, 0.0, 0.0, 0.0];
2793 let quantized = quantizer.quantize(&vector);
2794 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2795 let dist = adc.distance(&quantized.data);
2796 assert!(dist.is_finite());
2797 }
2798
2799 #[test]
2800 fn test_adc_2bit_accuracy() {
2801 let quantizer = RaBitQ::new(RaBitQParams::bits2());
2802
2803 let query = vec![0.1, 0.2, 0.3, 0.4];
2804 let vector = vec![0.12, 0.22, 0.32, 0.42];
2805
2806 let quantized = quantizer.quantize(&vector);
2807
2808 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2810 let dist_adc = adc.distance(&quantized.data);
2811 let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2812
2813 let diff = (dist_adc - dist_asymmetric).abs();
2815 assert!(diff < 0.2, "2-bit ADC diff too large: {diff}");
2816 }
2817
2818 #[test]
2819 fn test_adc_8bit_accuracy() {
2820 let quantizer = RaBitQ::new(RaBitQParams::bits8());
2821
2822 let query = vec![0.1, 0.2, 0.3, 0.4];
2823 let vector = vec![0.12, 0.22, 0.32, 0.42];
2824
2825 let quantized = quantizer.quantize(&vector);
2826
2827 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2829 let dist_adc = adc.distance(&quantized.data);
2830 let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2831
2832 let diff = (dist_adc - dist_asymmetric).abs();
2834 assert!(
2835 diff < 0.05,
2836 "8-bit ADC should be highly accurate, diff: {diff}"
2837 );
2838 }
2839}