1use bytemuck;
37use memmap2::Mmap;
38use rand::prelude::*;
39use serde::{Deserialize, Serialize};
40use std::cmp::Ordering;
41use std::collections::{BinaryHeap, HashSet};
42use std::{
43 fs::OpenOptions,
44 io::{Seek, SeekFrom, Write},
45};
46use thiserror::Error;
47
48#[derive(Debug, Error)]
50pub enum DiskAnnError {
51 #[error("I/O error: {0}")]
53 Io(#[from] std::io::Error),
54
55 #[error("Serialization error: {0}")]
57 Bincode(#[from] bincode::Error),
58
59 #[error("Index error: {0}")]
61 IndexError(String),
62}
63
64#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
66pub enum DistanceMetric {
67 Euclidean,
69 Cosine,
71}
72
73#[derive(Serialize, Deserialize, Debug)]
75struct Metadata {
76 dim: usize,
77 num_vectors: usize,
78 max_degree: usize,
79 distance_metric: DistanceMetric,
80 medoid_id: u32,
81 vectors_offset: u64,
82 adjacency_offset: u64,
83}
84
85pub struct DiskANN {
87 pub dim: usize,
89 pub num_vectors: usize,
91 pub max_degree: usize,
93 pub distance_metric: DistanceMetric,
95 medoid_id: u32,
97 vectors_offset: u64,
98 adjacency_offset: u64,
99 mmap: Mmap,
100}
101
102#[derive(Clone)]
104struct Candidate {
105 dist: f32,
106 id: u32,
107}
108
109impl PartialEq for Candidate {
110 fn eq(&self, other: &Self) -> bool {
111 self.dist == other.dist && self.id == other.id
112 }
113}
114
115impl Eq for Candidate {}
116
117impl PartialOrd for Candidate {
118 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
119 other.dist.partial_cmp(&self.dist)
121 }
122}
123
124impl Ord for Candidate {
125 fn cmp(&self, other: &Self) -> Ordering {
126 self.partial_cmp(other).unwrap_or(Ordering::Equal)
127 }
128}
129
130impl DiskANN {
131 pub fn build_index(
146 vectors: &[Vec<f32>],
147 max_degree: usize,
148 build_beam_width: usize,
149 alpha: f32,
150 distance_metric: DistanceMetric,
151 file_path: &str,
152 ) -> Result<Self, DiskAnnError> {
153 if vectors.is_empty() {
154 return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
155 }
156
157 let num_vectors = vectors.len();
158 let dim = vectors[0].len();
159
160 for (i, v) in vectors.iter().enumerate() {
162 if v.len() != dim {
163 return Err(DiskAnnError::IndexError(format!(
164 "Vector {} has dimension {} but expected {}",
165 i,
166 v.len(),
167 dim
168 )));
169 }
170 }
171
172 println!(
173 "Building index for {} vectors of dimension {} with max_degree={}",
174 num_vectors, dim, max_degree
175 );
176
177 let mut file = OpenOptions::new()
178 .create(true)
179 .write(true)
180 .read(true)
181 .truncate(true)
182 .open(file_path)?;
183
184 let vectors_offset = 1024 * 1024; let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
187
188 file.seek(SeekFrom::Start(vectors_offset))?;
190 for vector in vectors {
191 let bytes = bytemuck::cast_slice(vector);
192 file.write_all(bytes)?;
193 }
194
195 let medoid_id = calculate_medoid(vectors, distance_metric);
197 println!("Calculated medoid: {}", medoid_id);
198
199 let adjacency_offset = vectors_offset + total_vector_bytes;
201 let graph = build_vamana_graph(
202 vectors,
203 max_degree,
204 build_beam_width,
205 alpha,
206 distance_metric,
207 medoid_id as u32,
208 );
209
210 file.seek(SeekFrom::Start(adjacency_offset))?;
212 for neighbors in &graph {
213 let mut padded = neighbors.clone();
215 padded.resize(max_degree, 0);
216 let bytes = bytemuck::cast_slice(&padded);
217 file.write_all(bytes)?;
218 }
219
220 let metadata = Metadata {
222 dim,
223 num_vectors,
224 max_degree,
225 distance_metric,
226 medoid_id: medoid_id as u32,
227 vectors_offset,
228 adjacency_offset,
229 };
230
231 let md_bytes = bincode::serialize(&metadata)?;
232 file.seek(SeekFrom::Start(0))?;
233 let md_len = md_bytes.len() as u64;
234 file.write_all(&md_len.to_le_bytes())?;
235 file.write_all(&md_bytes)?;
236 file.sync_all()?;
237
238 let mmap = unsafe { memmap2::Mmap::map(&file)? };
240
241 Ok(Self {
242 dim,
243 num_vectors,
244 max_degree,
245 distance_metric,
246 medoid_id: metadata.medoid_id,
247 vectors_offset,
248 adjacency_offset,
249 mmap,
250 })
251 }
252
253 pub fn open_index(path: &str) -> Result<Self, DiskAnnError> {
263 let file = OpenOptions::new().read(true).write(false).open(path)?;
264
265 let mut buf8 = [0u8; 8];
267 use std::os::unix::fs::FileExt;
268 file.read_exact_at(&mut buf8, 0)?;
269 let md_len = u64::from_le_bytes(buf8);
270
271 let mut md_bytes = vec![0u8; md_len as usize];
273 file.read_exact_at(&mut md_bytes, 8)?;
274 let metadata: Metadata = bincode::deserialize(&md_bytes)?;
275
276 let mmap = unsafe { memmap2::Mmap::map(&file)? };
277
278 Ok(Self {
279 dim: metadata.dim,
280 num_vectors: metadata.num_vectors,
281 max_degree: metadata.max_degree,
282 distance_metric: metadata.distance_metric,
283 medoid_id: metadata.medoid_id,
284 vectors_offset: metadata.vectors_offset,
285 adjacency_offset: metadata.adjacency_offset,
286 mmap,
287 })
288 }
289
290 pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
302 if query.len() != self.dim {
303 panic!(
304 "Query dimension {} does not match index dimension {}",
305 query.len(),
306 self.dim
307 );
308 }
309
310 let mut visited = HashSet::new();
312 let mut candidates = BinaryHeap::new();
313 let mut w = BinaryHeap::new(); let start_dist = self.distance_to(query, self.medoid_id as usize);
317 candidates.push(Candidate {
318 dist: start_dist,
319 id: self.medoid_id,
320 });
321 w.push(Candidate {
322 dist: start_dist,
323 id: self.medoid_id,
324 });
325 visited.insert(self.medoid_id);
326
327 let mut best_dist = start_dist;
329 let mut iterations_without_improvement = 0;
330 const MAX_ITERATIONS_WITHOUT_IMPROVEMENT: usize = 5;
331
332 while let Some(current) = candidates.pop() {
333 if current.dist > best_dist {
335 iterations_without_improvement += 1;
336 if iterations_without_improvement > MAX_ITERATIONS_WITHOUT_IMPROVEMENT {
337 break;
338 }
339 } else {
340 best_dist = current.dist;
341 iterations_without_improvement = 0;
342 }
343
344 let neighbors = self.get_neighbors(current.id);
346
347 for &neighbor_id in neighbors {
348 if neighbor_id == 0 || visited.contains(&neighbor_id) {
349 continue;
350 }
351
352 visited.insert(neighbor_id);
353 let dist = self.distance_to(query, neighbor_id as usize);
354
355 w.push(Candidate {
357 dist,
358 id: neighbor_id,
359 });
360
361 if w.len() > beam_width {
363 let mut temp = Vec::new();
365 for _ in 0..beam_width {
366 if let Some(c) = w.pop() {
367 temp.push(c);
368 }
369 }
370 w.clear();
371 for c in temp {
372 w.push(c);
373 }
374 }
375
376 if w.len() < beam_width || dist < w.peek().unwrap().dist {
378 candidates.push(Candidate {
379 dist,
380 id: neighbor_id,
381 });
382 }
383 }
384 }
385
386 let mut results: Vec<_> = w.into_vec();
388 results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
389 results.truncate(k);
390 results.into_iter().map(|c| c.id).collect()
391 }
392
393 fn get_neighbors(&self, node_id: u32) -> &[u32] {
395 let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
396 let start = offset as usize;
397 let end = start + (self.max_degree * 4);
398 let bytes = &self.mmap[start..end];
399 bytemuck::cast_slice(bytes)
400 }
401
402 fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
404 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
405 let start = offset as usize;
406 let end = start + (self.dim * 4);
407 let bytes = &self.mmap[start..end];
408 let vector: &[f32] = bytemuck::cast_slice(bytes);
409
410 match self.distance_metric {
411 DistanceMetric::Euclidean => euclidean_distance(query, vector),
412 DistanceMetric::Cosine => 1.0 - cosine_similarity(query, vector),
413 }
414 }
415
416 pub fn get_vector(&self, idx: usize) -> Vec<f32> {
418 let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
419 let start = offset as usize;
420 let end = start + (self.dim * 4);
421 let bytes = &self.mmap[start..end];
422 let vector: &[f32] = bytemuck::cast_slice(bytes);
423 vector.to_vec()
424 }
425}
426
427fn calculate_medoid(vectors: &[Vec<f32>], distance_metric: DistanceMetric) -> usize {
429 let dim = vectors[0].len();
430 let mut centroid = vec![0.0; dim];
431
432 for vector in vectors {
434 for (i, &val) in vector.iter().enumerate() {
435 centroid[i] += val;
436 }
437 }
438 for val in &mut centroid {
439 *val /= vectors.len() as f32;
440 }
441
442 let mut best_idx = 0;
444 let mut best_dist = f32::MAX;
445
446 for (idx, vector) in vectors.iter().enumerate() {
447 let dist = match distance_metric {
448 DistanceMetric::Euclidean => euclidean_distance(¢roid, vector),
449 DistanceMetric::Cosine => 1.0 - cosine_similarity(¢roid, vector),
450 };
451 if dist < best_dist {
452 best_dist = dist;
453 best_idx = idx;
454 }
455 }
456
457 best_idx
458}
459
460fn build_vamana_graph(
462 vectors: &[Vec<f32>],
463 max_degree: usize,
464 beam_width: usize,
465 alpha: f32,
466 distance_metric: DistanceMetric,
467 medoid_id: u32,
468) -> Vec<Vec<u32>> {
469 let num_vectors = vectors.len();
470 let mut graph = vec![Vec::new(); num_vectors];
471
472 let mut rng = thread_rng();
474 for i in 0..num_vectors {
475 let mut neighbors = HashSet::new();
476 while neighbors.len() < max_degree.min(num_vectors - 1) {
477 let neighbor = rng.gen_range(0..num_vectors);
478 if neighbor != i {
479 neighbors.insert(neighbor as u32);
480 }
481 }
482 graph[i] = neighbors.into_iter().collect();
483 }
484
485 println!("Building Vamana graph with beam_width={}, alpha={}", beam_width, alpha);
486
487 for iteration in 0..2 {
489 println!("Graph building iteration {}", iteration + 1);
490
491 let mut node_order: Vec<usize> = (0..num_vectors).collect();
493 node_order.shuffle(&mut rng);
494
495 for &node_id in &node_order {
496 let neighbors = greedy_search(
498 &vectors[node_id],
499 vectors,
500 &graph,
501 medoid_id as usize,
502 beam_width,
503 distance_metric,
504 );
505
506 let pruned = prune_neighbors(
508 node_id,
509 &neighbors,
510 vectors,
511 max_degree,
512 alpha,
513 distance_metric,
514 );
515
516 graph[node_id] = pruned;
517
518 let current_neighbors = graph[node_id].clone();
520 for neighbor in current_neighbors {
521 if !graph[neighbor as usize].contains(&(node_id as u32)) {
522 graph[neighbor as usize].push(node_id as u32);
523
524 if graph[neighbor as usize].len() > max_degree {
526 let neighbors_of_neighbor: Vec<_> = graph[neighbor as usize]
527 .iter()
528 .map(|&id| (id, {
529 let dist = match distance_metric {
530 DistanceMetric::Euclidean => {
531 euclidean_distance(&vectors[neighbor as usize], &vectors[id as usize])
532 }
533 DistanceMetric::Cosine => {
534 1.0 - cosine_similarity(&vectors[neighbor as usize], &vectors[id as usize])
535 }
536 };
537 dist
538 }))
539 .collect();
540
541 let pruned = prune_neighbors(
542 neighbor as usize,
543 &neighbors_of_neighbor,
544 vectors,
545 max_degree,
546 alpha,
547 distance_metric,
548 );
549 graph[neighbor as usize] = pruned;
550 }
551 }
552 }
553 }
554 }
555
556 graph
557}
558
559fn greedy_search(
561 query: &[f32],
562 vectors: &[Vec<f32>],
563 graph: &[Vec<u32>],
564 start_id: usize,
565 beam_width: usize,
566 distance_metric: DistanceMetric,
567) -> Vec<(u32, f32)> {
568 let mut visited = HashSet::new();
569 let mut candidates = BinaryHeap::new();
570 let mut w = BinaryHeap::new();
571
572 let start_dist = match distance_metric {
574 DistanceMetric::Euclidean => euclidean_distance(query, &vectors[start_id]),
575 DistanceMetric::Cosine => 1.0 - cosine_similarity(query, &vectors[start_id]),
576 };
577
578 candidates.push(Candidate {
579 dist: start_dist,
580 id: start_id as u32,
581 });
582 w.push(Candidate {
583 dist: start_dist,
584 id: start_id as u32,
585 });
586 visited.insert(start_id as u32);
587
588 while let Some(current) = candidates.pop() {
589 for &neighbor_id in &graph[current.id as usize] {
590 if visited.contains(&neighbor_id) {
591 continue;
592 }
593
594 visited.insert(neighbor_id);
595 let dist = match distance_metric {
596 DistanceMetric::Euclidean => euclidean_distance(query, &vectors[neighbor_id as usize]),
597 DistanceMetric::Cosine => 1.0 - cosine_similarity(query, &vectors[neighbor_id as usize]),
598 };
599
600 w.push(Candidate { dist, id: neighbor_id });
601
602 if w.len() > beam_width {
603 let mut temp = Vec::new();
604 for _ in 0..beam_width {
605 if let Some(c) = w.pop() {
606 temp.push(c);
607 }
608 }
609 w.clear();
610 for c in temp {
611 w.push(c);
612 }
613 }
614
615 if w.len() < beam_width || dist < w.peek().unwrap().dist {
616 candidates.push(Candidate { dist, id: neighbor_id });
617 }
618 }
619 }
620
621 w.into_vec()
622 .into_iter()
623 .map(|c| (c.id, c.dist))
624 .collect()
625}
626
627fn prune_neighbors(
629 node_id: usize,
630 candidates: &[(u32, f32)],
631 vectors: &[Vec<f32>],
632 max_degree: usize,
633 alpha: f32,
634 distance_metric: DistanceMetric,
635) -> Vec<u32> {
636 if candidates.is_empty() {
637 return Vec::new();
638 }
639
640 let mut sorted_candidates = candidates.to_vec();
641 sorted_candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
642
643 let mut pruned = Vec::new();
644
645 for &(candidate_id, candidate_dist) in &sorted_candidates {
646 if candidate_id as usize == node_id {
647 continue;
648 }
649
650 let mut should_add = true;
652 for &selected_id in &pruned {
653 let dist_to_selected = match distance_metric {
654 DistanceMetric::Euclidean => {
655 euclidean_distance(&vectors[candidate_id as usize], &vectors[selected_id as usize])
656 }
657 DistanceMetric::Cosine => {
658 1.0 - cosine_similarity(&vectors[candidate_id as usize], &vectors[selected_id as usize])
659 }
660 };
661
662 if dist_to_selected < alpha * candidate_dist {
663 should_add = false;
664 break;
665 }
666 }
667
668 if should_add {
669 pruned.push(candidate_id);
670 if pruned.len() >= max_degree {
671 break;
672 }
673 }
674 }
675
676 for &(candidate_id, _) in &sorted_candidates {
678 if !pruned.contains(&candidate_id) && candidate_id as usize != node_id {
679 pruned.push(candidate_id);
680 if pruned.len() >= max_degree {
681 break;
682 }
683 }
684 }
685
686 pruned
687}
688
689fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
691 a.iter()
692 .zip(b.iter())
693 .map(|(x, y)| (x - y) * (x - y))
694 .sum::<f32>()
695 .sqrt()
696}
697
698fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
700 let mut dot = 0.0;
701 let mut norm_a = 0.0;
702 let mut norm_b = 0.0;
703 for (x, y) in a.iter().zip(b.iter()) {
704 dot += x * y;
705 norm_a += x * x;
706 norm_b += y * y;
707 }
708 if norm_a == 0.0 || norm_b == 0.0 {
709 return 0.0;
710 }
711 dot / (norm_a.sqrt() * norm_b.sqrt())
712}
713
714#[cfg(test)]
715mod tests {
716 use super::*;
717 use std::fs;
718
719 #[test]
720 fn test_small_index() {
721 let test_file = "test_small.db";
722
723 let _ = fs::remove_file(test_file);
725
726 let vectors = vec![
728 vec![0.0, 0.0],
729 vec![1.0, 0.0],
730 vec![0.0, 1.0],
731 vec![1.0, 1.0],
732 vec![0.5, 0.5],
733 ];
734
735 let index = DiskANN::build_index(
737 &vectors,
738 3, 4, 1.2, DistanceMetric::Euclidean,
742 test_file,
743 )
744 .unwrap();
745
746 let query = vec![0.1, 0.1];
748 let neighbors = index.search(&query, 3, 4);
749
750 assert_eq!(neighbors.len(), 3);
753
754 let first_vector = index.get_vector(neighbors[0] as usize);
756 let dist = euclidean_distance(&query, &first_vector);
757 assert!(dist < 1.0, "First neighbor should be close to query");
758
759 let _ = fs::remove_file(test_file);
761 }
762
763 #[test]
764 fn test_memory_efficiency() {
765 let test_file = "test_memory.db";
766 let _ = fs::remove_file(test_file);
767
768 let num_vectors = 1000;
770 let dim = 128;
771 let mut vectors = Vec::new();
772 let mut rng = thread_rng();
773
774 for _ in 0..num_vectors {
775 let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
776 vectors.push(v);
777 }
778
779 let index = DiskANN::build_index(
781 &vectors,
782 32, 64, 1.2, DistanceMetric::Euclidean,
786 test_file,
787 )
788 .unwrap();
789
790 let query: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
792 let k = 10;
793 let beam_width = 32;
794
795 let start = std::time::Instant::now();
797 let neighbors = index.search(&query, k, beam_width);
798 let elapsed = start.elapsed();
799
800 assert_eq!(neighbors.len(), k);
801 assert!(elapsed.as_millis() < 100, "Search took too long: {:?}", elapsed);
802
803 let distances: Vec<f32> = neighbors
805 .iter()
806 .map(|&id| index.distance_to(&query, id as usize))
807 .collect();
808
809 let mut sorted = distances.clone();
811 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
812 assert_eq!(distances, sorted);
813
814 let _ = fs::remove_file(test_file);
815 }
816
817 #[test]
818 fn test_cosine_similarity() {
819 let test_file = "test_cosine.db";
820 let _ = fs::remove_file(test_file);
821
822 let vectors = vec![
823 vec![1.0, 0.0, 0.0],
824 vec![0.0, 1.0, 0.0],
825 vec![0.0, 0.0, 1.0],
826 vec![1.0, 1.0, 0.0],
827 vec![1.0, 0.0, 1.0],
828 ];
829
830 let index = DiskANN::build_index(
831 &vectors,
832 3,
833 4,
834 1.2,
835 DistanceMetric::Cosine,
836 test_file,
837 )
838 .unwrap();
839
840 let query = vec![2.0, 0.0, 0.0]; let neighbors = index.search(&query, 2, 4);
843
844 assert_eq!(neighbors.len(), 2);
846
847 let first_vector = index.get_vector(neighbors[0] as usize);
849 let similarity = cosine_similarity(&query, &first_vector);
850 assert!(similarity > 0.7, "First neighbor should have high cosine similarity");
851
852 let _ = fs::remove_file(test_file);
853 }
854
855 #[test]
856 fn test_persistence() {
857 let test_file = "test_persist.db";
858 let _ = fs::remove_file(test_file);
859
860 let vectors = vec![
861 vec![0.0, 0.0],
862 vec![1.0, 0.0],
863 vec![0.0, 1.0],
864 vec![1.0, 1.0],
865 ];
866
867 {
869 let _index = DiskANN::build_index(
870 &vectors,
871 2,
872 4,
873 1.2,
874 DistanceMetric::Euclidean,
875 test_file,
876 )
877 .unwrap();
878 }
879
880 let index = DiskANN::open_index(test_file).unwrap();
882 assert_eq!(index.num_vectors, 4);
883 assert_eq!(index.dim, 2);
884
885 let query = vec![0.9, 0.9];
887 let neighbors = index.search(&query, 2, 4);
888 assert_eq!(neighbors[0], 3); let _ = fs::remove_file(test_file);
891 }
892
893 #[test]
894 fn test_graph_connectivity() {
895 let test_file = "test_graph.db";
896 let _ = fs::remove_file(test_file);
897
898 let mut vectors = Vec::new();
900 for i in 0..5 {
901 for j in 0..5 {
902 vectors.push(vec![i as f32, j as f32]);
903 }
904 }
905
906 let index = DiskANN::build_index(
907 &vectors,
908 4, 8, 1.5, DistanceMetric::Euclidean,
912 test_file,
913 )
914 .unwrap();
915
916 for target_idx in 0..vectors.len() {
918 let query = &vectors[target_idx];
919 let neighbors = index.search(query, 10, 32);
921
922 if !neighbors.contains(&(target_idx as u32)) {
925 let first_vec = index.get_vector(neighbors[0] as usize);
927 let dist = euclidean_distance(query, &first_vec);
928 assert!(
929 dist < 2.0,
930 "Vector {} not found but nearest neighbor at distance {} is too far",
931 target_idx, dist
932 );
933 }
934
935 for &neighbor_id in neighbors.iter().take(5) {
937 let neighbor_vec = index.get_vector(neighbor_id as usize);
938 let dist = euclidean_distance(query, &neighbor_vec);
939 assert!(dist < 5.0, "Neighbor should be reasonably close");
940 }
941 }
942
943 let _ = fs::remove_file(test_file);
944 }
945}