1use anndists::prelude::Distance;
39use memmap2::Mmap;
40use rand::{prelude::*, thread_rng};
41use rayon::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::cmp::{Ordering, Reverse};
44use std::collections::{BinaryHeap, HashSet};
45use std::fs::OpenOptions;
46use std::io::{Read, Seek, SeekFrom, Write};
47use std::marker::PhantomData;
48use thiserror::Error;
49
50const PAD_U32: u32 = u32::MAX;
52
53pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
55pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
56pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
57
58#[derive(Clone, Copy, Debug)]
60pub struct DiskAnnParams {
61 pub max_degree: usize,
62 pub build_beam_width: usize,
63 pub alpha: f32,
64}
65impl Default for DiskAnnParams {
66 fn default() -> Self {
67 Self {
68 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
69 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
70 alpha: DISKANN_DEFAULT_ALPHA,
71 }
72 }
73}
74
75#[derive(Debug, Error)]
77pub enum DiskAnnError {
78 #[error("I/O error: {0}")]
80 Io(#[from] std::io::Error),
81
82 #[error("Serialization error: {0}")]
84 Bincode(#[from] bincode::Error),
85
86 #[error("Index error: {0}")]
88 IndexError(String),
89}
90
91#[derive(Serialize, Deserialize, Debug)]
93struct Metadata {
94 dim: usize,
95 num_vectors: usize,
96 max_degree: usize,
97 medoid_id: u32,
98 vectors_offset: u64,
99 adjacency_offset: u64,
100 elem_size: u8,
101 distance_name: String,
102}
103
104#[derive(Clone, Copy)]
106struct Candidate {
107 dist: f32,
108 id: u32,
109}
110impl PartialEq for Candidate {
111 fn eq(&self, other: &Self) -> bool {
112 self.dist == other.dist && self.id == other.id
113 }
114}
115impl Eq for Candidate {}
116impl PartialOrd for Candidate {
117 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
118 self.dist.partial_cmp(&other.dist)
119 }
120}
121impl Ord for Candidate {
122 fn cmp(&self, other: &Self) -> Ordering {
123 self.partial_cmp(other).unwrap_or(Ordering::Equal)
124 }
125}
126
127pub struct DiskANN<T, D>
129where
130 T: bytemuck::Pod + Copy + Send + Sync + 'static,
131 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
132{
133 pub dim: usize,
135 pub num_vectors: usize,
137 pub max_degree: usize,
139 pub distance_name: String,
141
142 medoid_id: u32,
144 vectors_offset: u64,
146 adjacency_offset: u64,
147
148 mmap: Mmap,
150
151 dist: D,
153
154 _phantom: PhantomData<T>,
156}
157
158impl<T, D> DiskANN<T, D>
161where
162 T: bytemuck::Pod + Copy + Send + Sync + 'static,
163 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
164{
165 pub fn build_index_default(
167 vectors: &[Vec<T>],
168 dist: D,
169 file_path: &str,
170 ) -> Result<Self, DiskAnnError> {
171 Self::build_index(
172 vectors,
173 DISKANN_DEFAULT_MAX_DEGREE,
174 DISKANN_DEFAULT_BUILD_BEAM,
175 DISKANN_DEFAULT_ALPHA,
176 dist,
177 file_path,
178 )
179 }
180
181 pub fn build_index_with_params(
183 vectors: &[Vec<T>],
184 dist: D,
185 file_path: &str,
186 p: DiskAnnParams,
187 ) -> Result<Self, DiskAnnError> {
188 Self::build_index(
189 vectors,
190 p.max_degree,
191 p.build_beam_width,
192 p.alpha,
193 dist,
194 file_path,
195 )
196 }
197
198 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
200 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
201
202 let mut buf8 = [0u8; 8];
204 file.seek(SeekFrom::Start(0))?;
205 file.read_exact(&mut buf8)?;
206 let md_len = u64::from_le_bytes(buf8);
207
208 let mut md_bytes = vec![0u8; md_len as usize];
210 file.read_exact(&mut md_bytes)?;
211 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
212
213 let mmap = unsafe { memmap2::Mmap::map(&file)? };
214
215 let want = std::mem::size_of::<T>() as u8;
217 if metadata.elem_size != want {
218 return Err(DiskAnnError::IndexError(format!(
219 "element size mismatch: file has {}B, T is {}B",
220 metadata.elem_size, want
221 )));
222 }
223
224 let expected = std::any::type_name::<D>();
226 if metadata.distance_name != expected {
227 eprintln!(
228 "Warning: index recorded distance `{}` but you opened with `{}`",
229 metadata.distance_name, expected
230 );
231 }
232
233 Ok(Self {
234 dim: metadata.dim,
235 num_vectors: metadata.num_vectors,
236 max_degree: metadata.max_degree,
237 distance_name: metadata.distance_name,
238 medoid_id: metadata.medoid_id,
239 vectors_offset: metadata.vectors_offset,
240 adjacency_offset: metadata.adjacency_offset,
241 mmap,
242 dist,
243 _phantom: PhantomData,
244 })
245 }
246}
247
248impl<T, D> DiskANN<T, D>
250where
251 T: bytemuck::Pod + Copy + Send + Sync + 'static,
252 D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
253{
254 pub fn build_index_default_metric(
256 vectors: &[Vec<T>],
257 file_path: &str,
258 ) -> Result<Self, DiskAnnError> {
259 Self::build_index_default(vectors, D::default(), file_path)
260 }
261
262 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
264 Self::open_index_with(path, D::default())
265 }
266}
267
268impl<T, D> DiskANN<T, D>
269where
270 T: bytemuck::Pod + Copy + Send + Sync + 'static,
271 D: Distance<T> + Send + Sync + Copy + Clone + 'static,
272{
273 pub fn build_index(
283 vectors: &[Vec<T>],
284 max_degree: usize,
285 build_beam_width: usize,
286 alpha: f32,
287 dist: D,
288 file_path: &str,
289 ) -> Result<Self, DiskAnnError> {
290 if vectors.is_empty() {
291 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
292 }
293
294 let num_vectors = vectors.len();
295 let dim = vectors[0].len();
296 for (i, v) in vectors.iter().enumerate() {
297 if v.len() != dim {
298 return Err(DiskAnnError::IndexError(format!(
299 "Vector {} has dimension {} but expected {}",
300 i,
301 v.len(),
302 dim
303 )));
304 }
305 }
306
307 let mut file = OpenOptions::new()
308 .create(true)
309 .write(true)
310 .read(true)
311 .truncate(true)
312 .open(file_path)?;
313
314 let vectors_offset = 1024 * 1024;
316 assert_eq!(
318 (vectors_offset as usize) % std::mem::align_of::<T>(),
319 0,
320 "vectors_offset must be aligned for T"
321 );
322
323 let elem_sz = std::mem::size_of::<T>() as u64;
324 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
325
326 file.seek(SeekFrom::Start(vectors_offset as u64))?;
328 for vector in vectors {
329 let bytes = bytemuck::cast_slice::<T, u8>(vector);
330 file.write_all(bytes)?;
331 }
332
333 let medoid_id = calculate_medoid(vectors, dist);
335
336 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
338 let graph = build_vamana_graph(
339 vectors,
340 max_degree,
341 build_beam_width,
342 alpha,
343 dist,
344 medoid_id as u32,
345 );
346
347 file.seek(SeekFrom::Start(adjacency_offset))?;
349 for neighbors in &graph {
350 let mut padded = neighbors.clone();
351 padded.resize(max_degree, PAD_U32);
352 let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
353 file.write_all(bytes)?;
354 }
355
356 let metadata = Metadata {
358 dim,
359 num_vectors,
360 max_degree,
361 medoid_id: medoid_id as u32,
362 vectors_offset: vectors_offset as u64,
363 adjacency_offset,
364 elem_size: std::mem::size_of::<T>() as u8,
365 distance_name: std::any::type_name::<D>().to_string(),
366 };
367
368 let md_bytes = bincode::serialize(&metadata)?;
369 file.seek(SeekFrom::Start(0))?;
370 let md_len = md_bytes.len() as u64;
371 file.write_all(&md_len.to_le_bytes())?;
372 file.write_all(&md_bytes)?;
373 file.sync_all()?;
374
375 let mmap = unsafe { memmap2::Mmap::map(&file)? };
377
378 Ok(Self {
379 dim,
380 num_vectors,
381 max_degree,
382 distance_name: metadata.distance_name,
383 medoid_id: metadata.medoid_id,
384 vectors_offset: metadata.vectors_offset,
385 adjacency_offset: metadata.adjacency_offset,
386 mmap,
387 dist,
388 _phantom: PhantomData,
389 })
390 }
391
392 pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
396 assert_eq!(
397 query.len(),
398 self.dim,
399 "Query dim {} != index dim {}",
400 query.len(),
401 self.dim
402 );
403
404 let mut visited = HashSet::new();
405 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = self.distance_to(query, self.medoid_id as usize);
410 let start = Candidate {
411 dist: start_dist,
412 id: self.medoid_id,
413 };
414 frontier.push(Reverse(start));
415 w.push(start);
416 visited.insert(self.medoid_id);
417
418 while let Some(Reverse(best)) = frontier.peek().copied() {
420 if w.len() >= beam_width {
421 if let Some(worst) = w.peek() {
422 if best.dist >= worst.dist {
423 break;
424 }
425 }
426 }
427 let Reverse(current) = frontier.pop().unwrap();
428
429 for &nb in self.get_neighbors(current.id) {
430 if nb == PAD_U32 {
431 continue;
432 }
433 if !visited.insert(nb) {
434 continue;
435 }
436
437 let d = self.distance_to(query, nb as usize);
438 let cand = Candidate { dist: d, id: nb };
439
440 if w.len() < beam_width {
441 w.push(cand);
442 frontier.push(Reverse(cand));
443 } else if d < w.peek().unwrap().dist {
444 w.pop();
445 w.push(cand);
446 frontier.push(Reverse(cand));
447 }
448 }
449 }
450
451 let mut results: Vec<_> = w.into_vec();
453 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
454 results.truncate(k);
455 results.into_iter().map(|c| (c.id, c.dist)).collect()
456 }
457
458 pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
460 self.search_with_dists(query, k, beam_width)
461 .into_iter()
462 .map(|(id, _dist)| id)
463 .collect()
464 }
465
466 fn get_neighbors(&self, node_id: u32) -> &[u32] {
468 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
469 let start = offset as usize;
470 let end = start + (self.max_degree * 4);
471 let bytes = &self.mmap[start..end];
472 bytemuck::cast_slice(bytes)
473 }
474
475 fn distance_to(&self, query: &[T], idx: usize) -> f32 {
477 let elem_sz = std::mem::size_of::<T>();
478 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
479 let start = offset as usize;
480 let end = start + (self.dim * elem_sz);
481 let bytes = &self.mmap[start..end];
482 let vector: &[T] = bytemuck::cast_slice(bytes);
483 self.dist.eval(query, vector)
484 }
485
486 pub fn get_vector(&self, idx: usize) -> Vec<T> {
488 let elem_sz = std::mem::size_of::<T>();
489 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
490 let start = offset as usize;
491 let end = start + (self.dim * elem_sz);
492 let bytes = &self.mmap[start..end];
493 let vector: &[T] = bytemuck::cast_slice(bytes);
494 vector.to_vec()
495 }
496}
497
498fn calculate_medoid<T, D>(vectors: &[Vec<T>], dist: D) -> usize
501where
502 T: bytemuck::Pod + Copy + Send + Sync,
503 D: Distance<T> + Copy + Sync,
504{
505 let n = vectors.len();
506 let k = 8.min(n); let mut rng = thread_rng();
508 let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
509
510 let (best_idx, _best_score) = (0..n)
511 .into_par_iter()
512 .map(|i| {
513 let score: f32 = pivots
514 .iter()
515 .map(|&p| dist.eval(&vectors[i], &vectors[p]))
516 .sum();
517 (i, score)
518 })
519 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
520
521 best_idx
522}
523
524fn build_vamana_graph<T, D>(
529 vectors: &[Vec<T>],
530 max_degree: usize,
531 build_beam_width: usize,
532 alpha: f32,
533 dist: D,
534 medoid_id: u32,
535) -> Vec<Vec<u32>>
536where
537 T: bytemuck::Pod + Copy + Send + Sync,
538 D: Distance<T> + Copy + Sync,
539{
540 let n = vectors.len();
541 let mut graph = vec![Vec::<u32>::new(); n];
542
543 {
545 let mut rng = thread_rng();
546 for i in 0..n {
547 let mut s = HashSet::new();
548 let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
549 while s.len() < target {
550 let nb = rng.gen_range(0..n);
551 if nb != i {
552 s.insert(nb as u32);
553 }
554 }
555 graph[i] = s.into_iter().collect();
556 }
557 }
558
559 const PASSES: usize = 2;
561 const EXTRA_SEEDS: usize = 2;
562
563 let mut rng = thread_rng();
564 for _pass in 0..PASSES {
565 let mut order: Vec<usize> = (0..n).collect();
567 order.shuffle(&mut rng);
568
569 let snapshot = &graph;
571
572 let new_graph: Vec<Vec<u32>> = order
574 .par_iter()
575 .map(|&u| {
576 let mut candidates: Vec<(u32, f32)> =
577 Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
578
579 for &nb in &snapshot[u] {
581 let d = dist.eval(&vectors[u], &vectors[nb as usize]);
582 candidates.push((nb, d));
583 }
584
585 let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
587 seeds.push(medoid_id as usize);
588 let mut trng = thread_rng();
589 for _ in 0..EXTRA_SEEDS {
590 seeds.push(trng.gen_range(0..n));
591 }
592
593 for start in seeds {
595 let mut part =
596 greedy_search(&vectors[u], vectors, snapshot, start, build_beam_width, dist);
597 candidates.append(&mut part);
598 }
599
600 candidates.sort_by(|a, b| a.0.cmp(&b.0));
602 candidates.dedup_by(|a, b| {
603 if a.0 == b.0 {
604 if a.1 < b.1 {
605 *b = *a;
606 }
607 true
608 } else {
609 false
610 }
611 });
612
613 prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
615 })
616 .collect();
617
618 let mut pos_of = vec![0usize; n];
620 for (pos, &u) in order.iter().enumerate() {
621 pos_of[u] = pos;
622 }
623
624 let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
626
627 graph = (0..n)
629 .into_par_iter()
630 .map(|u| {
631 let ng = &new_graph[pos_of[u]]; let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
636 pool_ids.extend_from_slice(ng);
637 pool_ids.extend_from_slice(inc);
638 pool_ids.sort_unstable();
639 pool_ids.dedup();
640
641 let pool: Vec<(u32, f32)> = pool_ids
643 .into_iter()
644 .filter(|&id| id as usize != u)
645 .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
646 .collect();
647
648 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
649 })
650 .collect();
651 }
652
653 graph
655 .into_par_iter()
656 .enumerate()
657 .map(|(u, neigh)| {
658 if neigh.len() <= max_degree {
659 return neigh;
660 }
661 let pool: Vec<(u32, f32)> = neigh
662 .iter()
663 .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
664 .collect();
665 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
666 })
667 .collect()
668}
669
670fn greedy_search<T, D>(
673 query: &[T],
674 vectors: &[Vec<T>],
675 graph: &[Vec<u32>],
676 start_id: usize,
677 beam_width: usize,
678 dist: D,
679) -> Vec<(u32, f32)>
680where
681 T: bytemuck::Pod + Copy + Send + Sync,
682 D: Distance<T> + Copy,
683{
684 let mut visited = HashSet::new();
685 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = dist.eval(query, &vectors[start_id]);
689 let start = Candidate {
690 dist: start_dist,
691 id: start_id as u32,
692 };
693 frontier.push(Reverse(start));
694 w.push(start);
695 visited.insert(start_id as u32);
696
697 while let Some(Reverse(best)) = frontier.peek().copied() {
698 if w.len() >= beam_width {
699 if let Some(worst) = w.peek() {
700 if best.dist >= worst.dist {
701 break;
702 }
703 }
704 }
705 let Reverse(cur) = frontier.pop().unwrap();
706
707 for &nb in &graph[cur.id as usize] {
708 if !visited.insert(nb) {
709 continue;
710 }
711 let d = dist.eval(query, &vectors[nb as usize]);
712 let cand = Candidate { dist: d, id: nb };
713
714 if w.len() < beam_width {
715 w.push(cand);
716 frontier.push(Reverse(cand));
717 } else if d < w.peek().unwrap().dist {
718 w.pop();
719 w.push(cand);
720 frontier.push(Reverse(cand));
721 }
722 }
723 }
724
725 let mut v = w.into_vec();
726 v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
727 v.into_iter().map(|c| (c.id, c.dist)).collect()
728}
729
730fn prune_neighbors<T, D>(
732 node_id: usize,
733 candidates: &[(u32, f32)],
734 vectors: &[Vec<T>],
735 max_degree: usize,
736 alpha: f32,
737 dist: D,
738) -> Vec<u32>
739where
740 T: bytemuck::Pod + Copy + Send + Sync,
741 D: Distance<T> + Copy,
742{
743 if candidates.is_empty() {
744 return Vec::new();
745 }
746
747 let mut sorted = candidates.to_vec();
748 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
749
750 let mut pruned = Vec::<u32>::new();
751
752 for &(cand_id, cand_dist) in &sorted {
753 if cand_id as usize == node_id {
754 continue;
755 }
756 let mut ok = true;
757 for &sel in &pruned {
758 let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
759 if d < alpha * cand_dist {
760 ok = false;
761 break;
762 }
763 }
764 if ok {
765 pruned.push(cand_id);
766 if pruned.len() >= max_degree {
767 break;
768 }
769 }
770 }
771
772 for &(cand_id, _) in &sorted {
774 if cand_id as usize == node_id {
775 continue;
776 }
777 if !pruned.contains(&cand_id) {
778 pruned.push(cand_id);
779 if pruned.len() >= max_degree {
780 break;
781 }
782 }
783 }
784
785 pruned
786}
787
788fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
789 let mut indeg = vec![0usize; n];
791 for (pos, _u) in order.iter().enumerate() {
792 for &v in &new_graph[pos] {
793 indeg[v as usize] += 1;
794 }
795 }
796 let mut off = vec![0usize; n + 1];
798 for i in 0..n {
799 off[i + 1] = off[i] + indeg[i];
800 }
801 let mut cur = off.clone();
803 let mut incoming_flat = vec![0u32; off[n]];
804 for (pos, &u) in order.iter().enumerate() {
805 for &v in &new_graph[pos] {
806 let idx = cur[v as usize];
807 incoming_flat[idx] = u as u32;
808 cur[v as usize] += 1;
809 }
810 }
811 (incoming_flat, off)
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817 use anndists::dist::{DistCosine, DistL2};
818 use rand::Rng;
819 use std::fs;
820
821 fn euclid(a: &[f32], b: &[f32]) -> f32 {
822 a.iter()
823 .zip(b)
824 .map(|(x, y)| (x - y) * (x - y))
825 .sum::<f32>()
826 .sqrt()
827 }
828
829 #[test]
830 fn test_small_index_l2() {
831 let path = "test_small_l2.db";
832 let _ = fs::remove_file(path);
833
834 let vectors = vec![
835 vec![0.0, 0.0],
836 vec![1.0, 0.0],
837 vec![0.0, 1.0],
838 vec![1.0, 1.0],
839 vec![0.5, 0.5],
840 ];
841
842 let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
843
844 let q = vec![0.1, 0.1];
845 let nns = index.search(&q, 3, 8);
846 assert_eq!(nns.len(), 3);
847
848 let v = index.get_vector(nns[0] as usize);
850 assert!(euclid(&q, &v) < 1.0);
851
852 let _ = fs::remove_file(path);
853 }
854
855 #[test]
856 fn test_cosine() {
857 let path = "test_cosine.db";
858 let _ = fs::remove_file(path);
859
860 let vectors = vec![
861 vec![1.0, 0.0, 0.0],
862 vec![0.0, 1.0, 0.0],
863 vec![0.0, 0.0, 1.0],
864 vec![1.0, 1.0, 0.0],
865 vec![1.0, 0.0, 1.0],
866 ];
867
868 let index =
869 DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
870
871 let q = vec![2.0, 0.0, 0.0]; let nns = index.search(&q, 2, 8);
873 assert_eq!(nns.len(), 2);
874
875 let v = index.get_vector(nns[0] as usize);
877 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
878 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
879 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
880 let cos = dot / (n1 * n2);
881 assert!(cos > 0.7);
882
883 let _ = fs::remove_file(path);
884 }
885
886 #[test]
887 fn test_persistence_and_open() {
888 let path = "test_persist.db";
889 let _ = fs::remove_file(path);
890
891 let vectors = vec![
892 vec![0.0, 0.0],
893 vec![1.0, 0.0],
894 vec![0.0, 1.0],
895 vec![1.0, 1.0],
896 ];
897
898 {
899 let _idx =
900 DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
901 }
902
903 let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
905 assert_eq!(idx2.num_vectors, 4);
906 assert_eq!(idx2.dim, 2);
907
908 let q = vec![0.9, 0.9];
909 let res = idx2.search(&q, 2, 8);
910 assert_eq!(res[0], 3);
912
913 let _ = fs::remove_file(path);
914 }
915
916 #[test]
917 fn test_grid_connectivity() {
918 let path = "test_grid.db";
919 let _ = fs::remove_file(path);
920
921 let mut vectors = Vec::new();
923 for i in 0..5 {
924 for j in 0..5 {
925 vectors.push(vec![i as f32, j as f32]);
926 }
927 }
928
929 let index = DiskANN::<f32, DistL2>::build_index_with_params(
930 &vectors,
931 DistL2,
932 path,
933 DiskAnnParams {
934 max_degree: 4,
935 build_beam_width: 64,
936 alpha: 1.5,
937 },
938 )
939 .unwrap();
940
941 for target in 0..vectors.len() {
942 let q = &vectors[target];
943 let nns = index.search(q, 10, 32);
944 if !nns.contains(&(target as u32)) {
945 let v = index.get_vector(nns[0] as usize);
946 assert!(euclid(q, &v) < 2.0);
947 }
948 for &nb in nns.iter().take(5) {
949 let v = index.get_vector(nb as usize);
950 assert!(euclid(q, &v) < 5.0);
951 }
952 }
953
954 let _ = fs::remove_file(path);
955 }
956
957 #[test]
958 fn test_medium_random() {
959 let path = "test_medium.db";
960 let _ = fs::remove_file(path);
961
962 let n = 200usize;
963 let d = 32usize;
964 let mut rng = rand::thread_rng();
965 let vectors: Vec<Vec<f32>> = (0..n)
966 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
967 .collect();
968
969 let index = DiskANN::<f32, DistL2>::build_index_with_params(
970 &vectors,
971 DistL2,
972 path,
973 DiskAnnParams {
974 max_degree: 32,
975 build_beam_width: 128,
976 alpha: 1.2,
977 },
978 )
979 .unwrap();
980
981 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
982 let res = index.search(&q, 10, 64);
983 assert_eq!(res.len(), 10);
984
985 let dists: Vec<f32> = res
987 .iter()
988 .map(|&id| {
989 let v = index.get_vector(id as usize);
990 euclid(&q, &v)
991 })
992 .collect();
993 let mut sorted = dists.clone();
994 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
995 assert_eq!(dists, sorted);
996
997 let _ = fs::remove_file(path);
998 }
999}