1use arc_swap::ArcSwap;
11use ordered_float::OrderedFloat;
12use parking_lot::Mutex;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15
16use crate::compression::rabitq::QuantizedVector;
17use crate::compression::{ADCTable, RaBitQ, RaBitQParams, ScalarParams};
18use crate::distance::dot_product;
19
20static EMPTY_NEIGHBORS: &[u32] = &[];
22
23#[derive(Debug)]
35pub struct NeighborLists {
36 neighbors: Vec<Vec<ArcSwap<Box<[u32]>>>>,
42
43 write_locks: Vec<Vec<Mutex<()>>>,
46
47 max_levels: usize,
49
50 m_max: usize,
53}
54
55impl NeighborLists {
56 #[must_use]
58 pub fn new(max_levels: usize) -> Self {
59 Self {
60 neighbors: Vec::new(),
61 write_locks: Vec::new(),
62 max_levels,
63 m_max: 32, }
65 }
66
67 #[must_use]
69 pub fn with_capacity(num_nodes: usize, max_levels: usize, m: usize) -> Self {
70 Self {
71 neighbors: Vec::with_capacity(num_nodes),
72 write_locks: Vec::with_capacity(num_nodes),
73 max_levels,
74 m_max: m * 2,
75 }
76 }
77
78 #[must_use]
80 pub fn m_max(&self) -> usize {
81 self.m_max
82 }
83
84 #[must_use]
86 pub fn len(&self) -> usize {
87 self.neighbors.len()
88 }
89
90 #[must_use]
92 pub fn is_empty(&self) -> bool {
93 self.neighbors.is_empty()
94 }
95
96 #[must_use]
100 pub fn get_neighbors(&self, node_id: u32, level: u8) -> Vec<u32> {
101 let node_idx = node_id as usize;
102 let level_idx = level as usize;
103
104 if node_idx >= self.neighbors.len() {
105 return Vec::new();
106 }
107
108 if level_idx >= self.neighbors[node_idx].len() {
109 return Vec::new();
110 }
111
112 self.neighbors[node_idx][level_idx].load().to_vec()
114 }
115
116 #[inline]
121 pub fn with_neighbors<F, R>(&self, node_id: u32, level: u8, f: F) -> R
122 where
123 F: FnOnce(&[u32]) -> R,
124 {
125 let node_idx = node_id as usize;
126 let level_idx = level as usize;
127
128 if node_idx >= self.neighbors.len() {
129 return f(EMPTY_NEIGHBORS);
130 }
131
132 if level_idx >= self.neighbors[node_idx].len() {
133 return f(EMPTY_NEIGHBORS);
134 }
135
136 let guard = self.neighbors[node_idx][level_idx].load();
139 f(&guard)
140 }
141
142 #[inline]
148 pub fn prefetch(&self, node_id: u32, level: u8) {
149 use super::prefetch::PrefetchConfig;
150 if !PrefetchConfig::enabled() {
151 return;
152 }
153
154 let node_idx = node_id as usize;
155 let level_idx = level as usize;
156
157 if node_idx >= self.neighbors.len() {
158 return;
159 }
160 if level_idx >= self.neighbors[node_idx].len() {
161 return;
162 }
163
164 let ptr = &self.neighbors[node_idx][level_idx] as *const _ as *const u8;
167 #[cfg(target_arch = "x86_64")]
168 unsafe {
169 use std::arch::x86_64::_mm_prefetch;
170 use std::arch::x86_64::_MM_HINT_T0;
171 _mm_prefetch(ptr.cast(), _MM_HINT_T0);
172 }
173 #[cfg(target_arch = "aarch64")]
174 unsafe {
175 std::arch::asm!(
176 "prfm pldl1keep, [{ptr}]",
177 ptr = in(reg) ptr,
178 options(nostack, preserves_flags)
179 );
180 }
181 }
182
183 fn ensure_node_exists(&mut self, node_idx: usize) {
185 while self.neighbors.len() <= node_idx {
186 let mut levels = Vec::with_capacity(self.max_levels);
187 let mut locks = Vec::with_capacity(self.max_levels);
188 for _ in 0..self.max_levels {
189 levels.push(ArcSwap::from_pointee(Vec::new().into_boxed_slice()));
191 locks.push(Mutex::new(()));
192 }
193 self.neighbors.push(levels);
194 self.write_locks.push(locks);
195 }
196 }
197
198 pub fn set_neighbors(&mut self, node_id: u32, level: u8, neighbors_list: Vec<u32>) {
200 let node_idx = node_id as usize;
201 let level_idx = level as usize;
202
203 self.ensure_node_exists(node_idx);
204
205 self.neighbors[node_idx][level_idx].store(Arc::new(neighbors_list.into_boxed_slice()));
207 }
208
209 pub fn add_bidirectional_link(&mut self, node_a: u32, node_b: u32, level: u8) {
214 let node_a_idx = node_a as usize;
215 let node_b_idx = node_b as usize;
216 let level_idx = level as usize;
217
218 if node_a_idx == node_b_idx {
219 return; }
221
222 let max_idx = node_a_idx.max(node_b_idx);
224 self.ensure_node_exists(max_idx);
225
226 {
228 let current = self.neighbors[node_a_idx][level_idx].load();
229 if !current.contains(&node_b) {
230 let mut new_list = current.to_vec();
231 new_list.push(node_b);
232 self.neighbors[node_a_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
233 }
234 }
235
236 {
238 let current = self.neighbors[node_b_idx][level_idx].load();
239 if !current.contains(&node_a) {
240 let mut new_list = current.to_vec();
241 new_list.push(node_a);
242 self.neighbors[node_b_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
243 }
244 }
245 }
246
247 pub fn add_bidirectional_link_parallel(&self, node_a: u32, node_b: u32, level: u8) {
252 let node_a_idx = node_a as usize;
253 let node_b_idx = node_b as usize;
254 let level_idx = level as usize;
255
256 if node_a_idx == node_b_idx {
257 return; }
259
260 if node_a_idx >= self.neighbors.len() || node_b_idx >= self.neighbors.len() {
262 return; }
264
265 let (first_idx, second_idx, first_neighbor, second_neighbor) = if node_a_idx < node_b_idx {
267 (node_a_idx, node_b_idx, node_b, node_a)
268 } else {
269 (node_b_idx, node_a_idx, node_a, node_b)
270 };
271
272 let _lock_first = self.write_locks[first_idx][level_idx].lock();
274 let _lock_second = self.write_locks[second_idx][level_idx].lock();
275
276 {
278 let current = self.neighbors[first_idx][level_idx].load();
279 if !current.contains(&first_neighbor) {
280 let mut new_list = current.to_vec();
281 new_list.push(first_neighbor);
282 self.neighbors[first_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
283 }
284 }
285
286 {
288 let current = self.neighbors[second_idx][level_idx].load();
289 if !current.contains(&second_neighbor) {
290 let mut new_list = current.to_vec();
291 new_list.push(second_neighbor);
292 self.neighbors[second_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
293 }
294 }
295 }
296
297 pub fn remove_link_parallel(&self, node_a: u32, node_b: u32, level: u8) {
302 let node_a_idx = node_a as usize;
303 let level_idx = level as usize;
304
305 if node_a_idx >= self.neighbors.len() {
307 return; }
309
310 let _lock = self.write_locks[node_a_idx][level_idx].lock();
312 let current = self.neighbors[node_a_idx][level_idx].load();
313 let new_list: Vec<u32> = current.iter().copied().filter(|&n| n != node_b).collect();
314 self.neighbors[node_a_idx][level_idx].store(Arc::new(new_list.into_boxed_slice()));
315 }
316
317 pub fn set_neighbors_parallel(&self, node_id: u32, level: u8, neighbors_list: Vec<u32>) {
321 let node_idx = node_id as usize;
322 let level_idx = level as usize;
323
324 if node_idx >= self.neighbors.len() {
326 return; }
328
329 let _lock = self.write_locks[node_idx][level_idx].lock();
331 self.neighbors[node_idx][level_idx].store(Arc::new(neighbors_list.into_boxed_slice()));
332 }
333
334 #[must_use]
336 pub fn total_neighbors(&self) -> usize {
337 self.neighbors
338 .iter()
339 .flat_map(|node| node.iter())
340 .map(|level| level.load().len())
341 .sum()
342 }
343
344 #[must_use]
346 pub fn memory_usage(&self) -> usize {
347 let mut total = 0;
348
349 total += self.neighbors.capacity() * std::mem::size_of::<Vec<ArcSwap<Box<[u32]>>>>();
351
352 for node in &self.neighbors {
354 total += node.capacity() * std::mem::size_of::<ArcSwap<Box<[u32]>>>();
355
356 for level in node {
358 let guard = level.load();
359 total += guard.len() * std::mem::size_of::<u32>();
360 }
361 }
362
363 total += self.write_locks.capacity() * std::mem::size_of::<Vec<Mutex<()>>>();
365 for node in &self.write_locks {
366 total += node.capacity() * std::mem::size_of::<Mutex<()>>();
367 }
368
369 total
370 }
371
372 pub fn reorder_bfs(&mut self, entry_point: u32, start_level: u8) -> Vec<u32> {
379 use std::collections::{HashSet, VecDeque};
380
381 let num_nodes = self.neighbors.len();
382 if num_nodes == 0 {
383 return Vec::new();
384 }
385
386 let mut visited = HashSet::new();
388 let mut queue = VecDeque::new();
389 let mut old_to_new = vec![u32::MAX; num_nodes]; let mut new_id = 0u32;
391
392 queue.push_back(entry_point);
394 visited.insert(entry_point);
395
396 while let Some(node_id) = queue.pop_front() {
397 old_to_new[node_id as usize] = new_id;
399 new_id += 1;
400
401 for level in (0..=start_level).rev() {
403 let neighbors = self.get_neighbors(node_id, level);
404 for &neighbor_id in &neighbors {
405 if visited.insert(neighbor_id) {
406 queue.push_back(neighbor_id);
407 }
408 }
409 }
410 }
411
412 for (_old_id, mapping) in old_to_new.iter_mut().enumerate().take(num_nodes) {
414 if *mapping == u32::MAX {
415 *mapping = new_id;
416 new_id += 1;
417 }
418 }
419
420 let mut new_neighbors = Vec::with_capacity(num_nodes);
422 let mut new_write_locks = Vec::with_capacity(num_nodes);
423 for _ in 0..num_nodes {
424 let mut levels = Vec::with_capacity(self.max_levels);
425 let mut locks = Vec::with_capacity(self.max_levels);
426 for _ in 0..self.max_levels {
427 levels.push(ArcSwap::from_pointee(Vec::new().into_boxed_slice()));
428 locks.push(Mutex::new(()));
429 }
430 new_neighbors.push(levels);
431 new_write_locks.push(locks);
432 }
433
434 for old_id in 0..num_nodes {
435 let new_node_id = old_to_new[old_id] as usize;
436 #[allow(clippy::needless_range_loop)]
437 for level in 0..self.max_levels {
438 let old_neighbor_list = self.neighbors[old_id][level].load();
440 let remapped: Vec<u32> = old_neighbor_list
441 .iter()
442 .map(|&old_neighbor| old_to_new[old_neighbor as usize])
443 .collect();
444 new_neighbors[new_node_id][level].store(Arc::new(remapped.into_boxed_slice()));
446 }
447 }
448
449 self.neighbors = new_neighbors;
450 self.write_locks = new_write_locks;
451
452 old_to_new
453 }
454
455 #[must_use]
457 pub fn num_nodes(&self) -> usize {
458 self.neighbors.len()
459 }
460}
461
462impl Serialize for NeighborLists {
464 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
465 where
466 S: serde::Serializer,
467 {
468 use serde::ser::SerializeStruct;
469
470 let mut state = serializer.serialize_struct("NeighborLists", 3)?;
471
472 let neighbors_data: Vec<Vec<Vec<u32>>> = self
474 .neighbors
475 .iter()
476 .map(|node| node.iter().map(|level| level.load().to_vec()).collect())
477 .collect();
478
479 state.serialize_field("neighbors", &neighbors_data)?;
480 state.serialize_field("max_levels", &self.max_levels)?;
481 state.serialize_field("m_max", &self.m_max)?;
482 state.end()
483 }
484}
485
486impl<'de> Deserialize<'de> for NeighborLists {
487 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
488 where
489 D: serde::Deserializer<'de>,
490 {
491 #[derive(Deserialize)]
492 struct NeighborListsData {
493 neighbors: Vec<Vec<Vec<u32>>>,
494 max_levels: usize,
495 m_max: usize,
496 }
497
498 let data = NeighborListsData::deserialize(deserializer)?;
499
500 let neighbors: Vec<Vec<ArcSwap<Box<[u32]>>>> = data
502 .neighbors
503 .iter()
504 .map(|node| {
505 node.iter()
506 .map(|level| ArcSwap::from_pointee(level.clone().into_boxed_slice()))
507 .collect()
508 })
509 .collect();
510
511 let write_locks: Vec<Vec<Mutex<()>>> = data
513 .neighbors
514 .iter()
515 .map(|node| node.iter().map(|_| Mutex::new(())).collect())
516 .collect();
517
518 Ok(NeighborLists {
519 neighbors,
520 write_locks,
521 max_levels: data.max_levels,
522 m_max: data.m_max,
523 })
524 }
525}
526
527#[allow(dead_code)] pub const FASTSCAN_BATCH_SIZE: usize = 32;
535
536#[allow(dead_code)] #[derive(Clone, Debug, Default, Serialize, Deserialize)]
561pub struct NeighborCodeStorage {
562 codes: Vec<u8>,
565
566 offsets: Vec<usize>,
569
570 neighbor_counts: Vec<usize>,
573
574 code_size: usize,
576
577 block_size: usize,
579}
580
581#[allow(dead_code)] impl NeighborCodeStorage {
583 #[must_use]
585 pub fn new(code_size: usize) -> Self {
586 Self {
587 codes: Vec::new(),
588 offsets: Vec::new(),
589 neighbor_counts: Vec::new(),
590 code_size,
591 block_size: code_size * FASTSCAN_BATCH_SIZE,
592 }
593 }
594
595 #[must_use]
597 pub fn is_empty(&self) -> bool {
598 self.offsets.is_empty()
599 }
600
601 #[must_use]
603 pub fn len(&self) -> usize {
604 self.offsets.len()
605 }
606
607 #[must_use]
619 #[inline]
620 pub fn get_block(&self, vertex_id: u32) -> Option<&[u8]> {
621 let idx = vertex_id as usize;
622 if idx >= self.offsets.len() {
623 return None;
624 }
625 let start = self.offsets[idx];
626 let end = start + self.block_size;
627 if end > self.codes.len() {
628 return None;
629 }
630 Some(&self.codes[start..end])
631 }
632
633 #[must_use]
635 #[inline]
636 pub fn get_neighbor_count(&self, vertex_id: u32) -> usize {
637 let idx = vertex_id as usize;
638 if idx >= self.neighbor_counts.len() {
639 return 0;
640 }
641 self.neighbor_counts[idx]
642 }
643
644 pub fn build_from_storage(
657 vectors: &VectorStorage,
658 neighbors: &NeighborLists,
659 level: u8,
660 ) -> Option<Self> {
661 let (quantized_data, code_size) = match vectors {
663 VectorStorage::RaBitQQuantized {
664 quantized_data,
665 code_size,
666 ..
667 } => (quantized_data, *code_size),
668 _ => return None,
669 };
670
671 let num_vertices = neighbors.len();
672 if num_vertices == 0 {
673 return Some(Self::new(code_size));
674 }
675
676 let block_size = code_size * FASTSCAN_BATCH_SIZE;
677 let mut codes = Vec::with_capacity(num_vertices * block_size);
678 let mut offsets = Vec::with_capacity(num_vertices);
679 let mut neighbor_counts = Vec::with_capacity(num_vertices);
680
681 for vertex_id in 0..num_vertices {
682 let offset = codes.len();
683 offsets.push(offset);
684
685 let vertex_neighbors = neighbors.get_neighbors(vertex_id as u32, level);
687 let count = vertex_neighbors.len().min(FASTSCAN_BATCH_SIZE);
688 neighbor_counts.push(count);
689
690 let block_start = codes.len();
692 codes.resize(block_start + block_size, 0);
693
694 for sq in 0..code_size {
696 for (n, &neighbor_id) in vertex_neighbors
697 .iter()
698 .take(FASTSCAN_BATCH_SIZE)
699 .enumerate()
700 {
701 let neighbor_idx = neighbor_id as usize;
702 let code_start = neighbor_idx * code_size;
704 if code_start + sq < quantized_data.len() {
705 codes[block_start + sq * FASTSCAN_BATCH_SIZE + n] =
706 quantized_data[code_start + sq];
707 }
708 }
710 }
711 }
712
713 Some(Self {
714 codes,
715 offsets,
716 neighbor_counts,
717 code_size,
718 block_size,
719 })
720 }
721
722 pub fn update_vertex(&mut self, vertex_id: u32, new_neighbors: &[u32], quantized_data: &[u8]) {
727 let idx = vertex_id as usize;
728
729 while self.offsets.len() <= idx {
731 let offset = self.codes.len();
732 self.offsets.push(offset);
733 self.neighbor_counts.push(0);
734 self.codes.resize(self.codes.len() + self.block_size, 0);
735 }
736
737 let block_start = self.offsets[idx];
738 let count = new_neighbors.len().min(FASTSCAN_BATCH_SIZE);
739 self.neighbor_counts[idx] = count;
740
741 for i in 0..self.block_size {
743 self.codes[block_start + i] = 0;
744 }
745
746 for sq in 0..self.code_size {
748 for (n, &neighbor_id) in new_neighbors.iter().take(FASTSCAN_BATCH_SIZE).enumerate() {
749 let neighbor_idx = neighbor_id as usize;
750 let code_start = neighbor_idx * self.code_size;
751 if code_start + sq < quantized_data.len() {
752 self.codes[block_start + sq * FASTSCAN_BATCH_SIZE + n] =
753 quantized_data[code_start + sq];
754 }
755 }
756 }
757 }
758
759 #[must_use]
761 pub fn memory_usage(&self) -> usize {
762 self.codes.len()
763 + self.offsets.len() * std::mem::size_of::<usize>()
764 + self.neighbor_counts.len() * std::mem::size_of::<usize>()
765 }
766}
767
768#[derive(Clone, Debug)]
774pub enum UnifiedADC {
775 RaBitQ(ADCTable),
777}
778
779#[derive(Clone, Debug, Serialize, Deserialize)]
781pub enum VectorStorage {
782 FullPrecision {
794 vectors: Vec<f32>,
796 norms: Vec<f32>,
798 count: usize,
800 dimensions: usize,
802 },
803
804 BinaryQuantized {
809 quantized: Vec<Vec<u8>>,
811
812 original: Option<Vec<Vec<f32>>>,
817
818 thresholds: Vec<f32>,
820
821 dimensions: usize,
823 },
824
825 RaBitQQuantized {
836 #[serde(skip)]
838 quantizer: Option<RaBitQ>,
839
840 params: RaBitQParams,
842
843 quantized_data: Vec<u8>,
846
847 quantized_scales: Vec<f32>,
850
851 code_size: usize,
854
855 original: Vec<f32>,
858
859 original_count: usize,
861
862 dimensions: usize,
864 },
865
866 ScalarQuantized {
877 params: ScalarParams,
879
880 quantized: Vec<u8>,
884
885 norms: Vec<f32>,
889
890 sums: Vec<i32>,
893
894 training_buffer: Vec<f32>,
897
898 count: usize,
900
901 dimensions: usize,
903
904 trained: bool,
907 },
908}
909
910impl VectorStorage {
911 #[must_use]
913 pub fn new_full_precision(dimensions: usize) -> Self {
914 Self::FullPrecision {
915 vectors: Vec::new(),
916 norms: Vec::new(),
917 count: 0,
918 dimensions,
919 }
920 }
921
922 #[must_use]
924 pub fn new_binary_quantized(dimensions: usize, keep_original: bool) -> Self {
925 Self::BinaryQuantized {
926 quantized: Vec::new(),
927 original: if keep_original {
928 Some(Vec::new())
929 } else {
930 None
931 },
932 thresholds: vec![0.0; dimensions], dimensions,
934 }
935 }
936
937 #[must_use]
948 pub fn new_rabitq_quantized(dimensions: usize, params: RaBitQParams) -> Self {
949 let values_per_byte = params.bits_per_dim.values_per_byte();
952 let code_size = dimensions.div_ceil(values_per_byte);
953
954 Self::RaBitQQuantized {
955 quantizer: Some(RaBitQ::new(params.clone())),
956 params,
957 quantized_data: Vec::new(),
958 quantized_scales: Vec::new(),
959 code_size,
960 original: Vec::new(),
961 original_count: 0,
962 dimensions,
963 }
964 }
965
966 #[must_use]
981 pub fn new_sq8_quantized(dimensions: usize) -> Self {
982 Self::ScalarQuantized {
983 params: ScalarParams::uninitialized(dimensions),
984 quantized: Vec::new(),
985 norms: Vec::new(),
986 sums: Vec::new(),
987 training_buffer: Vec::new(),
988 count: 0,
989 dimensions,
990 trained: false,
991 }
992 }
993
994 #[must_use]
1003 pub fn is_asymmetric(&self) -> bool {
1004 matches!(
1005 self,
1006 Self::RaBitQQuantized { .. } | Self::ScalarQuantized { .. }
1007 )
1008 }
1009
1010 #[must_use]
1012 pub fn is_sq8(&self) -> bool {
1013 matches!(self, Self::ScalarQuantized { .. })
1014 }
1015
1016 #[must_use]
1018 pub fn len(&self) -> usize {
1019 match self {
1020 Self::FullPrecision { count, .. } | Self::ScalarQuantized { count, .. } => *count,
1021 Self::BinaryQuantized { quantized, .. } => quantized.len(),
1022 Self::RaBitQQuantized { original_count, .. } => *original_count,
1023 }
1024 }
1025
1026 #[must_use]
1028 pub fn is_empty(&self) -> bool {
1029 self.len() == 0
1030 }
1031
1032 #[must_use]
1034 pub fn dimensions(&self) -> usize {
1035 match self {
1036 Self::FullPrecision { dimensions, .. }
1037 | Self::BinaryQuantized { dimensions, .. }
1038 | Self::RaBitQQuantized { dimensions, .. }
1039 | Self::ScalarQuantized { dimensions, .. } => *dimensions,
1040 }
1041 }
1042
1043 pub fn insert(&mut self, vector: Vec<f32>) -> Result<u32, String> {
1045 match self {
1046 Self::FullPrecision {
1047 vectors,
1048 norms,
1049 count,
1050 dimensions,
1051 } => {
1052 if vector.len() != *dimensions {
1053 return Err(format!(
1054 "Vector dimension mismatch: expected {}, got {}",
1055 dimensions,
1056 vector.len()
1057 ));
1058 }
1059 let id = *count as u32;
1060 let norm_sq: f32 = vector.iter().map(|&x| x * x).sum();
1062 norms.push(norm_sq);
1063 vectors.extend(vector);
1064 *count += 1;
1065 Ok(id)
1066 }
1067 Self::BinaryQuantized {
1068 quantized,
1069 original,
1070 thresholds,
1071 dimensions,
1072 } => {
1073 if vector.len() != *dimensions {
1074 return Err(format!(
1075 "Vector dimension mismatch: expected {}, got {}",
1076 dimensions,
1077 vector.len()
1078 ));
1079 }
1080
1081 let quant = Self::quantize_binary(&vector, thresholds);
1083 let id = quantized.len() as u32;
1084 quantized.push(quant);
1085
1086 if let Some(orig) = original {
1088 orig.push(vector);
1089 }
1090
1091 Ok(id)
1092 }
1093 Self::RaBitQQuantized {
1094 quantizer,
1095 params,
1096 quantized_data,
1097 quantized_scales,
1098 original,
1099 original_count,
1100 dimensions,
1101 ..
1102 } => {
1103 if vector.len() != *dimensions {
1104 return Err(format!(
1105 "Vector dimension mismatch: expected {}, got {}",
1106 dimensions,
1107 vector.len()
1108 ));
1109 }
1110
1111 let q = quantizer.get_or_insert_with(|| RaBitQ::new(params.clone()));
1113
1114 let quant = q.quantize(&vector);
1116 let id = *original_count as u32;
1117 quantized_data.extend(&quant.data);
1118 quantized_scales.push(quant.scale);
1119
1120 original.extend(vector);
1122 *original_count += 1;
1123
1124 Ok(id)
1125 }
1126 Self::ScalarQuantized {
1127 params,
1128 quantized,
1129 norms,
1130 sums,
1131 training_buffer,
1132 count,
1133 dimensions,
1134 trained,
1135 } => {
1136 if vector.len() != *dimensions {
1137 return Err(format!(
1138 "Vector dimension mismatch: expected {}, got {}",
1139 dimensions,
1140 vector.len()
1141 ));
1142 }
1143
1144 let id = *count as u32;
1145 let dim = *dimensions;
1146
1147 if *trained {
1148 let quant = params.quantize(&vector);
1150 norms.push(quant.norm_sq);
1151 sums.push(quant.sum);
1152 quantized.extend(quant.data);
1153 *count += 1;
1154 } else {
1155 training_buffer.extend(vector);
1157 *count += 1;
1158
1159 if *count >= 256 {
1160 let training_refs: Vec<&[f32]> = (0..256)
1162 .map(|i| &training_buffer[i * dim..(i + 1) * dim])
1163 .collect();
1164 *params =
1165 ScalarParams::train(&training_refs).map_err(ToString::to_string)?;
1166 *trained = true;
1167
1168 quantized.reserve(*count * dim);
1170 norms.reserve(*count);
1171 sums.reserve(*count);
1172 for i in 0..*count {
1173 let vec_slice = &training_buffer[i * dim..(i + 1) * dim];
1174 let quant = params.quantize(vec_slice);
1175 norms.push(quant.norm_sq);
1176 sums.push(quant.sum);
1177 quantized.extend(quant.data);
1178 }
1179
1180 training_buffer.clear();
1182 training_buffer.shrink_to_fit();
1183 }
1184 }
1185 Ok(id)
1189 }
1190 }
1191 }
1192
1193 #[inline]
1198 #[must_use]
1199 pub fn get(&self, id: u32) -> Option<&[f32]> {
1200 match self {
1201 Self::FullPrecision {
1202 vectors,
1203 count,
1204 dimensions,
1205 ..
1206 } => {
1207 let idx = id as usize;
1208 if idx >= *count {
1209 return None;
1210 }
1211 let start = idx * *dimensions;
1212 let end = start + *dimensions;
1213 Some(&vectors[start..end])
1214 }
1215 Self::BinaryQuantized { original, .. } => original
1216 .as_ref()
1217 .and_then(|o| o.get(id as usize).map(std::vec::Vec::as_slice)),
1218 Self::RaBitQQuantized {
1219 original,
1220 original_count,
1221 dimensions,
1222 ..
1223 } => {
1224 let idx = id as usize;
1225 if idx >= *original_count {
1226 return None;
1227 }
1228 let start = idx * *dimensions;
1229 let end = start + *dimensions;
1230 Some(&original[start..end])
1231 }
1232 Self::ScalarQuantized {
1233 training_buffer,
1234 count,
1235 dimensions,
1236 trained,
1237 ..
1238 } => {
1239 if *trained {
1242 return None; }
1244 let idx = id as usize;
1245 if idx >= *count {
1246 return None;
1247 }
1248 let start = idx * *dimensions;
1249 let end = start + *dimensions;
1250 Some(&training_buffer[start..end])
1251 }
1252 }
1253 }
1254
1255 #[must_use]
1261 pub fn get_dequantized(&self, id: u32) -> Option<Vec<f32>> {
1262 match self {
1263 Self::FullPrecision {
1264 vectors,
1265 count,
1266 dimensions,
1267 ..
1268 } => {
1269 let idx = id as usize;
1270 if idx >= *count {
1271 return None;
1272 }
1273 let start = idx * *dimensions;
1274 let end = start + *dimensions;
1275 Some(vectors[start..end].to_vec())
1276 }
1277 Self::BinaryQuantized { original, .. } => {
1278 original.as_ref().and_then(|o| o.get(id as usize).cloned())
1279 }
1280 Self::RaBitQQuantized {
1281 original,
1282 original_count,
1283 dimensions,
1284 ..
1285 } => {
1286 let idx = id as usize;
1287 if idx >= *original_count {
1288 return None;
1289 }
1290 let start = idx * *dimensions;
1291 let end = start + *dimensions;
1292 Some(original[start..end].to_vec())
1293 }
1294 Self::ScalarQuantized {
1295 params,
1296 quantized,
1297 training_buffer,
1298 count,
1299 dimensions,
1300 trained,
1301 ..
1302 } => {
1303 let idx = id as usize;
1304 if idx >= *count {
1305 return None;
1306 }
1307 let dim = *dimensions;
1308 if *trained {
1309 let start = idx * dim;
1311 let end = start + dim;
1312 Some(params.dequantize(&quantized[start..end]))
1313 } else {
1314 let start = idx * dim;
1316 let end = start + dim;
1317 Some(training_buffer[start..end].to_vec())
1318 }
1319 }
1320 }
1321 }
1322
1323 #[inline(always)]
1333 #[must_use]
1334 pub fn distance_asymmetric_l2(&self, query: &[f32], id: u32) -> Option<f32> {
1335 match self {
1336 Self::RaBitQQuantized {
1337 quantizer,
1338 quantized_data,
1339 quantized_scales,
1340 code_size,
1341 original_count,
1342 ..
1343 } => {
1344 let idx = id as usize;
1345 if idx >= *original_count {
1346 return None;
1347 }
1348
1349 let q = quantizer.as_ref()?;
1351
1352 let start = idx * code_size;
1354 let end = start + code_size;
1355 let data = &quantized_data[start..end];
1356 let scale = quantized_scales[idx];
1357
1358 Some(q.distance_asymmetric_l2_flat(query, data, scale))
1359 }
1360 Self::ScalarQuantized {
1361 params,
1362 quantized,
1363 norms,
1364 sums,
1365 count,
1366 dimensions,
1367 trained,
1368 ..
1369 } => {
1370 if !*trained {
1372 return None;
1373 }
1374
1375 let idx = id as usize;
1376 if idx >= *count {
1377 return None;
1378 }
1379
1380 let start = idx * *dimensions;
1381 let end = start + *dimensions;
1382 let query_prep = params.prepare_query(query);
1383 Some(params.distance_l2_squared_raw(
1384 &query_prep,
1385 &quantized[start..end],
1386 sums[idx],
1387 norms[idx],
1388 ))
1389 }
1390 _ => None,
1392 }
1393 }
1394
1395 #[inline]
1399 #[must_use]
1400 pub fn get_norm(&self, id: u32) -> Option<f32> {
1401 match self {
1402 Self::FullPrecision { norms, count, .. } => {
1403 let idx = id as usize;
1404 if idx >= *count {
1405 return None;
1406 }
1407 Some(norms[idx])
1408 }
1409 _ => None,
1410 }
1411 }
1412
1413 #[inline]
1422 #[must_use]
1423 pub fn supports_l2_decomposition(&self) -> bool {
1424 matches!(self, Self::FullPrecision { .. })
1428 }
1429
1430 #[inline(always)]
1440 #[must_use]
1441 pub fn distance_l2_decomposed(&self, query: &[f32], query_norm: f32, id: u32) -> Option<f32> {
1442 match self {
1443 Self::FullPrecision {
1444 vectors,
1445 norms,
1446 count,
1447 dimensions,
1448 } => {
1449 let idx = id as usize;
1450 if idx >= *count {
1451 return None;
1452 }
1453 let start = idx * *dimensions;
1454 let end = start + *dimensions;
1455 let vec = &vectors[start..end];
1456 let vec_norm = norms[idx];
1457
1458 let dot = dot_product(query, vec);
1461 Some(query_norm + vec_norm - 2.0 * dot)
1462 }
1463 Self::ScalarQuantized {
1464 params,
1465 quantized,
1466 norms,
1467 sums,
1468 count,
1469 dimensions,
1470 trained,
1471 ..
1472 } => {
1473 if !*trained {
1474 return None;
1475 }
1476 let idx = id as usize;
1477 if idx >= *count {
1478 return None;
1479 }
1480 let start = idx * *dimensions;
1481 let end = start + *dimensions;
1482 let vec_norm = norms[idx];
1483 let vec_sum = sums[idx];
1484
1485 let query_prep = params.prepare_query(query);
1487 let quantized_slice = &quantized[start..end];
1488 Some(params.distance_l2_squared_raw(
1489 &query_prep,
1490 quantized_slice,
1491 vec_sum,
1492 vec_norm,
1493 ))
1494 }
1495 _ => None,
1496 }
1497 }
1498
1499 #[inline]
1504 #[must_use]
1505 pub fn get_quantized(&self, id: u32) -> Option<QuantizedVector> {
1506 match self {
1507 Self::RaBitQQuantized {
1508 quantized_data,
1509 quantized_scales,
1510 code_size,
1511 original_count,
1512 dimensions,
1513 params,
1514 ..
1515 } => {
1516 let idx = id as usize;
1517 if idx >= *original_count {
1518 return None;
1519 }
1520 let start = idx * code_size;
1521 let end = start + code_size;
1522 Some(QuantizedVector::new(
1523 quantized_data[start..end].to_vec(),
1524 quantized_scales[idx],
1525 params.bits_per_dim.to_u8(),
1526 *dimensions,
1527 ))
1528 }
1529 _ => None,
1530 }
1531 }
1532
1533 #[must_use]
1535 pub fn quantizer(&self) -> Option<&RaBitQ> {
1536 match self {
1537 Self::RaBitQQuantized { quantizer, .. } => quantizer.as_ref(),
1538 _ => None,
1539 }
1540 }
1541
1542 #[must_use]
1557 pub fn build_adc_table(&self, query: &[f32]) -> Option<UnifiedADC> {
1558 match self {
1559 Self::RaBitQQuantized { quantizer, .. } => {
1560 let q = quantizer.as_ref()?;
1561 Some(UnifiedADC::RaBitQ(q.build_adc_table(query)?))
1562 }
1563 _ => None,
1565 }
1566 }
1567
1568 #[inline]
1572 #[must_use]
1573 pub fn distance_adc(&self, adc: &UnifiedADC, id: u32) -> Option<f32> {
1574 match (self, adc) {
1575 (
1576 Self::RaBitQQuantized {
1577 quantized_data,
1578 code_size,
1579 original_count,
1580 ..
1581 },
1582 UnifiedADC::RaBitQ(table),
1583 ) => {
1584 let idx = id as usize;
1585 if idx >= *original_count {
1586 return None;
1587 }
1588 let start = idx * code_size;
1590 let end = start + code_size;
1591 Some(table.distance(&quantized_data[start..end]))
1592 }
1593 _ => None, }
1595 }
1596
1597 #[inline]
1610 pub fn prefetch(&self, id: u32) {
1611 let ptr: Option<*const u8> = match self {
1614 Self::FullPrecision {
1615 vectors,
1616 count,
1617 dimensions,
1618 ..
1619 } => {
1620 let idx = id as usize;
1621 if idx >= *count {
1622 None
1623 } else {
1624 let start = idx * *dimensions;
1625 Some(vectors[start..].as_ptr().cast())
1626 }
1627 }
1628 Self::BinaryQuantized { original, .. } => original
1629 .as_ref()
1630 .and_then(|o| o.get(id as usize).map(|v| v.as_ptr().cast())),
1631 Self::RaBitQQuantized {
1632 quantized_data,
1633 code_size,
1634 original_count,
1635 ..
1636 } => {
1637 let idx = id as usize;
1639 if idx >= *original_count {
1640 None
1641 } else {
1642 let start = idx * code_size;
1643 Some(quantized_data[start..].as_ptr())
1644 }
1645 }
1646 Self::ScalarQuantized {
1647 quantized,
1648 training_buffer,
1649 count,
1650 dimensions,
1651 trained,
1652 ..
1653 } => {
1654 let idx = id as usize;
1655 if idx >= *count {
1656 None
1657 } else if *trained {
1658 let start = idx * *dimensions;
1660 Some(quantized[start..].as_ptr())
1661 } else {
1662 let start = idx * *dimensions;
1664 Some(training_buffer[start..].as_ptr().cast())
1665 }
1666 }
1667 };
1668
1669 if let Some(ptr) = ptr {
1670 #[cfg(target_arch = "x86_64")]
1672 unsafe {
1673 std::arch::x86_64::_mm_prefetch(ptr.cast::<i8>(), std::arch::x86_64::_MM_HINT_T0);
1674 }
1675 #[cfg(target_arch = "aarch64")]
1676 unsafe {
1677 std::arch::asm!(
1678 "prfm pldl1keep, [{ptr}]",
1679 ptr = in(reg) ptr,
1680 options(nostack, preserves_flags)
1681 );
1682 }
1683 }
1684 }
1685
1686 #[inline]
1691 pub fn prefetch_quantized(&self, id: u32) {
1692 if let Self::RaBitQQuantized {
1693 quantized_data,
1694 code_size,
1695 original_count,
1696 ..
1697 } = self
1698 {
1699 let idx = id as usize;
1700 if idx < *original_count {
1701 let start = idx * code_size;
1702 let ptr = quantized_data[start..].as_ptr();
1703 #[cfg(target_arch = "x86_64")]
1704 unsafe {
1705 std::arch::x86_64::_mm_prefetch(
1706 ptr.cast::<i8>(),
1707 std::arch::x86_64::_MM_HINT_T0,
1708 );
1709 }
1710 #[cfg(target_arch = "aarch64")]
1711 unsafe {
1712 std::arch::asm!(
1713 "prfm pldl1keep, [{ptr}]",
1714 ptr = in(reg) ptr,
1715 options(nostack, preserves_flags)
1716 );
1717 }
1718 }
1719 }
1720 }
1721
1722 #[must_use]
1726 pub fn rabitq_code_size(&self) -> Option<usize> {
1727 match self {
1728 Self::RaBitQQuantized { code_size, .. } => Some(*code_size),
1729 _ => None,
1730 }
1731 }
1732
1733 #[must_use]
1738 pub fn get_rabitq_code(&self, id: u32) -> Option<&[u8]> {
1739 match self {
1740 Self::RaBitQQuantized {
1741 quantized_data,
1742 code_size,
1743 original_count,
1744 ..
1745 } => {
1746 let idx = id as usize;
1747 if idx >= *original_count {
1748 return None;
1749 }
1750 let start = idx * code_size;
1751 let end = start + code_size;
1752 if end <= quantized_data.len() {
1753 Some(&quantized_data[start..end])
1754 } else {
1755 None
1756 }
1757 }
1758 _ => None,
1759 }
1760 }
1761
1762 pub fn build_interleaved_codes(&self, neighbors: &[u32], output: &mut [u8]) -> usize {
1777 let code_size = match self.rabitq_code_size() {
1778 Some(cs) => cs,
1779 None => return 0,
1780 };
1781
1782 let batch_size = 32;
1783 let expected_len = code_size * batch_size;
1784 if output.len() < expected_len {
1785 return 0;
1786 }
1787
1788 output[..expected_len].fill(0);
1790
1791 let valid_count = neighbors.len().min(batch_size);
1792
1793 for (n, &neighbor_id) in neighbors.iter().take(valid_count).enumerate() {
1795 if let Some(code) = self.get_rabitq_code(neighbor_id) {
1796 for sq in 0..code_size {
1797 output[sq * batch_size + n] = code[sq];
1798 }
1799 }
1800 }
1801
1802 valid_count
1803 }
1804
1805 fn quantize_binary(vector: &[f32], thresholds: &[f32]) -> Vec<u8> {
1811 debug_assert_eq!(vector.len(), thresholds.len());
1812
1813 let num_bytes = vector.len().div_ceil(8); let mut quantized = vec![0u8; num_bytes];
1815
1816 for (i, (&value, &threshold)) in vector.iter().zip(thresholds.iter()).enumerate() {
1817 if value >= threshold {
1818 let byte_idx = i / 8;
1819 let bit_idx = i % 8;
1820 quantized[byte_idx] |= 1 << bit_idx;
1821 }
1822 }
1823
1824 quantized
1825 }
1826
1827 pub fn train_quantization(&mut self, sample_vectors: &[Vec<f32>]) -> Result<(), String> {
1831 match self {
1832 Self::BinaryQuantized {
1833 thresholds,
1834 dimensions,
1835 ..
1836 } => {
1837 if sample_vectors.is_empty() {
1838 return Err("Cannot train on empty sample".to_string());
1839 }
1840
1841 for vec in sample_vectors {
1843 if vec.len() != *dimensions {
1844 return Err("Sample vector dimension mismatch".to_string());
1845 }
1846 }
1847
1848 for dim in 0..*dimensions {
1850 let mut values: Vec<f32> = sample_vectors.iter().map(|v| v[dim]).collect();
1851 values.sort_unstable_by_key(|&x| OrderedFloat(x));
1852
1853 let median = if values.len().is_multiple_of(2) {
1854 let mid = values.len() / 2;
1855 f32::midpoint(values[mid - 1], values[mid])
1856 } else {
1857 values[values.len() / 2]
1858 };
1859
1860 thresholds[dim] = median;
1861 }
1862
1863 Ok(())
1864 }
1865 Self::FullPrecision { .. } => {
1866 Err("Cannot train quantization on full precision storage".to_string())
1867 }
1868 Self::RaBitQQuantized {
1869 quantizer, params, ..
1870 } => {
1871 if sample_vectors.is_empty() {
1872 return Err("Cannot train on empty sample".to_string());
1873 }
1874 let q = quantizer.get_or_insert_with(|| RaBitQ::new(params.clone()));
1876 q.train_owned(sample_vectors).map_err(ToString::to_string)?;
1877 Ok(())
1878 }
1879 Self::ScalarQuantized {
1880 params,
1881 quantized,
1882 norms,
1883 sums,
1884 training_buffer,
1885 count,
1886 dimensions,
1887 trained,
1888 } => {
1889 if sample_vectors.is_empty() {
1890 return Err("Cannot train on empty sample".to_string());
1891 }
1892
1893 let refs: Vec<&[f32]> =
1895 sample_vectors.iter().map(std::vec::Vec::as_slice).collect();
1896 *params = ScalarParams::train(&refs).map_err(ToString::to_string)?;
1897 *trained = true;
1898
1899 if *count > 0 && quantized.is_empty() && !training_buffer.is_empty() {
1901 let dim = *dimensions;
1902 quantized.reserve(*count * dim);
1903 norms.reserve(*count);
1904 sums.reserve(*count);
1905 for i in 0..*count {
1906 let vec_slice = &training_buffer[i * dim..(i + 1) * dim];
1907 let quant = params.quantize(vec_slice);
1908 norms.push(quant.norm_sq);
1909 sums.push(quant.sum);
1910 quantized.extend(quant.data);
1911 }
1912 training_buffer.clear();
1914 training_buffer.shrink_to_fit();
1915 }
1916
1917 Ok(())
1918 }
1919 }
1920 }
1921
1922 #[must_use]
1924 pub fn memory_usage(&self) -> usize {
1925 match self {
1926 Self::FullPrecision { vectors, norms, .. } => {
1927 vectors.len() * std::mem::size_of::<f32>()
1928 + norms.len() * std::mem::size_of::<f32>()
1929 }
1930 Self::BinaryQuantized {
1931 quantized,
1932 original,
1933 thresholds,
1934 dimensions,
1935 } => {
1936 let quantized_size = quantized.len() * (dimensions + 7) / 8;
1937 let original_size = original
1938 .as_ref()
1939 .map_or(0, |o| o.len() * dimensions * std::mem::size_of::<f32>());
1940 let thresholds_size = thresholds.len() * std::mem::size_of::<f32>();
1941 quantized_size + original_size + thresholds_size
1942 }
1943 Self::RaBitQQuantized {
1944 quantized_data,
1945 quantized_scales,
1946 original,
1947 ..
1948 } => {
1949 let quantized_size =
1951 quantized_data.len() + quantized_scales.len() * std::mem::size_of::<f32>();
1952 let original_size = original.len() * std::mem::size_of::<f32>();
1954 quantized_size + original_size
1955 }
1956 Self::ScalarQuantized {
1957 quantized,
1958 norms,
1959 sums,
1960 training_buffer,
1961 ..
1962 } => {
1963 let quantized_size = quantized.len();
1965 let norms_size = norms.len() * std::mem::size_of::<f32>();
1966 let sums_size = sums.len() * std::mem::size_of::<i32>();
1967 let buffer_size = training_buffer.len() * std::mem::size_of::<f32>();
1968 let params_size = 2 * std::mem::size_of::<f32>() + std::mem::size_of::<usize>();
1970 quantized_size + norms_size + sums_size + buffer_size + params_size
1971 }
1972 }
1973 }
1974
1975 pub fn reorder(&mut self, old_to_new: &[u32]) {
1980 match self {
1981 Self::FullPrecision {
1982 vectors,
1983 norms,
1984 count,
1985 dimensions,
1986 } => {
1987 let dim = *dimensions;
1988 let n = *count;
1989 let mut new_vectors = vec![0.0f32; vectors.len()];
1990 let mut new_norms = vec![0.0f32; norms.len()];
1991 for (old_id, &new_id) in old_to_new.iter().enumerate() {
1992 if old_id < n {
1993 let old_start = old_id * dim;
1994 let new_start = new_id as usize * dim;
1995 new_vectors[new_start..new_start + dim]
1996 .copy_from_slice(&vectors[old_start..old_start + dim]);
1997 new_norms[new_id as usize] = norms[old_id];
1998 }
1999 }
2000 *vectors = new_vectors;
2001 *norms = new_norms;
2002 }
2003 Self::BinaryQuantized {
2004 quantized,
2005 original,
2006 ..
2007 } => {
2008 let mut new_quantized = vec![Vec::new(); quantized.len()];
2010 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2011 new_quantized[new_id as usize] = std::mem::take(&mut quantized[old_id]);
2012 }
2013 *quantized = new_quantized;
2014
2015 if let Some(orig) = original {
2017 let mut new_original = vec![Vec::new(); orig.len()];
2018 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2019 new_original[new_id as usize] = std::mem::take(&mut orig[old_id]);
2020 }
2021 *orig = new_original;
2022 }
2023 }
2024 Self::RaBitQQuantized {
2025 quantized_data,
2026 quantized_scales,
2027 code_size,
2028 original,
2029 original_count,
2030 dimensions,
2031 ..
2032 } => {
2033 let dim = *dimensions;
2034 let n = *original_count;
2035 let cs = *code_size;
2036
2037 let mut new_data = vec![0u8; quantized_data.len()];
2039 let mut new_scales = vec![0.0f32; quantized_scales.len()];
2040 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2041 if old_id < n {
2042 let old_start = old_id * cs;
2043 let new_start = new_id as usize * cs;
2044 new_data[new_start..new_start + cs]
2045 .copy_from_slice(&quantized_data[old_start..old_start + cs]);
2046 new_scales[new_id as usize] = quantized_scales[old_id];
2047 }
2048 }
2049 *quantized_data = new_data;
2050 *quantized_scales = new_scales;
2051
2052 let mut new_original = vec![0.0f32; original.len()];
2054 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2055 if old_id < n {
2056 let old_start = old_id * dim;
2057 let new_start = new_id as usize * dim;
2058 new_original[new_start..new_start + dim]
2059 .copy_from_slice(&original[old_start..old_start + dim]);
2060 }
2061 }
2062 *original = new_original;
2063 }
2064 Self::ScalarQuantized {
2065 quantized,
2066 norms,
2067 sums,
2068 count,
2069 dimensions,
2070 ..
2071 } => {
2072 let dim = *dimensions;
2073 let n = *count;
2074
2075 let mut new_quantized = vec![0u8; quantized.len()];
2077 let mut new_norms = vec![0.0f32; norms.len()];
2078 let mut new_sums = vec![0i32; sums.len()];
2079 for (old_id, &new_id) in old_to_new.iter().enumerate() {
2080 if old_id < n {
2081 let old_start = old_id * dim;
2082 let new_start = new_id as usize * dim;
2083 new_quantized[new_start..new_start + dim]
2084 .copy_from_slice(&quantized[old_start..old_start + dim]);
2085 if old_id < norms.len() {
2086 new_norms[new_id as usize] = norms[old_id];
2087 }
2088 if old_id < sums.len() {
2089 new_sums[new_id as usize] = sums[old_id];
2090 }
2091 }
2092 }
2093 *quantized = new_quantized;
2094 *norms = new_norms;
2095 *sums = new_sums;
2096 }
2097 }
2098 }
2099}
2100
2101#[cfg(test)]
2102mod tests {
2103 use super::*;
2104
2105 #[test]
2106 fn test_neighbor_lists_basic() {
2107 let mut lists = NeighborLists::new(8);
2108
2109 lists.set_neighbors(0, 0, vec![1, 2, 3]);
2111
2112 let neighbors = lists.get_neighbors(0, 0);
2113 assert_eq!(neighbors, &[1, 2, 3]);
2114
2115 let empty = lists.get_neighbors(0, 1);
2117 assert_eq!(empty.len(), 0);
2118 }
2119
2120 #[test]
2121 fn test_neighbor_lists_bidirectional() {
2122 let mut lists = NeighborLists::new(8);
2123
2124 lists.add_bidirectional_link(0, 1, 0);
2125
2126 assert_eq!(lists.get_neighbors(0, 0), &[1]);
2127 assert_eq!(lists.get_neighbors(1, 0), &[0]);
2128 }
2129
2130 #[test]
2131 fn test_vector_storage_full_precision() {
2132 let mut storage = VectorStorage::new_full_precision(3);
2133
2134 let vec1 = vec![1.0, 2.0, 3.0];
2135 let vec2 = vec![4.0, 5.0, 6.0];
2136
2137 let id1 = storage.insert(vec1.clone()).unwrap();
2138 let id2 = storage.insert(vec2.clone()).unwrap();
2139
2140 assert_eq!(id1, 0);
2141 assert_eq!(id2, 1);
2142 assert_eq!(storage.len(), 2);
2143
2144 assert_eq!(storage.get(0), Some(vec1.as_slice()));
2145 assert_eq!(storage.get(1), Some(vec2.as_slice()));
2146 }
2147
2148 #[test]
2149 fn test_vector_storage_dimension_check() {
2150 let mut storage = VectorStorage::new_full_precision(3);
2151
2152 let wrong_dim = vec![1.0, 2.0]; assert!(storage.insert(wrong_dim).is_err());
2154 }
2155
2156 #[test]
2157 fn test_binary_quantization() {
2158 let vector = vec![0.5, -0.3, 0.8, -0.1];
2159 let thresholds = vec![0.0, 0.0, 0.0, 0.0];
2160
2161 let quantized = VectorStorage::quantize_binary(&vector, &thresholds);
2162
2163 assert_eq!(quantized[0], 5);
2166 }
2167
2168 #[test]
2169 fn test_quantization_training() {
2170 let mut storage = VectorStorage::new_binary_quantized(2, true);
2171
2172 let samples = vec![vec![1.0, 5.0], vec![2.0, 6.0], vec![3.0, 7.0]];
2173
2174 storage.train_quantization(&samples).unwrap();
2175
2176 match storage {
2178 VectorStorage::BinaryQuantized { thresholds, .. } => {
2179 assert_eq!(thresholds, vec![2.0, 6.0]);
2180 }
2181 _ => panic!("Expected BinaryQuantized storage"),
2182 }
2183 }
2184
2185 #[test]
2186 fn test_rabitq_storage_insert_and_get() {
2187 let params = RaBitQParams::bits4();
2188 let mut storage = VectorStorage::new_rabitq_quantized(4, params);
2189
2190 let vec1 = vec![1.0, 2.0, 3.0, 4.0];
2191 let vec2 = vec![5.0, 6.0, 7.0, 8.0];
2192
2193 let id1 = storage.insert(vec1.clone()).unwrap();
2194 let id2 = storage.insert(vec2.clone()).unwrap();
2195
2196 assert_eq!(id1, 0);
2197 assert_eq!(id2, 1);
2198 assert_eq!(storage.len(), 2);
2199 assert!(storage.is_asymmetric());
2200
2201 assert_eq!(storage.get(0), Some(vec1.as_slice()));
2203 assert_eq!(storage.get(1), Some(vec2.as_slice()));
2204 }
2205
2206 #[test]
2207 fn test_rabitq_asymmetric_distance() {
2208 let params = RaBitQParams::bits4();
2209 let mut storage = VectorStorage::new_rabitq_quantized(4, params);
2210
2211 let vec1 = vec![1.0, 0.0, 0.0, 0.0];
2212 let vec2 = vec![0.0, 1.0, 0.0, 0.0];
2213
2214 storage.insert(vec1.clone()).unwrap();
2215 storage.insert(vec2.clone()).unwrap();
2216
2217 let query = vec![1.0, 0.0, 0.0, 0.0];
2219 let dist0 = storage.distance_asymmetric_l2(&query, 0).unwrap();
2220 let dist1 = storage.distance_asymmetric_l2(&query, 1).unwrap();
2221
2222 assert!(dist0 < 0.5, "Distance to self should be small: {dist0}");
2224 assert!(dist1 > dist0, "Distance to orthogonal should be larger");
2226 }
2227
2228 #[test]
2229 fn test_rabitq_get_quantized() {
2230 let params = RaBitQParams::bits4();
2231 let mut storage = VectorStorage::new_rabitq_quantized(4, params);
2232
2233 let vec1 = vec![1.0, 2.0, 3.0, 4.0];
2234 storage.insert(vec1).unwrap();
2235
2236 let qv = storage.get_quantized(0);
2237 assert!(qv.is_some());
2238 let qv = qv.unwrap();
2239 assert_eq!(qv.dimensions, 4);
2240 assert_eq!(qv.bits, 4); }
2242
2243 #[test]
2244 fn test_binary_quantized_train_empty_sample_rejected() {
2245 let mut storage = VectorStorage::new_binary_quantized(4, true);
2246 let empty_samples: Vec<Vec<f32>> = vec![];
2247 let result = storage.train_quantization(&empty_samples);
2248 assert!(result.is_err());
2249 assert!(result.unwrap_err().contains("empty sample"));
2250 }
2251
2252 #[test]
2253 fn test_binary_quantized_train_dimension_mismatch_rejected() {
2254 let mut storage = VectorStorage::new_binary_quantized(4, true);
2255 let samples = vec![vec![1.0, 2.0]];
2257 let result = storage.train_quantization(&samples);
2258 assert!(result.is_err());
2259 assert!(result.unwrap_err().contains("dimension mismatch"));
2260 }
2261
2262 #[test]
2263 fn test_rabitq_train_empty_sample_rejected() {
2264 let params = RaBitQParams::bits4();
2265 let mut storage = VectorStorage::new_rabitq_quantized(4, params);
2266 let empty_samples: Vec<Vec<f32>> = vec![];
2267 let result = storage.train_quantization(&empty_samples);
2268 assert!(result.is_err());
2269 assert!(result.unwrap_err().contains("empty sample"));
2270 }
2271
2272 #[test]
2273 fn test_sq8_train_empty_sample_rejected() {
2274 let mut storage = VectorStorage::new_sq8_quantized(4);
2275 let empty_samples: Vec<Vec<f32>> = vec![];
2276 let result = storage.train_quantization(&empty_samples);
2277 assert!(result.is_err());
2278 assert!(result.unwrap_err().contains("empty sample"));
2279 }
2280
2281 #[test]
2282 fn test_neighbor_code_storage_interleaving() {
2283 let params = RaBitQParams::bits4();
2285 let mut storage = VectorStorage::new_rabitq_quantized(8, params);
2286
2287 let vec0 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2289 let vec1 = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
2290 let vec2 = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0];
2291 storage.insert(vec0).unwrap();
2292 storage.insert(vec1).unwrap();
2293 storage.insert(vec2).unwrap();
2294
2295 let mut neighbors = NeighborLists::new(8);
2297 neighbors.set_neighbors(0, 0, vec![1, 2]);
2298 neighbors.set_neighbors(1, 0, vec![0, 2]);
2299 neighbors.set_neighbors(2, 0, vec![0, 1]);
2300
2301 let ncs = NeighborCodeStorage::build_from_storage(&storage, &neighbors, 0);
2303 assert!(ncs.is_some());
2304 let ncs = ncs.unwrap();
2305
2306 assert_eq!(ncs.len(), 3); assert_eq!(ncs.get_neighbor_count(0), 2); assert_eq!(ncs.get_neighbor_count(1), 2);
2310 assert_eq!(ncs.get_neighbor_count(2), 2);
2311
2312 let block = ncs.get_block(0);
2314 assert!(block.is_some());
2315 let block = block.unwrap();
2316
2317 assert_eq!(block.len(), 4 * FASTSCAN_BATCH_SIZE);
2320
2321 assert!(ncs.memory_usage() > 0);
2330 }
2331
2332 #[test]
2333 fn test_neighbor_code_storage_update() {
2334 let params = RaBitQParams::bits4();
2335 let mut storage = VectorStorage::new_rabitq_quantized(8, params);
2336
2337 for i in 0..4 {
2339 let v: Vec<f32> = (0..8).map(|j| (i * 8 + j) as f32).collect();
2340 storage.insert(v).unwrap();
2341 }
2342
2343 let quantized_data = match &storage {
2345 VectorStorage::RaBitQQuantized {
2346 quantized_data,
2347 code_size,
2348 ..
2349 } => (quantized_data.clone(), *code_size),
2350 _ => panic!("Expected RaBitQQuantized"),
2351 };
2352
2353 let mut ncs = NeighborCodeStorage::new(quantized_data.1);
2355 assert!(ncs.is_empty());
2356
2357 ncs.update_vertex(0, &[1, 2, 3], &quantized_data.0);
2359 assert_eq!(ncs.len(), 1);
2360 assert_eq!(ncs.get_neighbor_count(0), 3);
2361
2362 let block = ncs.get_block(0);
2364 assert!(block.is_some());
2365
2366 ncs.update_vertex(2, &[0, 1], &quantized_data.0);
2368 assert_eq!(ncs.len(), 3); assert_eq!(ncs.get_neighbor_count(2), 2);
2370 }
2371}