1use arc_swap::ArcSwap;
11use ordered_float::OrderedFloat;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15
16use crate::compression::binary::hamming_distance;
17use crate::compression::rabitq::QuantizedVector;
18use crate::compression::{ADCTable, RaBitQ, RaBitQParams, ScalarParams};
19use crate::distance::dot_product;
20
21static EMPTY_NEIGHBORS: &[u32] = &[];
23
24#[derive(Debug)]
36pub struct NeighborLists {
37 neighbors: Vec<Vec<ArcSwap<Box<[u32]>>>>,
43
44 write_locks: Vec<Vec<Mutex<()>>>,
47
48 max_levels: usize,
50
51 m_max: usize,
54}
55
56impl NeighborLists {
57 #[must_use]
59 pub fn new(max_levels: usize) -> Self {
60 Self {
61 neighbors: Vec::new(),
62 write_locks: Vec::new(),
63 max_levels,
64 m_max: 32, }
66 }
67
68 #[must_use]
70 pub fn with_capacity(num_nodes: usize, max_levels: usize, m: usize) -> Self {
71 Self {
72 neighbors: Vec::with_capacity(num_nodes),
73 write_locks: Vec::with_capacity(num_nodes),
74 max_levels,
75 m_max: m * 2,
76 }
77 }
78
79 #[must_use]
81 pub fn m_max(&self) -> usize {
82 self.m_max
83 }
84
85 #[must_use]
87 pub fn len(&self) -> usize {
88 self.neighbors.len()
89 }
90
91 #[must_use]
93 pub fn is_empty(&self) -> bool {
94 self.neighbors.is_empty()
95 }
96
97 #[must_use]
101 pub fn get_neighbors(&self, node_id: u32, level: u8) -> Vec<u32> {
102 let node_idx = node_id as usize;
103 let level_idx = level as usize;
104
105 if node_idx >= self.neighbors.len() {
106 return Vec::new();
107 }
108
109 if level_idx >= self.neighbors[node_idx].len() {
110 return Vec::new();
111 }
112
113 self.neighbors[node_idx][level_idx].load().to_vec()
115 }
116
117 #[inline]
122 pub fn with_neighbors<F, R>(&self, node_id: u32, level: u8, f: F) -> R
123 where
124 F: FnOnce(&[u32]) -> R,
125 {
126 let node_idx = node_id as usize;
127 let level_idx = level as usize;
128
129 if node_idx >= self.neighbors.len() {
130 return f(EMPTY_NEIGHBORS);
131 }
132
133 if level_idx >= self.neighbors[node_idx].len() {
134 return f(EMPTY_NEIGHBORS);
135 }
136
137 let guard = self.neighbors[node_idx][level_idx].load();
140 f(&guard)
141 }
142
143 #[inline]
149 pub fn prefetch(&self, node_id: u32, level: u8) {
150 use super::prefetch::PrefetchConfig;
151 if !PrefetchConfig::enabled() {
152 return;
153 }
154
155 let node_idx = node_id as usize;
156 let level_idx = level as usize;
157
158 if node_idx >= self.neighbors.len() {
159 return;
160 }
161 if level_idx >= self.neighbors[node_idx].len() {
162 return;
163 }
164
165 let ptr = &self.neighbors[node_idx][level_idx] as *const _ as *const u8;
168 #[cfg(target_arch = "x86_64")]
169 unsafe {
170 use std::arch::x86_64::_mm_prefetch;
171 use std::arch::x86_64::_MM_HINT_T0;
172 _mm_prefetch(ptr.cast(), _MM_HINT_T0);
173 }
174 #[cfg(target_arch = "aarch64")]
175 unsafe {
176 std::arch::asm!(
177 "prfm pldl1keep, [{ptr}]",
178 ptr = in(reg) ptr,
179 options(nostack, preserves_flags)
180 );
181 }
182 }
183
184 fn ensure_node_exists(&mut self, node_idx: usize) {
186 while self.neighbors.len() <= node_idx {
187 let mut levels = Vec::with_capacity(self.max_levels);
188 let mut locks = Vec::with_capacity(self.max_levels);
189 for _ in 0..self.max_levels {
190 levels.push(ArcSwap::from_pointee(Vec::new().into_boxed_slice()));
192 locks.push(Mutex::new(()));
193 }
194 self.neighbors.push(levels);
195 self.write_locks.push(locks);
196 }
197 }
198
199 pub fn set_neighbors(&mut self, node_id: u32, level: u8, neighbors_list: Vec<u32>) {
201 let node_idx = node_id as usize;
202 let level_idx = level as usize;
203
204 self.ensure_node_exists(node_idx);
205
206 self.neighbors[node_idx][level_idx].store(Arc::new(neighbors_list.into_boxed_slice()));
208 }
209
210 pub fn add_bidirectional_link(&mut self, node_a: u32, node_b: u32, level: u8) {
215 let node_a_idx = node_a as usize;
216 let node_b_idx = node_b as usize;
217 let level_idx = level as usize;
218
219 if node_a_idx == node_b_idx {
220 return; }
222
223 let max_idx = node_a_idx.max(node_b_idx);
225 self.ensure_node_exists(max_idx);
226
227 {
229 let current = self.neighbors[node_a_idx][level_idx].load();
230 if !current.contains(&node_b) {
231 let mut new_list = current.to_vec();
232 new_list.push(node_b);
233 self.neighbors[node_a_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
234 }
235 }
236
237 {
239 let current = self.neighbors[node_b_idx][level_idx].load();
240 if !current.contains(&node_a) {
241 let mut new_list = current.to_vec();
242 new_list.push(node_a);
243 self.neighbors[node_b_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
244 }
245 }
246 }
247
248 pub fn add_bidirectional_link_parallel(&self, node_a: u32, node_b: u32, level: u8) {
253 let node_a_idx = node_a as usize;
254 let node_b_idx = node_b as usize;
255 let level_idx = level as usize;
256
257 if node_a_idx == node_b_idx {
258 return; }
260
261 if node_a_idx >= self.neighbors.len() || node_b_idx >= self.neighbors.len() {
263 return; }
265
266 let (first_idx, second_idx, first_neighbor, second_neighbor) = if node_a_idx < node_b_idx {
268 (node_a_idx, node_b_idx, node_b, node_a)
269 } else {
270 (node_b_idx, node_a_idx, node_a, node_b)
271 };
272
273 let _lock_first = self.write_locks[first_idx][level_idx].lock();
275 let _lock_second = self.write_locks[second_idx][level_idx].lock();
276
277 {
279 let current = self.neighbors[first_idx][level_idx].load();
280 if !current.contains(&first_neighbor) {
281 let mut new_list = current.to_vec();
282 new_list.push(first_neighbor);
283 self.neighbors[first_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
284 }
285 }
286
287 {
289 let current = self.neighbors[second_idx][level_idx].load();
290 if !current.contains(&second_neighbor) {
291 let mut new_list = current.to_vec();
292 new_list.push(second_neighbor);
293 self.neighbors[second_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
294 }
295 }
296 }
297
298 pub fn remove_link_parallel(&self, node_a: u32, node_b: u32, level: u8) {
303 let node_a_idx = node_a as usize;
304 let level_idx = level as usize;
305
306 if node_a_idx >= self.neighbors.len() {
308 return; }
310
311 let _lock = self.write_locks[node_a_idx][level_idx].lock();
313 let current = self.neighbors[node_a_idx][level_idx].load();
314 let new_list: Vec<u32> = current.iter().copied().filter(|&n| n != node_b).collect();
315 self.neighbors[node_a_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
316 }
317
318 pub fn set_neighbors_parallel(&self, node_id: u32, level: u8, neighbors_list: Vec<u32>) {
322 let node_idx = node_id as usize;
323 let level_idx = level as usize;
324
325 if node_idx >= self.neighbors.len() {
327 return; }
329
330 let _lock = self.write_locks[node_idx][level_idx].lock();
332 self.neighbors[node_idx][level_idx].store(Arc::new(neighbors_list.into_boxed_slice()));
333 }
334
335 #[must_use]
337 pub fn total_neighbors(&self) -> usize {
338 self.neighbors
339 .iter()
340 .flat_map(|node| node.iter())
341 .map(|level| level.load().len())
342 .sum()
343 }
344
345 #[must_use]
347 pub fn memory_usage(&self) -> usize {
348 let mut total = 0;
349
350 total += self.neighbors.capacity() * std::mem::size_of::<Vec<ArcSwap<Box<[u32]>>>>();
352
353 for node in &self.neighbors {
355 total += node.capacity() * std::mem::size_of::<ArcSwap<Box<[u32]>>>();
356
357 for level in node {
359 let guard = level.load();
360 total += guard.len() * std::mem::size_of::<u32>();
361 }
362 }
363
364 total += self.write_locks.capacity() * std::mem::size_of::<Vec<Mutex<()>>>();
366 for node in &self.write_locks {
367 total += node.capacity() * std::mem::size_of::<Mutex<()>>();
368 }
369
370 total
371 }
372
373 pub fn reorder_bfs(&mut self, entry_point: u32, start_level: u8) -> Vec<u32> {
380 use std::collections::{HashSet, VecDeque};
381
382 let num_nodes = self.neighbors.len();
383 if num_nodes == 0 {
384 return Vec::new();
385 }
386
387 let mut visited = HashSet::new();
389 let mut queue = VecDeque::new();
390 let mut old_to_new = vec![u32::MAX; num_nodes]; let mut new_id = 0u32;
392
393 queue.push_back(entry_point);
395 visited.insert(entry_point);
396
397 while let Some(node_id) = queue.pop_front() {
398 old_to_new[node_id as usize] = new_id;
400 new_id += 1;
401
402 for level in (0..=start_level).rev() {
404 let neighbors = self.get_neighbors(node_id, level);
405 for &neighbor_id in &neighbors {
406 if visited.insert(neighbor_id) {
407 queue.push_back(neighbor_id);
408 }
409 }
410 }
411 }
412
413 for (_old_id, mapping) in old_to_new.iter_mut().enumerate().take(num_nodes) {
415 if *mapping == u32::MAX {
416 *mapping = new_id;
417 new_id += 1;
418 }
419 }
420
421 let mut new_neighbors = Vec::with_capacity(num_nodes);
423 let mut new_write_locks = Vec::with_capacity(num_nodes);
424 for _ in 0..num_nodes {
425 let mut levels = Vec::with_capacity(self.max_levels);
426 let mut locks = Vec::with_capacity(self.max_levels);
427 for _ in 0..self.max_levels {
428 levels.push(ArcSwap::from_pointee(Vec::new().into_boxed_slice()));
429 locks.push(Mutex::new(()));
430 }
431 new_neighbors.push(levels);
432 new_write_locks.push(locks);
433 }
434
435 for old_id in 0..num_nodes {
436 let new_node_id = old_to_new[old_id] as usize;
437 #[allow(clippy::needless_range_loop)]
438 for level in 0..self.max_levels {
439 let old_neighbor_list = self.neighbors[old_id][level].load();
441 let remapped: Vec<u32> = old_neighbor_list
442 .iter()
443 .map(|&old_neighbor| old_to_new[old_neighbor as usize])
444 .collect();
445 new_neighbors[new_node_id][level].store(Arc::new(remapped.into_boxed_slice()));
447 }
448 }
449
450 self.neighbors = new_neighbors;
451 self.write_locks = new_write_locks;
452
453 old_to_new
454 }
455
456 #[must_use]
458 pub fn num_nodes(&self) -> usize {
459 self.neighbors.len()
460 }
461}
462
463impl Serialize for NeighborLists {
465 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
466 where
467 S: serde::Serializer,
468 {
469 use serde::ser::SerializeStruct;
470
471 let mut state = serializer.serialize_struct("NeighborLists", 3)?;
472
473 let neighbors_data: Vec<Vec<Vec<u32>>> = self
475 .neighbors
476 .iter()
477 .map(|node| node.iter().map(|level| level.load().to_vec()).collect())
478 .collect();
479
480 state.serialize_field("neighbors", &neighbors_data)?;
481 state.serialize_field("max_levels", &self.max_levels)?;
482 state.serialize_field("m_max", &self.m_max)?;
483 state.end()
484 }
485}
486
487impl<'de> Deserialize<'de> for NeighborLists {
488 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
489 where
490 D: serde::Deserializer<'de>,
491 {
492 #[derive(Deserialize)]
493 struct NeighborListsData {
494 neighbors: Vec<Vec<Vec<u32>>>,
495 max_levels: usize,
496 m_max: usize,
497 }
498
499 let data = NeighborListsData::deserialize(deserializer)?;
500
501 let neighbors: Vec<Vec<ArcSwap<Box<[u32]>>>> = data
503 .neighbors
504 .iter()
505 .map(|node| {
506 node.iter()
507 .map(|level| ArcSwap::from_pointee(level.clone().into_boxed_slice()))
508 .collect()
509 })
510 .collect();
511
512 let write_locks: Vec<Vec<Mutex<()>>> = data
514 .neighbors
515 .iter()
516 .map(|node| node.iter().map(|_| Mutex::new(())).collect())
517 .collect();
518
519 Ok(NeighborLists {
520 neighbors,
521 write_locks,
522 max_levels: data.max_levels,
523 m_max: data.m_max,
524 })
525 }
526}
527
528#[allow(dead_code)] pub const FASTSCAN_BATCH_SIZE: usize = 32;
536
537#[allow(dead_code)] #[derive(Clone, Debug, Default, Serialize, Deserialize)]
562pub struct NeighborCodeStorage {
563 codes: Vec<u8>,
566
567 offsets: Vec<usize>,
570
571 neighbor_counts: Vec<usize>,
574
575 code_size: usize,
577
578 block_size: usize,
580}
581
582#[allow(dead_code)] impl NeighborCodeStorage {
584 #[must_use]
586 pub fn new(code_size: usize) -> Self {
587 Self {
588 codes: Vec::new(),
589 offsets: Vec::new(),
590 neighbor_counts: Vec::new(),
591 code_size,
592 block_size: code_size * FASTSCAN_BATCH_SIZE,
593 }
594 }
595
596 #[must_use]
598 pub fn is_empty(&self) -> bool {
599 self.offsets.is_empty()
600 }
601
602 #[must_use]
604 pub fn len(&self) -> usize {
605 self.offsets.len()
606 }
607
608 #[must_use]
620 #[inline]
621 pub fn get_block(&self, vertex_id: u32) -> Option<&[u8]> {
622 let idx = vertex_id as usize;
623 if idx >= self.offsets.len() {
624 return None;
625 }
626 let start = self.offsets[idx];
627 let end = start + self.block_size;
628 if end > self.codes.len() {
629 return None;
630 }
631 Some(&self.codes[start..end])
632 }
633
634 #[must_use]
636 #[inline]
637 pub fn get_neighbor_count(&self, vertex_id: u32) -> usize {
638 let idx = vertex_id as usize;
639 if idx >= self.neighbor_counts.len() {
640 return 0;
641 }
642 self.neighbor_counts[idx]
643 }
644
645 pub fn build_from_storage(
658 vectors: &VectorStorage,
659 neighbors: &NeighborLists,
660 level: u8,
661 ) -> Option<Self> {
662 let (quantized_data, code_size) = match vectors {
664 VectorStorage::RaBitQQuantized {
665 quantized_data,
666 code_size,
667 ..
668 } => (quantized_data, *code_size),
669 _ => return None,
670 };
671
672 let num_vertices = neighbors.len();
673 if num_vertices == 0 {
674 return Some(Self::new(code_size));
675 }
676
677 let block_size = code_size * FASTSCAN_BATCH_SIZE;
678 let mut codes = Vec::with_capacity(num_vertices * block_size);
679 let mut offsets = Vec::with_capacity(num_vertices);
680 let mut neighbor_counts = Vec::with_capacity(num_vertices);
681
682 for vertex_id in 0..num_vertices {
683 let offset = codes.len();
684 offsets.push(offset);
685
686 let vertex_neighbors = neighbors.get_neighbors(vertex_id as u32, level);
688 let count = vertex_neighbors.len().min(FASTSCAN_BATCH_SIZE);
689 neighbor_counts.push(count);
690
691 let block_start = codes.len();
693 codes.resize(block_start + block_size, 0);
694
695 for sq in 0..code_size {
697 for (n, &neighbor_id) in vertex_neighbors
698 .iter()
699 .take(FASTSCAN_BATCH_SIZE)
700 .enumerate()
701 {
702 let neighbor_idx = neighbor_id as usize;
703 let code_start = neighbor_idx * code_size;
705 if code_start + sq < quantized_data.len() {
706 codes[block_start + sq * FASTSCAN_BATCH_SIZE + n] =
707 quantized_data[code_start + sq];
708 }
709 }
711 }
712 }
713
714 Some(Self {
715 codes,
716 offsets,
717 neighbor_counts,
718 code_size,
719 block_size,
720 })
721 }
722
723 pub fn update_vertex(&mut self, vertex_id: u32, new_neighbors: &[u32], quantized_data: &[u8]) {
728 let idx = vertex_id as usize;
729
730 while self.offsets.len() <= idx {
732 let offset = self.codes.len();
733 self.offsets.push(offset);
734 self.neighbor_counts.push(0);
735 self.codes.resize(self.codes.len() + self.block_size, 0);
736 }
737
738 let block_start = self.offsets[idx];
739 let count = new_neighbors.len().min(FASTSCAN_BATCH_SIZE);
740 self.neighbor_counts[idx] = count;
741
742 for i in 0..self.block_size {
744 self.codes[block_start + i] = 0;
745 }
746
747 for sq in 0..self.code_size {
749 for (n, &neighbor_id) in new_neighbors.iter().take(FASTSCAN_BATCH_SIZE).enumerate() {
750 let neighbor_idx = neighbor_id as usize;
751 let code_start = neighbor_idx * self.code_size;
752 if code_start + sq < quantized_data.len() {
753 self.codes[block_start + sq * FASTSCAN_BATCH_SIZE + n] =
754 quantized_data[code_start + sq];
755 }
756 }
757 }
758 }
759
760 #[must_use]
762 pub fn memory_usage(&self) -> usize {
763 self.codes.len()
764 + self.offsets.len() * std::mem::size_of::<usize>()
765 + self.neighbor_counts.len() * std::mem::size_of::<usize>()
766 }
767}
768
769#[derive(Clone, Debug)]
775pub enum UnifiedADC {
776 RaBitQ(ADCTable),
778}
779
780#[derive(Clone, Debug, Serialize, Deserialize)]
782pub enum VectorStorage {
783 FullPrecision {
795 vectors: Vec<f32>,
797 norms: Vec<f32>,
799 count: usize,
801 dimensions: usize,
803 },
804
805 BinaryQuantized {
810 quantized: Vec<Vec<u8>>,
812
813 original: Option<Vec<Vec<f32>>>,
818
819 thresholds: Vec<f32>,
821
822 dimensions: usize,
824 },
825
826 RaBitQQuantized {
837 #[serde(skip)]
839 quantizer: Option<RaBitQ>,
840
841 params: RaBitQParams,
843
844 quantized_data: Vec<u8>,
847
848 quantized_scales: Vec<f32>,
851
852 code_size: usize,
855
856 original: Vec<f32>,
859
860 original_count: usize,
862
863 dimensions: usize,
865 },
866
867 ScalarQuantized {
878 params: ScalarParams,
880
881 quantized: Vec<u8>,
885
886 norms: Vec<f32>,
890
891 sums: Vec<i32>,
894
895 training_buffer: Vec<f32>,
898
899 count: usize,
901
902 dimensions: usize,
904
905 trained: bool,
908 },
909}
910
911impl VectorStorage {
912 #[must_use]
914 pub fn new_full_precision(dimensions: usize) -> Self {
915 Self::FullPrecision {
916 vectors: Vec::new(),
917 norms: Vec::new(),
918 count: 0,
919 dimensions,
920 }
921 }
922
923 #[must_use]
925 pub fn new_binary_quantized(dimensions: usize, keep_original: bool) -> Self {
926 Self::BinaryQuantized {
927 quantized: Vec::new(),
928 original: if keep_original {
929 Some(Vec::new())
930 } else {
931 None
932 },
933 thresholds: vec![0.0; dimensions], dimensions,
935 }
936 }
937
938 #[must_use]
949 pub fn new_rabitq_quantized(dimensions: usize, params: RaBitQParams) -> Self {
950 let values_per_byte = params.bits_per_dim.values_per_byte();
953 let code_size = dimensions.div_ceil(values_per_byte);
954
955 Self::RaBitQQuantized {
956 quantizer: Some(RaBitQ::new(params.clone())),
957 params,
958 quantized_data: Vec::new(),
959 quantized_scales: Vec::new(),
960 code_size,
961 original: Vec::new(),
962 original_count: 0,
963 dimensions,
964 }
965 }
966
967 #[must_use]
982 pub fn new_sq8_quantized(dimensions: usize) -> Self {
983 Self::ScalarQuantized {
984 params: ScalarParams::uninitialized(dimensions),
985 quantized: Vec::new(),
986 norms: Vec::new(),
987 sums: Vec::new(),
988 training_buffer: Vec::new(),
989 count: 0,
990 dimensions,
991 trained: false,
992 }
993 }
994
995 #[must_use]
1004 pub fn is_asymmetric(&self) -> bool {
1005 matches!(
1006 self,
1007 Self::RaBitQQuantized { .. }
1008 | Self::ScalarQuantized { .. }
1009 | Self::BinaryQuantized { .. }
1010 )
1011 }
1012
1013 #[must_use]
1015 pub fn is_binary_quantized(&self) -> bool {
1016 matches!(self, Self::BinaryQuantized { .. })
1017 }
1018
1019 #[must_use]
1021 pub fn is_sq8(&self) -> bool {
1022 matches!(self, Self::ScalarQuantized { .. })
1023 }
1024
1025 #[must_use]
1027 pub fn len(&self) -> usize {
1028 match self {
1029 Self::FullPrecision { count, .. } | Self::ScalarQuantized { count, .. } => *count,
1030 Self::BinaryQuantized { quantized, .. } => quantized.len(),
1031 Self::RaBitQQuantized { original_count, .. } => *original_count,
1032 }
1033 }
1034
1035 #[must_use]
1037 pub fn is_empty(&self) -> bool {
1038 self.len() == 0
1039 }
1040
1041 #[must_use]
1043 pub fn dimensions(&self) -> usize {
1044 match self {
1045 Self::FullPrecision { dimensions, .. }
1046 | Self::BinaryQuantized { dimensions, .. }
1047 | Self::RaBitQQuantized { dimensions, .. }
1048 | Self::ScalarQuantized { dimensions, .. } => *dimensions,
1049 }
1050 }
1051
1052 pub fn insert(&mut self, vector: Vec<f32>) -> Result<u32, String> {
1054 match self {
1055 Self::FullPrecision {
1056 vectors,
1057 norms,
1058 count,
1059 dimensions,
1060 } => {
1061 if vector.len() != *dimensions {
1062 return Err(format!(
1063 "Vector dimension mismatch: expected {}, got {}",
1064 dimensions,
1065 vector.len()
1066 ));
1067 }
1068 let id = *count as u32;
1069 let norm_sq: f32 = vector.iter().map(|&x| x * x).sum();
1071 norms.push(norm_sq);
1072 vectors.extend(vector);
1073 *count += 1;
1074 Ok(id)
1075 }
1076 Self::BinaryQuantized {
1077 quantized,
1078 original,
1079 thresholds,
1080 dimensions,
1081 } => {
1082 if vector.len() != *dimensions {
1083 return Err(format!(
1084 "Vector dimension mismatch: expected {}, got {}",
1085 dimensions,
1086 vector.len()
1087 ));
1088 }
1089
1090 let quant = Self::quantize_binary(&vector, thresholds);
1092 let id = quantized.len() as u32;
1093 quantized.push(quant);
1094
1095 if let Some(orig) = original {
1097 orig.push(vector);
1098 }
1099
1100 Ok(id)
1101 }
1102 Self::RaBitQQuantized {
1103 quantizer,
1104 params,
1105 quantized_data,
1106 quantized_scales,
1107 original,
1108 original_count,
1109 dimensions,
1110 ..
1111 } => {
1112 if vector.len() != *dimensions {
1113 return Err(format!(
1114 "Vector dimension mismatch: expected {}, got {}",
1115 dimensions,
1116 vector.len()
1117 ));
1118 }
1119
1120 let q = quantizer.get_or_insert_with(|| RaBitQ::new(params.clone()));
1122
1123 let quant = q.quantize(&vector);
1125 let id = *original_count as u32;
1126 quantized_data.extend(&quant.data);
1127 quantized_scales.push(quant.scale);
1128
1129 original.extend(vector);
1131 *original_count += 1;
1132
1133 Ok(id)
1134 }
1135 Self::ScalarQuantized {
1136 params,
1137 quantized,
1138 norms,
1139 sums,
1140 training_buffer,
1141 count,
1142 dimensions,
1143 trained,
1144 } => {
1145 if vector.len() != *dimensions {
1146 return Err(format!(
1147 "Vector dimension mismatch: expected {}, got {}",
1148 dimensions,
1149 vector.len()
1150 ));
1151 }
1152
1153 let id = *count as u32;
1154 let dim = *dimensions;
1155
1156 if *trained {
1157 let quant = params.quantize(&vector);
1159 norms.push(quant.norm_sq);
1160 sums.push(quant.sum);
1161 quantized.extend(quant.data);
1162 *count += 1;
1163 } else {
1164 training_buffer.extend(vector);
1166 *count += 1;
1167
1168 if *count >= 256 {
1169 let training_refs: Vec<&[f32]> = (0..256)
1171 .map(|i| &training_buffer[i * dim..(i + 1) * dim])
1172 .collect();
1173 *params =
1174 ScalarParams::train(&training_refs).map_err(ToString::to_string)?;
1175 *trained = true;
1176
1177 quantized.reserve(*count * dim);
1179 norms.reserve(*count);
1180 sums.reserve(*count);
1181 for i in 0..*count {
1182 let vec_slice = &training_buffer[i * dim..(i + 1) * dim];
1183 let quant = params.quantize(vec_slice);
1184 norms.push(quant.norm_sq);
1185 sums.push(quant.sum);
1186 quantized.extend(quant.data);
1187 }
1188
1189 training_buffer.clear();
1191 training_buffer.shrink_to_fit();
1192 }
1193 }
1194 Ok(id)
1198 }
1199 }
1200 }
1201
1202 #[inline]
1207 #[must_use]
1208 pub fn get(&self, id: u32) -> Option<&[f32]> {
1209 match self {
1210 Self::FullPrecision {
1211 vectors,
1212 count,
1213 dimensions,
1214 ..
1215 } => {
1216 let idx = id as usize;
1217 if idx >= *count {
1218 return None;
1219 }
1220 let start = idx * *dimensions;
1221 let end = start + *dimensions;
1222 Some(&vectors[start..end])
1223 }
1224 Self::BinaryQuantized { original, .. } => original
1225 .as_ref()
1226 .and_then(|o| o.get(id as usize).map(std::vec::Vec::as_slice)),
1227 Self::RaBitQQuantized {
1228 original,
1229 original_count,
1230 dimensions,
1231 ..
1232 } => {
1233 let idx = id as usize;
1234 if idx >= *original_count {
1235 return None;
1236 }
1237 let start = idx * *dimensions;
1238 let end = start + *dimensions;
1239 Some(&original[start..end])
1240 }
1241 Self::ScalarQuantized {
1242 training_buffer,
1243 count,
1244 dimensions,
1245 trained,
1246 ..
1247 } => {
1248 if *trained {
1251 return None; }
1253 let idx = id as usize;
1254 if idx >= *count {
1255 return None;
1256 }
1257 let start = idx * *dimensions;
1258 let end = start + *dimensions;
1259 Some(&training_buffer[start..end])
1260 }
1261 }
1262 }
1263
1264 #[must_use]
1270 pub fn get_dequantized(&self, id: u32) -> Option<Vec<f32>> {
1271 match self {
1272 Self::FullPrecision {
1273 vectors,
1274 count,
1275 dimensions,
1276 ..
1277 } => {
1278 let idx = id as usize;
1279 if idx >= *count {
1280 return None;
1281 }
1282 let start = idx * *dimensions;
1283 let end = start + *dimensions;
1284 Some(vectors[start..end].to_vec())
1285 }
1286 Self::BinaryQuantized { original, .. } => {
1287 original.as_ref().and_then(|o| o.get(id as usize).cloned())
1288 }
1289 Self::RaBitQQuantized {
1290 original,
1291 original_count,
1292 dimensions,
1293 ..
1294 } => {
1295 let idx = id as usize;
1296 if idx >= *original_count {
1297 return None;
1298 }
1299 let start = idx * *dimensions;
1300 let end = start + *dimensions;
1301 Some(original[start..end].to_vec())
1302 }
1303 Self::ScalarQuantized {
1304 params,
1305 quantized,
1306 training_buffer,
1307 count,
1308 dimensions,
1309 trained,
1310 ..
1311 } => {
1312 let idx = id as usize;
1313 if idx >= *count {
1314 return None;
1315 }
1316 let dim = *dimensions;
1317 if *trained {
1318 let start = idx * dim;
1320 let end = start + dim;
1321 Some(params.dequantize(&quantized[start..end]))
1322 } else {
1323 let start = idx * dim;
1325 let end = start + dim;
1326 Some(training_buffer[start..end].to_vec())
1327 }
1328 }
1329 }
1330 }
1331
1332 #[inline(always)]
1342 #[must_use]
1343 pub fn distance_asymmetric_l2(&self, query: &[f32], id: u32) -> Option<f32> {
1344 match self {
1345 Self::RaBitQQuantized {
1346 quantizer,
1347 quantized_data,
1348 quantized_scales,
1349 code_size,
1350 original_count,
1351 ..
1352 } => {
1353 let idx = id as usize;
1354 if idx >= *original_count {
1355 return None;
1356 }
1357
1358 let q = quantizer.as_ref()?;
1360
1361 let start = idx * code_size;
1363 let end = start + code_size;
1364 let data = &quantized_data[start..end];
1365 let scale = quantized_scales[idx];
1366
1367 Some(q.distance_asymmetric_l2_flat(query, data, scale))
1368 }
1369 Self::ScalarQuantized {
1370 params,
1371 quantized,
1372 norms,
1373 sums,
1374 count,
1375 dimensions,
1376 trained,
1377 ..
1378 } => {
1379 if !*trained {
1381 return None;
1382 }
1383
1384 let idx = id as usize;
1385 if idx >= *count {
1386 return None;
1387 }
1388
1389 let start = idx * *dimensions;
1390 let end = start + *dimensions;
1391 let query_prep = params.prepare_query(query);
1392 Some(params.distance_l2_squared_raw(
1393 &query_prep,
1394 &quantized[start..end],
1395 sums[idx],
1396 norms[idx],
1397 ))
1398 }
1399 Self::BinaryQuantized {
1400 quantized,
1401 thresholds,
1402 ..
1403 } => {
1404 let idx = id as usize;
1405 if idx >= quantized.len() {
1406 return None;
1407 }
1408
1409 let query_quantized = Self::quantize_binary(query, thresholds);
1411
1412 let hamming = hamming_distance(&query_quantized, &quantized[idx]);
1414
1415 Some(hamming as f32)
1418 }
1419 Self::FullPrecision { .. } => None,
1421 }
1422 }
1423
1424 #[inline]
1428 #[must_use]
1429 pub fn get_norm(&self, id: u32) -> Option<f32> {
1430 match self {
1431 Self::FullPrecision { norms, count, .. } => {
1432 let idx = id as usize;
1433 if idx >= *count {
1434 return None;
1435 }
1436 Some(norms[idx])
1437 }
1438 _ => None,
1439 }
1440 }
1441
1442 #[inline]
1451 #[must_use]
1452 pub fn supports_l2_decomposition(&self) -> bool {
1453 matches!(self, Self::FullPrecision { .. })
1457 }
1458
1459 #[inline(always)]
1469 #[must_use]
1470 pub fn distance_l2_decomposed(&self, query: &[f32], query_norm: f32, id: u32) -> Option<f32> {
1471 match self {
1472 Self::FullPrecision {
1473 vectors,
1474 norms,
1475 count,
1476 dimensions,
1477 } => {
1478 let idx = id as usize;
1479 if idx >= *count {
1480 return None;
1481 }
1482 let start = idx * *dimensions;
1483 let end = start + *dimensions;
1484 let vec = &vectors[start..end];
1485 let vec_norm = norms[idx];
1486
1487 let dot = dot_product(query, vec);
1490 Some(query_norm + vec_norm - 2.0 * dot)
1491 }
1492 Self::ScalarQuantized {
1493 params,
1494 quantized,
1495 norms,
1496 sums,
1497 count,
1498 dimensions,
1499 trained,
1500 ..
1501 } => {
1502 if !*trained {
1503 return None;
1504 }
1505 let idx = id as usize;
1506 if idx >= *count {
1507 return None;
1508 }
1509 let start = idx * *dimensions;
1510 let end = start + *dimensions;
1511 let vec_norm = norms[idx];
1512 let vec_sum = sums[idx];
1513
1514 let query_prep = params.prepare_query(query);
1516 let quantized_slice = &quantized[start..end];
1517 Some(params.distance_l2_squared_raw(
1518 &query_prep,
1519 quantized_slice,
1520 vec_sum,
1521 vec_norm,
1522 ))
1523 }
1524 _ => None,
1525 }
1526 }
1527
1528 #[inline]
1533 #[must_use]
1534 pub fn get_quantized(&self, id: u32) -> Option<QuantizedVector> {
1535 match self {
1536 Self::RaBitQQuantized {
1537 quantized_data,
1538 quantized_scales,
1539 code_size,
1540 original_count,
1541 dimensions,
1542 params,
1543 ..
1544 } => {
1545 let idx = id as usize;
1546 if idx >= *original_count {
1547 return None;
1548 }
1549 let start = idx * code_size;
1550 let end = start + code_size;
1551 Some(QuantizedVector::new(
1552 quantized_data[start..end].to_vec(),
1553 quantized_scales[idx],
1554 params.bits_per_dim.to_u8(),
1555 *dimensions,
1556 ))
1557 }
1558 _ => None,
1559 }
1560 }
1561
1562 #[must_use]
1564 pub fn quantizer(&self) -> Option<&RaBitQ> {
1565 match self {
1566 Self::RaBitQQuantized { quantizer, .. } => quantizer.as_ref(),
1567 _ => None,
1568 }
1569 }
1570
1571 #[must_use]
1586 pub fn build_adc_table(&self, query: &[f32]) -> Option<UnifiedADC> {
1587 match self {
1588 Self::RaBitQQuantized { quantizer, .. } => {
1589 let q = quantizer.as_ref()?;
1590 Some(UnifiedADC::RaBitQ(q.build_adc_table(query)?))
1591 }
1592 _ => None,
1594 }
1595 }
1596
1597 #[inline]
1601 #[must_use]
1602 pub fn distance_adc(&self, adc: &UnifiedADC, id: u32) -> Option<f32> {
1603 match (self, adc) {
1604 (
1605 Self::RaBitQQuantized {
1606 quantized_data,
1607 code_size,
1608 original_count,
1609 ..
1610 },
1611 UnifiedADC::RaBitQ(table),
1612 ) => {
1613 let idx = id as usize;
1614 if idx >= *original_count {
1615 return None;
1616 }
1617 let start = idx * code_size;
1619 let end = start + code_size;
1620 Some(table.distance(&quantized_data[start..end]))
1621 }
1622 _ => None, }
1624 }
1625
1626 #[inline]
1639 pub fn prefetch(&self, id: u32) {
1640 let ptr: Option<*const u8> = match self {
1643 Self::FullPrecision {
1644 vectors,
1645 count,
1646 dimensions,
1647 ..
1648 } => {
1649 let idx = id as usize;
1650 if idx >= *count {
1651 None
1652 } else {
1653 let start = idx * *dimensions;
1654 Some(vectors[start..].as_ptr().cast())
1655 }
1656 }
1657 Self::BinaryQuantized { original, .. } => original
1658 .as_ref()
1659 .and_then(|o| o.get(id as usize).map(|v| v.as_ptr().cast())),
1660 Self::RaBitQQuantized {
1661 quantized_data,
1662 code_size,
1663 original_count,
1664 ..
1665 } => {
1666 let idx = id as usize;
1668 if idx >= *original_count {
1669 None
1670 } else {
1671 let start = idx * code_size;
1672 Some(quantized_data[start..].as_ptr())
1673 }
1674 }
1675 Self::ScalarQuantized {
1676 quantized,
1677 training_buffer,
1678 count,
1679 dimensions,
1680 trained,
1681 ..
1682 } => {
1683 let idx = id as usize;
1684 if idx >= *count {
1685 None
1686 } else if *trained {
1687 let start = idx * *dimensions;
1689 Some(quantized[start..].as_ptr())
1690 } else {
1691 let start = idx * *dimensions;
1693 Some(training_buffer[start..].as_ptr().cast())
1694 }
1695 }
1696 };
1697
1698 if let Some(ptr) = ptr {
1699 #[cfg(target_arch = "x86_64")]
1701 unsafe {
1702 std::arch::x86_64::_mm_prefetch(ptr.cast::<i8>(), std::arch::x86_64::_MM_HINT_T0);
1703 }
1704 #[cfg(target_arch = "aarch64")]
1705 unsafe {
1706 std::arch::asm!(
1707 "prfm pldl1keep, [{ptr}]",
1708 ptr = in(reg) ptr,
1709 options(nostack, preserves_flags)
1710 );
1711 }
1712 }
1713 }
1714
1715 #[inline]
1720 pub fn prefetch_quantized(&self, id: u32) {
1721 if let Self::RaBitQQuantized {
1722 quantized_data,
1723 code_size,
1724 original_count,
1725 ..
1726 } = self
1727 {
1728 let idx = id as usize;
1729 if idx < *original_count {
1730 let start = idx * code_size;
1731 let ptr = quantized_data[start..].as_ptr();
1732 #[cfg(target_arch = "x86_64")]
1733 unsafe {
1734 std::arch::x86_64::_mm_prefetch(
1735 ptr.cast::<i8>(),
1736 std::arch::x86_64::_MM_HINT_T0,
1737 );
1738 }
1739 #[cfg(target_arch = "aarch64")]
1740 unsafe {
1741 std::arch::asm!(
1742 "prfm pldl1keep, [{ptr}]",
1743 ptr = in(reg) ptr,
1744 options(nostack, preserves_flags)
1745 );
1746 }
1747 }
1748 }
1749 }
1750
1751 #[must_use]
1755 pub fn rabitq_code_size(&self) -> Option<usize> {
1756 match self {
1757 Self::RaBitQQuantized { code_size, .. } => Some(*code_size),
1758 _ => None,
1759 }
1760 }
1761
1762 #[must_use]
1767 pub fn get_rabitq_code(&self, id: u32) -> Option<&[u8]> {
1768 match self {
1769 Self::RaBitQQuantized {
1770 quantized_data,
1771 code_size,
1772 original_count,
1773 ..
1774 } => {
1775 let idx = id as usize;
1776 if idx >= *original_count {
1777 return None;
1778 }
1779 let start = idx * code_size;
1780 let end = start + code_size;
1781 if end <= quantized_data.len() {
1782 Some(&quantized_data[start..end])
1783 } else {
1784 None
1785 }
1786 }
1787 _ => None,
1788 }
1789 }
1790
1791 pub fn build_interleaved_codes(&self, neighbors: &[u32], output: &mut [u8]) -> usize {
1806 let code_size = match self.rabitq_code_size() {
1807 Some(cs) => cs,
1808 None => return 0,
1809 };
1810
1811 let batch_size = 32;
1812 let expected_len = code_size * batch_size;
1813 if output.len() < expected_len {
1814 return 0;
1815 }
1816
1817 output[..expected_len].fill(0);
1819
1820 let valid_count = neighbors.len().min(batch_size);
1821
1822 for (n, &neighbor_id) in neighbors.iter().take(valid_count).enumerate() {
1824 if let Some(code) = self.get_rabitq_code(neighbor_id) {
1825 for sq in 0..code_size {
1826 output[sq * batch_size + n] = code[sq];
1827 }
1828 }
1829 }
1830
1831 valid_count
1832 }
1833
1834 fn quantize_binary(vector: &[f32], thresholds: &[f32]) -> Vec<u8> {
1840 debug_assert_eq!(vector.len(), thresholds.len());
1841
1842 let num_bytes = vector.len().div_ceil(8); let mut quantized = vec![0u8; num_bytes];
1844
1845 for (i, (&value, &threshold)) in vector.iter().zip(thresholds.iter()).enumerate() {
1846 if value >= threshold {
1847 let byte_idx = i / 8;
1848 let bit_idx = i % 8;
1849 quantized[byte_idx] |= 1 << bit_idx;
1850 }
1851 }
1852
1853 quantized
1854 }
1855
1856 pub fn train_quantization(&mut self, sample_vectors: &[Vec<f32>]) -> Result<(), String> {
1860 match self {
1861 Self::BinaryQuantized {
1862 thresholds,
1863 dimensions,
1864 ..
1865 } => {
1866 if sample_vectors.is_empty() {
1867 return Err("Cannot train on empty sample".to_string());
1868 }
1869
1870 for vec in sample_vectors {
1872 if vec.len() != *dimensions {
1873 return Err("Sample vector dimension mismatch".to_string());
1874 }
1875 }
1876
1877 for dim in 0..*dimensions {
1879 let mut values: Vec<f32> = sample_vectors.iter().map(|v| v[dim]).collect();
1880 values.sort_unstable_by_key(|&x| OrderedFloat(x));
1881
1882 let median = if values.len().is_multiple_of(2) {
1883 let mid = values.len() / 2;
1884 f32::midpoint(values[mid - 1], values[mid])
1885 } else {
1886 values[values.len() / 2]
1887 };
1888
1889 thresholds[dim] = median;
1890 }
1891
1892 Ok(())
1893 }
1894 Self::FullPrecision { .. } => {
1895 Err("Cannot train quantization on full precision storage".to_string())
1896 }
1897 Self::RaBitQQuantized {
1898 quantizer, params, ..
1899 } => {
1900 if sample_vectors.is_empty() {
1901 return Err("Cannot train on empty sample".to_string());
1902 }
1903 let q = quantizer.get_or_insert_with(|| RaBitQ::new(params.clone()));
1905 q.train_owned(sample_vectors).map_err(ToString::to_string)?;
1906 Ok(())
1907 }
1908 Self::ScalarQuantized {
1909 params,
1910 quantized,
1911 norms,
1912 sums,
1913 training_buffer,
1914 count,
1915 dimensions,
1916 trained,
1917 } => {
1918 if sample_vectors.is_empty() {
1919 return Err("Cannot train on empty sample".to_string());
1920 }
1921
1922 let refs: Vec<&[f32]> =
1924 sample_vectors.iter().map(std::vec::Vec::as_slice).collect();
1925 *params = ScalarParams::train(&refs).map_err(ToString::to_string)?;
1926 *trained = true;
1927
1928 if *count > 0 && quantized.is_empty() && !training_buffer.is_empty() {
1930 let dim = *dimensions;
1931 quantized.reserve(*count * dim);
1932 norms.reserve(*count);
1933 sums.reserve(*count);
1934 for i in 0..*count {
1935 let vec_slice = &training_buffer[i * dim..(i + 1) * dim];
1936 let quant = params.quantize(vec_slice);
1937 norms.push(quant.norm_sq);
1938 sums.push(quant.sum);
1939 quantized.extend(quant.data);
1940 }
1941 training_buffer.clear();
1943 training_buffer.shrink_to_fit();
1944 }
1945
1946 Ok(())
1947 }
1948 }
1949 }
1950
1951 #[must_use]
1953 pub fn memory_usage(&self) -> usize {
1954 match self {
1955 Self::FullPrecision { vectors, norms, .. } => {
1956 vectors.len() * std::mem::size_of::<f32>()
1957 + norms.len() * std::mem::size_of::<f32>()
1958 }
1959 Self::BinaryQuantized {
1960 quantized,
1961 original,
1962 thresholds,
1963 dimensions,
1964 } => {
1965 let quantized_size = quantized.len() * (dimensions + 7) / 8;
1966 let original_size = original
1967 .as_ref()
1968 .map_or(0, |o| o.len() * dimensions * std::mem::size_of::<f32>());
1969 let thresholds_size = thresholds.len() * std::mem::size_of::<f32>();
1970 quantized_size + original_size + thresholds_size
1971 }
1972 Self::RaBitQQuantized {
1973 quantized_data,
1974 quantized_scales,
1975 original,
1976 ..
1977 } => {
1978 let quantized_size =
1980 quantized_data.len() + quantized_scales.len() * std::mem::size_of::<f32>();
1981 let original_size = original.len() * std::mem::size_of::<f32>();
1983 quantized_size + original_size
1984 }
1985 Self::ScalarQuantized {
1986 quantized,
1987 norms,
1988 sums,
1989 training_buffer,
1990 ..
1991 } => {
1992 let quantized_size = quantized.len();
1994 let norms_size = norms.len() * std::mem::size_of::<f32>();
1995 let sums_size = sums.len() * std::mem::size_of::<i32>();
1996 let buffer_size = training_buffer.len() * std::mem::size_of::<f32>();
1997 let params_size = 2 * std::mem::size_of::<f32>() + std::mem::size_of::<usize>();
1999 quantized_size + norms_size + sums_size + buffer_size + params_size
2000 }
2001 }
2002 }
2003
2004 pub fn reorder(&mut self, old_to_new: &[u32]) {
2009 match self {
2010 Self::FullPrecision {
2011 vectors,
2012 norms,
2013 count,
2014 dimensions,
2015 } => {
2016 let dim = *dimensions;
2017 let n = *count;
2018 let mut new_vectors = vec![0.0f32; vectors.len()];
2019 let mut new_norms = vec![0.0f32; norms.len()];
2020 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2021 if old_id < n {
2022 let old_start = old_id * dim;
2023 let new_start = new_id as usize * dim;
2024 new_vectors[new_start..new_start + dim]
2025 .copy_from_slice(&vectors[old_start..old_start + dim]);
2026 new_norms[new_id as usize] = norms[old_id];
2027 }
2028 }
2029 *vectors = new_vectors;
2030 *norms = new_norms;
2031 }
2032 Self::BinaryQuantized {
2033 quantized,
2034 original,
2035 ..
2036 } => {
2037 let mut new_quantized = vec![Vec::new(); quantized.len()];
2039 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2040 new_quantized[new_id as usize] = std::mem::take(&mut quantized[old_id]);
2041 }
2042 *quantized = new_quantized;
2043
2044 if let Some(orig) = original {
2046 let mut new_original = vec![Vec::new(); orig.len()];
2047 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2048 new_original[new_id as usize] = std::mem::take(&mut orig[old_id]);
2049 }
2050 *orig = new_original;
2051 }
2052 }
2053 Self::RaBitQQuantized {
2054 quantized_data,
2055 quantized_scales,
2056 code_size,
2057 original,
2058 original_count,
2059 dimensions,
2060 ..
2061 } => {
2062 let dim = *dimensions;
2063 let n = *original_count;
2064 let cs = *code_size;
2065
2066 let mut new_data = vec![0u8; quantized_data.len()];
2068 let mut new_scales = vec![0.0f32; quantized_scales.len()];
2069 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2070 if old_id < n {
2071 let old_start = old_id * cs;
2072 let new_start = new_id as usize * cs;
2073 new_data[new_start..new_start + cs]
2074 .copy_from_slice(&quantized_data[old_start..old_start + cs]);
2075 new_scales[new_id as usize] = quantized_scales[old_id];
2076 }
2077 }
2078 *quantized_data = new_data;
2079 *quantized_scales = new_scales;
2080
2081 let mut new_original = vec![0.0f32; original.len()];
2083 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2084 if old_id < n {
2085 let old_start = old_id * dim;
2086 let new_start = new_id as usize * dim;
2087 new_original[new_start..new_start + dim]
2088 .copy_from_slice(&original[old_start..old_start + dim]);
2089 }
2090 }
2091 *original = new_original;
2092 }
2093 Self::ScalarQuantized {
2094 quantized,
2095 norms,
2096 sums,
2097 count,
2098 dimensions,
2099 ..
2100 } => {
2101 let dim = *dimensions;
2102 let n = *count;
2103
2104 let mut new_quantized = vec![0u8; quantized.len()];
2106 let mut new_norms = vec![0.0f32; norms.len()];
2107 let mut new_sums = vec![0i32; sums.len()];
2108 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2109 if old_id < n {
2110 let old_start = old_id * dim;
2111 let new_start = new_id as usize * dim;
2112 new_quantized[new_start..new_start + dim]
2113 .copy_from_slice(&quantized[old_start..old_start + dim]);
2114 if old_id < norms.len() {
2115 new_norms[new_id as usize] = norms[old_id];
2116 }
2117 if old_id < sums.len() {
2118 new_sums[new_id as usize] = sums[old_id];
2119 }
2120 }
2121 }
2122 *quantized = new_quantized;
2123 *norms = new_norms;
2124 *sums = new_sums;
2125 }
2126 }
2127 }
2128}
2129
2130#[cfg(test)]
2131mod tests {
2132 use super::*;
2133
2134 #[test]
2135 fn test_neighbor_lists_basic() {
2136 let mut lists = NeighborLists::new(8);
2137
2138 lists.set_neighbors(0, 0, vec![1, 2, 3]);
2140
2141 let neighbors = lists.get_neighbors(0, 0);
2142 assert_eq!(neighbors, &[1, 2, 3]);
2143
2144 let empty = lists.get_neighbors(0, 1);
2146 assert_eq!(empty.len(), 0);
2147 }
2148
2149 #[test]
2150 fn test_neighbor_lists_bidirectional() {
2151 let mut lists = NeighborLists::new(8);
2152
2153 lists.add_bidirectional_link(0, 1, 0);
2154
2155 assert_eq!(lists.get_neighbors(0, 0), &[1]);
2156 assert_eq!(lists.get_neighbors(1, 0), &[0]);
2157 }
2158
2159 #[test]
2160 fn test_vector_storage_full_precision() {
2161 let mut storage = VectorStorage::new_full_precision(3);
2162
2163 let vec1 = vec![1.0, 2.0, 3.0];
2164 let vec2 = vec![4.0, 5.0, 6.0];
2165
2166 let id1 = storage.insert(vec1.clone()).unwrap();
2167 let id2 = storage.insert(vec2.clone()).unwrap();
2168
2169 assert_eq!(id1, 0);
2170 assert_eq!(id2, 1);
2171 assert_eq!(storage.len(), 2);
2172
2173 assert_eq!(storage.get(0), Some(vec1.as_slice()));
2174 assert_eq!(storage.get(1), Some(vec2.as_slice()));
2175 }
2176
2177 #[test]
2178 fn test_vector_storage_dimension_check() {
2179 let mut storage = VectorStorage::new_full_precision(3);
2180
2181 let wrong_dim = vec![1.0, 2.0]; assert!(storage.insert(wrong_dim).is_err());
2183 }
2184
2185 #[test]
2186 fn test_binary_quantization() {
2187 let vector = vec![0.5, -0.3, 0.8, -0.1];
2188 let thresholds = vec![0.0, 0.0, 0.0, 0.0];
2189
2190 let quantized = VectorStorage::quantize_binary(&vector, &thresholds);
2191
2192 assert_eq!(quantized[0], 5);
2195 }
2196
2197 #[test]
2198 fn test_quantization_training() {
2199 let mut storage = VectorStorage::new_binary_quantized(2, true);
2200
2201 let samples = vec![vec![1.0, 5.0], vec![2.0, 6.0], vec![3.0, 7.0]];
2202
2203 storage.train_quantization(&samples).unwrap();
2204
2205 match storage {
2207 VectorStorage::BinaryQuantized { thresholds, .. } => {
2208 assert_eq!(thresholds, vec![2.0, 6.0]);
2209 }
2210 _ => panic!("Expected BinaryQuantized storage"),
2211 }
2212 }
2213
2214 #[test]
2215 fn test_rabitq_storage_insert_and_get() {
2216 let params = RaBitQParams::bits4();
2217 let mut storage = VectorStorage::new_rabitq_quantized(4, params);
2218
2219 let vec1 = vec![1.0, 2.0, 3.0, 4.0];
2220 let vec2 = vec![5.0, 6.0, 7.0, 8.0];
2221
2222 let id1 = storage.insert(vec1.clone()).unwrap();
2223 let id2 = storage.insert(vec2.clone()).unwrap();
2224
2225 assert_eq!(id1, 0);
2226 assert_eq!(id2, 1);
2227 assert_eq!(storage.len(), 2);
2228 assert!(storage.is_asymmetric());
2229
2230 assert_eq!(storage.get(0), Some(vec1.as_slice()));
2232 assert_eq!(storage.get(1), Some(vec2.as_slice()));
2233 }
2234
2235 #[test]
2236 fn test_rabitq_asymmetric_distance() {
2237 let params = RaBitQParams::bits4();
2238 let mut storage = VectorStorage::new_rabitq_quantized(4, params);
2239
2240 let vec1 = vec![1.0, 0.0, 0.0, 0.0];
2241 let vec2 = vec![0.0, 1.0, 0.0, 0.0];
2242
2243 storage.insert(vec1.clone()).unwrap();
2244 storage.insert(vec2.clone()).unwrap();
2245
2246 let query = vec![1.0, 0.0, 0.0, 0.0];
2248 let dist0 = storage.distance_asymmetric_l2(&query, 0).unwrap();
2249 let dist1 = storage.distance_asymmetric_l2(&query, 1).unwrap();
2250
2251 assert!(dist0 < 0.5, "Distance to self should be small: {dist0}");
2253 assert!(dist1 > dist0, "Distance to orthogonal should be larger");
2255 }
2256
2257 #[test]
2258 fn test_rabitq_get_quantized() {
2259 let params = RaBitQParams::bits4();
2260 let mut storage = VectorStorage::new_rabitq_quantized(4, params);
2261
2262 let vec1 = vec![1.0, 2.0, 3.0, 4.0];
2263 storage.insert(vec1).unwrap();
2264
2265 let qv = storage.get_quantized(0);
2266 assert!(qv.is_some());
2267 let qv = qv.unwrap();
2268 assert_eq!(qv.dimensions, 4);
2269 assert_eq!(qv.bits, 4); }
2271
2272 #[test]
2273 fn test_binary_quantized_train_empty_sample_rejected() {
2274 let mut storage = VectorStorage::new_binary_quantized(4, true);
2275 let empty_samples: Vec<Vec<f32>> = vec![];
2276 let result = storage.train_quantization(&empty_samples);
2277 assert!(result.is_err());
2278 assert!(result.unwrap_err().contains("empty sample"));
2279 }
2280
2281 #[test]
2282 fn test_binary_quantized_train_dimension_mismatch_rejected() {
2283 let mut storage = VectorStorage::new_binary_quantized(4, true);
2284 let samples = vec![vec![1.0, 2.0]];
2286 let result = storage.train_quantization(&samples);
2287 assert!(result.is_err());
2288 assert!(result.unwrap_err().contains("dimension mismatch"));
2289 }
2290
2291 #[test]
2292 fn test_rabitq_train_empty_sample_rejected() {
2293 let params = RaBitQParams::bits4();
2294 let mut storage = VectorStorage::new_rabitq_quantized(4, params);
2295 let empty_samples: Vec<Vec<f32>> = vec![];
2296 let result = storage.train_quantization(&empty_samples);
2297 assert!(result.is_err());
2298 assert!(result.unwrap_err().contains("empty sample"));
2299 }
2300
2301 #[test]
2302 fn test_sq8_train_empty_sample_rejected() {
2303 let mut storage = VectorStorage::new_sq8_quantized(4);
2304 let empty_samples: Vec<Vec<f32>> = vec![];
2305 let result = storage.train_quantization(&empty_samples);
2306 assert!(result.is_err());
2307 assert!(result.unwrap_err().contains("empty sample"));
2308 }
2309
2310 #[test]
2311 fn test_neighbor_code_storage_interleaving() {
2312 let params = RaBitQParams::bits4();
2314 let mut storage = VectorStorage::new_rabitq_quantized(8, params);
2315
2316 let vec0 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2318 let vec1 = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2319 let vec2 = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0];
2320 storage.insert(vec0).unwrap();
2321 storage.insert(vec1).unwrap();
2322 storage.insert(vec2).unwrap();
2323
2324 let mut neighbors = NeighborLists::new(8);
2326 neighbors.set_neighbors(0, 0, vec![1, 2]);
2327 neighbors.set_neighbors(1, 0, vec![0, 2]);
2328 neighbors.set_neighbors(2, 0, vec![0, 1]);
2329
2330 let ncs = NeighborCodeStorage::build_from_storage(&storage, &neighbors, 0);
2332 assert!(ncs.is_some());
2333 let ncs = ncs.unwrap();
2334
2335 assert_eq!(ncs.len(), 3); assert_eq!(ncs.get_neighbor_count(0), 2); assert_eq!(ncs.get_neighbor_count(1), 2);
2339 assert_eq!(ncs.get_neighbor_count(2), 2);
2340
2341 let block = ncs.get_block(0);
2343 assert!(block.is_some());
2344 let block = block.unwrap();
2345
2346 assert_eq!(block.len(), 4 * FASTSCAN_BATCH_SIZE);
2349
2350 assert!(ncs.memory_usage() > 0);
2359 }
2360
2361 #[test]
2362 fn test_neighbor_code_storage_update() {
2363 let params = RaBitQParams::bits4();
2364 let mut storage = VectorStorage::new_rabitq_quantized(8, params);
2365
2366 for i in 0..4 {
2368 let v: Vec<f32> = (0..8).map(|j| (i * 8 + j) as f32).collect();
2369 storage.insert(v).unwrap();
2370 }
2371
2372 let quantized_data = match &storage {
2374 VectorStorage::RaBitQQuantized {
2375 quantized_data,
2376 code_size,
2377 ..
2378 } => (quantized_data.clone(), *code_size),
2379 _ => panic!("Expected RaBitQQuantized"),
2380 };
2381
2382 let mut ncs = NeighborCodeStorage::new(quantized_data.1);
2384 assert!(ncs.is_empty());
2385
2386 ncs.update_vertex(0, &[1, 2, 3], &quantized_data.0);
2388 assert_eq!(ncs.len(), 1);
2389 assert_eq!(ncs.get_neighbor_count(0), 3);
2390
2391 let block = ncs.get_block(0);
2393 assert!(block.is_some());
2394
2395 ncs.update_vertex(2, &[0, 1], &quantized_data.0);
2397 assert_eq!(ncs.len(), 3); assert_eq!(ncs.get_neighbor_count(2), 2);
2399 }
2400}