1use serde::{Deserialize, Serialize};
25use smallvec::SmallVec;
26use std::fmt;
27
28#[cfg(target_arch = "aarch64")]
29use std::arch::aarch64::{vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vsubq_f32};
30#[cfg(target_arch = "x86_64")]
31#[allow(clippy::wildcard_imports)]
32use std::arch::x86_64::*;
33
34const MAX_CODES: usize = 16;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum QuantizationBits {
40 Bits1,
42 Bits2,
44 Bits3,
46 Bits4,
48 Bits5,
50 Bits7,
52 Bits8,
54}
55
56impl QuantizationBits {
57 #[must_use]
59 pub fn to_u8(self) -> u8 {
60 match self {
61 QuantizationBits::Bits1 => 1,
62 QuantizationBits::Bits2 => 2,
63 QuantizationBits::Bits3 => 3,
64 QuantizationBits::Bits4 => 4,
65 QuantizationBits::Bits5 => 5,
66 QuantizationBits::Bits7 => 7,
67 QuantizationBits::Bits8 => 8,
68 }
69 }
70
71 #[must_use]
73 pub fn levels(self) -> usize {
74 1 << self.to_u8()
75 }
76
77 #[must_use]
79 pub fn compression_ratio(self) -> f32 {
80 32.0 / self.to_u8() as f32
81 }
82
83 #[must_use]
85 pub fn values_per_byte(self) -> usize {
86 8 / self.to_u8() as usize
87 }
88}
89
90impl fmt::Display for QuantizationBits {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 write!(f, "{}-bit", self.to_u8())
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct RaBitQParams {
99 pub bits_per_dim: QuantizationBits,
101
102 pub num_rescale_factors: usize,
107
108 pub rescale_range: (f32, f32),
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct TrainedParams {
120 pub mins: Vec<f32>,
122 pub maxs: Vec<f32>,
124 pub dimensions: usize,
126}
127
128impl TrainedParams {
129 pub fn train(vectors: &[&[f32]]) -> Result<Self, &'static str> {
140 Self::train_with_percentiles(vectors, 0.01, 0.99)
141 }
142
143 pub fn train_with_percentiles(
153 vectors: &[&[f32]],
154 lower_percentile: f32,
155 upper_percentile: f32,
156 ) -> Result<Self, &'static str> {
157 if vectors.is_empty() {
158 return Err("Need at least one vector to train");
159 }
160 let dimensions = vectors[0].len();
161 if !vectors.iter().all(|v| v.len() == dimensions) {
162 return Err("All vectors must have same dimensions");
163 }
164
165 let n = vectors.len();
166 let lower_idx = ((n as f32 * lower_percentile) as usize).min(n - 1);
167 let upper_idx = ((n as f32 * upper_percentile) as usize).min(n - 1);
168
169 let mut mins = Vec::with_capacity(dimensions);
170 let mut maxs = Vec::with_capacity(dimensions);
171
172 let mut dim_values: Vec<f32> = Vec::with_capacity(n);
174 for d in 0..dimensions {
175 dim_values.clear();
176 for v in vectors {
177 dim_values.push(v[d]);
178 }
179 dim_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
180
181 let min_val = dim_values[lower_idx];
182 let max_val = dim_values[upper_idx];
183
184 let range = max_val - min_val;
186 if range < 1e-7 {
187 mins.push(min_val - 0.5);
188 maxs.push(max_val + 0.5);
189 } else {
190 mins.push(min_val);
191 maxs.push(max_val);
192 }
193 }
194
195 Ok(Self {
196 mins,
197 maxs,
198 dimensions,
199 })
200 }
201
202 #[inline]
204 #[must_use]
205 pub fn quantize_value(&self, value: f32, dim: usize, levels: usize) -> u8 {
206 let min = self.mins[dim];
207 let max = self.maxs[dim];
208 let range = max - min;
209
210 let normalized = (value - min) / range;
212 let level = (normalized * (levels - 1) as f32).round();
213 level.clamp(0.0, (levels - 1) as f32) as u8
214 }
215
216 #[inline]
218 #[must_use]
219 pub fn dequantize_value(&self, code: u8, dim: usize, levels: usize) -> f32 {
220 let min = self.mins[dim];
221 let max = self.maxs[dim];
222 let range = max - min;
223
224 (code as f32 / (levels - 1) as f32) * range + min
226 }
227}
228
229impl Default for RaBitQParams {
230 fn default() -> Self {
231 Self {
232 bits_per_dim: QuantizationBits::Bits4, num_rescale_factors: 12, rescale_range: (0.5, 2.0), }
236 }
237}
238
239impl RaBitQParams {
240 #[must_use]
242 pub fn bits2() -> Self {
243 Self {
244 bits_per_dim: QuantizationBits::Bits2,
245 ..Default::default()
246 }
247 }
248
249 #[must_use]
251 pub fn bits4() -> Self {
252 Self {
253 bits_per_dim: QuantizationBits::Bits4,
254 ..Default::default()
255 }
256 }
257
258 #[must_use]
260 pub fn bits8() -> Self {
261 Self {
262 bits_per_dim: QuantizationBits::Bits8,
263 num_rescale_factors: 16, rescale_range: (0.7, 1.5), }
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct QuantizedVector {
277 pub data: Vec<u8>,
285
286 pub scale: f32,
291
292 pub bits: u8,
294
295 pub dimensions: usize,
297}
298
299impl QuantizedVector {
300 #[must_use]
302 pub fn new(data: Vec<u8>, scale: f32, bits: u8, dimensions: usize) -> Self {
303 Self {
304 data,
305 scale,
306 bits,
307 dimensions,
308 }
309 }
310
311 #[must_use]
313 pub fn memory_bytes(&self) -> usize {
314 std::mem::size_of::<Self>() + self.data.len()
315 }
316
317 #[must_use]
319 pub fn compression_ratio(&self) -> f32 {
320 let original_bytes = self.dimensions * 4; let compressed_bytes = self.data.len() + 4 + 1; original_bytes as f32 / compressed_bytes as f32
323 }
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct RaBitQ {
348 params: RaBitQParams,
349 trained: Option<TrainedParams>,
352}
353
354#[derive(Debug, Clone)]
382pub struct ADCTable {
383 table: Vec<SmallVec<[f32; MAX_CODES]>>,
387
388 bits: u8,
390
391 dimensions: usize,
393}
394
395impl ADCTable {
396 #[must_use]
407 pub fn new_trained(query: &[f32], trained: &TrainedParams, params: &RaBitQParams) -> Self {
408 let bits = params.bits_per_dim.to_u8();
409 let num_codes = params.bits_per_dim.levels();
410 let dimensions = query.len();
411
412 let mut table = Vec::with_capacity(dimensions);
413
414 for (d, &q_value) in query.iter().enumerate() {
416 let mut dim_table = SmallVec::new();
417
418 for code in 0..num_codes {
419 let reconstructed = trained.dequantize_value(code as u8, d, num_codes);
421
422 let diff = q_value - reconstructed;
424 dim_table.push(diff * diff);
425 }
426
427 table.push(dim_table);
428 }
429
430 Self {
431 table,
432 bits,
433 dimensions,
434 }
435 }
436
437 #[must_use]
459 pub fn new(query: &[f32], scale: f32, params: &RaBitQParams) -> Self {
460 let bits = params.bits_per_dim.to_u8();
461 let num_codes = params.bits_per_dim.levels();
462 let dimensions = query.len();
463
464 let mut table = Vec::with_capacity(dimensions);
465
466 let levels = num_codes as f32;
468 let dequant_factor = 1.0 / ((levels - 1.0) * scale);
469
470 for &q_value in query {
472 let mut dim_table = SmallVec::new();
473
474 for code in 0..num_codes {
475 let reconstructed = (code as f32) * dequant_factor;
477
478 let diff = q_value - reconstructed;
480 dim_table.push(diff * diff);
481 }
482
483 table.push(dim_table);
484 }
485
486 Self {
487 table,
488 bits,
489 dimensions,
490 }
491 }
492
493 #[inline]
512 #[must_use]
513 pub fn distance_squared(&self, data: &[u8]) -> f32 {
514 match self.bits {
515 4 => self.distance_squared_4bit(data),
516 2 => self.distance_squared_2bit(data),
517 8 => self.distance_squared_8bit(data),
518 _ => self.distance_squared_generic(data),
519 }
520 }
521
522 #[inline]
526 #[must_use]
527 pub fn distance(&self, data: &[u8]) -> f32 {
528 self.distance_squared_simd(data).sqrt()
529 }
530
531 #[inline]
538 fn distance_squared_4bit(&self, data: &[u8]) -> f32 {
539 let mut sum = 0.0f32;
540 let num_pairs = self.dimensions / 2;
541
542 for i in 0..num_pairs {
544 if i >= data.len() {
545 break;
546 }
547
548 let byte = unsafe { *data.get_unchecked(i) };
550 let code_hi = (byte >> 4) as usize; let code_lo = (byte & 0x0F) as usize; sum += unsafe {
559 *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
560 + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo)
561 };
562 }
563
564 if self.dimensions % 2 == 1 && num_pairs < data.len() {
566 let byte = unsafe { *data.get_unchecked(num_pairs) };
568 let code_hi = (byte >> 4) as usize; sum += unsafe {
573 *self
574 .table
575 .get_unchecked(self.dimensions - 1)
576 .get_unchecked(code_hi)
577 };
578 }
579
580 sum
581 }
582
583 #[inline]
590 fn distance_squared_2bit(&self, data: &[u8]) -> f32 {
591 let mut sum = 0.0f32;
592 let num_quads = self.dimensions / 4;
593
594 for i in 0..num_quads {
596 if i >= data.len() {
597 break;
598 }
599
600 let byte = unsafe { *data.get_unchecked(i) };
602
603 sum += unsafe {
608 *self
609 .table
610 .get_unchecked(i * 4)
611 .get_unchecked((byte & 0b11) as usize)
612 + *self
613 .table
614 .get_unchecked(i * 4 + 1)
615 .get_unchecked(((byte >> 2) & 0b11) as usize)
616 + *self
617 .table
618 .get_unchecked(i * 4 + 2)
619 .get_unchecked(((byte >> 4) & 0b11) as usize)
620 + *self
621 .table
622 .get_unchecked(i * 4 + 3)
623 .get_unchecked(((byte >> 6) & 0b11) as usize)
624 };
625 }
626
627 let remaining = self.dimensions % 4;
629 if remaining > 0 && num_quads < data.len() {
630 let byte = unsafe { *data.get_unchecked(num_quads) };
632 for j in 0..remaining {
633 let code = ((byte >> (j * 2)) & 0b11) as usize; sum += unsafe {
638 *self
639 .table
640 .get_unchecked(num_quads * 4 + j)
641 .get_unchecked(code)
642 };
643 }
644 }
645
646 sum
647 }
648
649 #[inline]
656 fn distance_squared_8bit(&self, data: &[u8]) -> f32 {
657 let mut sum = 0.0f32;
658
659 for (i, &byte) in data.iter().enumerate().take(self.dimensions) {
660 sum += unsafe { *self.table.get_unchecked(i).get_unchecked(byte as usize) };
665 }
666
667 sum
668 }
669
670 #[inline]
672 fn distance_squared_generic(&self, data: &[u8]) -> f32 {
673 let mut sum = 0.0f32;
675
676 for (i, dim_table) in self.table.iter().enumerate() {
677 if i >= data.len() {
678 break;
679 }
680 let code = data[i] as usize;
681 if let Some(&dist) = dim_table.get(code) {
682 sum += dist;
683 }
684 }
685
686 sum
687 }
688
689 #[inline]
694 #[must_use]
695 pub fn distance_squared_simd(&self, data: &[u8]) -> f32 {
696 match self.bits {
697 4 => {
698 #[cfg(target_arch = "x86_64")]
699 {
700 if is_x86_feature_detected!("avx2") {
701 unsafe { self.distance_squared_4bit_avx2(data) }
702 } else {
703 self.distance_squared_4bit(data)
705 }
706 }
707 #[cfg(target_arch = "aarch64")]
708 {
709 unsafe { self.distance_squared_4bit_neon(data) }
711 }
712 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
713 {
714 self.distance_squared_4bit(data)
716 }
717 }
718 2 => {
719 self.distance_squared_2bit(data)
721 }
722 8 => {
723 self.distance_squared_8bit(data)
725 }
726 _ => self.distance_squared_generic(data),
727 }
728 }
729
730 #[cfg(target_arch = "x86_64")]
732 #[target_feature(enable = "avx2")]
733 #[target_feature(enable = "fma")]
734 unsafe fn distance_squared_4bit_avx2(&self, data: &[u8]) -> f32 {
735 let mut sum = _mm256_setzero_ps();
736 let num_pairs = self.dimensions / 2;
737
738 let chunks = num_pairs / 8;
740 for chunk_idx in 0..chunks {
741 let byte_idx = chunk_idx * 8;
742 if byte_idx + 8 > data.len() {
743 break;
744 }
745
746 let mut values = [0.0f32; 8];
748 for (i, value) in values.iter_mut().enumerate() {
749 let byte = *data.get_unchecked(byte_idx + i);
750 let code_hi = (byte >> 4) as usize;
751 let code_lo = (byte & 0x0F) as usize;
752
753 let dist_hi = *self
754 .table
755 .get_unchecked((byte_idx + i) * 2)
756 .get_unchecked(code_hi);
757 let dist_lo = *self
758 .table
759 .get_unchecked((byte_idx + i) * 2 + 1)
760 .get_unchecked(code_lo);
761 *value = dist_hi + dist_lo;
762 }
763
764 let vec = _mm256_loadu_ps(values.as_ptr());
765 sum = _mm256_add_ps(sum, vec);
766 }
767
768 let mut result = horizontal_sum_avx2(sum);
770
771 for i in (chunks * 8)..num_pairs {
773 if i >= data.len() {
774 break;
775 }
776 let byte = *data.get_unchecked(i);
777 let code_hi = (byte >> 4) as usize;
778 let code_lo = (byte & 0x0F) as usize;
779
780 result += *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
781 + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo);
782 }
783
784 if self.dimensions % 2 == 1 && num_pairs < data.len() {
786 let byte = *data.get_unchecked(num_pairs);
787 let code_hi = (byte >> 4) as usize;
788 result += *self
789 .table
790 .get_unchecked(self.dimensions - 1)
791 .get_unchecked(code_hi);
792 }
793
794 result
795 }
796
797 #[cfg(target_arch = "aarch64")]
799 unsafe fn distance_squared_4bit_neon(&self, data: &[u8]) -> f32 {
800 let mut sum = vdupq_n_f32(0.0);
801 let num_pairs = self.dimensions / 2;
802
803 let chunks = num_pairs / 4;
805 for chunk_idx in 0..chunks {
806 let byte_idx = chunk_idx * 4;
807 if byte_idx + 4 > data.len() {
808 break;
809 }
810
811 let mut values = [0.0f32; 4];
812 for (i, value) in values.iter_mut().enumerate() {
813 let byte = *data.get_unchecked(byte_idx + i);
814 let code_hi = (byte >> 4) as usize;
815 let code_lo = (byte & 0x0F) as usize;
816
817 let dist_hi = *self
818 .table
819 .get_unchecked((byte_idx + i) * 2)
820 .get_unchecked(code_hi);
821 let dist_lo = *self
822 .table
823 .get_unchecked((byte_idx + i) * 2 + 1)
824 .get_unchecked(code_lo);
825 *value = dist_hi + dist_lo;
826 }
827
828 let vec = vld1q_f32(values.as_ptr());
829 sum = vaddq_f32(sum, vec);
830 }
831
832 let mut result = vaddvq_f32(sum);
833
834 for i in (chunks * 4)..num_pairs {
836 if i >= data.len() {
837 break;
838 }
839 let byte = *data.get_unchecked(i);
840 let code_hi = (byte >> 4) as usize;
841 let code_lo = (byte & 0x0F) as usize;
842
843 result += *self.table.get_unchecked(i * 2).get_unchecked(code_hi)
844 + *self.table.get_unchecked(i * 2 + 1).get_unchecked(code_lo);
845 }
846
847 if self.dimensions % 2 == 1 && num_pairs < data.len() {
849 let byte = *data.get_unchecked(num_pairs);
850 let code_hi = (byte >> 4) as usize;
851 result += *self
852 .table
853 .get_unchecked(self.dimensions - 1)
854 .get_unchecked(code_hi);
855 }
856
857 result
858 }
859
860 #[must_use]
862 pub fn bits(&self) -> u8 {
863 self.bits
864 }
865
866 #[must_use]
868 pub fn dimensions(&self) -> usize {
869 self.dimensions
870 }
871
872 #[must_use]
876 pub fn get(&self, dim: usize, code: usize) -> f32 {
877 self.table
878 .get(dim)
879 .and_then(|t| t.get(code))
880 .copied()
881 .unwrap_or(0.0)
882 }
883
884 #[must_use]
886 pub fn memory_bytes(&self) -> usize {
887 std::mem::size_of::<Self>()
888 + self.table.len() * std::mem::size_of::<SmallVec<[f32; MAX_CODES]>>()
889 + self.table.iter().map(|t| t.len() * 4).sum::<usize>()
890 }
891}
892
893impl RaBitQ {
894 #[must_use]
898 pub fn new(params: RaBitQParams) -> Self {
899 Self {
900 params,
901 trained: None,
902 }
903 }
904
905 #[must_use]
907 pub fn new_trained(params: RaBitQParams, trained: TrainedParams) -> Self {
908 Self {
909 params,
910 trained: Some(trained),
911 }
912 }
913
914 #[must_use]
916 pub fn default_4bit() -> Self {
917 Self::new(RaBitQParams::bits4())
918 }
919
920 #[must_use]
922 pub fn params(&self) -> &RaBitQParams {
923 &self.params
924 }
925
926 #[must_use]
928 pub fn is_trained(&self) -> bool {
929 self.trained.is_some()
930 }
931
932 #[must_use]
934 pub fn trained_params(&self) -> Option<&TrainedParams> {
935 self.trained.as_ref()
936 }
937
938 pub fn train(&mut self, vectors: &[&[f32]]) -> Result<(), &'static str> {
949 self.trained = Some(TrainedParams::train(vectors)?);
950 Ok(())
951 }
952
953 pub fn train_owned(&mut self, vectors: &[Vec<f32>]) -> Result<(), &'static str> {
958 let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
959 self.train(&refs)
960 }
961
962 #[must_use]
967 pub fn quantize(&self, vector: &[f32]) -> QuantizedVector {
968 if let Some(ref trained) = self.trained {
970 return self.quantize_trained(vector, trained);
971 }
972
973 let mut best_error = f32::MAX;
975 let mut best_quantized = Vec::new();
976 let mut best_scale = 1.0;
977
978 let scales = self.generate_scales();
980
981 for scale in scales {
983 let quantized = self.quantize_with_scale(vector, scale);
984 let error = self.compute_error(vector, &quantized, scale);
985
986 if error < best_error {
987 best_error = error;
988 best_quantized = quantized;
989 best_scale = scale;
990 }
991 }
992
993 QuantizedVector::new(
994 best_quantized,
995 best_scale,
996 self.params.bits_per_dim.to_u8(),
997 vector.len(),
998 )
999 }
1000
1001 fn quantize_trained(&self, vector: &[f32], trained: &TrainedParams) -> QuantizedVector {
1005 let bits = self.params.bits_per_dim.to_u8();
1006 let levels = self.params.bits_per_dim.levels();
1007
1008 let quantized: Vec<u8> = vector
1010 .iter()
1011 .enumerate()
1012 .map(|(d, &v)| trained.quantize_value(v, d, levels))
1013 .collect();
1014
1015 let packed = self.pack_quantized(&quantized, bits);
1017
1018 QuantizedVector::new(packed, 1.0, bits, vector.len())
1020 }
1021
1022 fn generate_scales(&self) -> Vec<f32> {
1027 let (min_scale, max_scale) = self.params.rescale_range;
1028 let n = self.params.num_rescale_factors;
1029
1030 if n == 1 {
1031 return vec![f32::midpoint(min_scale, max_scale)];
1032 }
1033
1034 let step = (max_scale - min_scale) / (n - 1) as f32;
1035 (0..n).map(|i| min_scale + i as f32 * step).collect()
1036 }
1037
1038 fn quantize_with_scale(&self, vector: &[f32], scale: f32) -> Vec<u8> {
1046 let bits = self.params.bits_per_dim.to_u8();
1047 let levels = self.params.bits_per_dim.levels() as f32;
1048 let max_level = (levels - 1.0) as u8;
1049
1050 let quantized: Vec<u8> = vector
1052 .iter()
1053 .map(|&v| {
1054 let scaled = v * scale;
1056 let level = (scaled * (levels - 1.0)).round();
1058 level.clamp(0.0, max_level as f32) as u8
1060 })
1061 .collect();
1062
1063 self.pack_quantized(&quantized, bits)
1065 }
1066
1067 #[allow(clippy::unused_self)]
1074 fn pack_quantized(&self, values: &[u8], bits: u8) -> Vec<u8> {
1075 match bits {
1076 2 => {
1077 let mut packed = Vec::with_capacity(values.len().div_ceil(4));
1079 for chunk in values.chunks(4) {
1080 let mut byte = 0u8;
1081 for (i, &val) in chunk.iter().enumerate() {
1082 byte |= (val & 0b11) << (i * 2);
1083 }
1084 packed.push(byte);
1085 }
1086 packed
1087 }
1088 4 => {
1089 let mut packed = Vec::with_capacity(values.len().div_ceil(2));
1091 for chunk in values.chunks(2) {
1092 let byte = if chunk.len() == 2 {
1093 (chunk[0] << 4) | (chunk[1] & 0x0F)
1094 } else {
1095 chunk[0] << 4
1096 };
1097 packed.push(byte);
1098 }
1099 packed
1100 }
1101 8 => {
1102 values.to_vec()
1104 }
1105 _ => {
1106 values.to_vec()
1112 }
1113 }
1114 }
1115
1116 #[must_use]
1118 pub fn unpack_quantized(&self, packed: &[u8], bits: u8, dimensions: usize) -> Vec<u8> {
1119 match bits {
1120 2 => {
1121 let mut values = Vec::with_capacity(dimensions);
1123 for &byte in packed {
1124 for i in 0..4 {
1125 if values.len() < dimensions {
1126 values.push((byte >> (i * 2)) & 0b11);
1127 }
1128 }
1129 }
1130 values
1131 }
1132 4 => {
1133 let mut values = Vec::with_capacity(dimensions);
1135 for &byte in packed {
1136 values.push(byte >> 4);
1137 if values.len() < dimensions {
1138 values.push(byte & 0x0F);
1139 }
1140 }
1141 values.truncate(dimensions);
1142 values
1143 }
1144 8 => {
1145 packed[..dimensions.min(packed.len())].to_vec()
1147 }
1148 _ => {
1149 packed[..dimensions.min(packed.len())].to_vec()
1151 }
1152 }
1153 }
1154
1155 fn compute_error(&self, original: &[f32], quantized: &[u8], scale: f32) -> f32 {
1159 let reconstructed = self.reconstruct(quantized, scale, original.len());
1160
1161 original
1162 .iter()
1163 .zip(reconstructed.iter())
1164 .map(|(o, r)| (o - r).powi(2))
1165 .sum()
1166 }
1167
1168 #[must_use]
1175 pub fn reconstruct(&self, quantized: &[u8], scale: f32, dimensions: usize) -> Vec<f32> {
1176 let bits = self.params.bits_per_dim.to_u8();
1177 let levels = self.params.bits_per_dim.levels() as f32;
1178
1179 let values = self.unpack_quantized(quantized, bits, dimensions);
1181
1182 values
1184 .iter()
1185 .map(|&q| {
1186 let denorm = q as f32 / (levels - 1.0);
1188 denorm / scale
1190 })
1191 .collect()
1192 }
1193
1194 #[must_use]
1199 pub fn distance_l2(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1200 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1201 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1202
1203 v1.iter()
1204 .zip(v2.iter())
1205 .map(|(a, b)| (a - b).powi(2))
1206 .sum::<f32>()
1207 .sqrt()
1208 }
1209
1210 #[must_use]
1214 pub fn distance_cosine(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1215 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1216 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1217
1218 let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
1219 let norm1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
1220 let norm2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
1221
1222 if norm1 < 1e-10 || norm2 < 1e-10 {
1223 return 1.0; }
1225
1226 let cosine_sim = dot / (norm1 * norm2);
1227 1.0 - cosine_sim
1228 }
1229
1230 #[must_use]
1232 pub fn distance_dot(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1233 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1234 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1235
1236 -v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum::<f32>()
1238 }
1239
1240 #[must_use]
1245 pub fn distance_approximate(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1246 let v1 = self.unpack_quantized(&qv1.data, qv1.bits, qv1.dimensions);
1248 let v2 = self.unpack_quantized(&qv2.data, qv2.bits, qv2.dimensions);
1249
1250 v1.iter()
1252 .zip(v2.iter())
1253 .map(|(a, b)| {
1254 let diff = (*a as i16 - *b as i16) as f32;
1255 diff * diff
1256 })
1257 .sum::<f32>()
1258 .sqrt()
1259 }
1260
1261 #[must_use]
1269 pub fn distance_asymmetric_l2(&self, query: &[f32], quantized: &QuantizedVector) -> f32 {
1270 if let Some(trained) = &self.trained {
1272 return self.distance_asymmetric_l2_trained(query, &quantized.data, trained);
1273 }
1274 self.distance_asymmetric_l2_raw(query, &quantized.data, quantized.scale, quantized.bits)
1276 }
1277
1278 #[must_use]
1284 #[inline]
1285 pub fn distance_asymmetric_l2_flat(&self, query: &[f32], data: &[u8], scale: f32) -> f32 {
1286 if let Some(trained) = &self.trained {
1287 return self.distance_asymmetric_l2_trained(query, data, trained);
1288 }
1289 self.distance_asymmetric_l2_raw(query, data, scale, self.params.bits_per_dim.to_u8())
1291 }
1292
1293 #[must_use]
1298 fn distance_asymmetric_l2_trained(
1299 &self,
1300 query: &[f32],
1301 data: &[u8],
1302 trained: &TrainedParams,
1303 ) -> f32 {
1304 let levels = self.params.bits_per_dim.levels() as f32;
1305 let bits = self.params.bits_per_dim.to_u8();
1306
1307 let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1309
1310 match bits {
1311 4 => {
1312 let num_pairs = query.len() / 2;
1313 if data.len() < query.len().div_ceil(2) {
1314 return f32::MAX;
1315 }
1316
1317 for i in 0..num_pairs {
1318 let byte = unsafe { *data.get_unchecked(i) };
1319 let d0 = i * 2;
1320 let d1 = i * 2 + 1;
1321
1322 let code0 = (byte >> 4) as f32;
1324 let code1 = (byte & 0x0F) as f32;
1325
1326 let range0 = trained.maxs[d0] - trained.mins[d0];
1327 let range1 = trained.maxs[d1] - trained.mins[d1];
1328
1329 buffer.push((code0 / (levels - 1.0)) * range0 + trained.mins[d0]);
1330 buffer.push((code1 / (levels - 1.0)) * range1 + trained.mins[d1]);
1331 }
1332
1333 if !query.len().is_multiple_of(2) {
1334 let byte = unsafe { *data.get_unchecked(num_pairs) };
1335 let d = num_pairs * 2;
1336 let code = (byte >> 4) as f32;
1337 let range = trained.maxs[d] - trained.mins[d];
1338 buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1339 }
1340 }
1341 2 => {
1342 let num_quads = query.len() / 4;
1343 if data.len() < query.len().div_ceil(4) {
1344 return f32::MAX;
1345 }
1346
1347 for i in 0..num_quads {
1348 let byte = unsafe { *data.get_unchecked(i) };
1349 for j in 0..4 {
1350 let d = i * 4 + j;
1351 let code = ((byte >> (j * 2)) & 0b11) as f32;
1352 let range = trained.maxs[d] - trained.mins[d];
1353 buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1354 }
1355 }
1356
1357 let remaining = query.len() % 4;
1358 if remaining > 0 {
1359 let byte = unsafe { *data.get_unchecked(num_quads) };
1360 for j in 0..remaining {
1361 let d = num_quads * 4 + j;
1362 let code = ((byte >> (j * 2)) & 0b11) as f32;
1363 let range = trained.maxs[d] - trained.mins[d];
1364 buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1365 }
1366 }
1367 }
1368 8 => {
1369 if data.len() < query.len() {
1370 return f32::MAX;
1371 }
1372 for (d, &byte) in data.iter().enumerate().take(query.len()) {
1373 let code = byte as f32;
1374 let range = trained.maxs[d] - trained.mins[d];
1375 buffer.push((code / (levels - 1.0)) * range + trained.mins[d]);
1376 }
1377 }
1378 _ => {
1379 let unpacked = self.unpack_quantized(data, bits, query.len());
1381 for (d, &code) in unpacked.iter().enumerate().take(query.len()) {
1382 let range = trained.maxs[d] - trained.mins[d];
1383 buffer.push((code as f32 / (levels - 1.0)) * range + trained.mins[d]);
1384 }
1385 }
1386 }
1387
1388 simd_l2_distance(query, &buffer)
1389 }
1390
1391 #[must_use]
1407 pub fn build_adc_table(&self, query: &[f32]) -> Option<ADCTable> {
1408 self.trained
1409 .as_ref()
1410 .map(|trained| ADCTable::new_trained(query, trained, &self.params))
1411 }
1412
1413 #[must_use]
1419 pub fn build_adc_table_with_scale(&self, query: &[f32], scale: f32) -> ADCTable {
1420 ADCTable::new(query, scale, &self.params)
1421 }
1422
1423 #[must_use]
1427 pub fn distance_with_adc(&self, query: &[f32], quantized: &QuantizedVector) -> Option<f32> {
1428 let adc = self.build_adc_table(query)?;
1429 Some(adc.distance(&quantized.data))
1430 }
1431
1432 #[must_use]
1437 pub fn distance_asymmetric_l2_raw(
1438 &self,
1439 query: &[f32],
1440 data: &[u8],
1441 scale: f32,
1442 bits: u8,
1443 ) -> f32 {
1444 let levels = self.params.bits_per_dim.levels() as f32;
1445
1446 let factor = 1.0 / ((levels - 1.0) * scale);
1450
1451 match bits {
1452 4 => {
1453 let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1455
1456 let num_pairs = query.len() / 2;
1457
1458 if data.len() < query.len().div_ceil(2) {
1460 return f32::MAX;
1462 }
1463
1464 for i in 0..num_pairs {
1465 let byte = unsafe { *data.get_unchecked(i) };
1466 buffer.push((byte >> 4) as f32 * factor);
1467 buffer.push((byte & 0x0F) as f32 * factor);
1468 }
1469
1470 if !query.len().is_multiple_of(2) {
1471 let byte = unsafe { *data.get_unchecked(num_pairs) };
1472 buffer.push((byte >> 4) as f32 * factor);
1473 }
1474
1475 simd_l2_distance(query, &buffer)
1476 }
1477 2 => {
1478 let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1479 let num_quads = query.len() / 4;
1480
1481 if data.len() < query.len().div_ceil(4) {
1482 return f32::MAX;
1483 }
1484
1485 for i in 0..num_quads {
1486 let byte = unsafe { *data.get_unchecked(i) };
1487 buffer.push((byte & 0b11) as f32 * factor);
1488 buffer.push(((byte >> 2) & 0b11) as f32 * factor);
1489 buffer.push(((byte >> 4) & 0b11) as f32 * factor);
1490 buffer.push(((byte >> 6) & 0b11) as f32 * factor);
1491 }
1492
1493 let remaining = query.len() % 4;
1495 if remaining > 0 {
1496 let byte = unsafe { *data.get_unchecked(num_quads) };
1497 for i in 0..remaining {
1498 buffer.push(((byte >> (i * 2)) & 0b11) as f32 * factor);
1499 }
1500 }
1501
1502 simd_l2_distance(query, &buffer)
1503 }
1504 _ => {
1505 let unpacked = self.unpack_quantized(data, bits, query.len());
1509 let mut buffer: SmallVec<[f32; 256]> = SmallVec::with_capacity(query.len());
1510
1511 for &q in &unpacked {
1512 buffer.push(q as f32 * factor);
1513 }
1514
1515 simd_l2_distance(query, &buffer)
1516 }
1517 }
1518 }
1519
1520 #[inline]
1528 #[must_use]
1529 pub fn distance_l2_simd(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1530 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1532 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1533
1534 simd_l2_distance(&v1, &v2)
1536 }
1537
1538 #[inline]
1540 #[must_use]
1541 pub fn distance_cosine_simd(&self, qv1: &QuantizedVector, qv2: &QuantizedVector) -> f32 {
1542 let v1 = self.reconstruct(&qv1.data, qv1.scale, qv1.dimensions);
1543 let v2 = self.reconstruct(&qv2.data, qv2.scale, qv2.dimensions);
1544
1545 simd_cosine_distance(&v1, &v2)
1546 }
1547}
1548
1549#[inline]
1553fn simd_l2_distance(v1: &[f32], v2: &[f32]) -> f32 {
1554 #[cfg(target_arch = "x86_64")]
1555 {
1556 if is_x86_feature_detected!("avx2") {
1557 return unsafe { l2_distance_avx2(v1, v2) };
1558 } else if is_x86_feature_detected!("sse2") {
1559 return unsafe { l2_distance_sse2(v1, v2) };
1560 }
1561 l2_distance_scalar(v1, v2)
1563 }
1564
1565 #[cfg(target_arch = "aarch64")]
1566 {
1567 unsafe { l2_distance_neon(v1, v2) }
1569 }
1570
1571 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1572 {
1573 l2_distance_scalar(v1, v2)
1575 }
1576}
1577
1578#[inline]
1580fn simd_cosine_distance(v1: &[f32], v2: &[f32]) -> f32 {
1581 #[cfg(target_arch = "x86_64")]
1582 {
1583 if is_x86_feature_detected!("avx2") {
1584 return unsafe { cosine_distance_avx2(v1, v2) };
1585 } else if is_x86_feature_detected!("sse2") {
1586 return unsafe { cosine_distance_sse2(v1, v2) };
1587 }
1588 cosine_distance_scalar(v1, v2)
1590 }
1591
1592 #[cfg(target_arch = "aarch64")]
1593 {
1594 unsafe { cosine_distance_neon(v1, v2) }
1596 }
1597
1598 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1599 {
1600 cosine_distance_scalar(v1, v2)
1602 }
1603}
1604
1605#[inline]
1608#[allow(dead_code)] fn l2_distance_scalar(v1: &[f32], v2: &[f32]) -> f32 {
1610 v1.iter()
1611 .zip(v2.iter())
1612 .map(|(a, b)| (a - b).powi(2))
1613 .sum::<f32>()
1614 .sqrt()
1615}
1616
1617#[inline]
1618#[allow(dead_code)] fn cosine_distance_scalar(v1: &[f32], v2: &[f32]) -> f32 {
1620 let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
1621 let norm1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
1622 let norm2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
1623
1624 if norm1 < 1e-10 || norm2 < 1e-10 {
1625 return 1.0;
1626 }
1627
1628 let cosine_sim = dot / (norm1 * norm2);
1629 1.0 - cosine_sim
1630}
1631
1632#[cfg(target_arch = "x86_64")]
1635#[target_feature(enable = "avx2")]
1636#[target_feature(enable = "fma")]
1637unsafe fn l2_distance_avx2(v1: &[f32], v2: &[f32]) -> f32 {
1638 unsafe {
1639 let len = v1.len().min(v2.len());
1640 let mut sum = _mm256_setzero_ps();
1641
1642 let chunks = len / 8;
1643 for i in 0..chunks {
1644 let a = _mm256_loadu_ps(v1.as_ptr().add(i * 8));
1645 let b = _mm256_loadu_ps(v2.as_ptr().add(i * 8));
1646 let diff = _mm256_sub_ps(a, b);
1647 sum = _mm256_fmadd_ps(diff, diff, sum);
1648 }
1649
1650 let sum_high = _mm256_extractf128_ps(sum, 1);
1652 let sum_low = _mm256_castps256_ps128(sum);
1653 let sum128 = _mm_add_ps(sum_low, sum_high);
1654 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1655 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1656 let mut result = _mm_cvtss_f32(sum32);
1657
1658 for i in (chunks * 8)..len {
1660 let diff = v1[i] - v2[i];
1661 result += diff * diff;
1662 }
1663
1664 result.sqrt()
1665 }
1666}
1667
1668#[cfg(target_arch = "x86_64")]
1669#[target_feature(enable = "avx2")]
1670#[target_feature(enable = "fma")]
1671unsafe fn cosine_distance_avx2(v1: &[f32], v2: &[f32]) -> f32 {
1672 unsafe {
1673 let len = v1.len().min(v2.len());
1674 let mut dot_sum = _mm256_setzero_ps();
1675 let mut norm1_sum = _mm256_setzero_ps();
1676 let mut norm2_sum = _mm256_setzero_ps();
1677
1678 let chunks = len / 8;
1679 for i in 0..chunks {
1680 let a = _mm256_loadu_ps(v1.as_ptr().add(i * 8));
1681 let b = _mm256_loadu_ps(v2.as_ptr().add(i * 8));
1682 dot_sum = _mm256_fmadd_ps(a, b, dot_sum);
1683 norm1_sum = _mm256_fmadd_ps(a, a, norm1_sum);
1684 norm2_sum = _mm256_fmadd_ps(b, b, norm2_sum);
1685 }
1686
1687 let mut dot = horizontal_sum_avx2(dot_sum);
1689 let mut norm1 = horizontal_sum_avx2(norm1_sum);
1690 let mut norm2 = horizontal_sum_avx2(norm2_sum);
1691
1692 for i in (chunks * 8)..len {
1694 dot += v1[i] * v2[i];
1695 norm1 += v1[i] * v1[i];
1696 norm2 += v2[i] * v2[i];
1697 }
1698
1699 if norm1 < 1e-10 || norm2 < 1e-10 {
1700 return 1.0;
1701 }
1702
1703 let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1704 1.0 - cosine_sim
1705 }
1706}
1707
1708#[cfg(target_arch = "x86_64")]
1709#[inline]
1710unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
1711 unsafe {
1712 let sum_high = _mm256_extractf128_ps(v, 1);
1713 let sum_low = _mm256_castps256_ps128(v);
1714 let sum128 = _mm_add_ps(sum_low, sum_high);
1715 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
1716 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1717 _mm_cvtss_f32(sum32)
1718 }
1719}
1720
1721#[cfg(target_arch = "x86_64")]
1724#[target_feature(enable = "sse2")]
1725unsafe fn l2_distance_sse2(v1: &[f32], v2: &[f32]) -> f32 {
1726 unsafe {
1727 let len = v1.len().min(v2.len());
1728 let mut sum = _mm_setzero_ps();
1729
1730 let chunks = len / 4;
1731 for i in 0..chunks {
1732 let a = _mm_loadu_ps(v1.as_ptr().add(i * 4));
1733 let b = _mm_loadu_ps(v2.as_ptr().add(i * 4));
1734 let diff = _mm_sub_ps(a, b);
1735 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
1736 }
1737
1738 let sum64 = _mm_add_ps(sum, _mm_movehl_ps(sum, sum));
1740 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1741 let mut result = _mm_cvtss_f32(sum32);
1742
1743 for i in (chunks * 4)..len {
1745 let diff = v1[i] - v2[i];
1746 result += diff * diff;
1747 }
1748
1749 result.sqrt()
1750 }
1751}
1752
1753#[cfg(target_arch = "x86_64")]
1754#[target_feature(enable = "sse2")]
1755unsafe fn cosine_distance_sse2(v1: &[f32], v2: &[f32]) -> f32 {
1756 unsafe {
1757 let len = v1.len().min(v2.len());
1758 let mut dot_sum = _mm_setzero_ps();
1759 let mut norm1_sum = _mm_setzero_ps();
1760 let mut norm2_sum = _mm_setzero_ps();
1761
1762 let chunks = len / 4;
1763 for i in 0..chunks {
1764 let a = _mm_loadu_ps(v1.as_ptr().add(i * 4));
1765 let b = _mm_loadu_ps(v2.as_ptr().add(i * 4));
1766 dot_sum = _mm_add_ps(dot_sum, _mm_mul_ps(a, b));
1767 norm1_sum = _mm_add_ps(norm1_sum, _mm_mul_ps(a, a));
1768 norm2_sum = _mm_add_ps(norm2_sum, _mm_mul_ps(b, b));
1769 }
1770
1771 let mut dot = horizontal_sum_sse2(dot_sum);
1773 let mut norm1 = horizontal_sum_sse2(norm1_sum);
1774 let mut norm2 = horizontal_sum_sse2(norm2_sum);
1775
1776 for i in (chunks * 4)..len {
1778 dot += v1[i] * v2[i];
1779 norm1 += v1[i] * v1[i];
1780 norm2 += v2[i] * v2[i];
1781 }
1782
1783 if norm1 < 1e-10 || norm2 < 1e-10 {
1784 return 1.0;
1785 }
1786
1787 let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1788 1.0 - cosine_sim
1789 }
1790}
1791
1792#[cfg(target_arch = "x86_64")]
1793#[inline]
1794unsafe fn horizontal_sum_sse2(v: __m128) -> f32 {
1795 unsafe {
1796 let sum64 = _mm_add_ps(v, _mm_movehl_ps(v, v));
1797 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
1798 _mm_cvtss_f32(sum32)
1799 }
1800}
1801
1802#[cfg(target_arch = "aarch64")]
1805unsafe fn l2_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
1806 let len = v1.len().min(v2.len());
1807
1808 unsafe {
1810 let mut sum = vdupq_n_f32(0.0);
1811
1812 let chunks = len / 4;
1813 for i in 0..chunks {
1814 let a = vld1q_f32(v1.as_ptr().add(i * 4));
1815 let b = vld1q_f32(v2.as_ptr().add(i * 4));
1816 let diff = vsubq_f32(a, b);
1817 sum = vfmaq_f32(sum, diff, diff);
1818 }
1819
1820 let mut result = vaddvq_f32(sum);
1822
1823 for i in (chunks * 4)..len {
1825 let diff = v1[i] - v2[i];
1826 result += diff * diff;
1827 }
1828
1829 result.sqrt()
1830 }
1831}
1832
1833#[cfg(target_arch = "aarch64")]
1834unsafe fn cosine_distance_neon(v1: &[f32], v2: &[f32]) -> f32 {
1835 let len = v1.len().min(v2.len());
1836
1837 unsafe {
1839 let mut dot_sum = vdupq_n_f32(0.0);
1840 let mut norm1_sum = vdupq_n_f32(0.0);
1841 let mut norm2_sum = vdupq_n_f32(0.0);
1842
1843 let chunks = len / 4;
1844 for i in 0..chunks {
1845 let a = vld1q_f32(v1.as_ptr().add(i * 4));
1846 let b = vld1q_f32(v2.as_ptr().add(i * 4));
1847 dot_sum = vfmaq_f32(dot_sum, a, b);
1848 norm1_sum = vfmaq_f32(norm1_sum, a, a);
1849 norm2_sum = vfmaq_f32(norm2_sum, b, b);
1850 }
1851
1852 let mut dot = vaddvq_f32(dot_sum);
1854 let mut norm1 = vaddvq_f32(norm1_sum);
1855 let mut norm2 = vaddvq_f32(norm2_sum);
1856
1857 for i in (chunks * 4)..len {
1859 dot += v1[i] * v2[i];
1860 norm1 += v1[i] * v1[i];
1861 norm2 += v2[i] * v2[i];
1862 }
1863
1864 if norm1 < 1e-10 || norm2 < 1e-10 {
1865 return 1.0;
1866 }
1867
1868 let cosine_sim = dot / (norm1.sqrt() * norm2.sqrt());
1869 1.0 - cosine_sim
1870 }
1871}
1872
1873#[cfg(test)]
1874#[allow(clippy::float_cmp)]
1875mod tests {
1876 use super::*;
1877
1878 #[test]
1879 fn test_quantization_bits_conversion() {
1880 assert_eq!(QuantizationBits::Bits2.to_u8(), 2);
1881 assert_eq!(QuantizationBits::Bits4.to_u8(), 4);
1882 assert_eq!(QuantizationBits::Bits8.to_u8(), 8);
1883 }
1884
1885 #[test]
1886 fn test_quantization_bits_levels() {
1887 assert_eq!(QuantizationBits::Bits2.levels(), 4); assert_eq!(QuantizationBits::Bits4.levels(), 16); assert_eq!(QuantizationBits::Bits8.levels(), 256); }
1891
1892 #[test]
1893 fn test_quantization_bits_compression() {
1894 assert_eq!(QuantizationBits::Bits2.compression_ratio(), 16.0); assert_eq!(QuantizationBits::Bits4.compression_ratio(), 8.0); assert_eq!(QuantizationBits::Bits8.compression_ratio(), 4.0); }
1898
1899 #[test]
1900 fn test_quantization_bits_values_per_byte() {
1901 assert_eq!(QuantizationBits::Bits2.values_per_byte(), 4); assert_eq!(QuantizationBits::Bits4.values_per_byte(), 2); assert_eq!(QuantizationBits::Bits8.values_per_byte(), 1); }
1905
1906 #[test]
1907 fn test_default_params() {
1908 let params = RaBitQParams::default();
1909 assert_eq!(params.bits_per_dim, QuantizationBits::Bits4);
1910 assert_eq!(params.num_rescale_factors, 12);
1911 assert_eq!(params.rescale_range, (0.5, 2.0));
1912 }
1913
1914 #[test]
1915 fn test_preset_params() {
1916 let params2 = RaBitQParams::bits2();
1917 assert_eq!(params2.bits_per_dim, QuantizationBits::Bits2);
1918
1919 let params4 = RaBitQParams::bits4();
1920 assert_eq!(params4.bits_per_dim, QuantizationBits::Bits4);
1921
1922 let params8 = RaBitQParams::bits8();
1923 assert_eq!(params8.bits_per_dim, QuantizationBits::Bits8);
1924 assert_eq!(params8.num_rescale_factors, 16);
1925 }
1926
1927 #[test]
1928 fn test_quantized_vector_creation() {
1929 let data = vec![0u8, 128, 255];
1930 let qv = QuantizedVector::new(data.clone(), 1.5, 8, 3);
1931
1932 assert_eq!(qv.data, data);
1933 assert_eq!(qv.scale, 1.5);
1934 assert_eq!(qv.bits, 8);
1935 assert_eq!(qv.dimensions, 3);
1936 }
1937
1938 #[test]
1939 fn test_quantized_vector_memory() {
1940 let data = vec![0u8; 16]; let qv = QuantizedVector::new(data, 1.0, 4, 32);
1942
1943 let expected_min = 16; assert!(qv.memory_bytes() >= expected_min);
1946 }
1947
1948 #[test]
1949 fn test_quantized_vector_compression_ratio() {
1950 let data = vec![0u8; 64];
1952 let qv = QuantizedVector::new(data, 1.0, 4, 128);
1953
1954 let ratio = qv.compression_ratio();
1958 assert!(ratio > 7.0 && ratio < 8.0);
1959 }
1960
1961 #[test]
1962 fn test_create_quantizer() {
1963 let quantizer = RaBitQ::default_4bit();
1964 assert_eq!(quantizer.params().bits_per_dim, QuantizationBits::Bits4);
1965 }
1966
1967 #[test]
1970 fn test_generate_scales() {
1971 let quantizer = RaBitQ::new(RaBitQParams {
1972 bits_per_dim: QuantizationBits::Bits4,
1973 num_rescale_factors: 5,
1974 rescale_range: (0.5, 1.5),
1975 });
1976
1977 let scales = quantizer.generate_scales();
1978 assert_eq!(scales.len(), 5);
1979 assert_eq!(scales[0], 0.5);
1980 assert_eq!(scales[4], 1.5);
1981 assert!((scales[2] - 1.0).abs() < 0.01); }
1983
1984 #[test]
1985 fn test_generate_scales_single() {
1986 let quantizer = RaBitQ::new(RaBitQParams {
1987 bits_per_dim: QuantizationBits::Bits4,
1988 num_rescale_factors: 1,
1989 rescale_range: (0.5, 1.5),
1990 });
1991
1992 let scales = quantizer.generate_scales();
1993 assert_eq!(scales.len(), 1);
1994 assert_eq!(scales[0], 1.0); }
1996
1997 #[test]
1998 fn test_pack_unpack_2bit() {
1999 let quantizer = RaBitQ::new(RaBitQParams {
2000 bits_per_dim: QuantizationBits::Bits2,
2001 ..Default::default()
2002 });
2003
2004 let values = vec![0u8, 1, 2, 3, 0, 1, 2, 3];
2006 let packed = quantizer.pack_quantized(&values, 2);
2007 assert_eq!(packed.len(), 2); let unpacked = quantizer.unpack_quantized(&packed, 2, 8);
2010 assert_eq!(unpacked, values);
2011 }
2012
2013 #[test]
2014 fn test_pack_unpack_4bit() {
2015 let quantizer = RaBitQ::new(RaBitQParams {
2016 bits_per_dim: QuantizationBits::Bits4,
2017 ..Default::default()
2018 });
2019
2020 let values = vec![0u8, 1, 2, 3, 4, 5, 6, 7];
2022 let packed = quantizer.pack_quantized(&values, 4);
2023 assert_eq!(packed.len(), 4); let unpacked = quantizer.unpack_quantized(&packed, 4, 8);
2026 assert_eq!(unpacked, values);
2027 }
2028
2029 #[test]
2030 fn test_pack_unpack_8bit() {
2031 let quantizer = RaBitQ::new(RaBitQParams {
2032 bits_per_dim: QuantizationBits::Bits8,
2033 ..Default::default()
2034 });
2035
2036 let values = vec![0u8, 10, 20, 30, 40, 50, 60, 70];
2038 let packed = quantizer.pack_quantized(&values, 8);
2039 assert_eq!(packed.len(), 8); let unpacked = quantizer.unpack_quantized(&packed, 8, 8);
2042 assert_eq!(unpacked, values);
2043 }
2044
2045 #[test]
2046 fn test_quantize_simple_vector() {
2047 let quantizer = RaBitQ::new(RaBitQParams {
2048 bits_per_dim: QuantizationBits::Bits4,
2049 num_rescale_factors: 4,
2050 rescale_range: (0.5, 1.5),
2051 });
2052
2053 let vector = vec![0.0, 0.25, 0.5, 0.75, 1.0];
2055 let quantized = quantizer.quantize(&vector);
2056
2057 assert_eq!(quantized.dimensions, 5);
2059 assert_eq!(quantized.bits, 4);
2060 assert!(quantized.scale > 0.0);
2061
2062 assert!(quantized.data.len() <= 4);
2065 }
2066
2067 #[test]
2068 fn test_quantize_reconstruct_accuracy() {
2069 let quantizer = RaBitQ::new(RaBitQParams {
2070 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 8,
2072 rescale_range: (0.8, 1.2),
2073 });
2074
2075 let vector = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2077 let quantized = quantizer.quantize(&vector);
2078
2079 let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, vector.len());
2081
2082 for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
2084 let error = (orig - recon).abs();
2085 assert!(error < 0.1, "Error too large: {orig} vs {recon}");
2086 }
2087 }
2088
2089 #[test]
2090 fn test_quantize_uniform_vector() {
2091 let quantizer = RaBitQ::default_4bit();
2092
2093 let vector = vec![0.5; 10];
2095 let quantized = quantizer.quantize(&vector);
2096
2097 let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, vector.len());
2099
2100 let avg = reconstructed.iter().sum::<f32>() / reconstructed.len() as f32;
2102 for &val in &reconstructed {
2103 assert!((val - avg).abs() < 0.2);
2104 }
2105 }
2106
2107 #[test]
2108 fn test_compute_error() {
2109 let quantizer = RaBitQ::default_4bit();
2110
2111 let original = vec![0.1, 0.2, 0.3, 0.4];
2112 let quantized_vec = quantizer.quantize(&original);
2113
2114 let error = quantizer.compute_error(&original, &quantized_vec.data, quantized_vec.scale);
2116
2117 assert!(error >= 0.0);
2119 assert!(error.is_finite());
2120 }
2121
2122 #[test]
2123 fn test_quantize_different_bit_widths() {
2124 let test_vector = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2125
2126 let q2 = RaBitQ::new(RaBitQParams::bits2());
2128 let qv2 = q2.quantize(&test_vector);
2129 assert_eq!(qv2.bits, 2);
2130
2131 let q4 = RaBitQ::default_4bit();
2133 let qv4 = q4.quantize(&test_vector);
2134 assert_eq!(qv4.bits, 4);
2135
2136 let q8 = RaBitQ::new(RaBitQParams::bits8());
2138 let qv8 = q8.quantize(&test_vector);
2139 assert_eq!(qv8.bits, 8);
2140
2141 assert!(qv2.data.len() <= qv4.data.len());
2143 assert!(qv4.data.len() <= qv8.data.len());
2144 }
2145
2146 #[test]
2147 fn test_quantize_high_dimensional() {
2148 let quantizer = RaBitQ::default_4bit();
2149
2150 let vector: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2152 let quantized = quantizer.quantize(&vector);
2153
2154 assert_eq!(quantized.dimensions, 128);
2155 assert_eq!(quantized.bits, 4);
2156
2157 assert_eq!(quantized.data.len(), 64);
2159
2160 let reconstructed = quantizer.reconstruct(&quantized.data, quantized.scale, 128);
2162 assert_eq!(reconstructed.len(), 128);
2163 }
2164
2165 #[test]
2168 fn test_distance_l2() {
2169 let quantizer = RaBitQ::new(RaBitQParams {
2170 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 8,
2172 rescale_range: (0.8, 1.2),
2173 });
2174
2175 let v1 = vec![0.0, 0.0, 0.0];
2176 let v2 = vec![1.0, 0.0, 0.0];
2177
2178 let qv1 = quantizer.quantize(&v1);
2179 let qv2 = quantizer.quantize(&v2);
2180
2181 let dist = quantizer.distance_l2(&qv1, &qv2);
2182
2183 assert!((dist - 1.0).abs() < 0.2, "Distance: {dist}");
2185 }
2186
2187 #[test]
2188 fn test_distance_l2_identical() {
2189 let quantizer = RaBitQ::default_4bit();
2190
2191 let v = vec![0.5, 0.3, 0.8, 0.2];
2192 let qv1 = quantizer.quantize(&v);
2193 let qv2 = quantizer.quantize(&v);
2194
2195 let dist = quantizer.distance_l2(&qv1, &qv2);
2196
2197 assert!(dist < 0.3, "Distance should be near zero, got: {dist}");
2199 }
2200
2201 #[test]
2202 fn test_distance_cosine() {
2203 let quantizer = RaBitQ::new(RaBitQParams {
2204 bits_per_dim: QuantizationBits::Bits8,
2205 num_rescale_factors: 8,
2206 rescale_range: (0.8, 1.2),
2207 });
2208
2209 let v1 = vec![1.0, 0.0, 0.0];
2211 let v2 = vec![0.0, 1.0, 0.0];
2212
2213 let qv1 = quantizer.quantize(&v1);
2214 let qv2 = quantizer.quantize(&v2);
2215
2216 let dist = quantizer.distance_cosine(&qv1, &qv2);
2217
2218 assert!((dist - 1.0).abs() < 0.3, "Distance: {dist}");
2220 }
2221
2222 #[test]
2223 fn test_distance_cosine_identical() {
2224 let quantizer = RaBitQ::default_4bit();
2225
2226 let v = vec![0.5, 0.3, 0.8];
2227 let qv1 = quantizer.quantize(&v);
2228 let qv2 = quantizer.quantize(&v);
2229
2230 let dist = quantizer.distance_cosine(&qv1, &qv2);
2231
2232 assert!(dist < 0.2, "Distance should be near zero, got: {dist}");
2234 }
2235
2236 #[test]
2237 fn test_distance_dot() {
2238 let quantizer = RaBitQ::new(RaBitQParams {
2239 bits_per_dim: QuantizationBits::Bits8,
2240 num_rescale_factors: 8,
2241 rescale_range: (0.8, 1.2),
2242 });
2243
2244 let v1 = vec![1.0, 0.0, 0.0];
2245 let v2 = vec![1.0, 0.0, 0.0];
2246
2247 let qv1 = quantizer.quantize(&v1);
2248 let qv2 = quantizer.quantize(&v2);
2249
2250 let dist = quantizer.distance_dot(&qv1, &qv2);
2251
2252 assert!((dist + 1.0).abs() < 0.3, "Distance: {dist}");
2254 }
2255
2256 #[test]
2257 fn test_distance_approximate() {
2258 let quantizer = RaBitQ::default_4bit();
2259
2260 let v1 = vec![0.0, 0.0, 0.0];
2261 let v2 = vec![0.5, 0.5, 0.5];
2262
2263 let qv1 = quantizer.quantize(&v1);
2264 let qv2 = quantizer.quantize(&v2);
2265
2266 let dist_approx = quantizer.distance_approximate(&qv1, &qv2);
2267 let dist_exact = quantizer.distance_l2(&qv1, &qv2);
2268
2269 assert!(dist_approx >= 0.0);
2271 assert!(dist_approx.is_finite());
2272
2273 let v3 = vec![1.0, 1.0, 1.0];
2276 let qv3 = quantizer.quantize(&v3);
2277
2278 let dist_approx2 = quantizer.distance_approximate(&qv1, &qv3);
2279 let dist_exact2 = quantizer.distance_l2(&qv1, &qv3);
2280
2281 if dist_exact2 > dist_exact {
2283 assert!(dist_approx2 > dist_approx * 0.5); }
2285 }
2286
2287 #[test]
2288 fn test_distance_correlation() {
2289 let quantizer = RaBitQ::new(RaBitQParams {
2290 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 12,
2292 rescale_range: (0.8, 1.2),
2293 });
2294
2295 let vectors = [
2297 vec![0.1, 0.2, 0.3],
2298 vec![0.4, 0.5, 0.6],
2299 vec![0.7, 0.8, 0.9],
2300 ];
2301
2302 let quantized: Vec<QuantizedVector> =
2304 vectors.iter().map(|v| quantizer.quantize(v)).collect();
2305
2306 let ground_truth_01 = vectors[0]
2308 .iter()
2309 .zip(vectors[1].iter())
2310 .map(|(a, b)| (a - b).powi(2))
2311 .sum::<f32>()
2312 .sqrt();
2313
2314 let ground_truth_02 = vectors[0]
2315 .iter()
2316 .zip(vectors[2].iter())
2317 .map(|(a, b)| (a - b).powi(2))
2318 .sum::<f32>()
2319 .sqrt();
2320
2321 let quantized_01 = quantizer.distance_l2(&quantized[0], &quantized[1]);
2323 let quantized_02 = quantizer.distance_l2(&quantized[0], &quantized[2]);
2324
2325 if ground_truth_02 > ground_truth_01 {
2327 assert!(
2328 quantized_02 > quantized_01 * 0.8,
2329 "Order not preserved: {quantized_01} vs {quantized_02}"
2330 );
2331 }
2332 }
2333
2334 #[test]
2335 fn test_distance_zero_vectors() {
2336 let quantizer = RaBitQ::default_4bit();
2337
2338 let v_zero = vec![0.0, 0.0, 0.0];
2339 let qv_zero = quantizer.quantize(&v_zero);
2340
2341 let dist = quantizer.distance_l2(&qv_zero, &qv_zero);
2343 assert!(dist < 0.1);
2344
2345 let dist_cosine = quantizer.distance_cosine(&qv_zero, &qv_zero);
2347 assert!(dist_cosine.is_finite());
2348 }
2349
2350 #[test]
2351 fn test_distance_high_dimensional() {
2352 let quantizer = RaBitQ::default_4bit();
2353
2354 let v1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2356 let v2: Vec<f32> = (0..128).map(|i| ((i + 10) as f32) / 128.0).collect();
2357
2358 let qv1 = quantizer.quantize(&v1);
2359 let qv2 = quantizer.quantize(&v2);
2360
2361 let dist_l2 = quantizer.distance_l2(&qv1, &qv2);
2363 let dist_cosine = quantizer.distance_cosine(&qv1, &qv2);
2364 let dist_dot = quantizer.distance_dot(&qv1, &qv2);
2365 let dist_approx = quantizer.distance_approximate(&qv1, &qv2);
2366
2367 assert!(dist_l2 > 0.0 && dist_l2.is_finite());
2368 assert!(dist_cosine >= 0.0 && dist_cosine.is_finite());
2369 assert!(dist_dot.is_finite());
2370 assert!(dist_approx > 0.0 && dist_approx.is_finite());
2371 }
2372
2373 #[test]
2374 fn test_distance_asymmetric_l2() {
2375 let quantizer = RaBitQ::default_4bit();
2376
2377 let query = vec![0.1, 0.2, 0.3, 0.4];
2378 let vector = vec![0.12, 0.22, 0.32, 0.42];
2380
2381 let quantized = quantizer.quantize(&vector);
2382
2383 let dist_sym = quantizer.distance_l2_simd(&quantized, &quantizer.quantize(&query));
2385
2386 let dist_asym = quantizer.distance_asymmetric_l2(&query, &quantized);
2388
2389 assert!(dist_asym >= 0.0);
2392 assert!((dist_asym - dist_sym).abs() < 0.2);
2393 }
2394
2395 #[test]
2398 fn test_simd_l2_matches_scalar() {
2399 let quantizer = RaBitQ::new(RaBitQParams {
2400 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 8,
2402 rescale_range: (0.8, 1.2),
2403 });
2404
2405 let v1 = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2406 let v2 = vec![0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
2407
2408 let qv1 = quantizer.quantize(&v1);
2409 let qv2 = quantizer.quantize(&v2);
2410
2411 let dist_scalar = quantizer.distance_l2(&qv1, &qv2);
2412 let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2413
2414 let diff = (dist_scalar - dist_simd).abs();
2416 assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2417 }
2418
2419 #[test]
2420 fn test_simd_cosine_matches_scalar() {
2421 let quantizer = RaBitQ::new(RaBitQParams {
2422 bits_per_dim: QuantizationBits::Bits8,
2423 num_rescale_factors: 8,
2424 rescale_range: (0.8, 1.2),
2425 });
2426
2427 let v1 = vec![1.0, 0.0, 0.0];
2428 let v2 = vec![0.0, 1.0, 0.0];
2429
2430 let qv1 = quantizer.quantize(&v1);
2431 let qv2 = quantizer.quantize(&v2);
2432
2433 let dist_scalar = quantizer.distance_cosine(&qv1, &qv2);
2434 let dist_simd = quantizer.distance_cosine_simd(&qv1, &qv2);
2435
2436 let diff = (dist_scalar - dist_simd).abs();
2438 assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2439 }
2440
2441 #[test]
2442 fn test_simd_high_dimensional() {
2443 let quantizer = RaBitQ::default_4bit();
2444
2445 let v1: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2447 let v2: Vec<f32> = (0..128).map(|i| ((i + 1) as f32) / 128.0).collect();
2448
2449 let qv1 = quantizer.quantize(&v1);
2450 let qv2 = quantizer.quantize(&v2);
2451
2452 let dist_scalar = quantizer.distance_l2(&qv1, &qv2);
2453 let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2454
2455 let diff = (dist_scalar - dist_simd).abs();
2457 assert!(
2458 diff < 0.1,
2459 "High-D SIMD vs scalar: {dist_simd} vs {dist_scalar}"
2460 );
2461 }
2462
2463 #[test]
2464 fn test_simd_scalar_fallback() {
2465 let quantizer = RaBitQ::default_4bit();
2466
2467 let v1 = vec![0.1, 0.2, 0.3];
2469 let v2 = vec![0.4, 0.5, 0.6];
2470
2471 let qv1 = quantizer.quantize(&v1);
2472 let qv2 = quantizer.quantize(&v2);
2473
2474 let dist_l2 = quantizer.distance_l2_simd(&qv1, &qv2);
2476 let dist_cosine = quantizer.distance_cosine_simd(&qv1, &qv2);
2477
2478 assert!(dist_l2.is_finite());
2479 assert!(dist_cosine.is_finite());
2480 }
2481
2482 #[test]
2483 fn test_simd_performance_improvement() {
2484 let quantizer = RaBitQ::default_4bit();
2485
2486 let v1: Vec<f32> = (0..1536).map(|i| (i as f32) / 1536.0).collect();
2488 let v2: Vec<f32> = (0..1536).map(|i| ((i + 10) as f32) / 1536.0).collect();
2489
2490 let qv1 = quantizer.quantize(&v1);
2491 let qv2 = quantizer.quantize(&v2);
2492
2493 let dist_simd = quantizer.distance_l2_simd(&qv1, &qv2);
2495 assert!(dist_simd > 0.0 && dist_simd.is_finite());
2496
2497 }
2499
2500 #[test]
2501 fn test_scalar_distance_functions() {
2502 let v1 = vec![0.0, 0.0, 0.0];
2504 let v2 = vec![1.0, 0.0, 0.0];
2505
2506 let dist = l2_distance_scalar(&v1, &v2);
2507 assert!((dist - 1.0).abs() < 0.001);
2508
2509 let v1 = vec![1.0, 0.0, 0.0];
2510 let v2 = vec![0.0, 1.0, 0.0];
2511
2512 let dist = cosine_distance_scalar(&v1, &v2);
2513 assert!((dist - 1.0).abs() < 0.001);
2514 }
2515
2516 #[test]
2519 fn test_adc_table_creation() {
2520 let quantizer = RaBitQ::default_4bit();
2521 let query = vec![0.1, 0.2, 0.3, 0.4];
2522 let scale = 1.0;
2523
2524 let adc = quantizer.build_adc_table_with_scale(&query, scale);
2525
2526 assert_eq!(adc.dimensions, 4);
2528 assert_eq!(adc.bits, 4);
2529 assert_eq!(adc.table.len(), 4);
2530
2531 for dim_table in &adc.table {
2533 assert_eq!(dim_table.len(), 16);
2534 }
2535 }
2536
2537 #[test]
2538 fn test_adc_table_2bit() {
2539 let quantizer = RaBitQ::new(RaBitQParams::bits2());
2540 let query = vec![0.1, 0.2, 0.3, 0.4];
2541 let scale = 1.0;
2542
2543 let adc = quantizer.build_adc_table_with_scale(&query, scale);
2544
2545 for dim_table in &adc.table {
2547 assert_eq!(dim_table.len(), 4);
2548 }
2549 }
2550
2551 #[test]
2552 fn test_adc_distance_matches_asymmetric() {
2553 let quantizer = RaBitQ::default_4bit();
2554
2555 let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2557 let vector = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
2558
2559 let quantized = quantizer.quantize(&vector);
2561
2562 let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2564
2565 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2567 let dist_adc = adc.distance(&quantized.data);
2568
2569 let diff = (dist_asymmetric - dist_adc).abs();
2572 assert!(
2573 diff < 0.1,
2574 "ADC vs asymmetric: {dist_adc} vs {dist_asymmetric}, diff: {diff}"
2575 );
2576 }
2577
2578 #[test]
2579 fn test_adc_distance_accuracy() {
2580 let quantizer = RaBitQ::new(RaBitQParams {
2581 bits_per_dim: QuantizationBits::Bits8, num_rescale_factors: 16,
2583 rescale_range: (0.8, 1.2),
2584 });
2585
2586 let query = vec![0.1, 0.2, 0.3, 0.4];
2587 let vector = vec![0.1, 0.2, 0.3, 0.4]; let quantized = quantizer.quantize(&vector);
2590
2591 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2593
2594 let dist = adc.distance(&quantized.data);
2596 assert!(dist < 0.2, "Distance should be near zero, got: {dist}");
2597 }
2598
2599 #[test]
2600 fn test_adc_distance_ordering() {
2601 let quantizer = RaBitQ::default_4bit();
2602
2603 let query = vec![0.5, 0.5, 0.5, 0.5];
2604 let v1 = vec![0.5, 0.5, 0.5, 0.5]; let v2 = vec![0.6, 0.6, 0.6, 0.6]; let v3 = vec![0.9, 0.9, 0.9, 0.9]; let qv1 = quantizer.quantize(&v1);
2609 let qv2 = quantizer.quantize(&v2);
2610 let qv3 = quantizer.quantize(&v3);
2611
2612 let adc1 = quantizer.build_adc_table_with_scale(&query, qv1.scale);
2614 let adc2 = quantizer.build_adc_table_with_scale(&query, qv2.scale);
2615 let adc3 = quantizer.build_adc_table_with_scale(&query, qv3.scale);
2616
2617 let dist1 = adc1.distance(&qv1.data);
2618 let dist2 = adc2.distance(&qv2.data);
2619 let dist3 = adc3.distance(&qv3.data);
2620
2621 assert!(
2623 dist1 < dist2,
2624 "v1 should be closer than v2: {dist1} vs {dist2}"
2625 );
2626 assert!(
2627 dist2 < dist3,
2628 "v2 should be closer than v3: {dist2} vs {dist3}"
2629 );
2630 }
2631
2632 #[test]
2633 fn test_adc_high_dimensional() {
2634 let quantizer = RaBitQ::default_4bit();
2635
2636 let query: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2638 let vector: Vec<f32> = (0..128).map(|i| ((i + 5) as f32) / 128.0).collect();
2639
2640 let quantized = quantizer.quantize(&vector);
2641
2642 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2644
2645 let dist = adc.distance(&quantized.data);
2647 assert!(dist > 0.0 && dist.is_finite());
2648 }
2649
2650 #[test]
2651 fn test_adc_batch_search() {
2652 let quantizer = RaBitQ::default_4bit();
2653
2654 let query = vec![0.5, 0.5, 0.5, 0.5];
2655 let candidates = [
2656 vec![0.5, 0.5, 0.5, 0.5],
2657 vec![0.6, 0.6, 0.6, 0.6],
2658 vec![0.4, 0.4, 0.4, 0.4],
2659 vec![0.7, 0.7, 0.7, 0.7],
2660 ];
2661
2662 let quantized: Vec<QuantizedVector> =
2664 candidates.iter().map(|v| quantizer.quantize(v)).collect();
2665
2666 let mut results: Vec<(usize, f32)> = quantized
2668 .iter()
2669 .enumerate()
2670 .map(|(i, qv)| {
2671 let adc = quantizer.build_adc_table_with_scale(&query, qv.scale);
2672 (i, adc.distance(&qv.data))
2673 })
2674 .collect();
2675
2676 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
2678
2679 assert_eq!(results[0].0, 0, "Results: {results:?}");
2681 }
2682
2683 #[test]
2684 fn test_adc_distance_squared() {
2685 let quantizer = RaBitQ::default_4bit();
2686
2687 let query = vec![0.0, 0.0, 0.0];
2688 let vector = vec![1.0, 0.0, 0.0];
2689
2690 let quantized = quantizer.quantize(&vector);
2691 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2692
2693 let dist_squared = adc.distance_squared(&quantized.data);
2694 let dist = adc.distance(&quantized.data);
2695
2696 let diff = (dist_squared - dist * dist).abs();
2698 assert!(
2699 diff < 0.01,
2700 "distance_squared != dist^2: {} vs {}",
2701 dist_squared,
2702 dist * dist
2703 );
2704 }
2705
2706 #[test]
2707 fn test_adc_simd_matches_scalar() {
2708 let quantizer = RaBitQ::default_4bit();
2709
2710 let query = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2711 let vector = vec![0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85];
2712
2713 let quantized = quantizer.quantize(&vector);
2714 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2715
2716 let dist_scalar = adc.distance_squared(&quantized.data);
2717 let dist_simd = adc.distance_squared_simd(&quantized.data);
2718
2719 let diff = (dist_scalar - dist_simd).abs();
2721 assert!(diff < 0.01, "SIMD vs scalar: {dist_simd} vs {dist_scalar}");
2722 }
2723
2724 #[test]
2725 fn test_adc_simd_high_dimensional() {
2726 let quantizer = RaBitQ::default_4bit();
2727
2728 let query: Vec<f32> = (0..1536).map(|i| (i as f32) / 1536.0).collect();
2730 let vector: Vec<f32> = (0..1536).map(|i| ((i + 10) as f32) / 1536.0).collect();
2731
2732 let quantized = quantizer.quantize(&vector);
2733 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2734
2735 let dist_simd = adc.distance_squared_simd(&quantized.data);
2737 assert!(dist_simd > 0.0 && dist_simd.is_finite());
2738 }
2739
2740 #[test]
2741 fn test_adc_memory_usage() {
2742 let quantizer = RaBitQ::default_4bit();
2743
2744 let query: Vec<f32> = (0..128).map(|i| (i as f32) / 128.0).collect();
2745 let adc = quantizer.build_adc_table_with_scale(&query, 1.0);
2746
2747 let memory = adc.memory_bytes();
2748
2749 let expected_min = 128 * 16 * 4;
2751 assert!(
2752 memory >= expected_min,
2753 "Memory {memory} should be at least {expected_min}"
2754 );
2755 }
2756
2757 #[test]
2758 fn test_adc_different_scales() {
2759 let quantizer = RaBitQ::default_4bit();
2760
2761 let query = vec![0.5, 0.5, 0.5, 0.5];
2762 let vector = vec![0.6, 0.6, 0.6, 0.6];
2763
2764 let quantized = quantizer.quantize(&vector);
2765
2766 let adc1 = quantizer.build_adc_table_with_scale(&query, 0.5);
2768 let adc2 = quantizer.build_adc_table_with_scale(&query, 1.0);
2769 let adc3 = quantizer.build_adc_table_with_scale(&query, 2.0);
2770
2771 let dist1 = adc1.distance(&quantized.data);
2773 let dist2 = adc2.distance(&quantized.data);
2774 let dist3 = adc3.distance(&quantized.data);
2775
2776 assert!(dist1.is_finite());
2778 assert!(dist2.is_finite());
2779 assert!(dist3.is_finite());
2780 }
2781
2782 #[test]
2783 fn test_adc_edge_cases() {
2784 let quantizer = RaBitQ::default_4bit();
2785
2786 let query = vec![0.5];
2788 let vector = vec![0.6];
2789 let quantized = quantizer.quantize(&vector);
2790 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2791 let dist = adc.distance(&quantized.data);
2792 assert!(dist.is_finite());
2793
2794 let query = vec![0.0, 0.0, 0.0, 0.0];
2796 let vector = vec![0.0, 0.0, 0.0, 0.0];
2797 let quantized = quantizer.quantize(&vector);
2798 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2799 let dist = adc.distance(&quantized.data);
2800 assert!(dist.is_finite());
2801 }
2802
2803 #[test]
2804 fn test_adc_2bit_accuracy() {
2805 let quantizer = RaBitQ::new(RaBitQParams::bits2());
2806
2807 let query = vec![0.1, 0.2, 0.3, 0.4];
2808 let vector = vec![0.12, 0.22, 0.32, 0.42];
2809
2810 let quantized = quantizer.quantize(&vector);
2811
2812 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2814 let dist_adc = adc.distance(&quantized.data);
2815 let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2816
2817 let diff = (dist_adc - dist_asymmetric).abs();
2819 assert!(diff < 0.2, "2-bit ADC diff too large: {diff}");
2820 }
2821
2822 #[test]
2823 fn test_adc_8bit_accuracy() {
2824 let quantizer = RaBitQ::new(RaBitQParams::bits8());
2825
2826 let query = vec![0.1, 0.2, 0.3, 0.4];
2827 let vector = vec![0.12, 0.22, 0.32, 0.42];
2828
2829 let quantized = quantizer.quantize(&vector);
2830
2831 let adc = quantizer.build_adc_table_with_scale(&query, quantized.scale);
2833 let dist_adc = adc.distance(&quantized.data);
2834 let dist_asymmetric = quantizer.distance_asymmetric_l2(&query, &quantized);
2835
2836 let diff = (dist_adc - dist_asymmetric).abs();
2838 assert!(
2839 diff < 0.05,
2840 "8-bit ADC should be highly accurate, diff: {diff}"
2841 );
2842 }
2843}