1use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50const PAD_U32: u32 = u32::MAX;
52
53pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57pub const DISKANN_DEFAULT_PASSES: usize = 1;
59pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 1;
61
62const GRAPH_SLACK_FACTOR: f32 = 1.3;
66
67
68const MICRO_BATCH_CHUNK_SIZE: usize = 256;
84
85#[derive(Clone, Copy, Debug)]
87pub struct DiskAnnParams {
88 pub max_degree: usize,
89 pub build_beam_width: usize,
90 pub alpha: f32,
91 pub passes: usize,
93 pub extra_seeds: usize,
95}
96
97impl Default for DiskAnnParams {
98 fn default() -> Self {
99 Self {
100 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
101 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
102 alpha: DISKANN_DEFAULT_ALPHA,
103 passes: DISKANN_DEFAULT_PASSES,
104 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
105 }
106 }
107}
108
109#[derive(Debug, Error)]
111pub enum DiskAnnError {
112 #[error("I/O error: {0}")]
114 Io(#[from] std::io::Error),
115
116 #[error("Serialization error: {0}")]
118 Bincode(#[from] bincode::Error),
119
120 #[error("Index error: {0}")]
122 IndexError(String),
123}
124
125#[derive(Serialize, Deserialize, Debug)]
127struct Metadata {
128 dim: usize,
129 num_vectors: usize,
130 max_degree: usize,
131 medoid_id: u32,
132 vectors_offset: u64,
133 adjacency_offset: u64,
134 elem_size: u8,
135 distance_name: String,
136}
137
138#[derive(Clone, Copy, Debug)]
140struct Candidate {
141 dist: f32,
142 id: u32,
143}
144impl PartialEq for Candidate {
145 fn eq(&self, other: &Self) -> bool {
146 self.id == other.id && self.dist.to_bits() == other.dist.to_bits()
147 }
148}
149impl Eq for Candidate {}
150impl PartialOrd for Candidate {
151 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152 Some(
153 self.dist
154 .total_cmp(&other.dist)
155 .then_with(|| self.id.cmp(&other.id)),
156 )
157 }
158}
159impl Ord for Candidate {
160 fn cmp(&self, other: &Self) -> Ordering {
161 self.partial_cmp(other).unwrap_or(Ordering::Equal)
162 }
163}
164
165#[derive(Clone, Debug)]
169struct FlatVectors<T> {
170 data: Vec<T>,
171 dim: usize,
172 n: usize,
173}
174
175impl<T: Copy> FlatVectors<T> {
176 fn from_vecs(vectors: &[Vec<T>]) -> Result<Self, DiskAnnError> {
177 if vectors.is_empty() {
178 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
179 }
180 let dim = vectors[0].len();
181 for (i, v) in vectors.iter().enumerate() {
182 if v.len() != dim {
183 return Err(DiskAnnError::IndexError(format!(
184 "Vector {} has dimension {} but expected {}",
185 i,
186 v.len(),
187 dim
188 )));
189 }
190 }
191
192 let n = vectors.len();
193 let mut data = Vec::with_capacity(n * dim);
194 for v in vectors {
195 data.extend_from_slice(v);
196 }
197
198 Ok(Self { data, dim, n })
199 }
200
201 #[inline]
202 fn row(&self, idx: usize) -> &[T] {
203 let start = idx * self.dim;
204 let end = start + self.dim;
205 &self.data[start..end]
206 }
207}
208
209#[derive(Default, Debug)]
220struct OrderedBeam {
221 items: Vec<Candidate>,
222}
223
224impl OrderedBeam {
225 #[inline]
226 fn clear(&mut self) {
227 self.items.clear();
228 }
229
230 #[inline]
231 fn len(&self) -> usize {
232 self.items.len()
233 }
234
235 #[inline]
236 fn is_empty(&self) -> bool {
237 self.items.is_empty()
238 }
239
240 #[inline]
241 fn best(&self) -> Option<Candidate> {
242 self.items.last().copied()
243 }
244
245 #[inline]
246 fn worst(&self) -> Option<Candidate> {
247 self.items.first().copied()
248 }
249
250 #[inline]
251 fn pop_best(&mut self) -> Option<Candidate> {
252 self.items.pop()
253 }
254
255 #[inline]
256 fn reserve(&mut self, cap: usize) {
257 if self.items.capacity() < cap {
258 self.items.reserve(cap - self.items.capacity());
259 }
260 }
261
262 #[inline]
263 fn insert_unbounded(&mut self, cand: Candidate) {
264 let pos = self.items.partition_point(|x| {
265 x.dist > cand.dist || (x.dist.to_bits() == cand.dist.to_bits() && x.id > cand.id)
266 });
267 self.items.insert(pos, cand);
268 }
269
270 #[inline]
271 fn insert_capped(&mut self, cand: Candidate, cap: usize) {
272 if cap == 0 {
273 return;
274 }
275
276 if self.items.len() < cap {
277 self.insert_unbounded(cand);
278 return;
279 }
280
281 let worst = self.items[0];
283 if cand.dist >= worst.dist {
284 return;
285 }
286
287 self.insert_unbounded(cand);
288
289 if self.items.len() > cap {
290 self.items.remove(0);
291 }
292 }
293}
294
295#[derive(Debug)]
299struct BuildScratch {
300 marks: Vec<u32>,
301 epoch: u32,
302
303 visited_ids: Vec<u32>,
304 visited_dists: Vec<f32>,
305
306 frontier: OrderedBeam,
307 work: OrderedBeam,
308
309 seeds: Vec<usize>,
310 candidates: Vec<(u32, f32)>,
311}
312
313impl BuildScratch {
314 fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
315 Self {
316 marks: vec![0u32; n],
317 epoch: 1,
318 visited_ids: Vec::with_capacity(beam_width * 4),
319 visited_dists: Vec::with_capacity(beam_width * 4),
320 frontier: {
321 let mut b = OrderedBeam::default();
322 b.reserve(beam_width * 2);
323 b
324 },
325 work: {
326 let mut b = OrderedBeam::default();
327 b.reserve(beam_width * 2);
328 b
329 },
330 seeds: Vec::with_capacity(1 + extra_seeds),
331 candidates: Vec::with_capacity(beam_width * (4 + extra_seeds) + max_degree * 2),
332 }
333 }
334
335 #[inline]
336 fn reset_search(&mut self) {
337 self.epoch = self.epoch.wrapping_add(1);
338 if self.epoch == 0 {
339 self.marks.fill(0);
340 self.epoch = 1;
341 }
342 self.visited_ids.clear();
343 self.visited_dists.clear();
344 self.frontier.clear();
345 self.work.clear();
346 }
347
348 #[inline]
349 fn is_marked(&self, idx: usize) -> bool {
350 self.marks[idx] == self.epoch
351 }
352
353 #[inline]
354 fn mark_with_dist(&mut self, idx: usize, dist: f32) {
355 self.marks[idx] = self.epoch;
356 self.visited_ids.push(idx as u32);
357 self.visited_dists.push(dist);
358 }
359}
360
361#[derive(Debug)]
362struct IncrementalInsertScratch {
363 build: BuildScratch,
364}
365
366impl IncrementalInsertScratch {
367 fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
368 Self {
369 build: BuildScratch::new(n, beam_width, max_degree, extra_seeds),
370 }
371 }
372}
373
374pub struct DiskANN<T, D>
376where
377 T: bytemuck::Pod + Copy + Send + Sync + 'static,
378 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
379{
380 pub dim: usize,
382 pub num_vectors: usize,
384 pub max_degree: usize,
386 pub distance_name: String,
388
389 medoid_id: u32,
391 vectors_offset: u64,
393 adjacency_offset: u64,
394
395 mmap: Mmap,
397
398 dist: D,
400
401 _phantom: PhantomData<T>,
403}
404
405impl<T, D> DiskANN<T, D>
408where
409 T: bytemuck::Pod + Copy + Send + Sync + 'static,
410 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
411{
412 pub fn build_index_default(
414 vectors: &[Vec<T>],
415 dist: D,
416 file_path: &str,
417 ) -> Result<Self, DiskAnnError> {
418 Self::build_index(
419 vectors,
420 DISKANN_DEFAULT_MAX_DEGREE,
421 DISKANN_DEFAULT_BUILD_BEAM,
422 DISKANN_DEFAULT_ALPHA,
423 DISKANN_DEFAULT_PASSES,
424 DISKANN_DEFAULT_EXTRA_SEEDS,
425 dist,
426 file_path,
427 )
428 }
429
430 pub fn build_index_with_params(
432 vectors: &[Vec<T>],
433 dist: D,
434 file_path: &str,
435 p: DiskAnnParams,
436 ) -> Result<Self, DiskAnnError> {
437 Self::build_index(
438 vectors,
439 p.max_degree,
440 p.build_beam_width,
441 p.alpha,
442 p.passes,
443 p.extra_seeds,
444 dist,
445 file_path,
446 )
447 }
448
449 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
451 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
452
453 let mut buf8 = [0u8; 8];
455 file.seek(SeekFrom::Start(0))?;
456 file.read_exact(&mut buf8)?;
457 let md_len = u64::from_le_bytes(buf8);
458
459 let mut md_bytes = vec![0u8; md_len as usize];
461 file.read_exact(&mut md_bytes)?;
462 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
463
464 let mmap = unsafe { memmap2::Mmap::map(&file)? };
465
466 let want = std::mem::size_of::<T>() as u8;
468 if metadata.elem_size != want {
469 return Err(DiskAnnError::IndexError(format!(
470 "element size mismatch: file has {}B, T is {}B",
471 metadata.elem_size, want
472 )));
473 }
474
475 let expected = std::any::type_name::<D>();
477 if metadata.distance_name != expected {
478 eprintln!(
479 "Warning: index recorded distance `{}` but you opened with `{}`",
480 metadata.distance_name, expected
481 );
482 }
483
484 Ok(Self {
485 dim: metadata.dim,
486 num_vectors: metadata.num_vectors,
487 max_degree: metadata.max_degree,
488 distance_name: metadata.distance_name,
489 medoid_id: metadata.medoid_id,
490 vectors_offset: metadata.vectors_offset,
491 adjacency_offset: metadata.adjacency_offset,
492 mmap,
493 dist,
494 _phantom: PhantomData,
495 })
496 }
497}
498
499impl<T, D> DiskANN<T, D>
501where
502 T: bytemuck::Pod + Copy + Send + Sync + 'static,
503 D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
504{
505 pub fn build_index_default_metric(
507 vectors: &[Vec<T>],
508 file_path: &str,
509 ) -> Result<Self, DiskAnnError> {
510 Self::build_index_default(vectors, D::default(), file_path)
511 }
512
513 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
515 Self::open_index_with(path, D::default())
516 }
517}
518
519impl<T, D> DiskANN<T, D>
520where
521 T: bytemuck::Pod + Copy + Send + Sync + 'static,
522 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
523{
524 pub fn build_index(
536 vectors: &[Vec<T>],
537 max_degree: usize,
538 build_beam_width: usize,
539 alpha: f32,
540 passes: usize,
541 extra_seeds: usize,
542 dist: D,
543 file_path: &str,
544 ) -> Result<Self, DiskAnnError> {
545 let flat = FlatVectors::from_vecs(vectors)?;
546
547 let num_vectors = flat.n;
548 let dim = flat.dim;
549
550 let mut file = OpenOptions::new()
551 .create(true)
552 .write(true)
553 .read(true)
554 .truncate(true)
555 .open(file_path)?;
556
557 let vectors_offset = 1024 * 1024;
559 assert_eq!(
560 (vectors_offset as usize) % std::mem::align_of::<T>(),
561 0,
562 "vectors_offset must be aligned for T"
563 );
564
565 let elem_sz = std::mem::size_of::<T>() as u64;
566 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
567
568 file.seek(SeekFrom::Start(vectors_offset as u64))?;
570 file.write_all(bytemuck::cast_slice::<T, u8>(&flat.data))?;
571
572 let medoid_id = calculate_medoid(&flat, dist);
574
575 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
577 let graph = build_vamana_graph(
578 &flat,
579 max_degree,
580 build_beam_width,
581 alpha,
582 passes,
583 extra_seeds,
584 dist,
585 medoid_id as u32,
586 );
587
588 file.seek(SeekFrom::Start(adjacency_offset))?;
590 for neighbors in &graph {
591 let mut padded = neighbors.clone();
592 padded.resize(max_degree, PAD_U32);
593 let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
594 file.write_all(bytes)?;
595 }
596
597 let metadata = Metadata {
599 dim,
600 num_vectors,
601 max_degree,
602 medoid_id: medoid_id as u32,
603 vectors_offset: vectors_offset as u64,
604 adjacency_offset,
605 elem_size: std::mem::size_of::<T>() as u8,
606 distance_name: std::any::type_name::<D>().to_string(),
607 };
608
609 let md_bytes = bincode::serialize(&metadata)?;
610 file.seek(SeekFrom::Start(0))?;
611 let md_len = md_bytes.len() as u64;
612 file.write_all(&md_len.to_le_bytes())?;
613 file.write_all(&md_bytes)?;
614 file.sync_all()?;
615
616 let mmap = unsafe { memmap2::Mmap::map(&file)? };
618
619 Ok(Self {
620 dim,
621 num_vectors,
622 max_degree,
623 distance_name: metadata.distance_name,
624 medoid_id: metadata.medoid_id,
625 vectors_offset: metadata.vectors_offset,
626 adjacency_offset: metadata.adjacency_offset,
627 mmap,
628 dist,
629 _phantom: PhantomData,
630 })
631 }
632
633 pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
636 assert_eq!(
637 query.len(),
638 self.dim,
639 "Query dim {} != index dim {}",
640 query.len(),
641 self.dim
642 );
643
644 let mut visited = HashSet::new();
645 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
646 let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
647
648 let start_dist = self.distance_to(query, self.medoid_id as usize);
649 let start = Candidate {
650 dist: start_dist,
651 id: self.medoid_id,
652 };
653 frontier.push(Reverse(start));
654 w.push(start);
655 visited.insert(self.medoid_id);
656
657 while let Some(Reverse(best)) = frontier.peek().copied() {
658 if w.len() >= beam_width {
659 if let Some(worst) = w.peek() {
660 if best.dist >= worst.dist {
661 break;
662 }
663 }
664 }
665 let Reverse(current) = frontier.pop().unwrap();
666
667 for &nb in self.get_neighbors(current.id) {
668 if nb == PAD_U32 {
669 continue;
670 }
671 if !visited.insert(nb) {
672 continue;
673 }
674
675 let d = self.distance_to(query, nb as usize);
676 let cand = Candidate { dist: d, id: nb };
677
678 if w.len() < beam_width {
679 w.push(cand);
680 frontier.push(Reverse(cand));
681 } else if d < w.peek().unwrap().dist {
682 w.pop();
683 w.push(cand);
684 frontier.push(Reverse(cand));
685 }
686 }
687 }
688
689 let mut results: Vec<_> = w.into_vec();
690 results.sort_by(|a, b| a.dist.total_cmp(&b.dist));
691 results.truncate(k);
692 results.into_iter().map(|c| (c.id, c.dist)).collect()
693 }
694
695 pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
697 self.search_with_dists(query, k, beam_width)
698 .into_iter()
699 .map(|(id, _dist)| id)
700 .collect()
701 }
702
703 fn get_neighbors(&self, node_id: u32) -> &[u32] {
705 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
706 let start = offset as usize;
707 let end = start + (self.max_degree * 4);
708 let bytes = &self.mmap[start..end];
709 bytemuck::cast_slice(bytes)
710 }
711
712 fn distance_to(&self, query: &[T], idx: usize) -> f32 {
714 let elem_sz = std::mem::size_of::<T>();
715 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
716 let start = offset as usize;
717 let end = start + (self.dim * elem_sz);
718 let bytes = &self.mmap[start..end];
719 let vector: &[T] = bytemuck::cast_slice(bytes);
720 self.dist.eval(query, vector)
721 }
722
723 pub fn get_vector(&self, idx: usize) -> Vec<T> {
725 let elem_sz = std::mem::size_of::<T>();
726 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
727 let start = offset as usize;
728 let end = start + (self.dim * elem_sz);
729 let bytes = &self.mmap[start..end];
730 let vector: &[T] = bytemuck::cast_slice(bytes);
731 vector.to_vec()
732 }
733}
734
735fn calculate_medoid<T, D>(vectors: &FlatVectors<T>, dist: D) -> usize
737where
738 T: bytemuck::Pod + Copy + Send + Sync,
739 D: Distance<T> + Copy + Sync,
740{
741 let n = vectors.n;
742 let k = 8.min(n);
743 let mut rng = thread_rng();
744 let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
745
746 let (best_idx, _best_score) = (0..n)
747 .into_par_iter()
748 .map(|i| {
749 let vi = vectors.row(i);
750 let score: f32 = pivots.iter().map(|&p| dist.eval(vi, vectors.row(p))).sum();
751 (i, score)
752 })
753 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
754
755 best_idx
756}
757
758fn dedup_keep_best_by_id_in_place(cands: &mut Vec<(u32, f32)>) {
759 if cands.is_empty() {
760 return;
761 }
762
763 cands.sort_by(|a, b| {
764 a.0.cmp(&b.0)
765 .then_with(|| a.1.total_cmp(&b.1))
766 });
767
768 let mut write = 0usize;
769 for read in 0..cands.len() {
770 if write == 0 || cands[read].0 != cands[write - 1].0 {
771 cands[write] = cands[read];
772 write += 1;
773 }
774 }
775 cands.truncate(write);
776}
777
778fn merge_chunk_updates_into_graph_reuse<T, D>(
785 graph: &mut [Vec<u32>],
786 chunk_nodes: &[usize],
787 chunk_pruned: &[Vec<u32>],
788 vectors: &FlatVectors<T>,
789 max_degree: usize,
790 slack_limit: usize,
791 alpha: f32,
792 dist: D,
793 merge: &mut MergeScratch,
794) where
795 T: bytemuck::Pod + Copy + Send + Sync,
796 D: Distance<T> + Copy + Sync,
797{
798 merge.reset();
799
800 for &u in chunk_nodes {
802 merge.mark_affected(u);
803 }
804
805 let mut total_incoming = 0usize;
807
808 for (local_idx, &u) in chunk_nodes.iter().enumerate() {
809 for &dst in &chunk_pruned[local_idx] {
810 let dst_usize = dst as usize;
811 if dst_usize == u {
812 continue;
813 }
814
815 merge.mark_affected(dst_usize);
816 merge.incoming_counts[dst_usize] += 1;
817 total_incoming += 1;
818 }
819 }
820
821 merge.affected_nodes.sort_unstable();
823
824 let mut running = 0usize;
825 for &u in &merge.affected_nodes {
826 merge.incoming_offsets[u] = running;
827 running += merge.incoming_counts[u];
828 merge.incoming_offsets[u + 1] = running;
829 }
830
831 merge.incoming_flat.resize(total_incoming, PAD_U32);
832
833 for &u in &merge.affected_nodes {
835 merge.incoming_write[u] = merge.incoming_offsets[u];
836 }
837
838 for (local_idx, &u) in chunk_nodes.iter().enumerate() {
840 for &dst in &chunk_pruned[local_idx] {
841 let dst_usize = dst as usize;
842 if dst_usize == u {
843 continue;
844 }
845
846 let pos = merge.incoming_write[dst_usize];
847 merge.incoming_flat[pos] = u as u32;
848 merge.incoming_write[dst_usize] += 1;
849 }
850 }
851
852 for (local_idx, &u) in chunk_nodes.iter().enumerate() {
854 graph[u] = chunk_pruned[local_idx].clone();
855 }
856
857 let affected = merge.affected_nodes.clone();
860
861 let updated_pairs: Vec<(usize, Vec<u32>)> = affected
862 .into_par_iter()
863 .map(|u| {
864 let start = merge.incoming_offsets[u];
865 let end = merge.incoming_offsets[u + 1];
866
867 let mut ids: Vec<u32> = Vec::with_capacity(graph[u].len() + (end - start));
868
869 ids.extend_from_slice(&graph[u]);
871
872 if start < end {
874 ids.extend_from_slice(&merge.incoming_flat[start..end]);
875 }
876
877 ids.retain(|&id| id != PAD_U32 && id as usize != u);
879
880 ids.sort_unstable();
882 ids.dedup();
883
884 if ids.is_empty() {
885 return (u, Vec::new());
886 }
887
888 if ids.len() <= slack_limit {
890 return (u, ids);
891 }
892
893 let mut pool = Vec::<(u32, f32)>::with_capacity(ids.len());
895 for id in ids {
896 let d = dist.eval(vectors.row(u), vectors.row(id as usize));
897 pool.push((id, d));
898 }
899
900 let pruned = prune_neighbors(u, &pool, vectors, max_degree, alpha, dist);
901 (u, pruned)
902 })
903 .collect();
904
905 for (u, neigh) in updated_pairs {
906 graph[u] = neigh;
907 }
908
909 for &u in &merge.affected_nodes {
911 merge.incoming_counts[u] = 0;
912 merge.incoming_offsets[u + 1] = 0;
913 }
914}
915
916#[derive(Debug)]
926struct MergeScratch {
927 incoming_counts: Vec<usize>,
928 incoming_offsets: Vec<usize>,
929 incoming_write: Vec<usize>,
930 incoming_flat: Vec<u32>,
931
932 affected_marks: Vec<u32>,
933 affected_epoch: u32,
934 affected_nodes: Vec<usize>,
935}
936
937impl MergeScratch {
938 fn new(n: usize) -> Self {
939 Self {
940 incoming_counts: vec![0usize; n],
941 incoming_offsets: vec![0usize; n + 1],
942 incoming_write: vec![0usize; n],
943 incoming_flat: Vec::new(),
944 affected_marks: vec![0u32; n],
945 affected_epoch: 1,
946 affected_nodes: Vec::new(),
947 }
948 }
949
950 #[inline]
951 fn reset(&mut self) {
952 self.affected_epoch = self.affected_epoch.wrapping_add(1);
953 if self.affected_epoch == 0 {
954 self.affected_marks.fill(0);
955 self.affected_epoch = 1;
956 }
957 self.affected_nodes.clear();
958 self.incoming_flat.clear();
959 }
960
961 #[inline]
962 fn mark_affected(&mut self, u: usize) {
963 if self.affected_marks[u] != self.affected_epoch {
964 self.affected_marks[u] = self.affected_epoch;
965 self.affected_nodes.push(u);
966 self.incoming_counts[u] = 0;
967 }
968 }
969}
970
971fn build_vamana_graph<T, D>(
974 vectors: &FlatVectors<T>,
975 max_degree: usize,
976 build_beam_width: usize,
977 alpha: f32,
978 passes: usize,
979 extra_seeds: usize,
980 dist: D,
981 medoid_id: u32,
982) -> Vec<Vec<u32>>
983where
984 T: bytemuck::Pod + Copy + Send + Sync,
985 D: Distance<T> + Copy + Sync,
986{
987 let n = vectors.n;
988 let mut graph = vec![Vec::<u32>::new(); n];
989 {
991 let mut rng = thread_rng();
992 let target = max_degree.min(n.saturating_sub(1));
993
994 for i in 0..n {
995 let mut s = HashSet::with_capacity(target);
996 while s.len() < target {
997 let nb = rng.gen_range(0..n);
998 if nb != i {
999 s.insert(nb as u32);
1000 }
1001 }
1002 graph[i] = s.into_iter().collect();
1003 }
1004 }
1005
1006 let passes = passes.max(1);
1007 let mut rng = thread_rng();
1008 let slack_limit = ((GRAPH_SLACK_FACTOR * max_degree as f32).ceil() as usize).max(max_degree);
1009
1010 let mut merge_scratch = MergeScratch::new(n);
1012
1013 for pass_idx in 0..passes {
1014 let pass_alpha = if passes == 1 {
1015 alpha
1016 } else if pass_idx == 0 {
1017 1.0
1018 } else {
1019 alpha
1020 };
1021
1022 let mut order: Vec<usize> = (0..n).collect();
1023 order.shuffle(&mut rng);
1024
1025 for chunk in order.chunks(MICRO_BATCH_CHUNK_SIZE) {
1026 let snapshot = &graph;
1027 let chunk_results: Vec<(usize, Vec<u32>)> = chunk
1029 .par_iter()
1030 .map_init(
1031 || IncrementalInsertScratch::new(n, build_beam_width, max_degree, extra_seeds),
1032 |scratch, &u| {
1033 let bs = &mut scratch.build;
1034 bs.candidates.clear();
1035
1036 for &nb in &snapshot[u] {
1038 let d = dist.eval(vectors.row(u), vectors.row(nb as usize));
1039 bs.candidates.push((nb, d));
1040 }
1041
1042 bs.seeds.clear();
1044 bs.seeds.push(medoid_id as usize);
1045
1046 let mut local_rng = thread_rng();
1047 while bs.seeds.len() < 1 + extra_seeds {
1048 let s = local_rng.gen_range(0..n);
1049 if !bs.seeds.contains(&s) {
1050 bs.seeds.push(s);
1051 }
1052 }
1053
1054 let seeds_len = bs.seeds.len();
1055 for si in 0..seeds_len {
1056 let start = bs.seeds[si];
1057
1058 greedy_search_visited_collect(
1059 vectors.row(u),
1060 vectors,
1061 snapshot,
1062 start,
1063 build_beam_width,
1064 dist,
1065 bs,
1066 );
1067
1068 for i in 0..bs.visited_ids.len() {
1069 bs.candidates.push((bs.visited_ids[i], bs.visited_dists[i]));
1070 }
1071 }
1072
1073 dedup_keep_best_by_id_in_place(&mut bs.candidates);
1074
1075 let pruned = prune_neighbors(
1076 u,
1077 &bs.candidates,
1078 vectors,
1079 max_degree,
1080 pass_alpha,
1081 dist,
1082 );
1083
1084 (u, pruned)
1085 },
1086 )
1087 .collect();
1088
1089 let mut chunk_nodes = Vec::<usize>::with_capacity(chunk_results.len());
1090 let mut chunk_pruned = Vec::<Vec<u32>>::with_capacity(chunk_results.len());
1091
1092 for (u, pruned) in chunk_results {
1093 chunk_nodes.push(u);
1094 chunk_pruned.push(pruned);
1095 }
1096 merge_chunk_updates_into_graph_reuse(
1098 &mut graph,
1099 &chunk_nodes,
1100 &chunk_pruned,
1101 vectors,
1102 max_degree,
1103 slack_limit,
1104 pass_alpha,
1105 dist,
1106 &mut merge_scratch,
1107 );
1108 }
1109 }
1110
1111 graph
1113 .into_par_iter()
1114 .enumerate()
1115 .map(|(u, neigh)| {
1116 if neigh.len() <= max_degree {
1117 return neigh;
1118 }
1119
1120 let mut ids = neigh;
1121 ids.sort_unstable();
1122 ids.dedup();
1123
1124 let pool: Vec<(u32, f32)> = ids
1125 .into_iter()
1126 .filter(|&id| id as usize != u)
1127 .map(|id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
1128 .collect();
1129
1130 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
1131 })
1132 .collect()
1133}
1134
1135fn greedy_search_visited_collect<T, D>(
1141 query: &[T],
1142 vectors: &FlatVectors<T>,
1143 graph: &[Vec<u32>],
1144 start_id: usize,
1145 beam_width: usize,
1146 dist: D,
1147 scratch: &mut BuildScratch,
1148) where
1149 T: bytemuck::Pod + Copy + Send + Sync,
1150 D: Distance<T> + Copy,
1151{
1152 scratch.reset_search();
1153
1154 let start_dist = dist.eval(query, vectors.row(start_id));
1155 let start = Candidate {
1156 dist: start_dist,
1157 id: start_id as u32,
1158 };
1159
1160 scratch.frontier.insert_unbounded(start);
1161 scratch.work.insert_capped(start, beam_width);
1162 scratch.mark_with_dist(start_id, start_dist);
1163
1164 while !scratch.frontier.is_empty() {
1165 let best = scratch.frontier.best().unwrap();
1166 if scratch.work.len() >= beam_width {
1167 if let Some(worst) = scratch.work.worst() {
1168 if best.dist >= worst.dist {
1169 break;
1170 }
1171 }
1172 }
1173
1174 let cur = scratch.frontier.pop_best().unwrap();
1175
1176 for &nb in &graph[cur.id as usize] {
1177 let nb_usize = nb as usize;
1178 if scratch.is_marked(nb_usize) {
1179 continue;
1180 }
1181
1182 let d = dist.eval(query, vectors.row(nb_usize));
1183 scratch.mark_with_dist(nb_usize, d);
1184
1185 let cand = Candidate { dist: d, id: nb };
1186
1187 if scratch.work.len() < beam_width {
1188 scratch.work.insert_unbounded(cand);
1189 scratch.frontier.insert_unbounded(cand);
1190 } else if let Some(worst) = scratch.work.worst() {
1191 if d < worst.dist {
1192 scratch.work.insert_capped(cand, beam_width);
1193 scratch.frontier.insert_unbounded(cand);
1194 }
1195 }
1196 }
1197 }
1198}
1199
1200fn prune_neighbors<T, D>(
1202 node_id: usize,
1203 candidates: &[(u32, f32)],
1204 vectors: &FlatVectors<T>,
1205 max_degree: usize,
1206 alpha: f32,
1207 dist: D,
1208) -> Vec<u32>
1209where
1210 T: bytemuck::Pod + Copy + Send + Sync,
1211 D: Distance<T> + Copy,
1212{
1213 if candidates.is_empty() || max_degree == 0 {
1214 return Vec::new();
1215 }
1216
1217 let mut sorted = candidates.to_vec();
1219 sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
1220
1221 let mut uniq = Vec::<(u32, f32)>::with_capacity(sorted.len());
1223 let mut last_id: Option<u32> = None;
1224 for &(cand_id, cand_dist) in &sorted {
1225 if cand_id as usize == node_id {
1226 continue;
1227 }
1228 if last_id == Some(cand_id) {
1229 continue;
1230 }
1231 uniq.push((cand_id, cand_dist));
1232 last_id = Some(cand_id);
1233 }
1234
1235 if uniq.is_empty() {
1236 return Vec::new();
1237 }
1238
1239 let mut pruned = Vec::<u32>::with_capacity(max_degree);
1240
1241 for &(cand_id, cand_dist_to_node) in &uniq {
1243 let mut occluded = false;
1244
1245 for &sel_id in &pruned {
1246 let d_cand_sel = dist.eval(
1247 vectors.row(cand_id as usize),
1248 vectors.row(sel_id as usize),
1249 );
1250
1251 if alpha * d_cand_sel <= cand_dist_to_node {
1252 occluded = true;
1253 break;
1254 }
1255 }
1256
1257 if !occluded {
1258 pruned.push(cand_id);
1259 if pruned.len() >= max_degree {
1260 return pruned;
1261 }
1262 }
1263 }
1264
1265 if pruned.len() < max_degree {
1267 for &(cand_id, _) in &uniq {
1268 if pruned.contains(&cand_id) {
1269 continue;
1270 }
1271 pruned.push(cand_id);
1272 if pruned.len() >= max_degree {
1273 break;
1274 }
1275 }
1276 }
1277
1278 pruned
1279}
1280
1281#[cfg(test)]
1282mod tests {
1283 use super::*;
1284 use anndists::dist::{DistCosine, DistL2};
1285 use rand::Rng;
1286 use std::fs;
1287
1288 fn euclid(a: &[f32], b: &[f32]) -> f32 {
1289 a.iter()
1290 .zip(b)
1291 .map(|(x, y)| (x - y) * (x - y))
1292 .sum::<f32>()
1293 .sqrt()
1294 }
1295
1296 #[test]
1297 fn test_small_index_l2() {
1298 let path = "test_small_l2.db";
1299 let _ = fs::remove_file(path);
1300
1301 let vectors = vec![
1302 vec![0.0, 0.0],
1303 vec![1.0, 0.0],
1304 vec![0.0, 1.0],
1305 vec![1.0, 1.0],
1306 vec![0.5, 0.5],
1307 ];
1308
1309 let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1310
1311 let q = vec![0.1, 0.1];
1312 let nns = index.search(&q, 3, 8);
1313 assert_eq!(nns.len(), 3);
1314
1315 let v = index.get_vector(nns[0] as usize);
1316 assert!(euclid(&q, &v) < 1.0);
1317
1318 let _ = fs::remove_file(path);
1319 }
1320
1321 #[test]
1322 fn test_cosine() {
1323 let path = "test_cosine.db";
1324 let _ = fs::remove_file(path);
1325
1326 let vectors = vec![
1327 vec![1.0, 0.0, 0.0],
1328 vec![0.0, 1.0, 0.0],
1329 vec![0.0, 0.0, 1.0],
1330 vec![1.0, 1.0, 0.0],
1331 vec![1.0, 0.0, 1.0],
1332 ];
1333
1334 let index =
1335 DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
1336
1337 let q = vec![2.0, 0.0, 0.0];
1338 let nns = index.search(&q, 2, 8);
1339 assert_eq!(nns.len(), 2);
1340
1341 let v = index.get_vector(nns[0] as usize);
1342 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1343 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1344 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1345 let cos = dot / (n1 * n2);
1346 assert!(cos > 0.7);
1347
1348 let _ = fs::remove_file(path);
1349 }
1350
1351 #[test]
1352 fn test_persistence_and_open() {
1353 let path = "test_persist.db";
1354 let _ = fs::remove_file(path);
1355
1356 let vectors = vec![
1357 vec![0.0, 0.0],
1358 vec![1.0, 0.0],
1359 vec![0.0, 1.0],
1360 vec![1.0, 1.0],
1361 ];
1362
1363 {
1364 let _idx =
1365 DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1366 }
1367
1368 let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
1369 assert_eq!(idx2.num_vectors, 4);
1370 assert_eq!(idx2.dim, 2);
1371
1372 let q = vec![0.9, 0.9];
1373 let res = idx2.search(&q, 2, 8);
1374 assert_eq!(res[0], 3);
1375
1376 let _ = fs::remove_file(path);
1377 }
1378
1379 #[test]
1380 fn test_grid_connectivity() {
1381 let path = "test_grid.db";
1382 let _ = fs::remove_file(path);
1383
1384 let mut vectors = Vec::new();
1385 for i in 0..5 {
1386 for j in 0..5 {
1387 vectors.push(vec![i as f32, j as f32]);
1388 }
1389 }
1390
1391 let index = DiskANN::<f32, DistL2>::build_index_with_params(
1392 &vectors,
1393 DistL2,
1394 path,
1395 DiskAnnParams {
1396 max_degree: 4,
1397 build_beam_width: 64,
1398 alpha: 1.5,
1399 passes: DISKANN_DEFAULT_PASSES,
1400 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1401 },
1402 )
1403 .unwrap();
1404
1405 for target in 0..vectors.len() {
1406 let q = &vectors[target];
1407 let nns = index.search(q, 10, 32);
1408 if !nns.contains(&(target as u32)) {
1409 let v = index.get_vector(nns[0] as usize);
1410 assert!(euclid(q, &v) < 2.0);
1411 }
1412 for &nb in nns.iter().take(5) {
1413 let v = index.get_vector(nb as usize);
1414 assert!(euclid(q, &v) < 5.0);
1415 }
1416 }
1417
1418 let _ = fs::remove_file(path);
1419 }
1420
1421 #[test]
1422 fn test_medium_random() {
1423 let path = "test_medium.db";
1424 let _ = fs::remove_file(path);
1425
1426 let n = 200usize;
1427 let d = 32usize;
1428 let mut rng = rand::thread_rng();
1429 let vectors: Vec<Vec<f32>> = (0..n)
1430 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1431 .collect();
1432
1433 let index = DiskANN::<f32, DistL2>::build_index_with_params(
1434 &vectors,
1435 DistL2,
1436 path,
1437 DiskAnnParams {
1438 max_degree: 32,
1439 build_beam_width: 128,
1440 alpha: 1.2,
1441 passes: DISKANN_DEFAULT_PASSES,
1442 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1443 },
1444 )
1445 .unwrap();
1446
1447 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1448 let res = index.search(&q, 10, 64);
1449 assert_eq!(res.len(), 10);
1450
1451 let dists: Vec<f32> = res
1452 .iter()
1453 .map(|&id| {
1454 let v = index.get_vector(id as usize);
1455 euclid(&q, &v)
1456 })
1457 .collect();
1458 let mut sorted = dists.clone();
1459 sorted.sort_by(|a, b| a.total_cmp(b));
1460 assert_eq!(dists, sorted);
1461
1462 let _ = fs::remove_file(path);
1463 }
1464}