1mod incremental;
63mod filtered;
64pub mod simd;
65pub mod pq;
66pub mod storage;
67pub mod sq;
68pub mod formats;
69mod quantized;
70
71pub use quantized::{QuantizedDiskANN, QuantizedConfig};
72
73pub use incremental::{
74 IncrementalDiskANN, IncrementalConfig, IncrementalStats,
75 IncrementalQuantizedConfig, QuantizerKind,
76 is_delta_id, delta_local_idx,
77};
78
79pub use filtered::{FilteredDiskANN, Filter};
80
81pub use simd::{SimdL2, SimdDot, SimdCosine, simd_info};
82
83pub use pq::{ProductQuantizer, PQConfig, PQStats};
84
85pub use storage::Storage;
86
87pub use sq::{VectorQuantizer, F16Quantizer, Int8Quantizer};
88
89use anndists::prelude::Distance;
90use bytemuck;
91use rand::prelude::*;
92use rayon::prelude::*;
93use serde::{Deserialize, Serialize};
94use std::cmp::{Ordering, Reverse};
95use std::collections::{BinaryHeap, HashSet};
96use std::fs::OpenOptions;
97use std::io::{Read, Seek, SeekFrom, Write};
98use std::sync::Arc;
99use thiserror::Error;
100
101pub(crate) const PAD_U32: u32 = u32::MAX;
103
104const CORE_MAGIC: u32 = 0x44414E4E;
106const CORE_FORMAT_VERSION: u32 = 1;
108
109pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
111pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
112pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
113
114#[derive(Clone, Copy, Debug)]
116pub struct DiskAnnParams {
117 pub max_degree: usize,
118 pub build_beam_width: usize,
119 pub alpha: f32,
120}
121impl Default for DiskAnnParams {
122 fn default() -> Self {
123 Self {
124 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
125 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
126 alpha: DISKANN_DEFAULT_ALPHA,
127 }
128 }
129}
130
131#[derive(Debug, Error)]
133pub enum DiskAnnError {
134 #[error("I/O error: {0}")]
136 Io(#[from] std::io::Error),
137
138 #[error("Serialization error: {0}")]
140 Bincode(#[from] bincode::Error),
141
142 #[error("Index error: {0}")]
144 IndexError(String),
145}
146
147#[derive(Serialize, Deserialize, Debug)]
149struct Metadata {
150 dim: usize,
151 num_vectors: usize,
152 max_degree: usize,
153 medoid_id: u32,
154 vectors_offset: u64,
155 adjacency_offset: u64,
156 distance_name: String,
157}
158
159#[derive(Clone, Copy)]
161pub(crate) struct Candidate {
162 pub dist: f32,
163 pub id: u32,
164}
165impl PartialEq for Candidate {
166 fn eq(&self, other: &Self) -> bool {
167 self.dist == other.dist && self.id == other.id
168 }
169}
170impl Eq for Candidate {}
171impl PartialOrd for Candidate {
172 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
173 self.dist.partial_cmp(&other.dist)
175 }
176}
177impl Ord for Candidate {
178 fn cmp(&self, other: &Self) -> Ordering {
179 self.partial_cmp(other).unwrap_or(Ordering::Equal)
180 }
181}
182
183#[allow(dead_code)]
185pub(crate) trait GraphIndex: Send + Sync {
186 fn num_vectors(&self) -> usize;
187 fn dim(&self) -> usize;
188 fn entry_point(&self) -> u32;
189 fn distance_to(&self, query: &[f32], id: u32) -> f32;
190 fn get_neighbors(&self, id: u32) -> Vec<u32>; fn get_vector(&self, id: u32) -> Vec<f32>;
192 fn is_live(&self, _id: u32) -> bool {
193 true
194 }
195}
196
197impl<D> GraphIndex for DiskANN<D>
198where
199 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
200{
201 fn num_vectors(&self) -> usize {
202 self.num_vectors
203 }
204 fn dim(&self) -> usize {
205 self.dim
206 }
207 fn entry_point(&self) -> u32 {
208 self.medoid_id
209 }
210 fn distance_to(&self, query: &[f32], id: u32) -> f32 {
211 DiskANN::distance_to(self, query, id as usize)
212 }
213 fn get_neighbors(&self, id: u32) -> Vec<u32> {
214 DiskANN::get_neighbors(self, id)
215 .iter()
216 .copied()
217 .filter(|&nb| nb != PAD_U32)
218 .collect()
219 }
220 fn get_vector(&self, id: u32) -> Vec<f32> {
221 DiskANN::get_vector(self, id as usize)
222 }
223}
224
225pub(crate) struct BeamSearchConfig {
227 pub expanded_beam: Option<usize>,
231 pub max_iterations: Option<usize>,
233 pub early_term_factor: Option<f32>,
236}
237
238impl Default for BeamSearchConfig {
239 fn default() -> Self {
240 Self {
241 expanded_beam: None,
242 max_iterations: None,
243 early_term_factor: None,
244 }
245 }
246}
247
248pub(crate) fn beam_search(
258 start_ids: &[u32],
259 beam_width: usize,
260 k: usize,
261 distance_fn: impl Fn(u32) -> f32,
262 neighbors_fn: impl Fn(u32) -> Vec<u32>,
263 filter_fn: impl Fn(u32) -> bool,
264 config: BeamSearchConfig,
265) -> Vec<(u32, f32)> {
266 let working_beam = config.expanded_beam.unwrap_or(beam_width);
267 let is_filtered = config.expanded_beam.is_some();
268
269 let mut visited = HashSet::new();
270 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
271 let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let mut results: Vec<(u32, f32)> = if is_filtered {
275 Vec::with_capacity(k)
276 } else {
277 Vec::new() };
279
280 for &sid in start_ids {
282 if !visited.insert(sid) {
283 continue;
284 }
285 let d = distance_fn(sid);
286 let cand = Candidate { dist: d, id: sid };
287 frontier.push(Reverse(cand));
288 w.push(cand);
289 if is_filtered && filter_fn(sid) {
290 results.push((sid, d));
291 }
292 }
293
294 let mut iterations = 0;
295 let max_iterations = config.max_iterations.unwrap_or(usize::MAX);
296 let early_term_factor = config.early_term_factor.unwrap_or(f32::MAX);
297
298 while let Some(Reverse(best)) = frontier.peek().copied() {
299 iterations += 1;
300 if iterations > max_iterations {
301 break;
302 }
303
304 if is_filtered && results.len() >= k {
306 if let Some((_, worst_dist)) = results.last() {
307 if best.dist > *worst_dist * early_term_factor {
308 break;
309 }
310 }
311 }
312
313 if w.len() >= working_beam {
315 if let Some(worst) = w.peek() {
316 if best.dist >= worst.dist {
317 break;
318 }
319 }
320 }
321
322 let Reverse(current) = frontier.pop().unwrap();
323
324 for nb in neighbors_fn(current.id) {
325 if !visited.insert(nb) {
326 continue;
327 }
328
329 let d = distance_fn(nb);
330 let cand = Candidate { dist: d, id: nb };
331
332 if w.len() < working_beam {
334 w.push(cand);
335 frontier.push(Reverse(cand));
336 } else if d < w.peek().unwrap().dist {
337 w.pop();
338 w.push(cand);
339 frontier.push(Reverse(cand));
340 }
341
342 if is_filtered && filter_fn(nb) {
344 let pos = results
345 .iter()
346 .position(|(_, dist)| d < *dist)
347 .unwrap_or(results.len());
348 if pos < k {
349 results.insert(pos, (nb, d));
350 if results.len() > k {
351 results.pop();
352 }
353 }
354 }
355 }
356 }
357
358 if is_filtered {
359 results
360 } else {
361 let mut candidates: Vec<_> = w.into_vec();
363 candidates.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
364 candidates.truncate(k);
365 candidates.into_iter().map(|c| (c.id, c.dist)).collect()
366 }
367}
368
369pub struct DiskANN<D>
371where
372 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
373{
374 pub dim: usize,
376 pub num_vectors: usize,
378 pub max_degree: usize,
380 pub distance_name: String,
382
383 pub(crate) medoid_id: u32,
385 pub(crate) vectors_offset: u64,
387 pub(crate) adjacency_offset: u64,
388
389 pub(crate) storage: Storage,
391
392 pub(crate) dist: D,
394}
395
396impl<D> DiskANN<D>
399where
400 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
401{
402 pub fn build_index_default(
404 vectors: &[Vec<f32>],
405 dist: D,
406 file_path: &str,
407 ) -> Result<Self, DiskAnnError> {
408 Self::build_index(
409 vectors,
410 DISKANN_DEFAULT_MAX_DEGREE,
411 DISKANN_DEFAULT_BUILD_BEAM,
412 DISKANN_DEFAULT_ALPHA,
413 dist,
414 file_path,
415 )
416 }
417
418 pub fn build_index_with_params(
420 vectors: &[Vec<f32>],
421 dist: D,
422 file_path: &str,
423 p: DiskAnnParams,
424 ) -> Result<Self, DiskAnnError> {
425 Self::build_index(
426 vectors,
427 p.max_degree,
428 p.build_beam_width,
429 p.alpha,
430 dist,
431 file_path,
432 )
433 }
434}
435
436impl<D> DiskANN<D>
438where
439 D: Distance<f32> + Default + Send + Sync + Copy + Clone + 'static,
440{
441 pub fn build_index_default_metric(
443 vectors: &[Vec<f32>],
444 file_path: &str,
445 ) -> Result<Self, DiskAnnError> {
446 Self::build_index_default(vectors, D::default(), file_path)
447 }
448
449 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
451 Self::open_index_with(path, D::default())
452 }
453}
454
455impl<D> DiskANN<D>
456where
457 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
458{
459 pub fn build_index(
469 vectors: &[Vec<f32>],
470 max_degree: usize,
471 build_beam_width: usize,
472 alpha: f32,
473 dist: D,
474 file_path: &str,
475 ) -> Result<Self, DiskAnnError> {
476 if vectors.is_empty() {
477 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
478 }
479
480 let num_vectors = vectors.len();
481 let dim = vectors[0].len();
482 for (i, v) in vectors.iter().enumerate() {
483 if v.len() != dim {
484 return Err(DiskAnnError::IndexError(format!(
485 "Vector {} has dimension {} but expected {}",
486 i,
487 v.len(),
488 dim
489 )));
490 }
491 }
492
493 let mut file = OpenOptions::new()
494 .create(true)
495 .write(true)
496 .read(true)
497 .truncate(true)
498 .open(file_path)?;
499
500 let vectors_offset = 1024 * 1024;
502 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
503
504 file.seek(SeekFrom::Start(vectors_offset))?;
506 for vector in vectors {
507 let bytes = bytemuck::cast_slice(vector);
508 file.write_all(bytes)?;
509 }
510
511 let medoid_id = calculate_medoid(vectors, dist);
513
514 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
516 let graph = build_vamana_graph(
517 vectors,
518 max_degree,
519 build_beam_width,
520 alpha,
521 dist,
522 medoid_id as u32,
523 );
524
525 file.seek(SeekFrom::Start(adjacency_offset))?;
527 for neighbors in &graph {
528 let mut padded = neighbors.clone();
529 padded.resize(max_degree, PAD_U32);
530 let bytes = bytemuck::cast_slice(&padded);
531 file.write_all(bytes)?;
532 }
533
534 let metadata = Metadata {
536 dim,
537 num_vectors,
538 max_degree,
539 medoid_id: medoid_id as u32,
540 vectors_offset: vectors_offset as u64,
541 adjacency_offset,
542 distance_name: std::any::type_name::<D>().to_string(),
543 };
544
545 let md_bytes = bincode::serialize(&metadata)?;
546 file.seek(SeekFrom::Start(0))?;
547 file.write_all(&CORE_MAGIC.to_le_bytes())?;
548 file.write_all(&CORE_FORMAT_VERSION.to_le_bytes())?;
549 let md_len = md_bytes.len() as u64;
550 file.write_all(&md_len.to_le_bytes())?;
551 file.write_all(&md_bytes)?;
552 file.sync_all()?;
553
554 let mmap = unsafe { memmap2::Mmap::map(&file)? };
556
557 Ok(Self {
558 dim,
559 num_vectors,
560 max_degree,
561 distance_name: metadata.distance_name,
562 medoid_id: metadata.medoid_id,
563 vectors_offset: metadata.vectors_offset,
564 adjacency_offset: metadata.adjacency_offset,
565 storage: Storage::Mmap(mmap),
566 dist,
567 })
568 }
569
570 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
572 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
573
574 let mut buf4 = [0u8; 4];
576 file.seek(SeekFrom::Start(0))?;
577 file.read_exact(&mut buf4)?;
578 let first_u32 = u32::from_le_bytes(buf4);
579
580 let md_offset = if first_u32 == CORE_MAGIC {
581 let mut ver_buf = [0u8; 4];
583 file.read_exact(&mut ver_buf)?;
584 let version = u32::from_le_bytes(ver_buf);
585 if version != CORE_FORMAT_VERSION {
586 return Err(DiskAnnError::IndexError(format!(
587 "Unsupported core format version: {}", version
588 )));
589 }
590 8u64 } else {
592 file.seek(SeekFrom::Start(0))?;
594 0u64
595 };
596
597 let mut buf8 = [0u8; 8];
599 file.seek(SeekFrom::Start(md_offset))?;
600 file.read_exact(&mut buf8)?;
601 let md_len = u64::from_le_bytes(buf8);
602
603 let file_size = file.seek(SeekFrom::End(0))?;
605 if md_len > 1024 * 1024 || md_offset + 8 + md_len > file_size {
606 return Err(DiskAnnError::IndexError(format!(
607 "Invalid metadata length {} (file size {})",
608 md_len, file_size
609 )));
610 }
611 file.seek(SeekFrom::Start(md_offset + 8))?;
612
613 let mut md_bytes = vec![0u8; md_len as usize];
615 file.read_exact(&mut md_bytes)?;
616 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
617
618 let mmap = unsafe { memmap2::Mmap::map(&file)? };
619
620 let expected = std::any::type_name::<D>();
622 if metadata.distance_name != expected {
623 eprintln!(
624 "Warning: index recorded distance `{}` but you opened with `{}`",
625 metadata.distance_name, expected
626 );
627 }
628
629 Ok(Self {
630 dim: metadata.dim,
631 num_vectors: metadata.num_vectors,
632 max_degree: metadata.max_degree,
633 distance_name: metadata.distance_name,
634 medoid_id: metadata.medoid_id,
635 vectors_offset: metadata.vectors_offset,
636 adjacency_offset: metadata.adjacency_offset,
637 storage: Storage::Mmap(mmap),
638 dist,
639 })
640 }
641
642 pub fn from_bytes(bytes: Vec<u8>, dist: D) -> Result<Self, DiskAnnError> {
644 let metadata = Self::parse_metadata(&bytes)?;
645
646 let expected = std::any::type_name::<D>();
647 if metadata.distance_name != expected {
648 eprintln!(
649 "Warning: index recorded distance `{}` but you opened with `{}`",
650 metadata.distance_name, expected
651 );
652 }
653
654 Ok(Self {
655 dim: metadata.dim,
656 num_vectors: metadata.num_vectors,
657 max_degree: metadata.max_degree,
658 distance_name: metadata.distance_name,
659 medoid_id: metadata.medoid_id,
660 vectors_offset: metadata.vectors_offset,
661 adjacency_offset: metadata.adjacency_offset,
662 storage: Storage::Owned(bytes),
663 dist,
664 })
665 }
666
667 pub fn from_shared_bytes(bytes: Arc<[u8]>, dist: D) -> Result<Self, DiskAnnError> {
669 let metadata = Self::parse_metadata(&bytes)?;
670
671 let expected = std::any::type_name::<D>();
672 if metadata.distance_name != expected {
673 eprintln!(
674 "Warning: index recorded distance `{}` but you opened with `{}`",
675 metadata.distance_name, expected
676 );
677 }
678
679 Ok(Self {
680 dim: metadata.dim,
681 num_vectors: metadata.num_vectors,
682 max_degree: metadata.max_degree,
683 distance_name: metadata.distance_name,
684 medoid_id: metadata.medoid_id,
685 vectors_offset: metadata.vectors_offset,
686 adjacency_offset: metadata.adjacency_offset,
687 storage: Storage::Shared(bytes),
688 dist,
689 })
690 }
691
692 pub fn to_bytes(&self) -> Vec<u8> {
694 self.storage.to_vec()
695 }
696
697 fn parse_metadata(bytes: &[u8]) -> Result<Metadata, DiskAnnError> {
700 if bytes.len() < 8 {
701 return Err(DiskAnnError::IndexError("Buffer too small for metadata length".into()));
702 }
703
704 let first_u32 = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
706 let md_offset = if first_u32 == CORE_MAGIC {
707 if bytes.len() < 16 {
709 return Err(DiskAnnError::IndexError("Buffer too small for header".into()));
710 }
711 let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
712 if version != CORE_FORMAT_VERSION {
713 return Err(DiskAnnError::IndexError(format!(
714 "Unsupported core format version: {}", version
715 )));
716 }
717 8
718 } else {
719 0
720 };
721
722 if bytes.len() < md_offset + 8 {
723 return Err(DiskAnnError::IndexError("Buffer too small for metadata length".into()));
724 }
725 let md_len = u64::from_le_bytes(bytes[md_offset..md_offset + 8].try_into().unwrap()) as usize;
726 if bytes.len() < md_offset + 8 + md_len {
727 return Err(DiskAnnError::IndexError("Buffer too small for metadata".into()));
728 }
729 let metadata: Metadata = bincode::deserialize(&bytes[md_offset + 8..md_offset + 8 + md_len])?;
730 Ok(metadata)
731 }
732
733 pub fn search_with_dists(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
736 assert_eq!(
737 query.len(),
738 self.dim,
739 "Query dim {} != index dim {}",
740 query.len(),
741 self.dim
742 );
743
744 beam_search(
745 &[self.medoid_id],
746 beam_width,
747 k,
748 |id| self.distance_to(query, id as usize),
749 |id| self.get_neighbors(id).iter().copied().filter(|&nb| nb != PAD_U32).collect(),
750 |_| true,
751 BeamSearchConfig::default(),
752 )
753 }
754 pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
756 self.search_with_dists(query, k, beam_width)
757 .into_iter()
758 .map(|(id, _dist)| id)
759 .collect()
760 }
761
762 pub(crate) fn get_neighbors(&self, node_id: u32) -> &[u32] {
764 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
765 let start = offset as usize;
766 let end = start + (self.max_degree * 4);
767 let bytes = &self.storage[start..end];
768 bytemuck::cast_slice(bytes)
769 }
770
771 pub(crate) fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
773 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
774 let start = offset as usize;
775 let end = start + (self.dim * 4);
776 let bytes = &self.storage[start..end];
777 let vector: &[f32] = bytemuck::cast_slice(bytes);
778 self.dist.eval(query, vector)
779 }
780
781 pub fn get_vector(&self, idx: usize) -> Vec<f32> {
783 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
784 let start = offset as usize;
785 let end = start + (self.dim * 4);
786 let bytes = &self.storage[start..end];
787 let vector: &[f32] = bytemuck::cast_slice(bytes);
788 vector.to_vec()
789 }
790}
791
792fn calculate_medoid<D: Distance<f32> + Copy + Sync>(vectors: &[Vec<f32>], dist: D) -> usize {
795 let dim = vectors[0].len();
796 let mut centroid = vec![0.0f32; dim];
797
798 for v in vectors {
799 for (i, &val) in v.iter().enumerate() {
800 centroid[i] += val;
801 }
802 }
803 for val in &mut centroid {
804 *val /= vectors.len() as f32;
805 }
806
807 let (best_idx, _best_dist) = vectors
808 .par_iter()
809 .enumerate()
810 .map(|(idx, v)| (idx, dist.eval(¢roid, v)))
811 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
812
813 best_idx
814}
815
816fn build_vamana_graph<D: Distance<f32> + Copy + Sync>(
821 vectors: &[Vec<f32>],
822 max_degree: usize,
823 build_beam_width: usize,
824 alpha: f32,
825 dist: D,
826 medoid_id: u32,
827) -> Vec<Vec<u32>> {
828 let n = vectors.len();
829 let mut graph = vec![Vec::<u32>::new(); n];
830
831 {
833 let mut rng = thread_rng();
834 for i in 0..n {
835 let mut s = HashSet::new();
836 let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
837 while s.len() < target {
838 let nb = rng.gen_range(0..n);
839 if nb != i {
840 s.insert(nb as u32);
841 }
842 }
843 graph[i] = s.into_iter().collect();
844 }
845 }
846
847 const PASSES: usize = 2;
849 const EXTRA_SEEDS: usize = 2;
850
851 let mut rng = thread_rng();
852 for _pass in 0..PASSES {
853 let mut order: Vec<usize> = (0..n).collect();
855 order.shuffle(&mut rng);
856
857 let snapshot = &graph;
859
860 let new_graph: Vec<Vec<u32>> = order
862 .par_iter()
863 .map(|&u| {
864 let mut candidates: Vec<(u32, f32)> =
865 Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
866
867 for &nb in &snapshot[u] {
869 let d = dist.eval(&vectors[u], &vectors[nb as usize]);
870 candidates.push((nb, d));
871 }
872
873 let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
875 seeds.push(medoid_id as usize);
876 let mut trng = thread_rng();
877 for _ in 0..EXTRA_SEEDS {
878 seeds.push(trng.gen_range(0..n));
879 }
880
881 for start in seeds {
883 let mut part = greedy_search(
884 &vectors[u],
885 vectors,
886 snapshot,
887 start,
888 build_beam_width,
889 dist,
890 );
891 candidates.append(&mut part);
892 }
893
894 candidates.sort_by(|a, b| a.0.cmp(&b.0));
896 candidates.dedup_by(|a, b| {
897 if a.0 == b.0 {
898 if a.1 < b.1 {
899 *b = *a;
900 }
901 true
902 } else {
903 false
904 }
905 });
906
907 prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
909 })
910 .collect();
911
912 let mut pos_of = vec![0usize; n];
915 for (pos, &u) in order.iter().enumerate() {
916 pos_of[u] = pos;
917 }
918
919 let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
921
922 graph = (0..n)
924 .into_par_iter()
925 .map(|u| {
926 let ng = &new_graph[pos_of[u]]; let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
931 pool_ids.extend_from_slice(ng);
932 pool_ids.extend_from_slice(inc);
933 pool_ids.sort_unstable();
934 pool_ids.dedup();
935
936 let pool: Vec<(u32, f32)> = pool_ids
938 .into_iter()
939 .filter(|&id| id as usize != u)
940 .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
941 .collect();
942
943 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
944 })
945 .collect();
946 }
947
948 graph
950 .into_par_iter()
951 .enumerate()
952 .map(|(u, neigh)| {
953 if neigh.len() <= max_degree {
954 return neigh;
955 }
956 let pool: Vec<(u32, f32)> = neigh
957 .iter()
958 .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
959 .collect();
960 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
961 })
962 .collect()
963}
964
965fn greedy_search<D: Distance<f32> + Copy>(
968 query: &[f32],
969 vectors: &[Vec<f32>],
970 graph: &[Vec<u32>],
971 start_id: usize,
972 beam_width: usize,
973 dist: D,
974) -> Vec<(u32, f32)> {
975 let mut visited = HashSet::new();
976 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = dist.eval(query, &vectors[start_id]);
980 let start = Candidate {
981 dist: start_dist,
982 id: start_id as u32,
983 };
984 frontier.push(Reverse(start));
985 w.push(start);
986 visited.insert(start_id as u32);
987
988 while let Some(Reverse(best)) = frontier.peek().copied() {
989 if w.len() >= beam_width {
990 if let Some(worst) = w.peek() {
991 if best.dist >= worst.dist {
992 break;
993 }
994 }
995 }
996 let Reverse(cur) = frontier.pop().unwrap();
997
998 for &nb in &graph[cur.id as usize] {
999 if !visited.insert(nb) {
1000 continue;
1001 }
1002 let d = dist.eval(query, &vectors[nb as usize]);
1003 let cand = Candidate { dist: d, id: nb };
1004
1005 if w.len() < beam_width {
1006 w.push(cand);
1007 frontier.push(Reverse(cand));
1008 } else if d < w.peek().unwrap().dist {
1009 w.pop();
1010 w.push(cand);
1011 frontier.push(Reverse(cand));
1012 }
1013 }
1014 }
1015
1016 let mut v = w.into_vec();
1017 v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
1018 v.into_iter().map(|c| (c.id, c.dist)).collect()
1019}
1020
1021fn prune_neighbors<D: Distance<f32> + Copy>(
1023 node_id: usize,
1024 candidates: &[(u32, f32)],
1025 vectors: &[Vec<f32>],
1026 max_degree: usize,
1027 alpha: f32,
1028 dist: D,
1029) -> Vec<u32> {
1030 if candidates.is_empty() {
1031 return Vec::new();
1032 }
1033
1034 let mut sorted = candidates.to_vec();
1035 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1036
1037 let mut pruned = Vec::<u32>::new();
1038
1039 for &(cand_id, cand_dist) in &sorted {
1040 if cand_id as usize == node_id {
1041 continue;
1042 }
1043 let mut ok = true;
1044 for &sel in &pruned {
1045 let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
1046 if d < alpha * cand_dist {
1047 ok = false;
1048 break;
1049 }
1050 }
1051 if ok {
1052 pruned.push(cand_id);
1053 if pruned.len() >= max_degree {
1054 break;
1055 }
1056 }
1057 }
1058
1059 for &(cand_id, _) in &sorted {
1061 if pruned.len() >= max_degree {
1062 break;
1063 }
1064 if cand_id as usize == node_id {
1065 continue;
1066 }
1067 if !pruned.contains(&cand_id) {
1068 pruned.push(cand_id);
1069 }
1070 }
1071
1072 pruned
1073}
1074
1075fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
1076 let mut indeg = vec![0usize; n];
1078 for (pos, _u) in order.iter().enumerate() {
1079 for &v in &new_graph[pos] {
1080 indeg[v as usize] += 1;
1081 }
1082 }
1083 let mut off = vec![0usize; n + 1];
1085 for i in 0..n {
1086 off[i + 1] = off[i] + indeg[i];
1087 }
1088 let mut cur = off.clone();
1090 let mut incoming_flat = vec![0u32; off[n]];
1091 for (pos, &u) in order.iter().enumerate() {
1092 for &v in &new_graph[pos] {
1093 let idx = cur[v as usize];
1094 incoming_flat[idx] = u as u32;
1095 cur[v as usize] += 1;
1096 }
1097 }
1098 (incoming_flat, off)
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103 use super::*;
1104 use anndists::dist::{DistCosine, DistL2};
1105 use rand::Rng;
1106 use std::fs;
1107
1108 fn euclid(a: &[f32], b: &[f32]) -> f32 {
1109 a.iter()
1110 .zip(b)
1111 .map(|(x, y)| (x - y) * (x - y))
1112 .sum::<f32>()
1113 .sqrt()
1114 }
1115
1116 #[test]
1117 fn test_small_index_l2() {
1118 let path = "test_small_l2.db";
1119 let _ = fs::remove_file(path);
1120
1121 let vectors = vec![
1122 vec![0.0, 0.0],
1123 vec![1.0, 0.0],
1124 vec![0.0, 1.0],
1125 vec![1.0, 1.0],
1126 vec![0.5, 0.5],
1127 ];
1128
1129 let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1130
1131 let q = vec![0.1, 0.1];
1132 let nns = index.search(&q, 3, 8);
1133 assert_eq!(nns.len(), 3);
1134
1135 let v = index.get_vector(nns[0] as usize);
1137 assert!(euclid(&q, &v) < 1.0);
1138
1139 let _ = fs::remove_file(path);
1140 }
1141
1142 #[test]
1143 fn test_cosine() {
1144 let path = "test_cosine.db";
1145 let _ = fs::remove_file(path);
1146
1147 let vectors = vec![
1148 vec![1.0, 0.0, 0.0],
1149 vec![0.0, 1.0, 0.0],
1150 vec![0.0, 0.0, 1.0],
1151 vec![1.0, 1.0, 0.0],
1152 vec![1.0, 0.0, 1.0],
1153 ];
1154
1155 let index =
1156 DiskANN::<DistCosine>::build_index_default(&vectors, DistCosine {}, path).unwrap();
1157
1158 let q = vec![2.0, 0.0, 0.0]; let nns = index.search(&q, 2, 8);
1160 assert_eq!(nns.len(), 2);
1161
1162 let v = index.get_vector(nns[0] as usize);
1164 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1165 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1166 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1167 let cos = dot / (n1 * n2);
1168 assert!(cos > 0.7);
1169
1170 let _ = fs::remove_file(path);
1171 }
1172
1173 #[test]
1174 fn test_persistence_and_open() {
1175 let path = "test_persist.db";
1176 let _ = fs::remove_file(path);
1177
1178 let vectors = vec![
1179 vec![0.0, 0.0],
1180 vec![1.0, 0.0],
1181 vec![0.0, 1.0],
1182 vec![1.0, 1.0],
1183 ];
1184
1185 {
1186 let _idx = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1187 }
1188
1189 let idx2 = DiskANN::<DistL2>::open_index_default_metric(path).unwrap();
1190 assert_eq!(idx2.num_vectors, 4);
1191 assert_eq!(idx2.dim, 2);
1192
1193 let q = vec![0.9, 0.9];
1194 let res = idx2.search(&q, 2, 8);
1195 assert_eq!(res[0], 3);
1197
1198 let _ = fs::remove_file(path);
1199 }
1200
1201 #[test]
1202 fn test_grid_connectivity() {
1203 let path = "test_grid.db";
1204 let _ = fs::remove_file(path);
1205
1206 let mut vectors = Vec::new();
1208 for i in 0..5 {
1209 for j in 0..5 {
1210 vectors.push(vec![i as f32, j as f32]);
1211 }
1212 }
1213
1214 let index = DiskANN::<DistL2>::build_index_with_params(
1215 &vectors,
1216 DistL2 {},
1217 path,
1218 DiskAnnParams {
1219 max_degree: 4,
1220 build_beam_width: 64,
1221 alpha: 1.5,
1222 },
1223 )
1224 .unwrap();
1225
1226 for target in 0..vectors.len() {
1227 let q = &vectors[target];
1228 let nns = index.search(q, 10, 32);
1229 if !nns.contains(&(target as u32)) {
1230 let v = index.get_vector(nns[0] as usize);
1231 assert!(euclid(q, &v) < 2.0);
1232 }
1233 for &nb in nns.iter().take(5) {
1234 let v = index.get_vector(nb as usize);
1235 assert!(euclid(q, &v) < 5.0);
1236 }
1237 }
1238
1239 let _ = fs::remove_file(path);
1240 }
1241
1242 #[test]
1243 fn test_medium_random() {
1244 let path = "test_medium.db";
1245 let _ = fs::remove_file(path);
1246
1247 let n = 200usize;
1248 let d = 32usize;
1249 let mut rng = rand::thread_rng();
1250 let vectors: Vec<Vec<f32>> = (0..n)
1251 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1252 .collect();
1253
1254 let index = DiskANN::<DistL2>::build_index_with_params(
1255 &vectors,
1256 DistL2 {},
1257 path,
1258 DiskAnnParams {
1259 max_degree: 32,
1260 build_beam_width: 128,
1261 alpha: 1.2,
1262 },
1263 )
1264 .unwrap();
1265
1266 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1267 let res = index.search(&q, 10, 64);
1268 assert_eq!(res.len(), 10);
1269
1270 let dists: Vec<f32> = res
1272 .iter()
1273 .map(|&id| {
1274 let v = index.get_vector(id as usize);
1275 euclid(&q, &v)
1276 })
1277 .collect();
1278 let mut sorted = dists.clone();
1279 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1280 assert_eq!(dists, sorted);
1281
1282 let _ = fs::remove_file(path);
1283 }
1284
1285 #[test]
1286 fn test_to_bytes_from_bytes_round_trip() {
1287 let path = "test_bytes_rt.db";
1288 let _ = fs::remove_file(path);
1289
1290 let vectors = vec![
1291 vec![0.0, 0.0],
1292 vec![1.0, 0.0],
1293 vec![0.0, 1.0],
1294 vec![1.0, 1.0],
1295 vec![0.5, 0.5],
1296 ];
1297
1298 let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1299 let bytes = index.to_bytes();
1300
1301 let index2 = DiskANN::<DistL2>::from_bytes(bytes, DistL2 {}).unwrap();
1302 assert_eq!(index2.num_vectors, 5);
1303 assert_eq!(index2.dim, 2);
1304
1305 let q = vec![0.9, 0.9];
1306 let res1 = index.search(&q, 3, 8);
1307 let res2 = index2.search(&q, 3, 8);
1308 assert_eq!(res1, res2);
1309
1310 let _ = fs::remove_file(path);
1311 }
1312
1313 #[test]
1314 fn test_from_shared_bytes() {
1315 let path = "test_shared_bytes.db";
1316 let _ = fs::remove_file(path);
1317
1318 let vectors = vec![
1319 vec![0.0, 0.0],
1320 vec![1.0, 0.0],
1321 vec![0.0, 1.0],
1322 vec![1.0, 1.0],
1323 ];
1324
1325 let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1326 let bytes = index.to_bytes();
1327 let shared: std::sync::Arc<[u8]> = bytes.into();
1328
1329 let index2 = DiskANN::<DistL2>::from_shared_bytes(shared, DistL2 {}).unwrap();
1330 assert_eq!(index2.num_vectors, 4);
1331 assert_eq!(index2.dim, 2);
1332
1333 let q = vec![0.9, 0.9];
1334 let res = index2.search(&q, 2, 8);
1335 assert_eq!(res[0], 3); let _ = fs::remove_file(path);
1338 }
1339
1340 #[test]
1345 fn test_candidate_ordering() {
1346 use std::cmp::Reverse;
1347 use std::collections::BinaryHeap;
1348
1349 let a = Candidate { dist: 1.0, id: 0 };
1350 let b = Candidate { dist: 2.0, id: 1 };
1351 let c = Candidate { dist: 0.5, id: 2 };
1352
1353 assert!(a < b);
1355 assert!(c < a);
1356
1357 let mut min_heap: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
1359 min_heap.push(Reverse(a));
1360 min_heap.push(Reverse(b));
1361 min_heap.push(Reverse(c));
1362 assert_eq!(min_heap.pop().unwrap().0.id, 2); assert_eq!(min_heap.pop().unwrap().0.id, 0); assert_eq!(min_heap.pop().unwrap().0.id, 1); let mut max_heap: BinaryHeap<Candidate> = BinaryHeap::new();
1368 max_heap.push(a);
1369 max_heap.push(b);
1370 max_heap.push(c);
1371 assert_eq!(max_heap.peek().unwrap().id, 1); }
1373
1374 #[test]
1375 fn test_beam_search_small_graph() {
1376 let positions: Vec<[f32; 2]> = vec![
1385 [0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [0.0, 2.0], [2.0, 1.0], ];
1391
1392 let neighbors: Vec<Vec<u32>> = vec![
1393 vec![1, 3], vec![0, 2], vec![1, 4], vec![0, 4], vec![2, 3], ];
1399
1400 let query = [2.1f32, 0.9];
1402
1403 let results = beam_search(
1404 &[0], 5,
1406 3,
1407 |id| {
1408 let p = &positions[id as usize];
1409 ((query[0] - p[0]).powi(2) + (query[1] - p[1]).powi(2)).sqrt()
1410 },
1411 |id| neighbors[id as usize].clone(),
1412 |_| true,
1413 BeamSearchConfig::default(),
1414 );
1415
1416 assert_eq!(results.len(), 3);
1417 assert_eq!(results[0].0, 4);
1419 assert_eq!(results[1].0, 2);
1421 assert!(results[0].1 <= results[1].1);
1423 assert!(results[1].1 <= results[2].1);
1424 }
1425
1426 #[test]
1427 fn test_beam_search_with_filter() {
1428 let positions: Vec<[f32; 2]> = vec![
1430 [0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [0.0, 2.0], [2.0, 1.0],
1431 ];
1432 let neighbors: Vec<Vec<u32>> = vec![
1433 vec![1, 3], vec![0, 2], vec![1, 4], vec![0, 4], vec![2, 3],
1434 ];
1435
1436 let query = [2.1f32, 0.9];
1438
1439 let results = beam_search(
1440 &[0],
1441 5,
1442 3,
1443 |id| {
1444 let p = &positions[id as usize];
1445 ((query[0] - p[0]).powi(2) + (query[1] - p[1]).powi(2)).sqrt()
1446 },
1447 |id| neighbors[id as usize].clone(),
1448 |id| id % 2 == 1, BeamSearchConfig {
1450 expanded_beam: Some(10),
1451 max_iterations: Some(20),
1452 early_term_factor: Some(1.5),
1453 },
1454 );
1455
1456 for (id, _) in &results {
1458 assert!(id % 2 == 1, "Expected only odd IDs, got {}", id);
1459 }
1460 let ids: HashSet<u32> = results.iter().map(|(id, _)| *id).collect();
1462 assert!(ids.contains(&1));
1463 assert!(ids.contains(&3));
1464 }
1465
1466 #[test]
1467 fn test_prune_neighbors_alpha() {
1468 let vectors = vec![
1473 vec![0.0, 0.0], vec![1.0, 0.0], vec![1.2, 0.0], vec![0.0, 2.0], ];
1478
1479 let candidates: Vec<(u32, f32)> = vec![
1480 (1, DistL2 {}.eval(&vectors[0], &vectors[1])),
1481 (2, DistL2 {}.eval(&vectors[0], &vectors[2])),
1482 (3, DistL2 {}.eval(&vectors[0], &vectors[3])),
1483 ];
1484
1485 let pruned = prune_neighbors(0, &candidates, &vectors, 3, 1.0, DistL2 {});
1488
1489 assert!(pruned.contains(&1));
1491 assert!(pruned.contains(&3));
1493 }
1496
1497 #[test]
1498 fn test_prune_neighbors_max_degree() {
1499 let vectors = vec![
1500 vec![0.0, 0.0],
1501 vec![1.0, 0.0],
1502 vec![0.0, 1.0],
1503 vec![1.0, 1.0],
1504 vec![2.0, 0.0],
1505 vec![0.0, 2.0],
1506 ];
1507
1508 let candidates: Vec<(u32, f32)> = (1..6)
1509 .map(|i| (i as u32, DistL2 {}.eval(&vectors[0], &vectors[i])))
1510 .collect();
1511
1512 let pruned = prune_neighbors(0, &candidates, &vectors, 2, 1.2, DistL2 {});
1514 assert_eq!(pruned.len(), 2);
1515 assert!(!pruned.is_empty());
1516
1517 let pruned = prune_neighbors(0, &candidates, &vectors, 5, 1.2, DistL2 {});
1519 assert_eq!(pruned.len(), 5);
1520
1521 let pruned = prune_neighbors(0, &candidates, &vectors, 1, 1.2, DistL2 {});
1523 assert_eq!(pruned.len(), 1);
1524 }
1525
1526 #[test]
1527 fn test_core_magic_number_in_bytes() {
1528 let path = "test_magic.db";
1529 let _ = fs::remove_file(path);
1530
1531 let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1532 let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1533 let bytes = index.to_bytes();
1534
1535 let magic = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
1537 assert_eq!(magic, CORE_MAGIC, "Expected magic 0x{:08X}, got 0x{:08X}", CORE_MAGIC, magic);
1538
1539 let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
1541 assert_eq!(version, CORE_FORMAT_VERSION);
1542
1543 let _ = fs::remove_file(path);
1544 }
1545}