1use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50const PAD_U32: u32 = u32::MAX;
52
53pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57pub const DISKANN_DEFAULT_PASSES: usize = 2;
59pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 2;
61
62#[derive(Clone, Copy, Debug)]
64pub struct DiskAnnParams {
65 pub max_degree: usize,
66 pub build_beam_width: usize,
67 pub alpha: f32,
68 pub passes: usize,
70 pub extra_seeds: usize,
72}
73
74impl Default for DiskAnnParams {
75 fn default() -> Self {
76 Self {
77 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
78 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
79 alpha: DISKANN_DEFAULT_ALPHA,
80 passes: DISKANN_DEFAULT_PASSES,
81 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
82 }
83 }
84}
85
86#[derive(Debug, Error)]
88pub enum DiskAnnError {
89 #[error("I/O error: {0}")]
91 Io(#[from] std::io::Error),
92
93 #[error("Serialization error: {0}")]
95 Bincode(#[from] bincode::Error),
96
97 #[error("Index error: {0}")]
99 IndexError(String),
100}
101
102#[derive(Serialize, Deserialize, Debug)]
104struct Metadata {
105 dim: usize,
106 num_vectors: usize,
107 max_degree: usize,
108 medoid_id: u32,
109 vectors_offset: u64,
110 adjacency_offset: u64,
111 elem_size: u8,
112 distance_name: String,
113}
114
115#[derive(Clone, Copy)]
117struct Candidate {
118 dist: f32,
119 id: u32,
120}
121impl PartialEq for Candidate {
122 fn eq(&self, other: &Self) -> bool {
123 self.dist == other.dist && self.id == other.id
124 }
125}
126impl Eq for Candidate {}
127impl PartialOrd for Candidate {
128 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
129 self.dist.partial_cmp(&other.dist)
130 }
131}
132impl Ord for Candidate {
133 fn cmp(&self, other: &Self) -> Ordering {
134 self.partial_cmp(other).unwrap_or(Ordering::Equal)
135 }
136}
137
138pub struct DiskANN<T, D>
140where
141 T: bytemuck::Pod + Copy + Send + Sync + 'static,
142 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
143{
144 pub dim: usize,
146 pub num_vectors: usize,
148 pub max_degree: usize,
150 pub distance_name: String,
152
153 medoid_id: u32,
155 vectors_offset: u64,
157 adjacency_offset: u64,
158
159 mmap: Mmap,
161
162 dist: D,
164
165 _phantom: PhantomData<T>,
167}
168
169impl<T, D> DiskANN<T, D>
172where
173 T: bytemuck::Pod + Copy + Send + Sync + 'static,
174 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
175{
176 pub fn build_index_default(
178 vectors: &[Vec<T>],
179 dist: D,
180 file_path: &str,
181 ) -> Result<Self, DiskAnnError> {
182 Self::build_index(
183 vectors,
184 DISKANN_DEFAULT_MAX_DEGREE,
185 DISKANN_DEFAULT_BUILD_BEAM,
186 DISKANN_DEFAULT_ALPHA,
187 DISKANN_DEFAULT_PASSES,
188 DISKANN_DEFAULT_EXTRA_SEEDS,
189 dist,
190 file_path,
191 )
192 }
193
194 pub fn build_index_with_params(
196 vectors: &[Vec<T>],
197 dist: D,
198 file_path: &str,
199 p: DiskAnnParams,
200 ) -> Result<Self, DiskAnnError> {
201 Self::build_index(
202 vectors,
203 p.max_degree,
204 p.build_beam_width,
205 p.alpha,
206 p.passes,
207 p.extra_seeds,
208 dist,
209 file_path,
210 )
211 }
212
213 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
215 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
216
217 let mut buf8 = [0u8; 8];
219 file.seek(SeekFrom::Start(0))?;
220 file.read_exact(&mut buf8)?;
221 let md_len = u64::from_le_bytes(buf8);
222
223 let mut md_bytes = vec![0u8; md_len as usize];
225 file.read_exact(&mut md_bytes)?;
226 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
227
228 let mmap = unsafe { memmap2::Mmap::map(&file)? };
229
230 let want = std::mem::size_of::<T>() as u8;
232 if metadata.elem_size != want {
233 return Err(DiskAnnError::IndexError(format!(
234 "element size mismatch: file has {}B, T is {}B",
235 metadata.elem_size, want
236 )));
237 }
238
239 let expected = std::any::type_name::<D>();
241 if metadata.distance_name != expected {
242 eprintln!(
243 "Warning: index recorded distance `{}` but you opened with `{}`",
244 metadata.distance_name, expected
245 );
246 }
247
248 Ok(Self {
249 dim: metadata.dim,
250 num_vectors: metadata.num_vectors,
251 max_degree: metadata.max_degree,
252 distance_name: metadata.distance_name,
253 medoid_id: metadata.medoid_id,
254 vectors_offset: metadata.vectors_offset,
255 adjacency_offset: metadata.adjacency_offset,
256 mmap,
257 dist,
258 _phantom: PhantomData,
259 })
260 }
261}
262
263impl<T, D> DiskANN<T, D>
265where
266 T: bytemuck::Pod + Copy + Send + Sync + 'static,
267 D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
268{
269 pub fn build_index_default_metric(
271 vectors: &[Vec<T>],
272 file_path: &str,
273 ) -> Result<Self, DiskAnnError> {
274 Self::build_index_default(vectors, D::default(), file_path)
275 }
276
277 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
279 Self::open_index_with(path, D::default())
280 }
281}
282
283impl<T, D> DiskANN<T, D>
284where
285 T: bytemuck::Pod + Copy + Send + Sync + 'static,
286 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
287{
288 pub fn build_index(
300 vectors: &[Vec<T>],
301 max_degree: usize,
302 build_beam_width: usize,
303 alpha: f32,
304 passes: usize,
305 extra_seeds: usize,
306 dist: D,
307 file_path: &str,
308 ) -> Result<Self, DiskAnnError> {
309 if vectors.is_empty() {
310 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
311 }
312
313 let num_vectors = vectors.len();
314 let dim = vectors[0].len();
315 for (i, v) in vectors.iter().enumerate() {
316 if v.len() != dim {
317 return Err(DiskAnnError::IndexError(format!(
318 "Vector {} has dimension {} but expected {}",
319 i,
320 v.len(),
321 dim
322 )));
323 }
324 }
325
326 let mut file = OpenOptions::new()
327 .create(true)
328 .write(true)
329 .read(true)
330 .truncate(true)
331 .open(file_path)?;
332
333 let vectors_offset = 1024 * 1024;
335 assert_eq!(
337 (vectors_offset as usize) % std::mem::align_of::<T>(),
338 0,
339 "vectors_offset must be aligned for T"
340 );
341
342 let elem_sz = std::mem::size_of::<T>() as u64;
343 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
344
345 file.seek(SeekFrom::Start(vectors_offset as u64))?;
347 for vector in vectors {
348 let bytes = bytemuck::cast_slice::<T, u8>(vector);
349 file.write_all(bytes)?;
350 }
351
352 let medoid_id = calculate_medoid(vectors, dist);
354
355 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
357 let graph = build_vamana_graph(
358 vectors,
359 max_degree,
360 build_beam_width,
361 alpha,
362 passes,
363 extra_seeds,
364 dist,
365 medoid_id as u32,
366 );
367
368 file.seek(SeekFrom::Start(adjacency_offset))?;
370 for neighbors in &graph {
371 let mut padded = neighbors.clone();
372 padded.resize(max_degree, PAD_U32);
373 let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
374 file.write_all(bytes)?;
375 }
376
377 let metadata = Metadata {
379 dim,
380 num_vectors,
381 max_degree,
382 medoid_id: medoid_id as u32,
383 vectors_offset: vectors_offset as u64,
384 adjacency_offset,
385 elem_size: std::mem::size_of::<T>() as u8,
386 distance_name: std::any::type_name::<D>().to_string(),
387 };
388
389 let md_bytes = bincode::serialize(&metadata)?;
390 file.seek(SeekFrom::Start(0))?;
391 let md_len = md_bytes.len() as u64;
392 file.write_all(&md_len.to_le_bytes())?;
393 file.write_all(&md_bytes)?;
394 file.sync_all()?;
395
396 let mmap = unsafe { memmap2::Mmap::map(&file)? };
398
399 Ok(Self {
400 dim,
401 num_vectors,
402 max_degree,
403 distance_name: metadata.distance_name,
404 medoid_id: metadata.medoid_id,
405 vectors_offset: metadata.vectors_offset,
406 adjacency_offset: metadata.adjacency_offset,
407 mmap,
408 dist,
409 _phantom: PhantomData,
410 })
411 }
412
413 pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
417 assert_eq!(
418 query.len(),
419 self.dim,
420 "Query dim {} != index dim {}",
421 query.len(),
422 self.dim
423 );
424
425 let mut visited = HashSet::new();
426 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = self.distance_to(query, self.medoid_id as usize);
431 let start = Candidate {
432 dist: start_dist,
433 id: self.medoid_id,
434 };
435 frontier.push(Reverse(start));
436 w.push(start);
437 visited.insert(self.medoid_id);
438
439 while let Some(Reverse(best)) = frontier.peek().copied() {
441 if w.len() >= beam_width {
442 if let Some(worst) = w.peek() {
443 if best.dist >= worst.dist {
444 break;
445 }
446 }
447 }
448 let Reverse(current) = frontier.pop().unwrap();
449
450 for &nb in self.get_neighbors(current.id) {
451 if nb == PAD_U32 {
452 continue;
453 }
454 if !visited.insert(nb) {
455 continue;
456 }
457
458 let d = self.distance_to(query, nb as usize);
459 let cand = Candidate { dist: d, id: nb };
460
461 if w.len() < beam_width {
462 w.push(cand);
463 frontier.push(Reverse(cand));
464 } else if d < w.peek().unwrap().dist {
465 w.pop();
466 w.push(cand);
467 frontier.push(Reverse(cand));
468 }
469 }
470 }
471
472 let mut results: Vec<_> = w.into_vec();
474 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
475 results.truncate(k);
476 results.into_iter().map(|c| (c.id, c.dist)).collect()
477 }
478
479 pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
481 self.search_with_dists(query, k, beam_width)
482 .into_iter()
483 .map(|(id, _dist)| id)
484 .collect()
485 }
486
487 fn get_neighbors(&self, node_id: u32) -> &[u32] {
489 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
490 let start = offset as usize;
491 let end = start + (self.max_degree * 4);
492 let bytes = &self.mmap[start..end];
493 bytemuck::cast_slice(bytes)
494 }
495
496 fn distance_to(&self, query: &[T], idx: usize) -> f32 {
498 let elem_sz = std::mem::size_of::<T>();
499 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
500 let start = offset as usize;
501 let end = start + (self.dim * elem_sz);
502 let bytes = &self.mmap[start..end];
503 let vector: &[T] = bytemuck::cast_slice(bytes);
504 self.dist.eval(query, vector)
505 }
506
507 pub fn get_vector(&self, idx: usize) -> Vec<T> {
509 let elem_sz = std::mem::size_of::<T>();
510 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
511 let start = offset as usize;
512 let end = start + (self.dim * elem_sz);
513 let bytes = &self.mmap[start..end];
514 let vector: &[T] = bytemuck::cast_slice(bytes);
515 vector.to_vec()
516 }
517}
518
519fn calculate_medoid<T, D>(vectors: &[Vec<T>], dist: D) -> usize
522where
523 T: bytemuck::Pod + Copy + Send + Sync,
524 D: Distance<T> + Copy + Sync,
525{
526 let n = vectors.len();
527 let k = 8.min(n); let mut rng = thread_rng();
529 let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
530
531 let (best_idx, _best_score) = (0..n)
532 .into_par_iter()
533 .map(|i| {
534 let score: f32 = pivots
535 .iter()
536 .map(|&p| dist.eval(&vectors[i], &vectors[p]))
537 .sum();
538 (i, score)
539 })
540 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
541
542 best_idx
543}
544
545fn build_vamana_graph<T, D>(
550 vectors: &[Vec<T>],
551 max_degree: usize,
552 build_beam_width: usize,
553 alpha: f32,
554 passes: usize,
555 extra_seeds: usize,
556 dist: D,
557 medoid_id: u32,
558) -> Vec<Vec<u32>>
559where
560 T: bytemuck::Pod + Copy + Send + Sync,
561 D: Distance<T> + Copy + Sync,
562{
563 let n = vectors.len();
564 let mut graph = vec![Vec::<u32>::new(); n];
565
566 {
568 let mut rng = thread_rng();
569 for i in 0..n {
570 let mut s = HashSet::new();
571 let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
572 while s.len() < target {
573 let nb = rng.gen_range(0..n);
574 if nb != i {
575 s.insert(nb as u32);
576 }
577 }
578 graph[i] = s.into_iter().collect();
579 }
580 }
581
582 let passes = passes.max(1); let mut rng = thread_rng();
585 for _pass in 0..passes {
586 let mut order: Vec<usize> = (0..n).collect();
588 order.shuffle(&mut rng);
589
590 let snapshot = &graph;
592
593 let new_graph: Vec<Vec<u32>> = order
595 .par_iter()
596 .map(|&u| {
597 let mut candidates: Vec<(u32, f32)> =
598 Vec::with_capacity(build_beam_width * (2 + extra_seeds));
599
600 for &nb in &snapshot[u] {
602 let d = dist.eval(&vectors[u], &vectors[nb as usize]);
603 candidates.push((nb, d));
604 }
605
606 let mut seeds = Vec::with_capacity(1 + extra_seeds);
608 seeds.push(medoid_id as usize);
609 let mut trng = thread_rng();
610 for _ in 0..extra_seeds {
611 seeds.push(trng.gen_range(0..n));
612 }
613
614 for start in seeds {
616 let mut part = greedy_search(
617 &vectors[u],
618 vectors,
619 snapshot,
620 start,
621 build_beam_width,
622 dist,
623 );
624 candidates.append(&mut part);
625 }
626
627 candidates.sort_by(|a, b| a.0.cmp(&b.0));
629 candidates.dedup_by(|a, b| {
630 if a.0 == b.0 {
631 if a.1 < b.1 {
632 *b = *a;
633 }
634 true
635 } else {
636 false
637 }
638 });
639
640 prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
642 })
643 .collect();
644
645 let mut pos_of = vec![0usize; n];
647 for (pos, &u) in order.iter().enumerate() {
648 pos_of[u] = pos;
649 }
650
651 let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
653
654 graph = (0..n)
656 .into_par_iter()
657 .map(|u| {
658 let ng = &new_graph[pos_of[u]]; let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
663 pool_ids.extend_from_slice(ng);
664 pool_ids.extend_from_slice(inc);
665 pool_ids.sort_unstable();
666 pool_ids.dedup();
667
668 let pool: Vec<(u32, f32)> = pool_ids
670 .into_iter()
671 .filter(|&id| id as usize != u)
672 .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
673 .collect();
674
675 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
676 })
677 .collect();
678 }
679
680 graph
682 .into_par_iter()
683 .enumerate()
684 .map(|(u, neigh)| {
685 if neigh.len() <= max_degree {
686 return neigh;
687 }
688 let pool: Vec<(u32, f32)> = neigh
689 .iter()
690 .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
691 .collect();
692 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
693 })
694 .collect()
695}
696
697fn greedy_search<T, D>(
700 query: &[T],
701 vectors: &[Vec<T>],
702 graph: &[Vec<u32>],
703 start_id: usize,
704 beam_width: usize,
705 dist: D,
706) -> Vec<(u32, f32)>
707where
708 T: bytemuck::Pod + Copy + Send + Sync,
709 D: Distance<T> + Copy,
710{
711 let mut visited = HashSet::new();
712 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = dist.eval(query, &vectors[start_id]);
716 let start = Candidate {
717 dist: start_dist,
718 id: start_id as u32,
719 };
720 frontier.push(Reverse(start));
721 w.push(start);
722 visited.insert(start_id as u32);
723
724 while let Some(Reverse(best)) = frontier.peek().copied() {
725 if w.len() >= beam_width {
726 if let Some(worst) = w.peek() {
727 if best.dist >= worst.dist {
728 break;
729 }
730 }
731 }
732 let Reverse(cur) = frontier.pop().unwrap();
733
734 for &nb in &graph[cur.id as usize] {
735 if !visited.insert(nb) {
736 continue;
737 }
738 let d = dist.eval(query, &vectors[nb as usize]);
739 let cand = Candidate { dist: d, id: nb };
740
741 if w.len() < beam_width {
742 w.push(cand);
743 frontier.push(Reverse(cand));
744 } else if d < w.peek().unwrap().dist {
745 w.pop();
746 w.push(cand);
747 frontier.push(Reverse(cand));
748 }
749 }
750 }
751
752 let mut v = w.into_vec();
753 v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
754 v.into_iter().map(|c| (c.id, c.dist)).collect()
755}
756
757fn prune_neighbors<T, D>(
759 node_id: usize,
760 candidates: &[(u32, f32)],
761 vectors: &[Vec<T>],
762 max_degree: usize,
763 alpha: f32,
764 dist: D,
765) -> Vec<u32>
766where
767 T: bytemuck::Pod + Copy + Send + Sync,
768 D: Distance<T> + Copy,
769{
770 if candidates.is_empty() {
771 return Vec::new();
772 }
773
774 let mut sorted = candidates.to_vec();
775 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
776
777 let mut pruned = Vec::<u32>::new();
778
779 for &(cand_id, cand_dist) in &sorted {
780 if cand_id as usize == node_id {
781 continue;
782 }
783 let mut ok = true;
784 for &sel in &pruned {
785 let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
786 if alpha * d <= cand_dist {
789 ok = false;
790 break;
791 }
792 }
793 if ok {
794 pruned.push(cand_id);
795 if pruned.len() >= max_degree {
796 break;
797 }
798 }
799 }
800
801 for &(cand_id, _) in &sorted {
803 if cand_id as usize == node_id {
804 continue;
805 }
806 if !pruned.contains(&cand_id) {
807 pruned.push(cand_id);
808 if pruned.len() >= max_degree {
809 break;
810 }
811 }
812 }
813
814 pruned
815}
816
817fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
818 let mut indeg = vec![0usize; n];
820 for (pos, _u) in order.iter().enumerate() {
821 for &v in &new_graph[pos] {
822 indeg[v as usize] += 1;
823 }
824 }
825 let mut off = vec![0usize; n + 1];
827 for i in 0..n {
828 off[i + 1] = off[i] + indeg[i];
829 }
830 let mut cur = off.clone();
832 let mut incoming_flat = vec![0u32; off[n]];
833 for (pos, &u) in order.iter().enumerate() {
834 for &v in &new_graph[pos] {
835 let idx = cur[v as usize];
836 incoming_flat[idx] = u as u32;
837 cur[v as usize] += 1;
838 }
839 }
840 (incoming_flat, off)
841}
842
843#[cfg(test)]
844mod tests {
845 use super::*;
846 use anndists::dist::{DistCosine, DistL2};
847 use rand::Rng;
848 use std::fs;
849
850 fn euclid(a: &[f32], b: &[f32]) -> f32 {
851 a.iter()
852 .zip(b)
853 .map(|(x, y)| (x - y) * (x - y))
854 .sum::<f32>()
855 .sqrt()
856 }
857
858 #[test]
859 fn test_small_index_l2() {
860 let path = "test_small_l2.db";
861 let _ = fs::remove_file(path);
862
863 let vectors = vec![
864 vec![0.0, 0.0],
865 vec![1.0, 0.0],
866 vec![0.0, 1.0],
867 vec![1.0, 1.0],
868 vec![0.5, 0.5],
869 ];
870
871 let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
872
873 let q = vec![0.1, 0.1];
874 let nns = index.search(&q, 3, 8);
875 assert_eq!(nns.len(), 3);
876
877 let v = index.get_vector(nns[0] as usize);
879 assert!(euclid(&q, &v) < 1.0);
880
881 let _ = fs::remove_file(path);
882 }
883
884 #[test]
885 fn test_cosine() {
886 let path = "test_cosine.db";
887 let _ = fs::remove_file(path);
888
889 let vectors = vec![
890 vec![1.0, 0.0, 0.0],
891 vec![0.0, 1.0, 0.0],
892 vec![0.0, 0.0, 1.0],
893 vec![1.0, 1.0, 0.0],
894 vec![1.0, 0.0, 1.0],
895 ];
896
897 let index =
898 DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
899
900 let q = vec![2.0, 0.0, 0.0]; let nns = index.search(&q, 2, 8);
902 assert_eq!(nns.len(), 2);
903
904 let v = index.get_vector(nns[0] as usize);
906 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
907 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
908 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
909 let cos = dot / (n1 * n2);
910 assert!(cos > 0.7);
911
912 let _ = fs::remove_file(path);
913 }
914
915 #[test]
916 fn test_persistence_and_open() {
917 let path = "test_persist.db";
918 let _ = fs::remove_file(path);
919
920 let vectors = vec![
921 vec![0.0, 0.0],
922 vec![1.0, 0.0],
923 vec![0.0, 1.0],
924 vec![1.0, 1.0],
925 ];
926
927 {
928 let _idx =
929 DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
930 }
931
932 let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
934 assert_eq!(idx2.num_vectors, 4);
935 assert_eq!(idx2.dim, 2);
936
937 let q = vec![0.9, 0.9];
938 let res = idx2.search(&q, 2, 8);
939 assert_eq!(res[0], 3);
941
942 let _ = fs::remove_file(path);
943 }
944
945 #[test]
946 fn test_grid_connectivity() {
947 let path = "test_grid.db";
948 let _ = fs::remove_file(path);
949
950 let mut vectors = Vec::new();
952 for i in 0..5 {
953 for j in 0..5 {
954 vectors.push(vec![i as f32, j as f32]);
955 }
956 }
957
958 let index = DiskANN::<f32, DistL2>::build_index_with_params(
959 &vectors,
960 DistL2,
961 path,
962 DiskAnnParams {
963 max_degree: 4,
964 build_beam_width: 64,
965 alpha: 1.5,
966 passes: DISKANN_DEFAULT_PASSES,
967 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
968 },
969 )
970 .unwrap();
971
972 for target in 0..vectors.len() {
973 let q = &vectors[target];
974 let nns = index.search(q, 10, 32);
975 if !nns.contains(&(target as u32)) {
976 let v = index.get_vector(nns[0] as usize);
977 assert!(euclid(q, &v) < 2.0);
978 }
979 for &nb in nns.iter().take(5) {
980 let v = index.get_vector(nb as usize);
981 assert!(euclid(q, &v) < 5.0);
982 }
983 }
984
985 let _ = fs::remove_file(path);
986 }
987
988 #[test]
989 fn test_medium_random() {
990 let path = "test_medium.db";
991 let _ = fs::remove_file(path);
992
993 let n = 200usize;
994 let d = 32usize;
995 let mut rng = rand::thread_rng();
996 let vectors: Vec<Vec<f32>> = (0..n)
997 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
998 .collect();
999
1000 let index = DiskANN::<f32, DistL2>::build_index_with_params(
1001 &vectors,
1002 DistL2,
1003 path,
1004 DiskAnnParams {
1005 max_degree: 32,
1006 build_beam_width: 128,
1007 alpha: 1.2,
1008 passes: DISKANN_DEFAULT_PASSES,
1009 extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
1010 },
1011 )
1012 .unwrap();
1013
1014 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1015 let res = index.search(&q, 10, 64);
1016 assert_eq!(res.len(), 10);
1017
1018 let dists: Vec<f32> = res
1020 .iter()
1021 .map(|&id| {
1022 let v = index.get_vector(id as usize);
1023 euclid(&q, &v)
1024 })
1025 .collect();
1026 let mut sorted = dists.clone();
1027 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1028 assert_eq!(dists, sorted);
1029
1030 let _ = fs::remove_file(path);
1031 }
1032}