1mod incremental;
63mod filtered;
64pub mod simd;
65pub mod pq;
66pub mod storage;
67pub mod sq;
68pub mod formats;
69
70pub use incremental::{
71 IncrementalDiskANN, IncrementalConfig, IncrementalStats,
72 is_delta_id, delta_local_idx,
73};
74
75pub use filtered::{FilteredDiskANN, Filter};
76
77pub use simd::{SimdL2, SimdDot, SimdCosine, simd_info};
78
79pub use pq::{ProductQuantizer, PQConfig, PQStats};
80
81pub use storage::Storage;
82
83pub use sq::{VectorQuantizer, F16Quantizer, Int8Quantizer};
84
85use anndists::prelude::Distance;
86use bytemuck;
87use rand::prelude::*;
88use rayon::prelude::*;
89use serde::{Deserialize, Serialize};
90use std::cmp::{Ordering, Reverse};
91use std::collections::{BinaryHeap, HashSet};
92use std::fs::OpenOptions;
93use std::io::{Read, Seek, SeekFrom, Write};
94use std::sync::Arc;
95use thiserror::Error;
96
97const PAD_U32: u32 = u32::MAX;
99
100pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
102pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
103pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
104
105#[derive(Clone, Copy, Debug)]
107pub struct DiskAnnParams {
108 pub max_degree: usize,
109 pub build_beam_width: usize,
110 pub alpha: f32,
111}
112impl Default for DiskAnnParams {
113 fn default() -> Self {
114 Self {
115 max_degree: DISKANN_DEFAULT_MAX_DEGREE,
116 build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
117 alpha: DISKANN_DEFAULT_ALPHA,
118 }
119 }
120}
121
122#[derive(Debug, Error)]
124pub enum DiskAnnError {
125 #[error("I/O error: {0}")]
127 Io(#[from] std::io::Error),
128
129 #[error("Serialization error: {0}")]
131 Bincode(#[from] bincode::Error),
132
133 #[error("Index error: {0}")]
135 IndexError(String),
136}
137
138#[derive(Serialize, Deserialize, Debug)]
140struct Metadata {
141 dim: usize,
142 num_vectors: usize,
143 max_degree: usize,
144 medoid_id: u32,
145 vectors_offset: u64,
146 adjacency_offset: u64,
147 distance_name: String,
148}
149
150#[derive(Clone, Copy)]
152struct Candidate {
153 dist: f32,
154 id: u32,
155}
156impl PartialEq for Candidate {
157 fn eq(&self, other: &Self) -> bool {
158 self.dist == other.dist && self.id == other.id
159 }
160}
161impl Eq for Candidate {}
162impl PartialOrd for Candidate {
163 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
164 self.dist.partial_cmp(&other.dist)
166 }
167}
168impl Ord for Candidate {
169 fn cmp(&self, other: &Self) -> Ordering {
170 self.partial_cmp(other).unwrap_or(Ordering::Equal)
171 }
172}
173
174pub struct DiskANN<D>
176where
177 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
178{
179 pub dim: usize,
181 pub num_vectors: usize,
183 pub max_degree: usize,
185 pub distance_name: String,
187
188 pub(crate) medoid_id: u32,
190 pub(crate) vectors_offset: u64,
192 pub(crate) adjacency_offset: u64,
193
194 pub(crate) storage: Storage,
196
197 pub(crate) dist: D,
199}
200
201impl<D> DiskANN<D>
204where
205 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
206{
207 pub fn build_index_default(
209 vectors: &[Vec<f32>],
210 dist: D,
211 file_path: &str,
212 ) -> Result<Self, DiskAnnError> {
213 Self::build_index(
214 vectors,
215 DISKANN_DEFAULT_MAX_DEGREE,
216 DISKANN_DEFAULT_BUILD_BEAM,
217 DISKANN_DEFAULT_ALPHA,
218 dist,
219 file_path,
220 )
221 }
222
223 pub fn build_index_with_params(
225 vectors: &[Vec<f32>],
226 dist: D,
227 file_path: &str,
228 p: DiskAnnParams,
229 ) -> Result<Self, DiskAnnError> {
230 Self::build_index(
231 vectors,
232 p.max_degree,
233 p.build_beam_width,
234 p.alpha,
235 dist,
236 file_path,
237 )
238 }
239}
240
241impl<D> DiskANN<D>
243where
244 D: Distance<f32> + Default + Send + Sync + Copy + Clone + 'static,
245{
246 pub fn build_index_default_metric(
248 vectors: &[Vec<f32>],
249 file_path: &str,
250 ) -> Result<Self, DiskAnnError> {
251 Self::build_index_default(vectors, D::default(), file_path)
252 }
253
254 pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
256 Self::open_index_with(path, D::default())
257 }
258}
259
260impl<D> DiskANN<D>
261where
262 D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
263{
264 pub fn build_index(
274 vectors: &[Vec<f32>],
275 max_degree: usize,
276 build_beam_width: usize,
277 alpha: f32,
278 dist: D,
279 file_path: &str,
280 ) -> Result<Self, DiskAnnError> {
281 if vectors.is_empty() {
282 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
283 }
284
285 let num_vectors = vectors.len();
286 let dim = vectors[0].len();
287 for (i, v) in vectors.iter().enumerate() {
288 if v.len() != dim {
289 return Err(DiskAnnError::IndexError(format!(
290 "Vector {} has dimension {} but expected {}",
291 i,
292 v.len(),
293 dim
294 )));
295 }
296 }
297
298 let mut file = OpenOptions::new()
299 .create(true)
300 .write(true)
301 .read(true)
302 .truncate(true)
303 .open(file_path)?;
304
305 let vectors_offset = 1024 * 1024;
307 let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
308
309 file.seek(SeekFrom::Start(vectors_offset))?;
311 for vector in vectors {
312 let bytes = bytemuck::cast_slice(vector);
313 file.write_all(bytes)?;
314 }
315
316 let medoid_id = calculate_medoid(vectors, dist);
318
319 let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
321 let graph = build_vamana_graph(
322 vectors,
323 max_degree,
324 build_beam_width,
325 alpha,
326 dist,
327 medoid_id as u32,
328 );
329
330 file.seek(SeekFrom::Start(adjacency_offset))?;
332 for neighbors in &graph {
333 let mut padded = neighbors.clone();
334 padded.resize(max_degree, PAD_U32);
335 let bytes = bytemuck::cast_slice(&padded);
336 file.write_all(bytes)?;
337 }
338
339 let metadata = Metadata {
341 dim,
342 num_vectors,
343 max_degree,
344 medoid_id: medoid_id as u32,
345 vectors_offset: vectors_offset as u64,
346 adjacency_offset,
347 distance_name: std::any::type_name::<D>().to_string(),
348 };
349
350 let md_bytes = bincode::serialize(&metadata)?;
351 file.seek(SeekFrom::Start(0))?;
352 let md_len = md_bytes.len() as u64;
353 file.write_all(&md_len.to_le_bytes())?;
354 file.write_all(&md_bytes)?;
355 file.sync_all()?;
356
357 let mmap = unsafe { memmap2::Mmap::map(&file)? };
359
360 Ok(Self {
361 dim,
362 num_vectors,
363 max_degree,
364 distance_name: metadata.distance_name,
365 medoid_id: metadata.medoid_id,
366 vectors_offset: metadata.vectors_offset,
367 adjacency_offset: metadata.adjacency_offset,
368 storage: Storage::Mmap(mmap),
369 dist,
370 })
371 }
372
373 pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
375 let mut file = OpenOptions::new().read(true).write(false).open(path)?;
376
377 let mut buf8 = [0u8; 8];
379 file.seek(SeekFrom::Start(0))?;
380 file.read_exact(&mut buf8)?;
381 let md_len = u64::from_le_bytes(buf8);
382
383 let mut md_bytes = vec![0u8; md_len as usize];
385 file.read_exact(&mut md_bytes)?;
386 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
387
388 let mmap = unsafe { memmap2::Mmap::map(&file)? };
389
390 let expected = std::any::type_name::<D>();
392 if metadata.distance_name != expected {
393 eprintln!(
394 "Warning: index recorded distance `{}` but you opened with `{}`",
395 metadata.distance_name, expected
396 );
397 }
398
399 Ok(Self {
400 dim: metadata.dim,
401 num_vectors: metadata.num_vectors,
402 max_degree: metadata.max_degree,
403 distance_name: metadata.distance_name,
404 medoid_id: metadata.medoid_id,
405 vectors_offset: metadata.vectors_offset,
406 adjacency_offset: metadata.adjacency_offset,
407 storage: Storage::Mmap(mmap),
408 dist,
409 })
410 }
411
412 pub fn from_bytes(bytes: Vec<u8>, dist: D) -> Result<Self, DiskAnnError> {
414 let metadata = Self::parse_metadata(&bytes)?;
415
416 let expected = std::any::type_name::<D>();
417 if metadata.distance_name != expected {
418 eprintln!(
419 "Warning: index recorded distance `{}` but you opened with `{}`",
420 metadata.distance_name, expected
421 );
422 }
423
424 Ok(Self {
425 dim: metadata.dim,
426 num_vectors: metadata.num_vectors,
427 max_degree: metadata.max_degree,
428 distance_name: metadata.distance_name,
429 medoid_id: metadata.medoid_id,
430 vectors_offset: metadata.vectors_offset,
431 adjacency_offset: metadata.adjacency_offset,
432 storage: Storage::Owned(bytes),
433 dist,
434 })
435 }
436
437 pub fn from_shared_bytes(bytes: Arc<[u8]>, dist: D) -> Result<Self, DiskAnnError> {
439 let metadata = Self::parse_metadata(&bytes)?;
440
441 let expected = std::any::type_name::<D>();
442 if metadata.distance_name != expected {
443 eprintln!(
444 "Warning: index recorded distance `{}` but you opened with `{}`",
445 metadata.distance_name, expected
446 );
447 }
448
449 Ok(Self {
450 dim: metadata.dim,
451 num_vectors: metadata.num_vectors,
452 max_degree: metadata.max_degree,
453 distance_name: metadata.distance_name,
454 medoid_id: metadata.medoid_id,
455 vectors_offset: metadata.vectors_offset,
456 adjacency_offset: metadata.adjacency_offset,
457 storage: Storage::Shared(bytes),
458 dist,
459 })
460 }
461
462 pub fn to_bytes(&self) -> Vec<u8> {
464 self.storage.to_vec()
465 }
466
467 fn parse_metadata(bytes: &[u8]) -> Result<Metadata, DiskAnnError> {
469 if bytes.len() < 8 {
470 return Err(DiskAnnError::IndexError("Buffer too small for metadata length".into()));
471 }
472 let md_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
473 if bytes.len() < 8 + md_len {
474 return Err(DiskAnnError::IndexError("Buffer too small for metadata".into()));
475 }
476 let metadata: Metadata = bincode::deserialize(&bytes[8..8 + md_len])?;
477 Ok(metadata)
478 }
479
480 pub fn search_with_dists(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
485 assert_eq!(
486 query.len(),
487 self.dim,
488 "Query dim {} != index dim {}",
489 query.len(),
490 self.dim
491 );
492
493 #[derive(Clone, Copy)]
494 struct Candidate {
495 dist: f32,
496 id: u32,
497 }
498 impl PartialEq for Candidate {
499 fn eq(&self, o: &Self) -> bool {
500 self.dist == o.dist && self.id == o.id
501 }
502 }
503 impl Eq for Candidate {}
504 impl PartialOrd for Candidate {
505 fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
506 self.dist.partial_cmp(&o.dist)
507 }
508 }
509 impl Ord for Candidate {
510 fn cmp(&self, o: &Self) -> Ordering {
511 self.partial_cmp(o).unwrap_or(Ordering::Equal)
512 }
513 }
514
515 let mut visited = HashSet::new();
516 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = self.distance_to(query, self.medoid_id as usize);
521 let start = Candidate {
522 dist: start_dist,
523 id: self.medoid_id,
524 };
525 frontier.push(Reverse(start));
526 w.push(start);
527 visited.insert(self.medoid_id);
528
529 while let Some(Reverse(best)) = frontier.peek().copied() {
531 if w.len() >= beam_width {
532 if let Some(worst) = w.peek() {
533 if best.dist >= worst.dist {
534 break;
535 }
536 }
537 }
538 let Reverse(current) = frontier.pop().unwrap();
539
540 for &nb in self.get_neighbors(current.id) {
541 if nb == PAD_U32 {
542 continue;
543 }
544 if !visited.insert(nb) {
545 continue;
546 }
547
548 let d = self.distance_to(query, nb as usize);
549 let cand = Candidate { dist: d, id: nb };
550
551 if w.len() < beam_width {
552 w.push(cand);
553 frontier.push(Reverse(cand));
554 } else if d < w.peek().unwrap().dist {
555 w.pop();
556 w.push(cand);
557 frontier.push(Reverse(cand));
558 }
559 }
560 }
561
562 let mut results: Vec<_> = w.into_vec();
564 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
565 results.truncate(k);
566 results.into_iter().map(|c| (c.id, c.dist)).collect()
567 }
568 pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
570 self.search_with_dists(query, k, beam_width)
571 .into_iter()
572 .map(|(id, _dist)| id)
573 .collect()
574 }
575
576 fn get_neighbors(&self, node_id: u32) -> &[u32] {
578 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
579 let start = offset as usize;
580 let end = start + (self.max_degree * 4);
581 let bytes = &self.storage[start..end];
582 bytemuck::cast_slice(bytes)
583 }
584
585 fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
587 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
588 let start = offset as usize;
589 let end = start + (self.dim * 4);
590 let bytes = &self.storage[start..end];
591 let vector: &[f32] = bytemuck::cast_slice(bytes);
592 self.dist.eval(query, vector)
593 }
594
595 pub fn get_vector(&self, idx: usize) -> Vec<f32> {
597 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
598 let start = offset as usize;
599 let end = start + (self.dim * 4);
600 let bytes = &self.storage[start..end];
601 let vector: &[f32] = bytemuck::cast_slice(bytes);
602 vector.to_vec()
603 }
604}
605
606fn calculate_medoid<D: Distance<f32> + Copy + Sync>(vectors: &[Vec<f32>], dist: D) -> usize {
609 let dim = vectors[0].len();
610 let mut centroid = vec![0.0f32; dim];
611
612 for v in vectors {
613 for (i, &val) in v.iter().enumerate() {
614 centroid[i] += val;
615 }
616 }
617 for val in &mut centroid {
618 *val /= vectors.len() as f32;
619 }
620
621 let (best_idx, _best_dist) = vectors
622 .par_iter()
623 .enumerate()
624 .map(|(idx, v)| (idx, dist.eval(¢roid, v)))
625 .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
626
627 best_idx
628}
629
630fn build_vamana_graph<D: Distance<f32> + Copy + Sync>(
635 vectors: &[Vec<f32>],
636 max_degree: usize,
637 build_beam_width: usize,
638 alpha: f32,
639 dist: D,
640 medoid_id: u32,
641) -> Vec<Vec<u32>> {
642 let n = vectors.len();
643 let mut graph = vec![Vec::<u32>::new(); n];
644
645 {
647 let mut rng = thread_rng();
648 for i in 0..n {
649 let mut s = HashSet::new();
650 let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
651 while s.len() < target {
652 let nb = rng.gen_range(0..n);
653 if nb != i {
654 s.insert(nb as u32);
655 }
656 }
657 graph[i] = s.into_iter().collect();
658 }
659 }
660
661 const PASSES: usize = 2;
663 const EXTRA_SEEDS: usize = 2;
664
665 let mut rng = thread_rng();
666 for _pass in 0..PASSES {
667 let mut order: Vec<usize> = (0..n).collect();
669 order.shuffle(&mut rng);
670
671 let snapshot = &graph;
673
674 let new_graph: Vec<Vec<u32>> = order
676 .par_iter()
677 .map(|&u| {
678 let mut candidates: Vec<(u32, f32)> =
679 Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
680
681 for &nb in &snapshot[u] {
683 let d = dist.eval(&vectors[u], &vectors[nb as usize]);
684 candidates.push((nb, d));
685 }
686
687 let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
689 seeds.push(medoid_id as usize);
690 let mut trng = thread_rng();
691 for _ in 0..EXTRA_SEEDS {
692 seeds.push(trng.gen_range(0..n));
693 }
694
695 for start in seeds {
697 let mut part = greedy_search(
698 &vectors[u],
699 vectors,
700 snapshot,
701 start,
702 build_beam_width,
703 dist,
704 );
705 candidates.append(&mut part);
706 }
707
708 candidates.sort_by(|a, b| a.0.cmp(&b.0));
710 candidates.dedup_by(|a, b| {
711 if a.0 == b.0 {
712 if a.1 < b.1 {
713 *b = *a;
714 }
715 true
716 } else {
717 false
718 }
719 });
720
721 prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
723 })
724 .collect();
725
726 let mut pos_of = vec![0usize; n];
729 for (pos, &u) in order.iter().enumerate() {
730 pos_of[u] = pos;
731 }
732
733 let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
735
736 graph = (0..n)
738 .into_par_iter()
739 .map(|u| {
740 let ng = &new_graph[pos_of[u]]; let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
745 pool_ids.extend_from_slice(ng);
746 pool_ids.extend_from_slice(inc);
747 pool_ids.sort_unstable();
748 pool_ids.dedup();
749
750 let pool: Vec<(u32, f32)> = pool_ids
752 .into_iter()
753 .filter(|&id| id as usize != u)
754 .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
755 .collect();
756
757 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
758 })
759 .collect();
760 }
761
762 graph
764 .into_par_iter()
765 .enumerate()
766 .map(|(u, neigh)| {
767 if neigh.len() <= max_degree {
768 return neigh;
769 }
770 let pool: Vec<(u32, f32)> = neigh
771 .iter()
772 .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
773 .collect();
774 prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
775 })
776 .collect()
777}
778
779fn greedy_search<D: Distance<f32> + Copy>(
782 query: &[f32],
783 vectors: &[Vec<f32>],
784 graph: &[Vec<u32>],
785 start_id: usize,
786 beam_width: usize,
787 dist: D,
788) -> Vec<(u32, f32)> {
789 let mut visited = HashSet::new();
790 let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); let start_dist = dist.eval(query, &vectors[start_id]);
794 let start = Candidate {
795 dist: start_dist,
796 id: start_id as u32,
797 };
798 frontier.push(Reverse(start));
799 w.push(start);
800 visited.insert(start_id as u32);
801
802 while let Some(Reverse(best)) = frontier.peek().copied() {
803 if w.len() >= beam_width {
804 if let Some(worst) = w.peek() {
805 if best.dist >= worst.dist {
806 break;
807 }
808 }
809 }
810 let Reverse(cur) = frontier.pop().unwrap();
811
812 for &nb in &graph[cur.id as usize] {
813 if !visited.insert(nb) {
814 continue;
815 }
816 let d = dist.eval(query, &vectors[nb as usize]);
817 let cand = Candidate { dist: d, id: nb };
818
819 if w.len() < beam_width {
820 w.push(cand);
821 frontier.push(Reverse(cand));
822 } else if d < w.peek().unwrap().dist {
823 w.pop();
824 w.push(cand);
825 frontier.push(Reverse(cand));
826 }
827 }
828 }
829
830 let mut v = w.into_vec();
831 v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
832 v.into_iter().map(|c| (c.id, c.dist)).collect()
833}
834
835fn prune_neighbors<D: Distance<f32> + Copy>(
837 node_id: usize,
838 candidates: &[(u32, f32)],
839 vectors: &[Vec<f32>],
840 max_degree: usize,
841 alpha: f32,
842 dist: D,
843) -> Vec<u32> {
844 if candidates.is_empty() {
845 return Vec::new();
846 }
847
848 let mut sorted = candidates.to_vec();
849 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
850
851 let mut pruned = Vec::<u32>::new();
852
853 for &(cand_id, cand_dist) in &sorted {
854 if cand_id as usize == node_id {
855 continue;
856 }
857 let mut ok = true;
858 for &sel in &pruned {
859 let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
860 if d < alpha * cand_dist {
861 ok = false;
862 break;
863 }
864 }
865 if ok {
866 pruned.push(cand_id);
867 if pruned.len() >= max_degree {
868 break;
869 }
870 }
871 }
872
873 for &(cand_id, _) in &sorted {
875 if cand_id as usize == node_id {
876 continue;
877 }
878 if !pruned.contains(&cand_id) {
879 pruned.push(cand_id);
880 if pruned.len() >= max_degree {
881 break;
882 }
883 }
884 }
885
886 pruned
887}
888
889fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
890 let mut indeg = vec![0usize; n];
892 for (pos, _u) in order.iter().enumerate() {
893 for &v in &new_graph[pos] {
894 indeg[v as usize] += 1;
895 }
896 }
897 let mut off = vec![0usize; n + 1];
899 for i in 0..n {
900 off[i + 1] = off[i] + indeg[i];
901 }
902 let mut cur = off.clone();
904 let mut incoming_flat = vec![0u32; off[n]];
905 for (pos, &u) in order.iter().enumerate() {
906 for &v in &new_graph[pos] {
907 let idx = cur[v as usize];
908 incoming_flat[idx] = u as u32;
909 cur[v as usize] += 1;
910 }
911 }
912 (incoming_flat, off)
913}
914
915#[cfg(test)]
916mod tests {
917 use super::*;
918 use anndists::dist::{DistCosine, DistL2};
919 use rand::Rng;
920 use std::fs;
921
922 fn euclid(a: &[f32], b: &[f32]) -> f32 {
923 a.iter()
924 .zip(b)
925 .map(|(x, y)| (x - y) * (x - y))
926 .sum::<f32>()
927 .sqrt()
928 }
929
930 #[test]
931 fn test_small_index_l2() {
932 let path = "test_small_l2.db";
933 let _ = fs::remove_file(path);
934
935 let vectors = vec![
936 vec![0.0, 0.0],
937 vec![1.0, 0.0],
938 vec![0.0, 1.0],
939 vec![1.0, 1.0],
940 vec![0.5, 0.5],
941 ];
942
943 let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
944
945 let q = vec![0.1, 0.1];
946 let nns = index.search(&q, 3, 8);
947 assert_eq!(nns.len(), 3);
948
949 let v = index.get_vector(nns[0] as usize);
951 assert!(euclid(&q, &v) < 1.0);
952
953 let _ = fs::remove_file(path);
954 }
955
956 #[test]
957 fn test_cosine() {
958 let path = "test_cosine.db";
959 let _ = fs::remove_file(path);
960
961 let vectors = vec![
962 vec![1.0, 0.0, 0.0],
963 vec![0.0, 1.0, 0.0],
964 vec![0.0, 0.0, 1.0],
965 vec![1.0, 1.0, 0.0],
966 vec![1.0, 0.0, 1.0],
967 ];
968
969 let index =
970 DiskANN::<DistCosine>::build_index_default(&vectors, DistCosine {}, path).unwrap();
971
972 let q = vec![2.0, 0.0, 0.0]; let nns = index.search(&q, 2, 8);
974 assert_eq!(nns.len(), 2);
975
976 let v = index.get_vector(nns[0] as usize);
978 let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
979 let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
980 let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
981 let cos = dot / (n1 * n2);
982 assert!(cos > 0.7);
983
984 let _ = fs::remove_file(path);
985 }
986
987 #[test]
988 fn test_persistence_and_open() {
989 let path = "test_persist.db";
990 let _ = fs::remove_file(path);
991
992 let vectors = vec![
993 vec![0.0, 0.0],
994 vec![1.0, 0.0],
995 vec![0.0, 1.0],
996 vec![1.0, 1.0],
997 ];
998
999 {
1000 let _idx = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1001 }
1002
1003 let idx2 = DiskANN::<DistL2>::open_index_default_metric(path).unwrap();
1004 assert_eq!(idx2.num_vectors, 4);
1005 assert_eq!(idx2.dim, 2);
1006
1007 let q = vec![0.9, 0.9];
1008 let res = idx2.search(&q, 2, 8);
1009 assert_eq!(res[0], 3);
1011
1012 let _ = fs::remove_file(path);
1013 }
1014
1015 #[test]
1016 fn test_grid_connectivity() {
1017 let path = "test_grid.db";
1018 let _ = fs::remove_file(path);
1019
1020 let mut vectors = Vec::new();
1022 for i in 0..5 {
1023 for j in 0..5 {
1024 vectors.push(vec![i as f32, j as f32]);
1025 }
1026 }
1027
1028 let index = DiskANN::<DistL2>::build_index_with_params(
1029 &vectors,
1030 DistL2 {},
1031 path,
1032 DiskAnnParams {
1033 max_degree: 4,
1034 build_beam_width: 64,
1035 alpha: 1.5,
1036 },
1037 )
1038 .unwrap();
1039
1040 for target in 0..vectors.len() {
1041 let q = &vectors[target];
1042 let nns = index.search(q, 10, 32);
1043 if !nns.contains(&(target as u32)) {
1044 let v = index.get_vector(nns[0] as usize);
1045 assert!(euclid(q, &v) < 2.0);
1046 }
1047 for &nb in nns.iter().take(5) {
1048 let v = index.get_vector(nb as usize);
1049 assert!(euclid(q, &v) < 5.0);
1050 }
1051 }
1052
1053 let _ = fs::remove_file(path);
1054 }
1055
1056 #[test]
1057 fn test_medium_random() {
1058 let path = "test_medium.db";
1059 let _ = fs::remove_file(path);
1060
1061 let n = 200usize;
1062 let d = 32usize;
1063 let mut rng = rand::thread_rng();
1064 let vectors: Vec<Vec<f32>> = (0..n)
1065 .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1066 .collect();
1067
1068 let index = DiskANN::<DistL2>::build_index_with_params(
1069 &vectors,
1070 DistL2 {},
1071 path,
1072 DiskAnnParams {
1073 max_degree: 32,
1074 build_beam_width: 128,
1075 alpha: 1.2,
1076 },
1077 )
1078 .unwrap();
1079
1080 let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1081 let res = index.search(&q, 10, 64);
1082 assert_eq!(res.len(), 10);
1083
1084 let dists: Vec<f32> = res
1086 .iter()
1087 .map(|&id| {
1088 let v = index.get_vector(id as usize);
1089 euclid(&q, &v)
1090 })
1091 .collect();
1092 let mut sorted = dists.clone();
1093 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1094 assert_eq!(dists, sorted);
1095
1096 let _ = fs::remove_file(path);
1097 }
1098
1099 #[test]
1100 fn test_to_bytes_from_bytes_round_trip() {
1101 let path = "test_bytes_rt.db";
1102 let _ = fs::remove_file(path);
1103
1104 let vectors = vec![
1105 vec![0.0, 0.0],
1106 vec![1.0, 0.0],
1107 vec![0.0, 1.0],
1108 vec![1.0, 1.0],
1109 vec![0.5, 0.5],
1110 ];
1111
1112 let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1113 let bytes = index.to_bytes();
1114
1115 let index2 = DiskANN::<DistL2>::from_bytes(bytes, DistL2 {}).unwrap();
1116 assert_eq!(index2.num_vectors, 5);
1117 assert_eq!(index2.dim, 2);
1118
1119 let q = vec![0.9, 0.9];
1120 let res1 = index.search(&q, 3, 8);
1121 let res2 = index2.search(&q, 3, 8);
1122 assert_eq!(res1, res2);
1123
1124 let _ = fs::remove_file(path);
1125 }
1126
1127 #[test]
1128 fn test_from_shared_bytes() {
1129 let path = "test_shared_bytes.db";
1130 let _ = fs::remove_file(path);
1131
1132 let vectors = vec![
1133 vec![0.0, 0.0],
1134 vec![1.0, 0.0],
1135 vec![0.0, 1.0],
1136 vec![1.0, 1.0],
1137 ];
1138
1139 let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1140 let bytes = index.to_bytes();
1141 let shared: std::sync::Arc<[u8]> = bytes.into();
1142
1143 let index2 = DiskANN::<DistL2>::from_shared_bytes(shared, DistL2 {}).unwrap();
1144 assert_eq!(index2.num_vectors, 4);
1145 assert_eq!(index2.dim, 2);
1146
1147 let q = vec![0.9, 0.9];
1148 let res = index2.search(&q, 2, 8);
1149 assert_eq!(res[0], 3); let _ = fs::remove_file(path);
1152 }
1153}