1use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50const PAD_U32: u32 = u32::MAX;
52
53pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57pub const DISKANN_DEFAULT_PASSES: usize = 1;
59pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 1;
61
62const GRAPH_SLACK_FACTOR: f32 = 1.3;
66
67
68const MICRO_BATCH_CHUNK_SIZE: usize = 256;
84
85#[derive(Clone, Copy, Debug)]
87pub struct DiskAnnParams {
88 pub max_degree: usize,
89 pub build_beam_width: usize,
90 pub alpha: f32,
91 pub passes: usize,
93 pub extra_seeds: usize,
95}
96
97impl Default for DiskAnnParams {
98 fn default() -> Self {
99 Self {
100 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
101 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
102 alpha: DISKANN_DEFAULT_ALPHA,
103 passes: DISKANN_DEFAULT_PASSES,
104 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
105 }
106 }
107}
108
109#[derive(Debug, Error)]
111pub enum DiskAnnError {
112 #[error("I/O error: {0}")]
114 Io(#[from] std::io::Error),
115
116 #[error("Serialization error: {0}")]
118 Bincode(#[from] bincode::Error),
119
120 #[error("Index error: {0}")]
122 IndexError(String),
123}
124
125#[derive(Serialize, Deserialize, Debug)]
127struct Metadata {
128 dim: usize,
129 num_vectors: usize,
130 max_degree: usize,
131 medoid_id: u32,
132 vectors_offset: u64,
133 adjacency_offset: u64,
134 elem_size: u8,
135 distance_name: String,
136}
137
138#[derive(Clone, Copy, Debug)]
140struct Candidate {
141 dist: f32,
142 id: u32,
143}
144impl PartialEq for Candidate {
145 fn eq(&self, other: &Self) -> bool {
146 self.id == other.id && self.dist.to_bits() == other.dist.to_bits()
147 }
148}
149impl Eq for Candidate {}
150impl PartialOrd for Candidate {
151 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152 Some(
153 self.dist
154 .total_cmp(&other.dist)
155 .then_with(|| self.id.cmp(&other.id)),
156 )
157 }
158}
159impl Ord for Candidate {
160 fn cmp(&self, other: &Self) -> Ordering {
161 self.partial_cmp(other).unwrap_or(Ordering::Equal)
162 }
163}
164
165#[derive(Clone, Debug)]
169struct FlatVectors<T> {
170 data: Vec<T>,
171 dim: usize,
172 n: usize,
173}
174
175impl<T: Copy> FlatVectors<T> {
176 fn from_vecs(vectors: &[Vec<T>]) -> Result<Self, DiskAnnError> {
177 if vectors.is_empty() {
178 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
179 }
180 let dim = vectors[0].len();
181 for (i, v) in vectors.iter().enumerate() {
182 if v.len() != dim {
183 return Err(DiskAnnError::IndexError(format!(
184 "Vector {} has dimension {} but expected {}",
185 i,
186 v.len(),
187 dim
188 )));
189 }
190 }
191
192 let n = vectors.len();
193 let mut data = Vec::with_capacity(n * dim);
194 for v in vectors {
195 data.extend_from_slice(v);
196 }
197
198 Ok(Self { data, dim, n })
199 }
200
201 #[inline]
202 fn row(&self, idx: usize) -> &[T] {
203 let start = idx * self.dim;
204 let end = start + self.dim;
205 &self.data[start..end]
206 }
207}
208
209#[derive(Default, Debug)]
220struct OrderedBeam {
221 items: Vec<Candidate>,
222}
223
224impl OrderedBeam {
225 #[inline]
226 fn clear(&mut self) {
227 self.items.clear();
228 }
229
230 #[inline]
231 fn len(&self) -> usize {
232 self.items.len()
233 }
234
235 #[inline]
236 fn is_empty(&self) -> bool {
237 self.items.is_empty()
238 }
239
240 #[inline]
241 fn best(&self) -> Option<Candidate> {
242 self.items.last().copied()
243 }
244
245 #[inline]
246 fn worst(&self) -> Option<Candidate> {
247 self.items.first().copied()
248 }
249
250 #[inline]
251 fn pop_best(&mut self) -> Option<Candidate> {
252 self.items.pop()
253 }
254
255 #[inline]
256 fn reserve(&mut self, cap: usize) {
257 if self.items.capacity() < cap {
258 self.items.reserve(cap - self.items.capacity());
259 }
260 }
261
262 #[inline]
263 fn insert_unbounded(&mut self, cand: Candidate) {
264 let pos = self.items.partition_point(|x| {
265 x.dist > cand.dist || (x.dist.to_bits() == cand.dist.to_bits() && x.id > cand.id)
266 });
267 self.items.insert(pos, cand);
268 }
269
270 #[inline]
271 fn insert_capped(&mut self, cand: Candidate, cap: usize) {
272 if cap == 0 {
273 return;
274 }
275
276 if self.items.len() < cap {
277 self.insert_unbounded(cand);
278 return;
279 }
280
281 let worst = self.items[0];
283 if cand.dist >= worst.dist {
284 return;
285 }
286
287 self.insert_unbounded(cand);
288
289 if self.items.len() > cap {
290 self.items.remove(0);
291 }
292 }
293}
294
295#[derive(Debug)]
299struct BuildScratch {
300 marks: Vec<u32>,
301 epoch: u32,
302
303 visited_ids: Vec<u32>,
304 visited_dists: Vec<f32>,
305
306 frontier: OrderedBeam,
307 work: OrderedBeam,
308
309 seeds: Vec<usize>,
310 candidates: Vec<(u32, f32)>,
311}
312
313impl BuildScratch {
314 fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
315 Self {
316 marks: vec![0u32; n],
317 epoch: 1,
318 visited_ids: Vec::with_capacity(beam_width * 4),
319 visited_dists: Vec::with_capacity(beam_width * 4),
320 frontier: {
321 let mut b = OrderedBeam::default();
322 b.reserve(beam_width * 2);
323 b
324 },
325 work: {
326 let mut b = OrderedBeam::default();
327 b.reserve(beam_width * 2);
328 b
329 },
330 seeds: Vec::with_capacity(1 + extra_seeds),
331 candidates: Vec::with_capacity(beam_width * (4 + extra_seeds) + max_degree * 2),
332 }
333 }
334
335 #[inline]
336 fn reset_search(&mut self) {
337 self.epoch = self.epoch.wrapping_add(1);
338 if self.epoch == 0 {
339 self.marks.fill(0);
340 self.epoch = 1;
341 }
342 self.visited_ids.clear();
343 self.visited_dists.clear();
344 self.frontier.clear();
345 self.work.clear();
346 }
347
348 #[inline]
349 fn is_marked(&self, idx: usize) -> bool {
350 self.marks[idx] == self.epoch
351 }
352
353 #[inline]
354 fn mark_with_dist(&mut self, idx: usize, dist: f32) {
355 self.marks[idx] = self.epoch;
356 self.visited_ids.push(idx as u32);
357 self.visited_dists.push(dist);
358 }
359}
360
361#[derive(Debug)]
362struct IncrementalInsertScratch {
363 build: BuildScratch,
364}
365
366impl IncrementalInsertScratch {
367 fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
368 Self {
369 build: BuildScratch::new(n, beam_width, max_degree, extra_seeds),
370 }
371 }
372}
373
374pub struct DiskANN<T, D>
376where
377 T: bytemuck::Pod + Copy + Send + Sync + 'static,
378 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
379{
380 pub dim: usize,
382 pub num_vectors: usize,
384 pub max_degree: usize,
386 pub distance_name: String,
388
389 medoid_id: u32,
391 vectors_offset: u64,
393 adjacency_offset: u64,
394
395 mmap: Mmap,
397
398 dist: D,
400
401 _phantom: PhantomData<T>,
403}
404
405impl<T, D> DiskANN<T, D>
408where
409 T: bytemuck::Pod + Copy + Send + Sync + 'static,
410 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
411{
412 pub fn build_index_default(
414 vectors: &[Vec<T>],
415 dist: D,
416 file_path: &str,
417 ) -> Result<Self, DiskAnnError> {
418 Self::build_index(
419 vectors,
420 DISKANN_DEFAULT_MAX_DEGREE,
421 DISKANN_DEFAULT_BUILD_BEAM,
422 DISKANN_DEFAULT_ALPHA,
423 DISKANN_DEFAULT_PASSES,
424 DISKANN_DEFAULT_EXTRA_SEEDS,
425 dist,
426 file_path,
427 )
428 }
429
430 pub fn build_index_with_params(
432 vectors: &[Vec<T>],
433 dist: D,
434 file_path: &str,
435 p: DiskAnnParams,
436 ) -> Result<Self, DiskAnnError> {
437 Self::build_index(
438 vectors,
439 p.max_degree,
440 p.build_beam_width,
441 p.alpha,
442 p.passes,
443 p.extra_seeds,
444 dist,
445 file_path,
446 )
447 }
448
449 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
451 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
452
453 let mut buf8 = [0u8; 8];
455 file.seek(SeekFrom::Start(0))?;
456 file.read_exact(&mut buf8)?;
457 let md_len = u64::from_le_bytes(buf8);
458
459 let mut md_bytes = vec![0u8; md_len as usize];
461 file.read_exact(&mut md_bytes)?;
462 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
463
464 let mmap = unsafe { memmap2::Mmap::map(&file)? };
465
466 let want = std::mem::size_of::<T>() as u8;
468 if metadata.elem_size != want {
469 return Err(DiskAnnError::IndexError(format!(
470 "element size mismatch: file has {}B, T is {}B",
471 metadata.elem_size, want
472 )));
473 }
474
475 let expected = std::any::type_name::<D>();
477 if metadata.distance_name != expected {
478 eprintln!(
479 "Warning: index recorded distance `{}` but you opened with `{}`",
480 metadata.distance_name, expected
481 );
482 }
483
484 Ok(Self {
485 dim: metadata.dim,
486 num_vectors: metadata.num_vectors,
487 max_degree: metadata.max_degree,
488 distance_name: metadata.distance_name,
489 medoid_id: metadata.medoid_id,
490 vectors_offset: metadata.vectors_offset,
491 adjacency_offset: metadata.adjacency_offset,
492 mmap,
493 dist,
494 _phantom: PhantomData,
495 })
496 }
497}
498
499impl<T, D> DiskANN<T, D>
501where
502 T: bytemuck::Pod + Copy + Send + Sync + 'static,
503 D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
504{
505 pub fn build_index_default_metric(
507 vectors: &[Vec<T>],
508 file_path: &str,
509 ) -> Result<Self, DiskAnnError> {
510 Self::build_index_default(vectors, D::default(), file_path)
511 }
512
513 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
515 Self::open_index_with(path, D::default())
516 }
517}
518
519impl<T, D> DiskANN<T, D>
520where
521 T: bytemuck::Pod + Copy + Send + Sync + 'static,
522 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
523{
524 pub fn build_index(
536 vectors: &[Vec<T>],
537 max_degree: usize,
538 build_beam_width: usize,
539 alpha: f32,
540 passes: usize,
541 extra_seeds: usize,
542 dist: D,
543 file_path: &str,
544 ) -> Result<Self, DiskAnnError> {
545 let flat = FlatVectors::from_vecs(vectors)?;
546
547 let num_vectors = flat.n;
548 let dim = flat.dim;
549
550 let mut file = OpenOptions::new()
551 .create(true)
552 .write(true)
553 .read(true)
554 .truncate(true)
555 .open(file_path)?;
556
557 let vectors_offset = 1024 * 1024;
559 assert_eq!(
560 (vectors_offset as usize) % std::mem::align_of::<T>(),
561 0,
562 "vectors_offset must be aligned for T"
563 );
564
565 let elem_sz = std::mem::size_of::<T>() as u64;
566 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
567
568 file.seek(SeekFrom::Start(vectors_offset as u64))?;
570 file.write_all(bytemuck::cast_slice::<T, u8>(&flat.data))?;
571
572 let medoid_id = calculate_medoid(&flat, dist);
574
575 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
577 let graph = build_vamana_graph(
578 &flat,
579 max_degree,
580 build_beam_width,
581 alpha,
582 passes,
583 extra_seeds,
584 dist,
585 medoid_id as u32,
586 );
587
588 file.seek(SeekFrom::Start(adjacency_offset))?;
590 for neighbors in &graph {
591 let mut padded = neighbors.clone();
592 padded.resize(max_degree, PAD_U32);
593 let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
594 file.write_all(bytes)?;
595 }
596
597 let metadata = Metadata {
599 dim,
600 num_vectors,
601 max_degree,
602 medoid_id: medoid_id as u32,
603 vectors_offset: vectors_offset as u64,
604 adjacency_offset,
605 elem_size: std::mem::size_of::<T>() as u8,
606 distance_name: std::any::type_name::<D>().to_string(),
607 };
608
609 let md_bytes = bincode::serialize(&metadata)?;
610 file.seek(SeekFrom::Start(0))?;
611 let md_len = md_bytes.len() as u64;
612 file.write_all(&md_len.to_le_bytes())?;
613 file.write_all(&md_bytes)?;
614 file.sync_all()?;
615
616 let mmap = unsafe { memmap2::Mmap::map(&file)? };
618
619 Ok(Self {
620 dim,
621 num_vectors,
622 max_degree,
623 distance_name: metadata.distance_name,
624 medoid_id: metadata.medoid_id,
625 vectors_offset: metadata.vectors_offset,
626 adjacency_offset: metadata.adjacency_offset,
627 mmap,
628 dist,
629 _phantom: PhantomData,
630 })
631 }
632
633 pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
636 assert_eq!(
637 query.len(),
638 self.dim,
639 "Query dim {} != index dim {}",
640 query.len(),
641 self.dim
642 );
643
644 let mut visited = HashSet::new();
645 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
646 let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
647
648 let start_dist = self.distance_to(query, self.medoid_id as usize);
649 let start = Candidate {
650 dist: start_dist,
651 id: self.medoid_id,
652 };
653 frontier.push(Reverse(start));
654 w.push(start);
655 visited.insert(self.medoid_id);
656
657 while let Some(Reverse(best)) = frontier.peek().copied() {
658 if w.len() >= beam_width {
659 if let Some(worst) = w.peek() {
660 if best.dist >= worst.dist {
661 break;
662 }
663 }
664 }
665 let Reverse(current) = frontier.pop().unwrap();
666
667 for &nb in self.get_neighbors(current.id) {
668 if nb == PAD_U32 {
669 continue;
670 }
671 if !visited.insert(nb) {
672 continue;
673 }
674
675 let d = self.distance_to(query, nb as usize);
676 let cand = Candidate { dist: d, id: nb };
677
678 if w.len() < beam_width {
679 w.push(cand);
680 frontier.push(Reverse(cand));
681 } else if d < w.peek().unwrap().dist {
682 w.pop();
683 w.push(cand);
684 frontier.push(Reverse(cand));
685 }
686 }
687 }
688
689 let mut results: Vec<_> = w.into_vec();
690 results.sort_by(|a, b| a.dist.total_cmp(&b.dist));
691 results.truncate(k);
692 results.into_iter().map(|c| (c.id, c.dist)).collect()
693 }
694
695 pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
697 self.search_with_dists(query, k, beam_width)
698 .into_iter()
699 .map(|(id, _dist)| id)
700 .collect()
701 }
702
703 fn get_neighbors(&self, node_id: u32) -> &[u32] {
705 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
706 let start = offset as usize;
707 let end = start + (self.max_degree * 4);
708 let bytes = &self.mmap[start..end];
709 bytemuck::cast_slice(bytes)
710 }
711
712 fn distance_to(&self, query: &[T], idx: usize) -> f32 {
714 let elem_sz = std::mem::size_of::<T>();
715 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
716 let start = offset as usize;
717 let end = start + (self.dim * elem_sz);
718 let bytes = &self.mmap[start..end];
719 let vector: &[T] = bytemuck::cast_slice(bytes);
720 self.dist.eval(query, vector)
721 }
722
723 pub fn get_vector(&self, idx: usize) -> Vec<T> {
725 let elem_sz = std::mem::size_of::<T>();
726 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
727 let start = offset as usize;
728 let end = start + (self.dim * elem_sz);
729 let bytes = &self.mmap[start..end];
730 let vector: &[T] = bytemuck::cast_slice(bytes);
731 vector.to_vec()
732 }
733}
734
735fn calculate_medoid<T, D>(vectors: &FlatVectors<T>, dist: D) -> usize
737where
738 T: bytemuck::Pod + Copy + Send + Sync,
739 D: Distance<T> + Copy + Sync,
740{
741 let n = vectors.n;
742 let k = 8.min(n);
743 let mut rng = thread_rng();
744 let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
745
746 let (best_idx, _best_score) = (0..n)
747 .into_par_iter()
748 .map(|i| {
749 let vi = vectors.row(i);
750 let score: f32 = pivots.iter().map(|&p| dist.eval(vi, vectors.row(p))).sum();
751 (i, score)
752 })
753 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
754
755 best_idx
756}
757
758fn dedup_keep_best_by_id_in_place(cands: &mut Vec<(u32, f32)>) {
759 if cands.is_empty() {
760 return;
761 }
762
763 cands.sort_by(|a, b| {
764 a.0.cmp(&b.0)
765 .then_with(|| a.1.total_cmp(&b.1))
766 });
767
768 let mut write = 0usize;
769 for read in 0..cands.len() {
770 if write == 0 || cands[read].0 != cands[write - 1].0 {
771 cands[write] = cands[read];
772 write += 1;
773 }
774 }
775 cands.truncate(write);
776}
777
778fn merge_chunk_updates_into_graph_reuse<T, D>(
788 graph: &mut [Vec<u32>],
789 chunk_nodes: &[usize],
790 chunk_pruned: &[Vec<u32>],
791 vectors: &FlatVectors<T>,
792 max_degree: usize,
793 _slack_limit: usize,
794 alpha: f32,
795 dist: D,
796 merge: &mut MergeScratch,
797) where
798 T: bytemuck::Pod + Copy + Send + Sync,
799 D: Distance<T> + Copy + Sync,
800{
801 merge.reset();
802 for &u in chunk_nodes {
804 merge.mark_affected(u);
805 }
806
807 let mut total_incoming = 0usize;
809
810 for (local_idx, &u) in chunk_nodes.iter().enumerate() {
811 for &dst in &chunk_pruned[local_idx] {
812 let dst_usize = dst as usize;
813 if dst_usize == u {
814 continue;
815 }
816
817 merge.mark_affected(dst_usize);
818 merge.incoming_counts[dst_usize] += 1;
819 total_incoming += 1;
820 }
821 }
822
823 merge.affected_nodes.sort_unstable();
825
826 let mut running = 0usize;
827 for &u in &merge.affected_nodes {
828 merge.incoming_offsets[u] = running;
829 running += merge.incoming_counts[u];
830 merge.incoming_offsets[u + 1] = running;
831 }
832
833 merge.incoming_flat.resize(total_incoming, PAD_U32);
834
835 for &u in &merge.affected_nodes {
837 merge.incoming_write[u] = merge.incoming_offsets[u];
838 }
839
840 for (local_idx, &u) in chunk_nodes.iter().enumerate() {
843 for &dst in &chunk_pruned[local_idx] {
844 let dst_usize = dst as usize;
845 if dst_usize == u {
846 continue;
847 }
848
849 let pos = merge.incoming_write[dst_usize];
850 merge.incoming_flat[pos] = u as u32;
851 merge.incoming_write[dst_usize] += 1;
852 }
853 }
854
855 for (local_idx, &u) in chunk_nodes.iter().enumerate() {
857 graph[u] = chunk_pruned[local_idx].clone();
858 }
859 let affected = merge.affected_nodes.clone();
861
862 let updated_pairs: Vec<(usize, Vec<u32>)> = affected
863 .into_par_iter()
864 .map(|u| {
865 let start = merge.incoming_offsets[u];
866 let end = merge.incoming_offsets[u + 1];
867
868 let mut ids: Vec<u32> = Vec::with_capacity(graph[u].len() + (end - start));
869
870 ids.extend_from_slice(&graph[u]);
872
873 if start < end {
875 ids.extend_from_slice(&merge.incoming_flat[start..end]);
876 }
877
878 ids.retain(|&id| id != PAD_U32 && id as usize != u);
880
881 ids.sort_unstable();
883 ids.dedup();
884
885 if ids.is_empty() {
886 return (u, Vec::new());
887 }
888
889 let mut pool = Vec::<(u32, f32)>::with_capacity(ids.len());
890 for id in ids {
891 let d = dist.eval(vectors.row(u), vectors.row(id as usize));
892 pool.push((id, d));
893 }
894
895 let pruned = prune_neighbors(u, &pool, vectors, max_degree, alpha, dist);
896 (u, pruned)
897 })
898 .collect();
899
900 for (u, neigh) in updated_pairs {
901 graph[u] = neigh;
902 }
903 for &u in &merge.affected_nodes {
905 merge.incoming_counts[u] = 0;
906 merge.incoming_offsets[u + 1] = 0;
907 }
908}
909
910#[derive(Debug)]
920struct MergeScratch {
921 incoming_counts: Vec<usize>,
922 incoming_offsets: Vec<usize>,
923 incoming_write: Vec<usize>,
924 incoming_flat: Vec<u32>,
925
926 affected_marks: Vec<u32>,
927 affected_epoch: u32,
928 affected_nodes: Vec<usize>,
929}
930
931impl MergeScratch {
932 fn new(n: usize) -> Self {
933 Self {
934 incoming_counts: vec![0usize; n],
935 incoming_offsets: vec![0usize; n + 1],
936 incoming_write: vec![0usize; n],
937 incoming_flat: Vec::new(),
938 affected_marks: vec![0u32; n],
939 affected_epoch: 1,
940 affected_nodes: Vec::new(),
941 }
942 }
943
944 #[inline]
945 fn reset(&mut self) {
946 self.affected_epoch = self.affected_epoch.wrapping_add(1);
947 if self.affected_epoch == 0 {
948 self.affected_marks.fill(0);
949 self.affected_epoch = 1;
950 }
951 self.affected_nodes.clear();
952 self.incoming_flat.clear();
953 }
954
955 #[inline]
956 fn mark_affected(&mut self, u: usize) {
957 if self.affected_marks[u] != self.affected_epoch {
958 self.affected_marks[u] = self.affected_epoch;
959 self.affected_nodes.push(u);
960 self.incoming_counts[u] = 0;
961 }
962 }
963}
964
965fn build_vamana_graph<T, D>(
968 vectors: &FlatVectors<T>,
969 max_degree: usize,
970 build_beam_width: usize,
971 alpha: f32,
972 passes: usize,
973 extra_seeds: usize,
974 dist: D,
975 medoid_id: u32,
976) -> Vec<Vec<u32>>
977where
978 T: bytemuck::Pod + Copy + Send + Sync,
979 D: Distance<T> + Copy + Sync,
980{
981 let n = vectors.n;
982 let mut graph = vec![Vec::<u32>::new(); n];
983 {
985 let mut rng = thread_rng();
986 let target = max_degree.min(n.saturating_sub(1));
987
988 for i in 0..n {
989 let mut s = HashSet::with_capacity(target);
990 while s.len() < target {
991 let nb = rng.gen_range(0..n);
992 if nb != i {
993 s.insert(nb as u32);
994 }
995 }
996 graph[i] = s.into_iter().collect();
997 }
998 }
999
1000 let passes = passes.max(1);
1001 let mut rng = thread_rng();
1002 let slack_limit = ((GRAPH_SLACK_FACTOR * max_degree as f32).ceil() as usize).max(max_degree);
1003
1004 let mut merge_scratch = MergeScratch::new(n);
1006
1007 for pass_idx in 0..passes {
1008 let pass_alpha = if passes == 1 {
1009 alpha
1010 } else if pass_idx == 0 {
1011 1.0
1012 } else {
1013 alpha
1014 };
1015
1016 let mut order: Vec<usize> = (0..n).collect();
1017 order.shuffle(&mut rng);
1018
1019 for chunk in order.chunks(MICRO_BATCH_CHUNK_SIZE) {
1020 let snapshot = &graph;
1021 let chunk_results: Vec<(usize, Vec<u32>)> = chunk
1023 .par_iter()
1024 .map_init(
1025 || IncrementalInsertScratch::new(n, build_beam_width, max_degree, extra_seeds),
1026 |scratch, &u| {
1027 let bs = &mut scratch.build;
1028 bs.candidates.clear();
1029
1030 for &nb in &snapshot[u] {
1032 let d = dist.eval(vectors.row(u), vectors.row(nb as usize));
1033 bs.candidates.push((nb, d));
1034 }
1035
1036 bs.seeds.clear();
1038 bs.seeds.push(medoid_id as usize);
1039
1040 let mut local_rng = thread_rng();
1041 while bs.seeds.len() < 1 + extra_seeds {
1042 let s = local_rng.gen_range(0..n);
1043 if !bs.seeds.contains(&s) {
1044 bs.seeds.push(s);
1045 }
1046 }
1047
1048 let seeds_len = bs.seeds.len();
1049 for si in 0..seeds_len {
1050 let start = bs.seeds[si];
1051
1052 greedy_search_visited_collect(
1053 vectors.row(u),
1054 vectors,
1055 snapshot,
1056 start,
1057 build_beam_width,
1058 dist,
1059 bs,
1060 );
1061
1062 for i in 0..bs.visited_ids.len() {
1063 bs.candidates.push((bs.visited_ids[i], bs.visited_dists[i]));
1064 }
1065 }
1066
1067 dedup_keep_best_by_id_in_place(&mut bs.candidates);
1068
1069 let pruned = prune_neighbors(
1070 u,
1071 &bs.candidates,
1072 vectors,
1073 max_degree,
1074 pass_alpha,
1075 dist,
1076 );
1077
1078 (u, pruned)
1079 },
1080 )
1081 .collect();
1082
1083 let mut chunk_nodes = Vec::<usize>::with_capacity(chunk_results.len());
1084 let mut chunk_pruned = Vec::<Vec<u32>>::with_capacity(chunk_results.len());
1085
1086 for (u, pruned) in chunk_results {
1087 chunk_nodes.push(u);
1088 chunk_pruned.push(pruned);
1089 }
1090 merge_chunk_updates_into_graph_reuse(
1092 &mut graph,
1093 &chunk_nodes,
1094 &chunk_pruned,
1095 vectors,
1096 max_degree,
1097 slack_limit,
1098 pass_alpha,
1099 dist,
1100 &mut merge_scratch,
1101 );
1102 }
1103 }
1104
1105 graph
1107 .into_par_iter()
1108 .enumerate()
1109 .map(|(u, neigh)| {
1110 if neigh.len() <= max_degree {
1111 return neigh;
1112 }
1113
1114 let mut ids = neigh;
1115 ids.sort_unstable();
1116 ids.dedup();
1117
1118 let pool: Vec<(u32, f32)> = ids
1119 .into_iter()
1120 .filter(|&id| id as usize != u)
1121 .map(|id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
1122 .collect();
1123
1124 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
1125 })
1126 .collect()
1127}
1128
1129fn greedy_search_visited_collect<T, D>(
1135 query: &[T],
1136 vectors: &FlatVectors<T>,
1137 graph: &[Vec<u32>],
1138 start_id: usize,
1139 beam_width: usize,
1140 dist: D,
1141 scratch: &mut BuildScratch,
1142) where
1143 T: bytemuck::Pod + Copy + Send + Sync,
1144 D: Distance<T> + Copy,
1145{
1146 scratch.reset_search();
1147
1148 let start_dist = dist.eval(query, vectors.row(start_id));
1149 let start = Candidate {
1150 dist: start_dist,
1151 id: start_id as u32,
1152 };
1153
1154 scratch.frontier.insert_unbounded(start);
1155 scratch.work.insert_capped(start, beam_width);
1156 scratch.mark_with_dist(start_id, start_dist);
1157
1158 while !scratch.frontier.is_empty() {
1159 let best = scratch.frontier.best().unwrap();
1160 if scratch.work.len() >= beam_width {
1161 if let Some(worst) = scratch.work.worst() {
1162 if best.dist >= worst.dist {
1163 break;
1164 }
1165 }
1166 }
1167
1168 let cur = scratch.frontier.pop_best().unwrap();
1169
1170 for &nb in &graph[cur.id as usize] {
1171 let nb_usize = nb as usize;
1172 if scratch.is_marked(nb_usize) {
1173 continue;
1174 }
1175
1176 let d = dist.eval(query, vectors.row(nb_usize));
1177 scratch.mark_with_dist(nb_usize, d);
1178
1179 let cand = Candidate { dist: d, id: nb };
1180
1181 if scratch.work.len() < beam_width {
1182 scratch.work.insert_unbounded(cand);
1183 scratch.frontier.insert_unbounded(cand);
1184 } else if let Some(worst) = scratch.work.worst() {
1185 if d < worst.dist {
1186 scratch.work.insert_capped(cand, beam_width);
1187 scratch.frontier.insert_unbounded(cand);
1188 }
1189 }
1190 }
1191 }
1192}
1193
1194fn prune_neighbors<T, D>(
1196 node_id: usize,
1197 candidates: &[(u32, f32)],
1198 vectors: &FlatVectors<T>,
1199 max_degree: usize,
1200 alpha: f32,
1201 dist: D,
1202) -> Vec<u32>
1203where
1204 T: bytemuck::Pod + Copy + Send + Sync,
1205 D: Distance<T> + Copy,
1206{
1207 if candidates.is_empty() || max_degree == 0 {
1208 return Vec::new();
1209 }
1210
1211 let mut sorted = candidates.to_vec();
1213 sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
1214
1215 let mut uniq = Vec::<(u32, f32)>::with_capacity(sorted.len());
1217 let mut last_id: Option<u32> = None;
1218 for &(cand_id, cand_dist) in &sorted {
1219 if cand_id as usize == node_id {
1220 continue;
1221 }
1222 if last_id == Some(cand_id) {
1223 continue;
1224 }
1225 uniq.push((cand_id, cand_dist));
1226 last_id = Some(cand_id);
1227 }
1228
1229 let mut pruned = Vec::<u32>::with_capacity(max_degree);
1230
1231 for &(cand_id, cand_dist_to_node) in &uniq {
1233 let mut occluded = false;
1234
1235 for &sel_id in &pruned {
1236 let d_cand_sel = dist.eval(
1237 vectors.row(cand_id as usize),
1238 vectors.row(sel_id as usize),
1239 );
1240
1241 if alpha * d_cand_sel <= cand_dist_to_node {
1243 occluded = true;
1244 break;
1245 }
1246 }
1247
1248 if !occluded {
1249 pruned.push(cand_id);
1250 if pruned.len() >= max_degree {
1251 break;
1252 }
1253 }
1254 }
1255
1256 pruned
1257}
1258
1259#[cfg(test)]
1260mod tests {
1261 use super::*;
1262 use anndists::dist::{DistCosine, DistL2};
1263 use rand::Rng;
1264 use std::fs;
1265
1266 fn euclid(a: &[f32], b: &[f32]) -> f32 {
1267 a.iter()
1268 .zip(b)
1269 .map(|(x, y)| (x - y) * (x - y))
1270 .sum::<f32>()
1271 .sqrt()
1272 }
1273
1274 #[test]
1275 fn test_small_index_l2() {
1276 let path = "test_small_l2.db";
1277 let _ = fs::remove_file(path);
1278
1279 let vectors = vec![
1280 vec![0.0, 0.0],
1281 vec![1.0, 0.0],
1282 vec![0.0, 1.0],
1283 vec![1.0, 1.0],
1284 vec![0.5, 0.5],
1285 ];
1286
1287 let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1288
1289 let q = vec![0.1, 0.1];
1290 let nns = index.search(&q, 3, 8);
1291 assert_eq!(nns.len(), 3);
1292
1293 let v = index.get_vector(nns[0] as usize);
1294 assert!(euclid(&q, &v) < 1.0);
1295
1296 let _ = fs::remove_file(path);
1297 }
1298
1299 #[test]
1300 fn test_cosine() {
1301 let path = "test_cosine.db";
1302 let _ = fs::remove_file(path);
1303
1304 let vectors = vec![
1305 vec![1.0, 0.0, 0.0],
1306 vec![0.0, 1.0, 0.0],
1307 vec![0.0, 0.0, 1.0],
1308 vec![1.0, 1.0, 0.0],
1309 vec![1.0, 0.0, 1.0],
1310 ];
1311
1312 let index =
1313 DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
1314
1315 let q = vec![2.0, 0.0, 0.0];
1316 let nns = index.search(&q, 2, 8);
1317 assert_eq!(nns.len(), 2);
1318
1319 let v = index.get_vector(nns[0] as usize);
1320 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1321 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1322 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1323 let cos = dot / (n1 * n2);
1324 assert!(cos > 0.7);
1325
1326 let _ = fs::remove_file(path);
1327 }
1328
1329 #[test]
1330 fn test_persistence_and_open() {
1331 let path = "test_persist.db";
1332 let _ = fs::remove_file(path);
1333
1334 let vectors = vec![
1335 vec![0.0, 0.0],
1336 vec![1.0, 0.0],
1337 vec![0.0, 1.0],
1338 vec![1.0, 1.0],
1339 ];
1340
1341 {
1342 let _idx =
1343 DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
1344 }
1345
1346 let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
1347 assert_eq!(idx2.num_vectors, 4);
1348 assert_eq!(idx2.dim, 2);
1349
1350 let q = vec![0.9, 0.9];
1351 let res = idx2.search(&q, 2, 8);
1352 assert_eq!(res[0], 3);
1353
1354 let _ = fs::remove_file(path);
1355 }
1356
1357 #[test]
1358 fn test_grid_connectivity() {
1359 let path = "test_grid.db";
1360 let _ = fs::remove_file(path);
1361
1362 let mut vectors = Vec::new();
1363 for i in 0..5 {
1364 for j in 0..5 {
1365 vectors.push(vec![i as f32, j as f32]);
1366 }
1367 }
1368
1369 let index = DiskANN::<f32, DistL2>::build_index_with_params(
1370 &vectors,
1371 DistL2,
1372 path,
1373 DiskAnnParams {
1374 max_degree: 4,
1375 build_beam_width: 64,
1376 alpha: 1.5,
1377 passes: DISKANN_DEFAULT_PASSES,
1378 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1379 },
1380 )
1381 .unwrap();
1382
1383 for target in 0..vectors.len() {
1384 let q = &vectors[target];
1385 let nns = index.search(q, 10, 32);
1386 if !nns.contains(&(target as u32)) {
1387 let v = index.get_vector(nns[0] as usize);
1388 assert!(euclid(q, &v) < 2.0);
1389 }
1390 for &nb in nns.iter().take(5) {
1391 let v = index.get_vector(nb as usize);
1392 assert!(euclid(q, &v) < 5.0);
1393 }
1394 }
1395
1396 let _ = fs::remove_file(path);
1397 }
1398
1399 #[test]
1400 fn test_medium_random() {
1401 let path = "test_medium.db";
1402 let _ = fs::remove_file(path);
1403
1404 let n = 200usize;
1405 let d = 32usize;
1406 let mut rng = rand::thread_rng();
1407 let vectors: Vec<Vec<f32>> = (0..n)
1408 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1409 .collect();
1410
1411 let index = DiskANN::<f32, DistL2>::build_index_with_params(
1412 &vectors,
1413 DistL2,
1414 path,
1415 DiskAnnParams {
1416 max_degree: 32,
1417 build_beam_width: 128,
1418 alpha: 1.2,
1419 passes: DISKANN_DEFAULT_PASSES,
1420 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1421 },
1422 )
1423 .unwrap();
1424
1425 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1426 let res = index.search(&q, 10, 64);
1427 assert_eq!(res.len(), 10);
1428
1429 let dists: Vec<f32> = res
1430 .iter()
1431 .map(|&id| {
1432 let v = index.get_vector(id as usize);
1433 euclid(&q, &v)
1434 })
1435 .collect();
1436 let mut sorted = dists.clone();
1437 sorted.sort_by(|a, b| a.total_cmp(b));
1438 assert_eq!(dists, sorted);
1439
1440 let _ = fs::remove_file(path);
1441 }
1442}