1mod incremental;
63mod filtered;
64pub mod simd;
65pub mod pq;
66
67pub use incremental::{
68 IncrementalDiskANN, IncrementalConfig, IncrementalStats,
69 is_delta_id, delta_local_idx,
70};
71
72pub use filtered::{FilteredDiskANN, Filter};
73
74pub use simd::{SimdL2, SimdDot, SimdCosine, simd_info};
75
76pub use pq::{ProductQuantizer, PQConfig, PQStats};
77
78use anndists::prelude::Distance;
79use bytemuck;
80use memmap2::Mmap;
81use rand::prelude::*;
82use rayon::prelude::*;
83use serde::{Deserialize, Serialize};
84use std::cmp::{Ordering, Reverse};
85use std::collections::{BinaryHeap, HashSet};
86use std::fs::OpenOptions;
87use std::io::{Read, Seek, SeekFrom, Write};
88use thiserror::Error;
89
90const PAD_U32: u32 = u32::MAX;
92
93pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
95pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
96pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
97
98#[derive(Clone, Copy, Debug)]
100pub struct DiskAnnParams {
101 pub max_degree: usize,
102 pub build_beam_width: usize,
103 pub alpha: f32,
104}
105impl Default for DiskAnnParams {
106 fn default() -> Self {
107 Self {
108 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
109 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
110 alpha: DISKANN_DEFAULT_ALPHA,
111 }
112 }
113}
114
115#[derive(Debug, Error)]
117pub enum DiskAnnError {
118 #[error("I/O error: {0}")]
120 Io(#[from] std::io::Error),
121
122 #[error("Serialization error: {0}")]
124 Bincode(#[from] bincode::Error),
125
126 #[error("Index error: {0}")]
128 IndexError(String),
129}
130
131#[derive(Serialize, Deserialize, Debug)]
133struct Metadata {
134 dim: usize,
135 num_vectors: usize,
136 max_degree: usize,
137 medoid_id: u32,
138 vectors_offset: u64,
139 adjacency_offset: u64,
140 distance_name: String,
141}
142
143#[derive(Clone, Copy)]
145struct Candidate {
146 dist: f32,
147 id: u32,
148}
149impl PartialEq for Candidate {
150 fn eq(&self, other: &Self) -> bool {
151 self.dist == other.dist && self.id == other.id
152 }
153}
154impl Eq for Candidate {}
155impl PartialOrd for Candidate {
156 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
157 self.dist.partial_cmp(&other.dist)
159 }
160}
161impl Ord for Candidate {
162 fn cmp(&self, other: &Self) -> Ordering {
163 self.partial_cmp(other).unwrap_or(Ordering::Equal)
164 }
165}
166
167pub struct DiskANN<D>
169where
170 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
171{
172 pub dim: usize,
174 pub num_vectors: usize,
176 pub max_degree: usize,
178 pub distance_name: String,
180
181 pub(crate) medoid_id: u32,
183 pub(crate) vectors_offset: u64,
185 pub(crate) adjacency_offset: u64,
186
187 pub(crate) mmap: Mmap,
189
190 pub(crate) dist: D,
192}
193
194impl<D> DiskANN<D>
197where
198 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
199{
200 pub fn build_index_default(
202 vectors: &[Vec<f32>],
203 dist: D,
204 file_path: &str,
205 ) -> Result<Self, DiskAnnError> {
206 Self::build_index(
207 vectors,
208 DISKANN_DEFAULT_MAX_DEGREE,
209 DISKANN_DEFAULT_BUILD_BEAM,
210 DISKANN_DEFAULT_ALPHA,
211 dist,
212 file_path,
213 )
214 }
215
216 pub fn build_index_with_params(
218 vectors: &[Vec<f32>],
219 dist: D,
220 file_path: &str,
221 p: DiskAnnParams,
222 ) -> Result<Self, DiskAnnError> {
223 Self::build_index(
224 vectors,
225 p.max_degree,
226 p.build_beam_width,
227 p.alpha,
228 dist,
229 file_path,
230 )
231 }
232}
233
234impl<D> DiskANN<D>
236where
237 D: Distance<f32> + Default + Send + Sync + Copy + Clone + 'static,
238{
239 pub fn build_index_default_metric(
241 vectors: &[Vec<f32>],
242 file_path: &str,
243 ) -> Result<Self, DiskAnnError> {
244 Self::build_index_default(vectors, D::default(), file_path)
245 }
246
247 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
249 Self::open_index_with(path, D::default())
250 }
251}
252
253impl<D> DiskANN<D>
254where
255 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
256{
257 pub fn build_index(
267 vectors: &[Vec<f32>],
268 max_degree: usize,
269 build_beam_width: usize,
270 alpha: f32,
271 dist: D,
272 file_path: &str,
273 ) -> Result<Self, DiskAnnError> {
274 if vectors.is_empty() {
275 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
276 }
277
278 let num_vectors = vectors.len();
279 let dim = vectors[0].len();
280 for (i, v) in vectors.iter().enumerate() {
281 if v.len() != dim {
282 return Err(DiskAnnError::IndexError(format!(
283 "Vector {} has dimension {} but expected {}",
284 i,
285 v.len(),
286 dim
287 )));
288 }
289 }
290
291 let mut file = OpenOptions::new()
292 .create(true)
293 .write(true)
294 .read(true)
295 .truncate(true)
296 .open(file_path)?;
297
298 let vectors_offset = 1024 * 1024;
300 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
301
302 file.seek(SeekFrom::Start(vectors_offset))?;
304 for vector in vectors {
305 let bytes = bytemuck::cast_slice(vector);
306 file.write_all(bytes)?;
307 }
308
309 let medoid_id = calculate_medoid(vectors, dist);
311
312 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
314 let graph = build_vamana_graph(
315 vectors,
316 max_degree,
317 build_beam_width,
318 alpha,
319 dist,
320 medoid_id as u32,
321 );
322
323 file.seek(SeekFrom::Start(adjacency_offset))?;
325 for neighbors in &graph {
326 let mut padded = neighbors.clone();
327 padded.resize(max_degree, PAD_U32);
328 let bytes = bytemuck::cast_slice(&padded);
329 file.write_all(bytes)?;
330 }
331
332 let metadata = Metadata {
334 dim,
335 num_vectors,
336 max_degree,
337 medoid_id: medoid_id as u32,
338 vectors_offset: vectors_offset as u64,
339 adjacency_offset,
340 distance_name: std::any::type_name::<D>().to_string(),
341 };
342
343 let md_bytes = bincode::serialize(&metadata)?;
344 file.seek(SeekFrom::Start(0))?;
345 let md_len = md_bytes.len() as u64;
346 file.write_all(&md_len.to_le_bytes())?;
347 file.write_all(&md_bytes)?;
348 file.sync_all()?;
349
350 let mmap = unsafe { memmap2::Mmap::map(&file)? };
352
353 Ok(Self {
354 dim,
355 num_vectors,
356 max_degree,
357 distance_name: metadata.distance_name,
358 medoid_id: metadata.medoid_id,
359 vectors_offset: metadata.vectors_offset,
360 adjacency_offset: metadata.adjacency_offset,
361 mmap,
362 dist,
363 })
364 }
365
366 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
368 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
369
370 let mut buf8 = [0u8; 8];
372 file.seek(SeekFrom::Start(0))?;
373 file.read_exact(&mut buf8)?;
374 let md_len = u64::from_le_bytes(buf8);
375
376 let mut md_bytes = vec![0u8; md_len as usize];
378 file.read_exact(&mut md_bytes)?;
379 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
380
381 let mmap = unsafe { memmap2::Mmap::map(&file)? };
382
383 let expected = std::any::type_name::<D>();
385 if metadata.distance_name != expected {
386 eprintln!(
387 "Warning: index recorded distance `{}` but you opened with `{}`",
388 metadata.distance_name, expected
389 );
390 }
391
392 Ok(Self {
393 dim: metadata.dim,
394 num_vectors: metadata.num_vectors,
395 max_degree: metadata.max_degree,
396 distance_name: metadata.distance_name,
397 medoid_id: metadata.medoid_id,
398 vectors_offset: metadata.vectors_offset,
399 adjacency_offset: metadata.adjacency_offset,
400 mmap,
401 dist,
402 })
403 }
404
405 pub fn search_with_dists(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
410 assert_eq!(
411 query.len(),
412 self.dim,
413 "Query dim {} != index dim {}",
414 query.len(),
415 self.dim
416 );
417
418 #[derive(Clone, Copy)]
419 struct Candidate {
420 dist: f32,
421 id: u32,
422 }
423 impl PartialEq for Candidate {
424 fn eq(&self, o: &Self) -> bool {
425 self.dist == o.dist && self.id == o.id
426 }
427 }
428 impl Eq for Candidate {}
429 impl PartialOrd for Candidate {
430 fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
431 self.dist.partial_cmp(&o.dist)
432 }
433 }
434 impl Ord for Candidate {
435 fn cmp(&self, o: &Self) -> Ordering {
436 self.partial_cmp(o).unwrap_or(Ordering::Equal)
437 }
438 }
439
440 let mut visited = HashSet::new();
441 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = self.distance_to(query, self.medoid_id as usize);
446 let start = Candidate {
447 dist: start_dist,
448 id: self.medoid_id,
449 };
450 frontier.push(Reverse(start));
451 w.push(start);
452 visited.insert(self.medoid_id);
453
454 while let Some(Reverse(best)) = frontier.peek().copied() {
456 if w.len() >= beam_width {
457 if let Some(worst) = w.peek() {
458 if best.dist >= worst.dist {
459 break;
460 }
461 }
462 }
463 let Reverse(current) = frontier.pop().unwrap();
464
465 for &nb in self.get_neighbors(current.id) {
466 if nb == PAD_U32 {
467 continue;
468 }
469 if !visited.insert(nb) {
470 continue;
471 }
472
473 let d = self.distance_to(query, nb as usize);
474 let cand = Candidate { dist: d, id: nb };
475
476 if w.len() < beam_width {
477 w.push(cand);
478 frontier.push(Reverse(cand));
479 } else if d < w.peek().unwrap().dist {
480 w.pop();
481 w.push(cand);
482 frontier.push(Reverse(cand));
483 }
484 }
485 }
486
487 let mut results: Vec<_> = w.into_vec();
489 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
490 results.truncate(k);
491 results.into_iter().map(|c| (c.id, c.dist)).collect()
492 }
493 pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
495 self.search_with_dists(query, k, beam_width)
496 .into_iter()
497 .map(|(id, _dist)| id)
498 .collect()
499 }
500
501 fn get_neighbors(&self, node_id: u32) -> &[u32] {
503 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
504 let start = offset as usize;
505 let end = start + (self.max_degree * 4);
506 let bytes = &self.mmap[start..end];
507 bytemuck::cast_slice(bytes)
508 }
509
510 fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
512 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
513 let start = offset as usize;
514 let end = start + (self.dim * 4);
515 let bytes = &self.mmap[start..end];
516 let vector: &[f32] = bytemuck::cast_slice(bytes);
517 self.dist.eval(query, vector)
518 }
519
520 pub fn get_vector(&self, idx: usize) -> Vec<f32> {
522 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
523 let start = offset as usize;
524 let end = start + (self.dim * 4);
525 let bytes = &self.mmap[start..end];
526 let vector: &[f32] = bytemuck::cast_slice(bytes);
527 vector.to_vec()
528 }
529}
530
531fn calculate_medoid<D: Distance<f32> + Copy + Sync>(vectors: &[Vec<f32>], dist: D) -> usize {
534 let dim = vectors[0].len();
535 let mut centroid = vec![0.0f32; dim];
536
537 for v in vectors {
538 for (i, &val) in v.iter().enumerate() {
539 centroid[i] += val;
540 }
541 }
542 for val in &mut centroid {
543 *val /= vectors.len() as f32;
544 }
545
546 let (best_idx, _best_dist) = vectors
547 .par_iter()
548 .enumerate()
549 .map(|(idx, v)| (idx, dist.eval(¢roid, v)))
550 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
551
552 best_idx
553}
554
555fn build_vamana_graph<D: Distance<f32> + Copy + Sync>(
560 vectors: &[Vec<f32>],
561 max_degree: usize,
562 build_beam_width: usize,
563 alpha: f32,
564 dist: D,
565 medoid_id: u32,
566) -> Vec<Vec<u32>> {
567 let n = vectors.len();
568 let mut graph = vec![Vec::<u32>::new(); n];
569
570 {
572 let mut rng = thread_rng();
573 for i in 0..n {
574 let mut s = HashSet::new();
575 let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
576 while s.len() < target {
577 let nb = rng.gen_range(0..n);
578 if nb != i {
579 s.insert(nb as u32);
580 }
581 }
582 graph[i] = s.into_iter().collect();
583 }
584 }
585
586 const PASSES: usize = 2;
588 const EXTRA_SEEDS: usize = 2;
589
590 let mut rng = thread_rng();
591 for _pass in 0..PASSES {
592 let mut order: Vec<usize> = (0..n).collect();
594 order.shuffle(&mut rng);
595
596 let snapshot = &graph;
598
599 let new_graph: Vec<Vec<u32>> = order
601 .par_iter()
602 .map(|&u| {
603 let mut candidates: Vec<(u32, f32)> =
604 Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
605
606 for &nb in &snapshot[u] {
608 let d = dist.eval(&vectors[u], &vectors[nb as usize]);
609 candidates.push((nb, d));
610 }
611
612 let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
614 seeds.push(medoid_id as usize);
615 let mut trng = thread_rng();
616 for _ in 0..EXTRA_SEEDS {
617 seeds.push(trng.gen_range(0..n));
618 }
619
620 for start in seeds {
622 let mut part = greedy_search(
623 &vectors[u],
624 vectors,
625 snapshot,
626 start,
627 build_beam_width,
628 dist,
629 );
630 candidates.append(&mut part);
631 }
632
633 candidates.sort_by(|a, b| a.0.cmp(&b.0));
635 candidates.dedup_by(|a, b| {
636 if a.0 == b.0 {
637 if a.1 < b.1 {
638 *b = *a;
639 }
640 true
641 } else {
642 false
643 }
644 });
645
646 prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
648 })
649 .collect();
650
651 let mut pos_of = vec![0usize; n];
654 for (pos, &u) in order.iter().enumerate() {
655 pos_of[u] = pos;
656 }
657
658 let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
660
661 graph = (0..n)
663 .into_par_iter()
664 .map(|u| {
665 let ng = &new_graph[pos_of[u]]; let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
670 pool_ids.extend_from_slice(ng);
671 pool_ids.extend_from_slice(inc);
672 pool_ids.sort_unstable();
673 pool_ids.dedup();
674
675 let pool: Vec<(u32, f32)> = pool_ids
677 .into_iter()
678 .filter(|&id| id as usize != u)
679 .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
680 .collect();
681
682 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
683 })
684 .collect();
685 }
686
687 graph
689 .into_par_iter()
690 .enumerate()
691 .map(|(u, neigh)| {
692 if neigh.len() <= max_degree {
693 return neigh;
694 }
695 let pool: Vec<(u32, f32)> = neigh
696 .iter()
697 .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
698 .collect();
699 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
700 })
701 .collect()
702}
703
704fn greedy_search<D: Distance<f32> + Copy>(
707 query: &[f32],
708 vectors: &[Vec<f32>],
709 graph: &[Vec<u32>],
710 start_id: usize,
711 beam_width: usize,
712 dist: D,
713) -> Vec<(u32, f32)> {
714 let mut visited = HashSet::new();
715 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = dist.eval(query, &vectors[start_id]);
719 let start = Candidate {
720 dist: start_dist,
721 id: start_id as u32,
722 };
723 frontier.push(Reverse(start));
724 w.push(start);
725 visited.insert(start_id as u32);
726
727 while let Some(Reverse(best)) = frontier.peek().copied() {
728 if w.len() >= beam_width {
729 if let Some(worst) = w.peek() {
730 if best.dist >= worst.dist {
731 break;
732 }
733 }
734 }
735 let Reverse(cur) = frontier.pop().unwrap();
736
737 for &nb in &graph[cur.id as usize] {
738 if !visited.insert(nb) {
739 continue;
740 }
741 let d = dist.eval(query, &vectors[nb as usize]);
742 let cand = Candidate { dist: d, id: nb };
743
744 if w.len() < beam_width {
745 w.push(cand);
746 frontier.push(Reverse(cand));
747 } else if d < w.peek().unwrap().dist {
748 w.pop();
749 w.push(cand);
750 frontier.push(Reverse(cand));
751 }
752 }
753 }
754
755 let mut v = w.into_vec();
756 v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
757 v.into_iter().map(|c| (c.id, c.dist)).collect()
758}
759
760fn prune_neighbors<D: Distance<f32> + Copy>(
762 node_id: usize,
763 candidates: &[(u32, f32)],
764 vectors: &[Vec<f32>],
765 max_degree: usize,
766 alpha: f32,
767 dist: D,
768) -> Vec<u32> {
769 if candidates.is_empty() {
770 return Vec::new();
771 }
772
773 let mut sorted = candidates.to_vec();
774 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
775
776 let mut pruned = Vec::<u32>::new();
777
778 for &(cand_id, cand_dist) in &sorted {
779 if cand_id as usize == node_id {
780 continue;
781 }
782 let mut ok = true;
783 for &sel in &pruned {
784 let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
785 if d < alpha * cand_dist {
786 ok = false;
787 break;
788 }
789 }
790 if ok {
791 pruned.push(cand_id);
792 if pruned.len() >= max_degree {
793 break;
794 }
795 }
796 }
797
798 for &(cand_id, _) in &sorted {
800 if cand_id as usize == node_id {
801 continue;
802 }
803 if !pruned.contains(&cand_id) {
804 pruned.push(cand_id);
805 if pruned.len() >= max_degree {
806 break;
807 }
808 }
809 }
810
811 pruned
812}
813
814fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
815 let mut indeg = vec![0usize; n];
817 for (pos, _u) in order.iter().enumerate() {
818 for &v in &new_graph[pos] {
819 indeg[v as usize] += 1;
820 }
821 }
822 let mut off = vec![0usize; n + 1];
824 for i in 0..n {
825 off[i + 1] = off[i] + indeg[i];
826 }
827 let mut cur = off.clone();
829 let mut incoming_flat = vec![0u32; off[n]];
830 for (pos, &u) in order.iter().enumerate() {
831 for &v in &new_graph[pos] {
832 let idx = cur[v as usize];
833 incoming_flat[idx] = u as u32;
834 cur[v as usize] += 1;
835 }
836 }
837 (incoming_flat, off)
838}
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843 use anndists::dist::{DistCosine, DistL2};
844 use rand::Rng;
845 use std::fs;
846
847 fn euclid(a: &[f32], b: &[f32]) -> f32 {
848 a.iter()
849 .zip(b)
850 .map(|(x, y)| (x - y) * (x - y))
851 .sum::<f32>()
852 .sqrt()
853 }
854
855 #[test]
856 fn test_small_index_l2() {
857 let path = "test_small_l2.db";
858 let _ = fs::remove_file(path);
859
860 let vectors = vec![
861 vec![0.0, 0.0],
862 vec![1.0, 0.0],
863 vec![0.0, 1.0],
864 vec![1.0, 1.0],
865 vec![0.5, 0.5],
866 ];
867
868 let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
869
870 let q = vec![0.1, 0.1];
871 let nns = index.search(&q, 3, 8);
872 assert_eq!(nns.len(), 3);
873
874 let v = index.get_vector(nns[0] as usize);
876 assert!(euclid(&q, &v) < 1.0);
877
878 let _ = fs::remove_file(path);
879 }
880
881 #[test]
882 fn test_cosine() {
883 let path = "test_cosine.db";
884 let _ = fs::remove_file(path);
885
886 let vectors = vec![
887 vec![1.0, 0.0, 0.0],
888 vec![0.0, 1.0, 0.0],
889 vec![0.0, 0.0, 1.0],
890 vec![1.0, 1.0, 0.0],
891 vec![1.0, 0.0, 1.0],
892 ];
893
894 let index =
895 DiskANN::<DistCosine>::build_index_default(&vectors, DistCosine {}, path).unwrap();
896
897 let q = vec![2.0, 0.0, 0.0]; let nns = index.search(&q, 2, 8);
899 assert_eq!(nns.len(), 2);
900
901 let v = index.get_vector(nns[0] as usize);
903 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
904 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
905 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
906 let cos = dot / (n1 * n2);
907 assert!(cos > 0.7);
908
909 let _ = fs::remove_file(path);
910 }
911
912 #[test]
913 fn test_persistence_and_open() {
914 let path = "test_persist.db";
915 let _ = fs::remove_file(path);
916
917 let vectors = vec![
918 vec![0.0, 0.0],
919 vec![1.0, 0.0],
920 vec![0.0, 1.0],
921 vec![1.0, 1.0],
922 ];
923
924 {
925 let _idx = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
926 }
927
928 let idx2 = DiskANN::<DistL2>::open_index_default_metric(path).unwrap();
929 assert_eq!(idx2.num_vectors, 4);
930 assert_eq!(idx2.dim, 2);
931
932 let q = vec![0.9, 0.9];
933 let res = idx2.search(&q, 2, 8);
934 assert_eq!(res[0], 3);
936
937 let _ = fs::remove_file(path);
938 }
939
940 #[test]
941 fn test_grid_connectivity() {
942 let path = "test_grid.db";
943 let _ = fs::remove_file(path);
944
945 let mut vectors = Vec::new();
947 for i in 0..5 {
948 for j in 0..5 {
949 vectors.push(vec![i as f32, j as f32]);
950 }
951 }
952
953 let index = DiskANN::<DistL2>::build_index_with_params(
954 &vectors,
955 DistL2 {},
956 path,
957 DiskAnnParams {
958 max_degree: 4,
959 build_beam_width: 64,
960 alpha: 1.5,
961 },
962 )
963 .unwrap();
964
965 for target in 0..vectors.len() {
966 let q = &vectors[target];
967 let nns = index.search(q, 10, 32);
968 if !nns.contains(&(target as u32)) {
969 let v = index.get_vector(nns[0] as usize);
970 assert!(euclid(q, &v) < 2.0);
971 }
972 for &nb in nns.iter().take(5) {
973 let v = index.get_vector(nb as usize);
974 assert!(euclid(q, &v) < 5.0);
975 }
976 }
977
978 let _ = fs::remove_file(path);
979 }
980
981 #[test]
982 fn test_medium_random() {
983 let path = "test_medium.db";
984 let _ = fs::remove_file(path);
985
986 let n = 200usize;
987 let d = 32usize;
988 let mut rng = rand::thread_rng();
989 let vectors: Vec<Vec<f32>> = (0..n)
990 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
991 .collect();
992
993 let index = DiskANN::<DistL2>::build_index_with_params(
994 &vectors,
995 DistL2 {},
996 path,
997 DiskAnnParams {
998 max_degree: 32,
999 build_beam_width: 128,
1000 alpha: 1.2,
1001 },
1002 )
1003 .unwrap();
1004
1005 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1006 let res = index.search(&q, 10, 64);
1007 assert_eq!(res.len(), 10);
1008
1009 let dists: Vec<f32> = res
1011 .iter()
1012 .map(|&id| {
1013 let v = index.get_vector(id as usize);
1014 euclid(&q, &v)
1015 })
1016 .collect();
1017 let mut sorted = dists.clone();
1018 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1019 assert_eq!(dists, sorted);
1020
1021 let _ = fs::remove_file(path);
1022 }
1023}