1use ipfrs_core::{Cid, Error, Result};
13use memmap2::MmapMut;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fs::OpenOptions;
17use std::path::Path;
18use std::sync::{Arc, RwLock};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DiskANNConfig {
23 pub dimension: usize,
25 pub max_degree: usize,
27 pub queue_size: usize,
29 pub alpha: f32,
31 pub num_entry_points: usize,
33}
34
35impl Default for DiskANNConfig {
36 fn default() -> Self {
37 Self {
38 dimension: 768,
39 max_degree: 64,
40 queue_size: 100,
41 alpha: 1.2,
42 num_entry_points: 4,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49struct IndexHeader {
50 magic: [u8; 8],
52 version: u32,
54 config: DiskANNConfig,
56 num_vectors: usize,
58 graph_offset: u64,
60 vector_offset: u64,
62 cid_mapping_offset: u64,
64}
65
66impl IndexHeader {
67 const MAGIC: [u8; 8] = *b"DISKANN1";
68
69 fn new(config: DiskANNConfig) -> Self {
70 Self {
71 magic: Self::MAGIC,
72 version: 1,
73 config,
74 num_vectors: 0,
75 graph_offset: 0,
76 vector_offset: 0,
77 cid_mapping_offset: 0,
78 }
79 }
80
81 fn validate(&self) -> Result<()> {
82 if self.magic != Self::MAGIC {
83 return Err(Error::InvalidInput(
84 "Invalid DiskANN index file format".to_string(),
85 ));
86 }
87 if self.version != 1 {
88 return Err(Error::InvalidInput(format!(
89 "Unsupported DiskANN version: {}",
90 self.version
91 )));
92 }
93 Ok(())
94 }
95}
96
97#[allow(dead_code)]
99#[derive(Debug, Clone)]
100struct GraphNode {
101 id: usize,
103 neighbors: Vec<usize>,
105}
106
107#[repr(C)]
109#[derive(Debug, Clone, Copy)]
110struct VectorFileHeader {
111 magic: [u8; 8],
113 num_vectors: u64,
115 dimension: u64,
117}
118
119impl VectorFileHeader {
120 const MAGIC: [u8; 8] = *b"VECDATA1";
121 const SIZE: usize = 24; fn new(dimension: usize) -> Self {
124 Self {
125 magic: Self::MAGIC,
126 num_vectors: 0,
127 dimension: dimension as u64,
128 }
129 }
130
131 #[allow(dead_code)]
132 fn validate(&self, expected_dim: usize) -> Result<()> {
133 if self.magic != Self::MAGIC {
134 return Err(Error::InvalidInput(
135 "Invalid vector file format".to_string(),
136 ));
137 }
138 if self.dimension != expected_dim as u64 {
139 return Err(Error::InvalidInput(format!(
140 "Vector dimension mismatch: expected {}, got {}",
141 expected_dim, self.dimension
142 )));
143 }
144 Ok(())
145 }
146
147 fn as_bytes(&self) -> [u8; Self::SIZE] {
148 let mut bytes = [0u8; Self::SIZE];
149 bytes[0..8].copy_from_slice(&self.magic);
150 bytes[8..16].copy_from_slice(&self.num_vectors.to_le_bytes());
151 bytes[16..24].copy_from_slice(&self.dimension.to_le_bytes());
152 bytes
153 }
154
155 #[allow(dead_code)]
156 fn from_bytes(bytes: &[u8]) -> Result<Self> {
157 if bytes.len() < Self::SIZE {
158 return Err(Error::InvalidInput(
159 "Vector file header too small".to_string(),
160 ));
161 }
162
163 let mut magic = [0u8; 8];
164 magic.copy_from_slice(&bytes[0..8]);
165
166 let num_vectors = u64::from_le_bytes(bytes[8..16].try_into().unwrap());
167 let dimension = u64::from_le_bytes(bytes[16..24].try_into().unwrap());
168
169 Ok(Self {
170 magic,
171 num_vectors,
172 dimension,
173 })
174 }
175}
176
177pub struct DiskANNIndex {
179 config: DiskANNConfig,
181 index_path: Arc<RwLock<Option<String>>>,
183 graph_mmap: Arc<RwLock<Option<MmapMut>>>,
185 vector_mmap: Arc<RwLock<Option<MmapMut>>>,
187 vector_file_path: Arc<RwLock<Option<String>>>,
189 id_to_cid: Arc<RwLock<HashMap<usize, Cid>>>,
191 cid_to_id: Arc<RwLock<HashMap<Cid, usize>>>,
192 graph: Arc<RwLock<Vec<Vec<usize>>>>,
194 entry_points: Arc<RwLock<Vec<usize>>>,
196 next_id: Arc<RwLock<usize>>,
198 loaded: Arc<RwLock<bool>>,
200}
201
202impl DiskANNIndex {
203 pub fn new(config: DiskANNConfig) -> Self {
205 Self {
206 config,
207 index_path: Arc::new(RwLock::new(None)),
208 graph_mmap: Arc::new(RwLock::new(None)),
209 vector_mmap: Arc::new(RwLock::new(None)),
210 vector_file_path: Arc::new(RwLock::new(None)),
211 id_to_cid: Arc::new(RwLock::new(HashMap::new())),
212 cid_to_id: Arc::new(RwLock::new(HashMap::new())),
213 graph: Arc::new(RwLock::new(Vec::new())),
214 entry_points: Arc::new(RwLock::new(Vec::new())),
215 next_id: Arc::new(RwLock::new(0)),
216 loaded: Arc::new(RwLock::new(false)),
217 }
218 }
219
220 fn get_vector_file_path(index_path: &str) -> String {
222 format!("{}.vectors", index_path)
223 }
224
225 fn vector_offset(&self, vector_id: usize) -> usize {
227 VectorFileHeader::SIZE + (vector_id * self.config.dimension * std::mem::size_of::<f32>())
228 }
229
230 fn read_vector(&self, vector_id: usize) -> Result<Vec<f32>> {
232 let mmap = self.vector_mmap.read().unwrap();
233 let mmap = mmap
234 .as_ref()
235 .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?;
236
237 let offset = self.vector_offset(vector_id);
238 let vec_size_bytes = self.config.dimension * std::mem::size_of::<f32>();
239
240 if offset + vec_size_bytes > mmap.len() {
241 return Err(Error::InvalidInput(format!(
242 "Vector {} out of bounds",
243 vector_id
244 )));
245 }
246
247 let bytes = &mmap[offset..offset + vec_size_bytes];
248 let floats: Vec<f32> = bytes
249 .chunks_exact(4)
250 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
251 .collect();
252
253 Ok(floats)
254 }
255
256 fn write_vector(&self, vector_id: usize, vector: &[f32]) -> Result<()> {
258 if vector.len() != self.config.dimension {
259 return Err(Error::InvalidInput(format!(
260 "Vector dimension {} doesn't match expected {}",
261 vector.len(),
262 self.config.dimension
263 )));
264 }
265
266 let mut mmap = self.vector_mmap.write().unwrap();
267 let mmap = mmap
268 .as_mut()
269 .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?;
270
271 let offset = self.vector_offset(vector_id);
272 let vec_size_bytes = self.config.dimension * std::mem::size_of::<f32>();
273
274 if offset + vec_size_bytes > mmap.len() {
275 return Err(Error::InvalidInput(format!(
276 "Vector {} out of bounds (mmap size: {}, needed: {})",
277 vector_id,
278 mmap.len(),
279 offset + vec_size_bytes
280 )));
281 }
282
283 let bytes = &mut mmap[offset..offset + vec_size_bytes];
284 for (i, &val) in vector.iter().enumerate() {
285 let val_bytes = val.to_le_bytes();
286 bytes[i * 4..(i + 1) * 4].copy_from_slice(&val_bytes);
287 }
288
289 Ok(())
290 }
291
292 fn update_vector_count(&self, count: usize) -> Result<()> {
294 let mut mmap = self.vector_mmap.write().unwrap();
295 let mmap = mmap
296 .as_mut()
297 .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?;
298
299 let count_bytes = (count as u64).to_le_bytes();
300 mmap[8..16].copy_from_slice(&count_bytes);
301
302 Ok(())
303 }
304
305 fn get_vector_count(&self) -> Result<usize> {
307 let mmap = self.vector_mmap.read().unwrap();
308 let mmap = mmap
309 .as_ref()
310 .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?;
311
312 let count_bytes: [u8; 8] = mmap[8..16].try_into().unwrap();
313 Ok(u64::from_le_bytes(count_bytes) as usize)
314 }
315
316 fn ensure_vector_capacity(&self, required_count: usize) -> Result<()> {
318 let mmap = self.vector_mmap.read().unwrap();
319 let current_size = mmap
320 .as_ref()
321 .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?
322 .len();
323 drop(mmap);
324
325 let required_size = VectorFileHeader::SIZE
326 + (required_count * self.config.dimension * std::mem::size_of::<f32>());
327
328 if required_size > current_size {
329 let new_capacity = (required_count * 2).max(required_count + 1000); let new_size = VectorFileHeader::SIZE
332 + (new_capacity * self.config.dimension * std::mem::size_of::<f32>());
333
334 let vec_path = self
336 .vector_file_path
337 .read()
338 .unwrap()
339 .clone()
340 .ok_or_else(|| Error::InvalidInput("No vector file path set".to_string()))?;
341
342 *self.vector_mmap.write().unwrap() = None;
344
345 let vec_file = OpenOptions::new()
347 .read(true)
348 .write(true)
349 .open(&vec_path)
350 .map_err(Error::Io)?;
351 vec_file.set_len(new_size as u64).map_err(Error::Io)?;
352
353 let new_mmap = unsafe {
355 MmapMut::map_mut(&vec_file)
356 .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
357 };
358
359 *self.vector_mmap.write().unwrap() = Some(new_mmap);
360 }
361
362 Ok(())
363 }
364
365 fn num_vectors(&self) -> usize {
367 self.get_vector_count()
368 .unwrap_or_else(|_| *self.next_id.read().unwrap())
369 }
370
371 pub fn with_defaults(dimension: usize) -> Self {
373 let config = DiskANNConfig {
374 dimension,
375 ..Default::default()
376 };
377 Self::new(config)
378 }
379
380 pub fn create(&mut self, path: impl AsRef<Path>) -> Result<()> {
382 let path = path.as_ref();
383 let path_str = path.to_string_lossy().to_string();
384
385 let file = OpenOptions::new()
387 .read(true)
388 .write(true)
389 .create(true)
390 .truncate(true)
391 .open(path)
392 .map_err(Error::Io)?;
393
394 let header = IndexHeader::new(self.config.clone());
396 let header_bytes = oxicode::serde::encode_to_vec(&header, oxicode::config::standard())
397 .map_err(|e| Error::Serialization(e.to_string()))?;
398
399 let initial_size = header_bytes.len() + 1024 * 1024; file.set_len(initial_size as u64).map_err(Error::Io)?;
402
403 let mut mmap = unsafe {
405 MmapMut::map_mut(&file).map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
406 };
407
408 mmap[..header_bytes.len()].copy_from_slice(&header_bytes);
410
411 let vec_path = Self::get_vector_file_path(&path_str);
413 let vec_file = OpenOptions::new()
414 .read(true)
415 .write(true)
416 .create(true)
417 .truncate(true)
418 .open(&vec_path)
419 .map_err(Error::Io)?;
420
421 let vec_header = VectorFileHeader::new(self.config.dimension);
423 let initial_vec_count = 1000;
424 let vec_file_size = VectorFileHeader::SIZE
425 + (initial_vec_count * self.config.dimension * std::mem::size_of::<f32>());
426 vec_file.set_len(vec_file_size as u64).map_err(Error::Io)?;
427
428 let mut vec_mmap = unsafe {
430 MmapMut::map_mut(&vec_file)
431 .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
432 };
433
434 let header_bytes = vec_header.as_bytes();
436 vec_mmap[..VectorFileHeader::SIZE].copy_from_slice(&header_bytes);
437
438 *self.index_path.write().unwrap() = Some(path_str.clone());
439 *self.vector_file_path.write().unwrap() = Some(vec_path);
440 *self.graph_mmap.write().unwrap() = Some(mmap);
441 *self.vector_mmap.write().unwrap() = Some(vec_mmap);
442 *self.loaded.write().unwrap() = true;
443
444 Ok(())
445 }
446
447 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
449 let path = path.as_ref();
450
451 let file = OpenOptions::new()
453 .read(true)
454 .write(true)
455 .open(path)
456 .map_err(Error::Io)?;
457
458 let mmap = unsafe {
460 MmapMut::map_mut(&file).map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
461 };
462
463 let header: IndexHeader =
465 oxicode::serde::decode_owned_from_slice(&mmap[..1024], oxicode::config::standard())
466 .map(|(v, _)| v)
467 .map_err(|e| Error::Serialization(e.to_string()))?;
468
469 header.validate()?;
470
471 let index = Self::new(header.config);
473 *index.index_path.write().unwrap() = Some(path.to_string_lossy().to_string());
474 *index.graph_mmap.write().unwrap() = Some(mmap);
475 *index.next_id.write().unwrap() = header.num_vectors;
476 *index.loaded.write().unwrap() = true;
477
478 Ok(index)
479 }
480
481 pub fn insert(&mut self, cid: &Cid, vector: &[f32]) -> Result<()> {
483 if !*self.loaded.read().unwrap() {
484 return Err(Error::InvalidInput(
485 "Index not created or loaded".to_string(),
486 ));
487 }
488
489 if vector.len() != self.config.dimension {
490 return Err(Error::InvalidInput(format!(
491 "Vector dimension {} doesn't match index dimension {}",
492 vector.len(),
493 self.config.dimension
494 )));
495 }
496
497 if self.cid_to_id.read().unwrap().contains_key(cid) {
499 return Err(Error::InvalidInput(format!(
500 "CID already in index: {}",
501 cid
502 )));
503 }
504
505 let id = *self.next_id.read().unwrap();
507
508 self.ensure_vector_capacity(id + 1)?;
510
511 self.write_vector(id, vector)?;
513
514 *self.next_id.write().unwrap() += 1;
516 self.update_vector_count(id + 1)?;
517
518 self.id_to_cid.write().unwrap().insert(id, *cid);
520 self.cid_to_id.write().unwrap().insert(*cid, id);
521
522 self.graph.write().unwrap().push(Vec::new());
524
525 if id == 0 {
527 self.entry_points.write().unwrap().push(0);
528 return Ok(());
529 }
530
531 self.vamana_insert(id, vector)?;
533
534 if id.is_multiple_of(1000) && id < 10000 {
536 self.entry_points.write().unwrap().push(id);
537 let mut eps = self.entry_points.write().unwrap();
539 let num_to_drain = if eps.len() > self.config.num_entry_points {
540 eps.len() - self.config.num_entry_points
541 } else {
542 0
543 };
544 if num_to_drain > 0 {
545 eps.drain(0..num_to_drain);
546 }
547 }
548
549 Ok(())
550 }
551
552 fn vamana_insert(&self, new_id: usize, new_vec: &[f32]) -> Result<()> {
554 let neighbors =
556 self.greedy_search_internal(new_vec, self.config.queue_size, self.config.queue_size)?;
557
558 let pruned = self.robust_prune(new_id, new_vec, &neighbors)?;
560
561 let mut graph = self.graph.write().unwrap();
563 graph[new_id] = pruned.clone();
564
565 for &neighbor_id in &pruned {
567 if neighbor_id >= graph.len() {
568 continue;
569 }
570
571 if !graph[neighbor_id].contains(&new_id) {
573 graph[neighbor_id].push(new_id);
574
575 if graph[neighbor_id].len() > self.config.max_degree {
577 let neighbor_vec = self.read_vector(neighbor_id)?;
578 let candidates = graph[neighbor_id].clone();
579
580 let pruned_neighbors =
581 self.robust_prune(neighbor_id, &neighbor_vec, &candidates)?;
582 graph[neighbor_id] = pruned_neighbors;
583 }
584 }
585 }
586
587 Ok(())
588 }
589
590 fn robust_prune(
592 &self,
593 node_id: usize,
594 node_vec: &[f32],
595 candidates: &[usize],
596 ) -> Result<Vec<usize>> {
597 let alpha = self.config.alpha;
598 let max_degree = self.config.max_degree;
599 let num_vecs = self.num_vectors();
600
601 let mut dists: Vec<(usize, f32)> = candidates
603 .iter()
604 .filter(|&&c| c != node_id && c < num_vecs)
605 .filter_map(|&c| {
606 self.read_vector(c).ok().map(|vec| {
607 let dist = self.l2_distance(node_vec, &vec);
608 (c, dist)
609 })
610 })
611 .collect();
612
613 dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
615
616 let mut pruned = Vec::new();
617
618 for (cand_id, cand_dist) in dists {
619 if pruned.len() >= max_degree {
620 break;
621 }
622
623 let mut should_add = true;
625 let cand_vec = self.read_vector(cand_id).ok();
626 if let Some(ref c_vec) = cand_vec {
627 for &selected_id in &pruned {
628 if let Ok(sel_vec) = self.read_vector(selected_id) {
629 let selected_dist = self.l2_distance(c_vec, &sel_vec);
630 if alpha * selected_dist < cand_dist {
631 should_add = false;
632 break;
633 }
634 }
635 }
636 } else {
637 should_add = false;
638 }
639
640 if should_add {
641 pruned.push(cand_id);
642 }
643 }
644
645 Ok(pruned)
646 }
647
648 fn l2_distance<T: AsRef<[f32]>, U: AsRef<[f32]>>(&self, a: T, b: U) -> f32 {
650 a.as_ref()
651 .iter()
652 .zip(b.as_ref().iter())
653 .map(|(x, y)| (x - y) * (x - y))
654 .sum::<f32>()
655 .sqrt()
656 }
657
658 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
660 if !*self.loaded.read().unwrap() {
661 return Err(Error::InvalidInput(
662 "Index not created or loaded".to_string(),
663 ));
664 }
665
666 if query.len() != self.config.dimension {
667 return Err(Error::InvalidInput(format!(
668 "Query dimension {} doesn't match index dimension {}",
669 query.len(),
670 self.config.dimension
671 )));
672 }
673
674 let num_vectors = self.num_vectors();
675 if num_vectors == 0 {
676 return Ok(Vec::new());
677 }
678
679 let search_list_size = k.max(self.config.queue_size);
681 let result_ids = self.greedy_search_internal(query, k, search_list_size)?;
682
683 let id_to_cid = self.id_to_cid.read().unwrap();
685 let results: Vec<SearchResult> = result_ids
686 .iter()
687 .filter_map(|&id| {
688 id_to_cid.get(&id).and_then(|cid| {
689 self.read_vector(id).ok().map(|vec| SearchResult {
690 cid: *cid,
691 distance: self.l2_distance(query, &vec),
692 })
693 })
694 })
695 .collect();
696
697 Ok(results)
698 }
699
700 fn greedy_search_internal(
702 &self,
703 query: &[f32],
704 k: usize,
705 search_list_size: usize,
706 ) -> Result<Vec<usize>> {
707 let graph = self.graph.read().unwrap();
708 let entry_points = self.entry_points.read().unwrap();
709 let num_vecs = self.num_vectors();
710
711 if num_vecs == 0 {
712 return Ok(Vec::new());
713 }
714
715 let start_nodes: Vec<usize> = if entry_points.is_empty() {
717 vec![0]
718 } else {
719 entry_points.clone()
720 };
721
722 let mut visited = vec![false; num_vecs];
724
725 let mut candidates: Vec<(f32, usize)> = Vec::new();
727 let mut results: Vec<(f32, usize)> = Vec::new();
728
729 for &node_id in &start_nodes {
731 if node_id >= num_vecs {
732 continue;
733 }
734 if let Ok(vec) = self.read_vector(node_id) {
735 let dist = self.l2_distance(query, &vec);
736 candidates.push((dist, node_id));
737 results.push((dist, node_id));
738 visited[node_id] = true;
739 }
740 }
741
742 candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
744 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
745
746 while !candidates.is_empty() {
748 let (current_dist, current_id) = candidates.remove(0);
750
751 if results.len() >= search_list_size {
753 let furthest_dist = results[search_list_size - 1].0;
754 if current_dist > furthest_dist {
755 break;
756 }
757 }
758
759 if current_id >= graph.len() {
761 continue;
762 }
763
764 for &neighbor_id in &graph[current_id] {
765 if neighbor_id >= num_vecs || visited[neighbor_id] {
766 continue;
767 }
768
769 visited[neighbor_id] = true;
770 let dist = if let Ok(vec) = self.read_vector(neighbor_id) {
771 self.l2_distance(query, &vec)
772 } else {
773 continue;
774 };
775
776 candidates.push((dist, neighbor_id));
778 candidates
779 .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
780
781 results.push((dist, neighbor_id));
783 results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
784
785 if results.len() > search_list_size {
787 results.truncate(search_list_size);
788 }
789 }
790 }
791
792 Ok(results.iter().take(k).map(|(_, id)| *id).collect())
794 }
795
796 pub fn stats(&self) -> DiskANNStats {
798 DiskANNStats {
799 num_vectors: *self.next_id.read().unwrap(),
800 dimension: self.config.dimension,
801 max_degree: self.config.max_degree,
802 index_loaded: *self.loaded.read().unwrap(),
803 estimated_disk_size: self.estimate_disk_size(),
804 }
805 }
806
807 fn estimate_disk_size(&self) -> usize {
809 let num_vectors = *self.next_id.read().unwrap();
810
811 let header_size = 1024;
813
814 let vector_size = num_vectors * self.config.dimension * 4;
816
817 let graph_size = num_vectors * self.config.max_degree * 4;
819
820 let mapping_size = num_vectors * 40;
822
823 header_size + vector_size + graph_size + mapping_size
824 }
825
826 pub fn is_loaded(&self) -> bool {
828 *self.loaded.read().unwrap()
829 }
830
831 pub fn config(&self) -> &DiskANNConfig {
833 &self.config
834 }
835
836 pub fn save(&self) -> Result<()> {
838 if !*self.loaded.read().unwrap() {
839 return Err(Error::InvalidInput("Index not loaded".to_string()));
840 }
841
842 let path = self
843 .index_path
844 .read()
845 .unwrap()
846 .clone()
847 .ok_or_else(|| Error::InvalidInput("No index path set".to_string()))?;
848
849 let num_vecs = self.num_vectors();
851 let mut vectors = Vec::with_capacity(num_vecs);
852 for i in 0..num_vecs {
853 if let Ok(vec) = self.read_vector(i) {
854 vectors.push(vec);
855 }
856 }
857
858 let graph = self.graph.read().unwrap();
859 let id_to_cid = self.id_to_cid.read().unwrap();
860 let entry_points = self.entry_points.read().unwrap();
861
862 let data = DiskANNData::from_index(
863 vectors,
864 graph.clone(),
865 id_to_cid.clone(),
866 entry_points.clone(),
867 );
868
869 let serialized = oxicode::serde::encode_to_vec(&data, oxicode::config::standard())
871 .map_err(|e| Error::Serialization(e.to_string()))?;
872
873 let temp_path = format!("{}.tmp", path);
875 std::fs::write(&temp_path, &serialized).map_err(Error::Io)?;
876 std::fs::rename(&temp_path, &path).map_err(Error::Io)?;
877
878 Ok(())
879 }
880
881 pub fn flush(&self) -> Result<()> {
883 if let Some(ref mut mmap) = *self.graph_mmap.write().unwrap() {
884 mmap.flush()
885 .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?;
886 }
887 Ok(())
888 }
889
890 pub fn compact(&mut self) -> Result<CompactionStats> {
897 if !*self.loaded.read().unwrap() {
898 return Err(Error::InvalidInput("Index not loaded".to_string()));
899 }
900
901 let start_time = std::time::Instant::now();
902 let old_size = self.num_vectors();
903 let graph = self.graph.read().unwrap();
904
905 let old_graph_edges: usize = graph.iter().map(|neighbors| neighbors.len()).sum();
906
907 let stats = CompactionStats {
910 duration_ms: start_time.elapsed().as_millis() as u64,
911 vectors_before: old_size,
912 vectors_after: old_size,
913 graph_edges_before: old_graph_edges,
914 graph_edges_after: old_graph_edges,
915 bytes_saved: 0,
916 };
917
918 Ok(stats)
919 }
920
921 pub fn prune_graph(&mut self, quality_threshold: f32) -> Result<usize> {
926 if !*self.loaded.read().unwrap() {
927 return Err(Error::InvalidInput("Index not loaded".to_string()));
928 }
929
930 let mut graph = self.graph.write().unwrap();
931 let num_vecs = self.num_vectors();
932 let mut total_pruned = 0;
933
934 for node_id in 0..graph.len() {
935 if node_id >= num_vecs {
936 continue;
937 }
938
939 let node_vec = match self.read_vector(node_id) {
940 Ok(v) => v,
941 Err(_) => continue,
942 };
943 let neighbors = &graph[node_id];
944
945 let mut neighbor_dists: Vec<(usize, f32)> = neighbors
947 .iter()
948 .filter(|&&n| n < num_vecs)
949 .filter_map(|&n| {
950 self.read_vector(n).ok().map(|vec| {
951 let dist = self.l2_distance(&node_vec, &vec);
952 (n, dist)
953 })
954 })
955 .collect();
956
957 neighbor_dists
959 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
960
961 if let Some(&(_, best_dist)) = neighbor_dists.first() {
963 let threshold_dist = best_dist * (1.0 + quality_threshold);
964 let keep_count = neighbor_dists
965 .iter()
966 .filter(|(_, d)| *d <= threshold_dist)
967 .count();
968
969 if keep_count < neighbors.len() {
970 total_pruned += neighbors.len() - keep_count;
971 graph[node_id] = neighbor_dists
972 .iter()
973 .take(keep_count)
974 .map(|(n, _)| *n)
975 .collect();
976 }
977 }
978 }
979
980 Ok(total_pruned)
981 }
982
983 pub fn len(&self) -> usize {
985 *self.next_id.read().unwrap()
986 }
987
988 pub fn is_empty(&self) -> bool {
990 self.len() == 0
991 }
992}
993
994#[derive(Debug, Clone, Serialize, Deserialize)]
996struct DiskANNData {
997 vectors: Vec<Vec<f32>>,
998 graph: Vec<Vec<usize>>,
999 id_to_cid: HashMap<usize, String>,
1000 entry_points: Vec<usize>,
1001}
1002
1003impl DiskANNData {
1004 fn from_index(
1005 vectors: Vec<Vec<f32>>,
1006 graph: Vec<Vec<usize>>,
1007 id_to_cid: HashMap<usize, Cid>,
1008 entry_points: Vec<usize>,
1009 ) -> Self {
1010 let id_to_cid_str = id_to_cid
1011 .into_iter()
1012 .map(|(k, v)| (k, v.to_string()))
1013 .collect();
1014 Self {
1015 vectors,
1016 graph,
1017 id_to_cid: id_to_cid_str,
1018 entry_points,
1019 }
1020 }
1021
1022 #[allow(dead_code)]
1023 fn to_cid_map(&self) -> Result<HashMap<usize, Cid>> {
1024 self.id_to_cid
1025 .iter()
1026 .map(|(k, v)| {
1027 v.parse::<Cid>()
1028 .map(|cid| (*k, cid))
1029 .map_err(|e| Error::InvalidInput(format!("Invalid CID: {}", e)))
1030 })
1031 .collect()
1032 }
1033}
1034
1035#[derive(Debug, Clone)]
1037pub struct CompactionStats {
1038 pub duration_ms: u64,
1040 pub vectors_before: usize,
1042 pub vectors_after: usize,
1044 pub graph_edges_before: usize,
1046 pub graph_edges_after: usize,
1048 pub bytes_saved: usize,
1050}
1051
1052#[derive(Debug, Clone)]
1054pub struct SearchResult {
1055 pub cid: Cid,
1057 pub distance: f32,
1059}
1060
1061#[derive(Debug, Clone)]
1063pub struct DiskANNStats {
1064 pub num_vectors: usize,
1066 pub dimension: usize,
1068 pub max_degree: usize,
1070 pub index_loaded: bool,
1072 pub estimated_disk_size: usize,
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078 use super::*;
1079
1080 #[test]
1081 fn test_diskann_create() {
1082 let config = DiskANNConfig::default();
1083 let mut index = DiskANNIndex::new(config);
1084
1085 let temp_file = "/tmp/test_diskann_index.dat";
1086 assert!(index.create(temp_file).is_ok());
1087 assert!(index.is_loaded());
1088
1089 std::fs::remove_file(temp_file).ok();
1091 }
1092
1093 #[test]
1094 fn test_diskann_stats() {
1095 let index = DiskANNIndex::with_defaults(128);
1096 let stats = index.stats();
1097
1098 assert_eq!(stats.dimension, 128);
1099 assert_eq!(stats.num_vectors, 0);
1100 assert!(!stats.index_loaded);
1101 }
1102
1103 #[test]
1104 fn test_index_header() {
1105 let config = DiskANNConfig::default();
1106 let header = IndexHeader::new(config);
1107
1108 assert_eq!(header.magic, IndexHeader::MAGIC);
1109 assert_eq!(header.version, 1);
1110 assert!(header.validate().is_ok());
1111
1112 let mut bad_header = header.clone();
1114 bad_header.magic = [0; 8];
1115 assert!(bad_header.validate().is_err());
1116 }
1117
1118 #[test]
1119 fn test_diskann_insert_and_search() {
1120 let config = DiskANNConfig {
1121 dimension: 4,
1122 max_degree: 16,
1123 queue_size: 50,
1124 ..Default::default()
1125 };
1126
1127 let mut index = DiskANNIndex::new(config);
1128 let temp_file = "/tmp/test_diskann_vamana.dat";
1129 index.create(temp_file).unwrap();
1130
1131 let vectors = [
1133 vec![1.0, 0.0, 0.0, 0.0],
1134 vec![0.9, 0.1, 0.0, 0.0],
1135 vec![0.0, 1.0, 0.0, 0.0],
1136 vec![0.0, 0.0, 1.0, 0.0],
1137 vec![0.0, 0.0, 0.9, 0.1],
1138 ];
1139
1140 let base_cids = [
1142 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1143 "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1144 "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1145 "bafybeiakou6e7kkxc5qycjkqwucq4zfkfvzmlbf2vlihvqqnfjfzpqrkmq",
1146 "bafybeibscyh5z3uk6fvdidffhybzsxmckblkjhajy4y4uzcglmfwqx67b4",
1147 ];
1148 for (i, vec) in vectors.iter().enumerate() {
1149 let cid: Cid = base_cids[i].parse().unwrap();
1150 index.insert(&cid, vec).unwrap();
1151 }
1152
1153 assert_eq!(index.stats().num_vectors, 5);
1154
1155 let query = vec![1.0, 0.0, 0.0, 0.0];
1157 let results = index.search(&query, 2).unwrap();
1158
1159 assert!(!results.is_empty());
1160 assert!(results.len() <= 2);
1161 assert!(results[0].distance < 0.2);
1163
1164 std::fs::remove_file(temp_file).ok();
1166 }
1167
1168 #[test]
1169 fn test_vamana_graph_construction() {
1170 let config = DiskANNConfig {
1171 dimension: 8,
1172 max_degree: 8,
1173 queue_size: 20,
1174 alpha: 1.2,
1175 ..Default::default()
1176 };
1177
1178 let max_degree = config.max_degree;
1179 let mut index = DiskANNIndex::new(config);
1180 let temp_file = "/tmp/test_vamana_graph.dat";
1181 index.create(temp_file).unwrap();
1182
1183 let base_cids: Vec<&str> = vec![
1185 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1186 "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1187 "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1188 "bafybeiakou6e7kkxc5qycjkqwucq4zfkfvzmlbf2vlihvqqnfjfzpqrkmq",
1189 "bafybeibscyh5z3uk6fvdidffhybzsxmckblkjhajy4y4uzcglmfwqx67b4",
1190 "bafybeiezkzpo2uy4teyix63fjc3vgpxlvhbmwjicxhxx6vaf3ywvkyz5ia",
1191 "bafybeifmyetvpv2uovt7ncnvjcwvshwqrr7zmyh5wpqwmf5mwy3m42xkre",
1192 "bafybeia7lv6vknr6fqjq2jlj3ygbdgzdqxqt7xo3u7dzz6ihfzd3zhd6pi",
1193 "bafybeif2ewg3nqa33yvecifp7jw7p2utbnkh34j7ku44mzs3lpmcbdkjzq",
1194 "bafybeid5cg74fzlh7okcaabfwexdvkiuocwbqhwrqc4x65jyplwsxzvvdq",
1195 "bafybeicy6rxfqlcdadwjfjjvvb7wlbnlrzuzsogpv5snwt46zpqrmihtnq",
1196 "bafybeie2kj53f4wmefncg3rvrvfegwk265iw2psfszftvq3slajlwkjfpm",
1197 "bafybeigk7gjp4y4m4gwvmblvf7mlufsqtfgwyjdqwvwudytucvx7wtnz4e",
1198 "bafybeihbsq7kdawlkzvfj7xttx27t4p52pkllmfevn5l2scgbvmgqcfmfy",
1199 "bafybeiej5vfvbkjbzyeouqxkn25yb2xzdz2igdwmawcbhv66kwfwqnvhzi",
1200 "bafybeigbkbpcxqbrvx56fqf7jb25r5wunzowl45uwmzcbxkwdtixlbtwim",
1201 "bafybeihyfvtf3uiilqvqsvhbphfdudqy7qrjkxqglh26xxvjhtxrkhhbxe",
1202 "bafybeicflzm3r35m4kj5chxjvdwgajq6ljhqpsjq6wdyqnlpfjwwb5nowi",
1203 "bafybeic73hjrp52jxz33zxlz5qthfxumqpyuvqfvawdcskqiqlpuww3vxi",
1204 "bafybeicbh5dkdyiq3gqufk46cktiwwucwl6mzhv6e5xhzmuvzojvykokpy",
1205 ];
1206 for (i, &cid_str) in base_cids.iter().enumerate() {
1207 let cid: Cid = cid_str.parse().unwrap();
1208 let vec: Vec<f32> = (0..8).map(|j| (i as f32 + j as f32) * 0.1).collect();
1209 index.insert(&cid, &vec).unwrap();
1210 }
1211
1212 let graph = index.graph.read().unwrap();
1214 assert_eq!(graph.len(), 20);
1215
1216 for (i, neighbors) in graph.iter().enumerate().skip(1) {
1218 if i < 19 {
1219 assert!(!neighbors.is_empty(), "Node {} should have neighbors", i);
1221 assert!(
1222 neighbors.len() <= max_degree,
1223 "Node {} has too many neighbors: {}",
1224 i,
1225 neighbors.len()
1226 );
1227 }
1228 }
1229
1230 std::fs::remove_file(temp_file).ok();
1232 }
1233
1234 #[test]
1235 fn test_robust_pruning() {
1236 let config = DiskANNConfig {
1237 dimension: 4,
1238 max_degree: 3,
1239 alpha: 1.2,
1240 ..Default::default()
1241 };
1242
1243 let max_degree = config.max_degree;
1244 let mut index = DiskANNIndex::new(config);
1245 let temp_file = "/tmp/test_robust_prune.dat";
1246 index.create(temp_file).unwrap();
1247
1248 index.ensure_vector_capacity(4).unwrap();
1250 index.write_vector(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
1251 index.write_vector(1, &[0.9, 0.1, 0.0, 0.0]).unwrap();
1252 index.write_vector(2, &[0.8, 0.2, 0.0, 0.0]).unwrap();
1253 index.write_vector(3, &[0.0, 1.0, 0.0, 0.0]).unwrap();
1254 index.update_vector_count(4).unwrap();
1255
1256 let node_vec = vec![1.0, 0.0, 0.0, 0.0];
1257 let candidates = vec![1, 2, 3];
1258
1259 let pruned = index.robust_prune(0, &node_vec, &candidates).unwrap();
1260
1261 assert!(pruned.len() <= max_degree);
1263 assert!(pruned.contains(&1));
1265
1266 std::fs::remove_file(temp_file).ok();
1268 }
1269
1270 #[test]
1271 fn test_diskann_save_and_load() {
1272 let config = DiskANNConfig {
1273 dimension: 4,
1274 max_degree: 16,
1275 ..Default::default()
1276 };
1277
1278 let mut index = DiskANNIndex::new(config);
1279 let temp_file = "/tmp/test_diskann_save.dat";
1280 index.create(temp_file).unwrap();
1281
1282 let vectors = [
1284 vec![1.0, 0.0, 0.0, 0.0],
1285 vec![0.0, 1.0, 0.0, 0.0],
1286 vec![0.0, 0.0, 1.0, 0.0],
1287 ];
1288
1289 let base_cids = [
1290 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1291 "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1292 "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1293 ];
1294
1295 for (i, vec) in vectors.iter().enumerate() {
1296 let cid: Cid = base_cids[i].parse().unwrap();
1297 index.insert(&cid, vec).unwrap();
1298 }
1299
1300 assert!(index.save().is_ok());
1302
1303 std::fs::remove_file(temp_file).ok();
1309 }
1310
1311 #[test]
1312 fn test_diskann_flush() {
1313 let config = DiskANNConfig {
1314 dimension: 4,
1315 ..Default::default()
1316 };
1317
1318 let mut index = DiskANNIndex::new(config);
1319 let temp_file = "/tmp/test_diskann_flush.dat";
1320 index.create(temp_file).unwrap();
1321
1322 assert!(index.flush().is_ok());
1324
1325 std::fs::remove_file(temp_file).ok();
1327 }
1328
1329 #[test]
1330 fn test_diskann_compact() {
1331 let config = DiskANNConfig {
1332 dimension: 4,
1333 max_degree: 16,
1334 ..Default::default()
1335 };
1336
1337 let mut index = DiskANNIndex::new(config);
1338 let temp_file = "/tmp/test_diskann_compact.dat";
1339 index.create(temp_file).unwrap();
1340
1341 let vectors = [
1343 vec![1.0, 0.0, 0.0, 0.0],
1344 vec![0.0, 1.0, 0.0, 0.0],
1345 vec![0.0, 0.0, 1.0, 0.0],
1346 ];
1347
1348 let base_cids = [
1349 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1350 "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1351 "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1352 ];
1353
1354 for (i, vec) in vectors.iter().enumerate() {
1355 let cid: Cid = base_cids[i].parse().unwrap();
1356 index.insert(&cid, vec).unwrap();
1357 }
1358
1359 let stats = index.compact().unwrap();
1361 assert_eq!(stats.vectors_before, 3);
1362 assert_eq!(stats.vectors_after, 3);
1363
1364 std::fs::remove_file(temp_file).ok();
1366 }
1367
1368 #[test]
1369 fn test_diskann_prune_graph() {
1370 let config = DiskANNConfig {
1371 dimension: 4,
1372 max_degree: 16,
1373 ..Default::default()
1374 };
1375
1376 let mut index = DiskANNIndex::new(config);
1377 let temp_file = "/tmp/test_diskann_prune.dat";
1378 index.create(temp_file).unwrap();
1379
1380 let vectors = [
1382 vec![1.0, 0.0, 0.0, 0.0],
1383 vec![0.9, 0.1, 0.0, 0.0],
1384 vec![0.8, 0.2, 0.0, 0.0],
1385 vec![0.0, 0.0, 1.0, 0.0],
1386 ];
1387
1388 let base_cids = [
1389 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1390 "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1391 "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1392 "bafybeiakou6e7kkxc5qycjkqwucq4zfkfvzmlbf2vlihvqqnfjfzpqrkmq",
1393 ];
1394
1395 for (i, vec) in vectors.iter().enumerate() {
1396 let cid: Cid = base_cids[i].parse().unwrap();
1397 index.insert(&cid, vec).unwrap();
1398 }
1399
1400 let _pruned = index.prune_graph(0.5).unwrap();
1402 std::fs::remove_file(temp_file).ok();
1406 }
1407
1408 #[test]
1409 fn test_diskann_len_and_is_empty() {
1410 let config = DiskANNConfig {
1411 dimension: 4,
1412 ..Default::default()
1413 };
1414
1415 let mut index = DiskANNIndex::new(config);
1416 let temp_file = "/tmp/test_diskann_len.dat";
1417 index.create(temp_file).unwrap();
1418
1419 assert_eq!(index.len(), 0);
1420 assert!(index.is_empty());
1421
1422 let cid: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
1424 .parse()
1425 .unwrap();
1426 let vec = vec![1.0, 0.0, 0.0, 0.0];
1427 index.insert(&cid, &vec).unwrap();
1428
1429 assert_eq!(index.len(), 1);
1430 assert!(!index.is_empty());
1431
1432 std::fs::remove_file(temp_file).ok();
1434 }
1435}