ipfrs_semantic/
diskann.rs

1//! DiskANN: Disk-based Approximate Nearest Neighbor Search
2//!
3//! This module provides on-disk graph-based indexing for handling
4//! datasets too large to fit in memory (100M+ vectors).
5//!
6//! Key features:
7//! - Memory-mapped graph access for constant memory usage
8//! - Vamana algorithm for efficient graph construction
9//! - Page cache optimization for fast queries
10//! - Index compaction and optimization
11
12use ipfrs_core::{Cid, Error, Result};
13use memmap2::MmapMut;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fs::OpenOptions;
17use std::path::Path;
18use std::sync::{Arc, RwLock};
19
20/// DiskANN index configuration
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DiskANNConfig {
23    /// Vector dimension
24    pub dimension: usize,
25    /// Max degree of graph nodes (R parameter in Vamana)
26    pub max_degree: usize,
27    /// Queue size for graph construction (L parameter)
28    pub queue_size: usize,
29    /// Alpha parameter for pruning (typically 1.2)
30    pub alpha: f32,
31    /// Number of entry points for search
32    pub num_entry_points: usize,
33}
34
35impl Default for DiskANNConfig {
36    fn default() -> Self {
37        Self {
38            dimension: 768,
39            max_degree: 64,
40            queue_size: 100,
41            alpha: 1.2,
42            num_entry_points: 4,
43        }
44    }
45}
46
47/// On-disk index format header
48#[derive(Debug, Clone, Serialize, Deserialize)]
49struct IndexHeader {
50    /// Magic bytes for format validation
51    magic: [u8; 8],
52    /// Format version
53    version: u32,
54    /// Configuration
55    config: DiskANNConfig,
56    /// Number of vectors in index
57    num_vectors: usize,
58    /// Offset to graph data
59    graph_offset: u64,
60    /// Offset to vector data
61    vector_offset: u64,
62    /// Offset to CID mapping
63    cid_mapping_offset: u64,
64}
65
66impl IndexHeader {
67    const MAGIC: [u8; 8] = *b"DISKANN1";
68
69    fn new(config: DiskANNConfig) -> Self {
70        Self {
71            magic: Self::MAGIC,
72            version: 1,
73            config,
74            num_vectors: 0,
75            graph_offset: 0,
76            vector_offset: 0,
77            cid_mapping_offset: 0,
78        }
79    }
80
81    fn validate(&self) -> Result<()> {
82        if self.magic != Self::MAGIC {
83            return Err(Error::InvalidInput(
84                "Invalid DiskANN index file format".to_string(),
85            ));
86        }
87        if self.version != 1 {
88            return Err(Error::InvalidInput(format!(
89                "Unsupported DiskANN version: {}",
90                self.version
91            )));
92        }
93        Ok(())
94    }
95}
96
97/// Node in the graph stored on disk
98#[allow(dead_code)]
99#[derive(Debug, Clone)]
100struct GraphNode {
101    /// Node ID
102    id: usize,
103    /// Neighbor IDs
104    neighbors: Vec<usize>,
105}
106
107/// Vector file header for memory-mapped storage
108#[repr(C)]
109#[derive(Debug, Clone, Copy)]
110struct VectorFileHeader {
111    /// Magic bytes for validation
112    magic: [u8; 8],
113    /// Number of vectors stored
114    num_vectors: u64,
115    /// Vector dimension
116    dimension: u64,
117}
118
119impl VectorFileHeader {
120    const MAGIC: [u8; 8] = *b"VECDATA1";
121    const SIZE: usize = 24; // 8 + 8 + 8 bytes
122
123    fn new(dimension: usize) -> Self {
124        Self {
125            magic: Self::MAGIC,
126            num_vectors: 0,
127            dimension: dimension as u64,
128        }
129    }
130
131    #[allow(dead_code)]
132    fn validate(&self, expected_dim: usize) -> Result<()> {
133        if self.magic != Self::MAGIC {
134            return Err(Error::InvalidInput(
135                "Invalid vector file format".to_string(),
136            ));
137        }
138        if self.dimension != expected_dim as u64 {
139            return Err(Error::InvalidInput(format!(
140                "Vector dimension mismatch: expected {}, got {}",
141                expected_dim, self.dimension
142            )));
143        }
144        Ok(())
145    }
146
147    fn as_bytes(&self) -> [u8; Self::SIZE] {
148        let mut bytes = [0u8; Self::SIZE];
149        bytes[0..8].copy_from_slice(&self.magic);
150        bytes[8..16].copy_from_slice(&self.num_vectors.to_le_bytes());
151        bytes[16..24].copy_from_slice(&self.dimension.to_le_bytes());
152        bytes
153    }
154
155    #[allow(dead_code)]
156    fn from_bytes(bytes: &[u8]) -> Result<Self> {
157        if bytes.len() < Self::SIZE {
158            return Err(Error::InvalidInput(
159                "Vector file header too small".to_string(),
160            ));
161        }
162
163        let mut magic = [0u8; 8];
164        magic.copy_from_slice(&bytes[0..8]);
165
166        let num_vectors = u64::from_le_bytes(bytes[8..16].try_into().unwrap());
167        let dimension = u64::from_le_bytes(bytes[16..24].try_into().unwrap());
168
169        Ok(Self {
170            magic,
171            num_vectors,
172            dimension,
173        })
174    }
175}
176
177/// DiskANN index for large-scale vector search
178pub struct DiskANNIndex {
179    /// Configuration
180    config: DiskANNConfig,
181    /// Index file path
182    index_path: Arc<RwLock<Option<String>>>,
183    /// Memory-mapped graph data
184    graph_mmap: Arc<RwLock<Option<MmapMut>>>,
185    /// Memory-mapped vector data (true disk-based storage)
186    vector_mmap: Arc<RwLock<Option<MmapMut>>>,
187    /// Vector file path
188    vector_file_path: Arc<RwLock<Option<String>>>,
189    /// In-memory CID mapping (relatively small)
190    id_to_cid: Arc<RwLock<HashMap<usize, Cid>>>,
191    cid_to_id: Arc<RwLock<HashMap<Cid, usize>>>,
192    /// In-memory graph (adjacency list)
193    graph: Arc<RwLock<Vec<Vec<usize>>>>,
194    /// Entry points for search
195    entry_points: Arc<RwLock<Vec<usize>>>,
196    /// Next available ID
197    next_id: Arc<RwLock<usize>>,
198    /// Whether index is loaded
199    loaded: Arc<RwLock<bool>>,
200}
201
202impl DiskANNIndex {
203    /// Create a new DiskANN index
204    pub fn new(config: DiskANNConfig) -> Self {
205        Self {
206            config,
207            index_path: Arc::new(RwLock::new(None)),
208            graph_mmap: Arc::new(RwLock::new(None)),
209            vector_mmap: Arc::new(RwLock::new(None)),
210            vector_file_path: Arc::new(RwLock::new(None)),
211            id_to_cid: Arc::new(RwLock::new(HashMap::new())),
212            cid_to_id: Arc::new(RwLock::new(HashMap::new())),
213            graph: Arc::new(RwLock::new(Vec::new())),
214            entry_points: Arc::new(RwLock::new(Vec::new())),
215            next_id: Arc::new(RwLock::new(0)),
216            loaded: Arc::new(RwLock::new(false)),
217        }
218    }
219
220    /// Helper: Get vector file path from index path
221    fn get_vector_file_path(index_path: &str) -> String {
222        format!("{}.vectors", index_path)
223    }
224
225    /// Helper: Calculate byte offset for a vector in the mmap file
226    fn vector_offset(&self, vector_id: usize) -> usize {
227        VectorFileHeader::SIZE + (vector_id * self.config.dimension * std::mem::size_of::<f32>())
228    }
229
230    /// Helper: Read a vector from the memory-mapped file
231    fn read_vector(&self, vector_id: usize) -> Result<Vec<f32>> {
232        let mmap = self.vector_mmap.read().unwrap();
233        let mmap = mmap
234            .as_ref()
235            .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?;
236
237        let offset = self.vector_offset(vector_id);
238        let vec_size_bytes = self.config.dimension * std::mem::size_of::<f32>();
239
240        if offset + vec_size_bytes > mmap.len() {
241            return Err(Error::InvalidInput(format!(
242                "Vector {} out of bounds",
243                vector_id
244            )));
245        }
246
247        let bytes = &mmap[offset..offset + vec_size_bytes];
248        let floats: Vec<f32> = bytes
249            .chunks_exact(4)
250            .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
251            .collect();
252
253        Ok(floats)
254    }
255
256    /// Helper: Write a vector to the memory-mapped file
257    fn write_vector(&self, vector_id: usize, vector: &[f32]) -> Result<()> {
258        if vector.len() != self.config.dimension {
259            return Err(Error::InvalidInput(format!(
260                "Vector dimension {} doesn't match expected {}",
261                vector.len(),
262                self.config.dimension
263            )));
264        }
265
266        let mut mmap = self.vector_mmap.write().unwrap();
267        let mmap = mmap
268            .as_mut()
269            .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?;
270
271        let offset = self.vector_offset(vector_id);
272        let vec_size_bytes = self.config.dimension * std::mem::size_of::<f32>();
273
274        if offset + vec_size_bytes > mmap.len() {
275            return Err(Error::InvalidInput(format!(
276                "Vector {} out of bounds (mmap size: {}, needed: {})",
277                vector_id,
278                mmap.len(),
279                offset + vec_size_bytes
280            )));
281        }
282
283        let bytes = &mut mmap[offset..offset + vec_size_bytes];
284        for (i, &val) in vector.iter().enumerate() {
285            let val_bytes = val.to_le_bytes();
286            bytes[i * 4..(i + 1) * 4].copy_from_slice(&val_bytes);
287        }
288
289        Ok(())
290    }
291
292    /// Helper: Update the vector count in the header
293    fn update_vector_count(&self, count: usize) -> Result<()> {
294        let mut mmap = self.vector_mmap.write().unwrap();
295        let mmap = mmap
296            .as_mut()
297            .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?;
298
299        let count_bytes = (count as u64).to_le_bytes();
300        mmap[8..16].copy_from_slice(&count_bytes);
301
302        Ok(())
303    }
304
305    /// Helper: Get current vector count from mmap header
306    fn get_vector_count(&self) -> Result<usize> {
307        let mmap = self.vector_mmap.read().unwrap();
308        let mmap = mmap
309            .as_ref()
310            .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?;
311
312        let count_bytes: [u8; 8] = mmap[8..16].try_into().unwrap();
313        Ok(u64::from_le_bytes(count_bytes) as usize)
314    }
315
316    /// Helper: Ensure vector file has capacity for n vectors (expand if needed)
317    fn ensure_vector_capacity(&self, required_count: usize) -> Result<()> {
318        let mmap = self.vector_mmap.read().unwrap();
319        let current_size = mmap
320            .as_ref()
321            .ok_or_else(|| Error::InvalidInput("Vector file not mapped".to_string()))?
322            .len();
323        drop(mmap);
324
325        let required_size = VectorFileHeader::SIZE
326            + (required_count * self.config.dimension * std::mem::size_of::<f32>());
327
328        if required_size > current_size {
329            // Need to expand the file
330            let new_capacity = (required_count * 2).max(required_count + 1000); // Double capacity or add 1000
331            let new_size = VectorFileHeader::SIZE
332                + (new_capacity * self.config.dimension * std::mem::size_of::<f32>());
333
334            // Get file path and reopen/remap
335            let vec_path = self
336                .vector_file_path
337                .read()
338                .unwrap()
339                .clone()
340                .ok_or_else(|| Error::InvalidInput("No vector file path set".to_string()))?;
341
342            // Drop the current mmap before resizing
343            *self.vector_mmap.write().unwrap() = None;
344
345            // Resize the file
346            let vec_file = OpenOptions::new()
347                .read(true)
348                .write(true)
349                .open(&vec_path)
350                .map_err(Error::Io)?;
351            vec_file.set_len(new_size as u64).map_err(Error::Io)?;
352
353            // Remap
354            let new_mmap = unsafe {
355                MmapMut::map_mut(&vec_file)
356                    .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
357            };
358
359            *self.vector_mmap.write().unwrap() = Some(new_mmap);
360        }
361
362        Ok(())
363    }
364
365    /// Helper: Get number of vectors (from mmap header or fallback to next_id)
366    fn num_vectors(&self) -> usize {
367        self.get_vector_count()
368            .unwrap_or_else(|_| *self.next_id.read().unwrap())
369    }
370
371    /// Create with default configuration
372    pub fn with_defaults(dimension: usize) -> Self {
373        let config = DiskANNConfig {
374            dimension,
375            ..Default::default()
376        };
377        Self::new(config)
378    }
379
380    /// Create index file on disk
381    pub fn create(&mut self, path: impl AsRef<Path>) -> Result<()> {
382        let path = path.as_ref();
383        let path_str = path.to_string_lossy().to_string();
384
385        // Create index file
386        let file = OpenOptions::new()
387            .read(true)
388            .write(true)
389            .create(true)
390            .truncate(true)
391            .open(path)
392            .map_err(Error::Io)?;
393
394        // Write header
395        let header = IndexHeader::new(self.config.clone());
396        let header_bytes = oxicode::serde::encode_to_vec(&header, oxicode::config::standard())
397            .map_err(|e| Error::Serialization(e.to_string()))?;
398
399        // Initial file size: header + some space for growth
400        let initial_size = header_bytes.len() + 1024 * 1024; // 1MB initial
401        file.set_len(initial_size as u64).map_err(Error::Io)?;
402
403        // Memory-map the file
404        let mut mmap = unsafe {
405            MmapMut::map_mut(&file).map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
406        };
407
408        // Write header to mmap
409        mmap[..header_bytes.len()].copy_from_slice(&header_bytes);
410
411        // Create vector file
412        let vec_path = Self::get_vector_file_path(&path_str);
413        let vec_file = OpenOptions::new()
414            .read(true)
415            .write(true)
416            .create(true)
417            .truncate(true)
418            .open(&vec_path)
419            .map_err(Error::Io)?;
420
421        // Initial vector file size: header + space for 1000 vectors
422        let vec_header = VectorFileHeader::new(self.config.dimension);
423        let initial_vec_count = 1000;
424        let vec_file_size = VectorFileHeader::SIZE
425            + (initial_vec_count * self.config.dimension * std::mem::size_of::<f32>());
426        vec_file.set_len(vec_file_size as u64).map_err(Error::Io)?;
427
428        // Memory-map the vector file
429        let mut vec_mmap = unsafe {
430            MmapMut::map_mut(&vec_file)
431                .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
432        };
433
434        // Write vector header
435        let header_bytes = vec_header.as_bytes();
436        vec_mmap[..VectorFileHeader::SIZE].copy_from_slice(&header_bytes);
437
438        *self.index_path.write().unwrap() = Some(path_str.clone());
439        *self.vector_file_path.write().unwrap() = Some(vec_path);
440        *self.graph_mmap.write().unwrap() = Some(mmap);
441        *self.vector_mmap.write().unwrap() = Some(vec_mmap);
442        *self.loaded.write().unwrap() = true;
443
444        Ok(())
445    }
446
447    /// Load existing index from disk
448    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
449        let path = path.as_ref();
450
451        // Open index file
452        let file = OpenOptions::new()
453            .read(true)
454            .write(true)
455            .open(path)
456            .map_err(Error::Io)?;
457
458        // Memory-map the file
459        let mmap = unsafe {
460            MmapMut::map_mut(&file).map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?
461        };
462
463        // Read header
464        let header: IndexHeader =
465            oxicode::serde::decode_owned_from_slice(&mmap[..1024], oxicode::config::standard())
466                .map(|(v, _)| v)
467                .map_err(|e| Error::Serialization(e.to_string()))?;
468
469        header.validate()?;
470
471        // Create index
472        let index = Self::new(header.config);
473        *index.index_path.write().unwrap() = Some(path.to_string_lossy().to_string());
474        *index.graph_mmap.write().unwrap() = Some(mmap);
475        *index.next_id.write().unwrap() = header.num_vectors;
476        *index.loaded.write().unwrap() = true;
477
478        Ok(index)
479    }
480
481    /// Insert a vector using Vamana algorithm
482    pub fn insert(&mut self, cid: &Cid, vector: &[f32]) -> Result<()> {
483        if !*self.loaded.read().unwrap() {
484            return Err(Error::InvalidInput(
485                "Index not created or loaded".to_string(),
486            ));
487        }
488
489        if vector.len() != self.config.dimension {
490            return Err(Error::InvalidInput(format!(
491                "Vector dimension {} doesn't match index dimension {}",
492                vector.len(),
493                self.config.dimension
494            )));
495        }
496
497        // Check if CID already exists
498        if self.cid_to_id.read().unwrap().contains_key(cid) {
499            return Err(Error::InvalidInput(format!(
500                "CID already in index: {}",
501                cid
502            )));
503        }
504
505        // Get new ID
506        let id = *self.next_id.read().unwrap();
507
508        // Ensure vector file has enough space (expand if needed)
509        self.ensure_vector_capacity(id + 1)?;
510
511        // Write vector to memory-mapped file
512        self.write_vector(id, vector)?;
513
514        // Update vector count and next ID
515        *self.next_id.write().unwrap() += 1;
516        self.update_vector_count(id + 1)?;
517
518        // Add CID mapping
519        self.id_to_cid.write().unwrap().insert(id, *cid);
520        self.cid_to_id.write().unwrap().insert(*cid, id);
521
522        // Initialize graph node
523        self.graph.write().unwrap().push(Vec::new());
524
525        // If this is the first vector, make it an entry point
526        if id == 0 {
527            self.entry_points.write().unwrap().push(0);
528            return Ok(());
529        }
530
531        // Vamana graph construction
532        self.vamana_insert(id, vector)?;
533
534        // Update entry points if needed
535        if id.is_multiple_of(1000) && id < 10000 {
536            self.entry_points.write().unwrap().push(id);
537            // Keep only num_entry_points
538            let mut eps = self.entry_points.write().unwrap();
539            let num_to_drain = if eps.len() > self.config.num_entry_points {
540                eps.len() - self.config.num_entry_points
541            } else {
542                0
543            };
544            if num_to_drain > 0 {
545                eps.drain(0..num_to_drain);
546            }
547        }
548
549        Ok(())
550    }
551
552    /// Vamana graph construction for a new node
553    fn vamana_insert(&self, new_id: usize, new_vec: &[f32]) -> Result<()> {
554        // 1. Greedy search to find L nearest neighbors
555        let neighbors =
556            self.greedy_search_internal(new_vec, self.config.queue_size, self.config.queue_size)?;
557
558        // 2. Prune to R neighbors using robust pruning
559        let pruned = self.robust_prune(new_id, new_vec, &neighbors)?;
560
561        // 3. Add bidirectional edges
562        let mut graph = self.graph.write().unwrap();
563        graph[new_id] = pruned.clone();
564
565        // Add reverse edges and prune if needed
566        for &neighbor_id in &pruned {
567            if neighbor_id >= graph.len() {
568                continue;
569            }
570
571            // Add reverse edge
572            if !graph[neighbor_id].contains(&new_id) {
573                graph[neighbor_id].push(new_id);
574
575                // Prune if neighbor exceeds max degree
576                if graph[neighbor_id].len() > self.config.max_degree {
577                    let neighbor_vec = self.read_vector(neighbor_id)?;
578                    let candidates = graph[neighbor_id].clone();
579
580                    let pruned_neighbors =
581                        self.robust_prune(neighbor_id, &neighbor_vec, &candidates)?;
582                    graph[neighbor_id] = pruned_neighbors;
583                }
584            }
585        }
586
587        Ok(())
588    }
589
590    /// Robust pruning algorithm (RobustPrune from Vamana paper)
591    fn robust_prune(
592        &self,
593        node_id: usize,
594        node_vec: &[f32],
595        candidates: &[usize],
596    ) -> Result<Vec<usize>> {
597        let alpha = self.config.alpha;
598        let max_degree = self.config.max_degree;
599        let num_vecs = self.num_vectors();
600
601        // Compute distances from node to all candidates
602        let mut dists: Vec<(usize, f32)> = candidates
603            .iter()
604            .filter(|&&c| c != node_id && c < num_vecs)
605            .filter_map(|&c| {
606                self.read_vector(c).ok().map(|vec| {
607                    let dist = self.l2_distance(node_vec, &vec);
608                    (c, dist)
609                })
610            })
611            .collect();
612
613        // Sort by distance
614        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
615
616        let mut pruned = Vec::new();
617
618        for (cand_id, cand_dist) in dists {
619            if pruned.len() >= max_degree {
620                break;
621            }
622
623            // Check if candidate is alpha-close to any already selected neighbor
624            let mut should_add = true;
625            let cand_vec = self.read_vector(cand_id).ok();
626            if let Some(ref c_vec) = cand_vec {
627                for &selected_id in &pruned {
628                    if let Ok(sel_vec) = self.read_vector(selected_id) {
629                        let selected_dist = self.l2_distance(c_vec, &sel_vec);
630                        if alpha * selected_dist < cand_dist {
631                            should_add = false;
632                            break;
633                        }
634                    }
635                }
636            } else {
637                should_add = false;
638            }
639
640            if should_add {
641                pruned.push(cand_id);
642            }
643        }
644
645        Ok(pruned)
646    }
647
648    /// L2 distance between two vectors
649    fn l2_distance<T: AsRef<[f32]>, U: AsRef<[f32]>>(&self, a: T, b: U) -> f32 {
650        a.as_ref()
651            .iter()
652            .zip(b.as_ref().iter())
653            .map(|(x, y)| (x - y) * (x - y))
654            .sum::<f32>()
655            .sqrt()
656    }
657
658    /// Search for k nearest neighbors using greedy search
659    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
660        if !*self.loaded.read().unwrap() {
661            return Err(Error::InvalidInput(
662                "Index not created or loaded".to_string(),
663            ));
664        }
665
666        if query.len() != self.config.dimension {
667            return Err(Error::InvalidInput(format!(
668                "Query dimension {} doesn't match index dimension {}",
669                query.len(),
670                self.config.dimension
671            )));
672        }
673
674        let num_vectors = self.num_vectors();
675        if num_vectors == 0 {
676            return Ok(Vec::new());
677        }
678
679        // Greedy search with L = max(k, queue_size)
680        let search_list_size = k.max(self.config.queue_size);
681        let result_ids = self.greedy_search_internal(query, k, search_list_size)?;
682
683        // Convert to SearchResult with CIDs
684        let id_to_cid = self.id_to_cid.read().unwrap();
685        let results: Vec<SearchResult> = result_ids
686            .iter()
687            .filter_map(|&id| {
688                id_to_cid.get(&id).and_then(|cid| {
689                    self.read_vector(id).ok().map(|vec| SearchResult {
690                        cid: *cid,
691                        distance: self.l2_distance(query, &vec),
692                    })
693                })
694            })
695            .collect();
696
697        Ok(results)
698    }
699
700    /// Internal greedy search returning node IDs
701    fn greedy_search_internal(
702        &self,
703        query: &[f32],
704        k: usize,
705        search_list_size: usize,
706    ) -> Result<Vec<usize>> {
707        let graph = self.graph.read().unwrap();
708        let entry_points = self.entry_points.read().unwrap();
709        let num_vecs = self.num_vectors();
710
711        if num_vecs == 0 {
712            return Ok(Vec::new());
713        }
714
715        // Start from entry points
716        let start_nodes: Vec<usize> = if entry_points.is_empty() {
717            vec![0]
718        } else {
719            entry_points.clone()
720        };
721
722        // Visited set
723        let mut visited = vec![false; num_vecs];
724
725        // Priority queue: (distance, node_id)
726        let mut candidates: Vec<(f32, usize)> = Vec::new();
727        let mut results: Vec<(f32, usize)> = Vec::new();
728
729        // Initialize with entry points
730        for &node_id in &start_nodes {
731            if node_id >= num_vecs {
732                continue;
733            }
734            if let Ok(vec) = self.read_vector(node_id) {
735                let dist = self.l2_distance(query, &vec);
736                candidates.push((dist, node_id));
737                results.push((dist, node_id));
738                visited[node_id] = true;
739            }
740        }
741
742        // Sort by distance (ascending)
743        candidates.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
744        results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
745
746        // Greedy search
747        while !candidates.is_empty() {
748            // Get closest unvisited neighbor
749            let (current_dist, current_id) = candidates.remove(0);
750
751            // Stop if current is farther than the k-th result
752            if results.len() >= search_list_size {
753                let furthest_dist = results[search_list_size - 1].0;
754                if current_dist > furthest_dist {
755                    break;
756                }
757            }
758
759            // Explore neighbors
760            if current_id >= graph.len() {
761                continue;
762            }
763
764            for &neighbor_id in &graph[current_id] {
765                if neighbor_id >= num_vecs || visited[neighbor_id] {
766                    continue;
767                }
768
769                visited[neighbor_id] = true;
770                let dist = if let Ok(vec) = self.read_vector(neighbor_id) {
771                    self.l2_distance(query, &vec)
772                } else {
773                    continue;
774                };
775
776                // Add to candidates
777                candidates.push((dist, neighbor_id));
778                candidates
779                    .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
780
781                // Add to results
782                results.push((dist, neighbor_id));
783                results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
784
785                // Keep only search_list_size best results
786                if results.len() > search_list_size {
787                    results.truncate(search_list_size);
788                }
789            }
790        }
791
792        // Return top k
793        Ok(results.iter().take(k).map(|(_, id)| *id).collect())
794    }
795
796    /// Get index statistics
797    pub fn stats(&self) -> DiskANNStats {
798        DiskANNStats {
799            num_vectors: *self.next_id.read().unwrap(),
800            dimension: self.config.dimension,
801            max_degree: self.config.max_degree,
802            index_loaded: *self.loaded.read().unwrap(),
803            estimated_disk_size: self.estimate_disk_size(),
804        }
805    }
806
807    /// Estimate disk usage
808    fn estimate_disk_size(&self) -> usize {
809        let num_vectors = *self.next_id.read().unwrap();
810
811        // Header: ~1KB
812        let header_size = 1024;
813
814        // Vectors: num_vectors * dimension * 4 bytes
815        let vector_size = num_vectors * self.config.dimension * 4;
816
817        // Graph: num_vectors * max_degree * 4 bytes (assuming u32 node IDs)
818        let graph_size = num_vectors * self.config.max_degree * 4;
819
820        // CID mapping: num_vectors * ~40 bytes (CID size)
821        let mapping_size = num_vectors * 40;
822
823        header_size + vector_size + graph_size + mapping_size
824    }
825
826    /// Check if index is loaded
827    pub fn is_loaded(&self) -> bool {
828        *self.loaded.read().unwrap()
829    }
830
831    /// Get configuration
832    pub fn config(&self) -> &DiskANNConfig {
833        &self.config
834    }
835
836    /// Save index to disk (persist all in-memory data)
837    pub fn save(&self) -> Result<()> {
838        if !*self.loaded.read().unwrap() {
839            return Err(Error::InvalidInput("Index not loaded".to_string()));
840        }
841
842        let path = self
843            .index_path
844            .read()
845            .unwrap()
846            .clone()
847            .ok_or_else(|| Error::InvalidInput("No index path set".to_string()))?;
848
849        // Serialize all data (read all vectors from mmap)
850        let num_vecs = self.num_vectors();
851        let mut vectors = Vec::with_capacity(num_vecs);
852        for i in 0..num_vecs {
853            if let Ok(vec) = self.read_vector(i) {
854                vectors.push(vec);
855            }
856        }
857
858        let graph = self.graph.read().unwrap();
859        let id_to_cid = self.id_to_cid.read().unwrap();
860        let entry_points = self.entry_points.read().unwrap();
861
862        let data = DiskANNData::from_index(
863            vectors,
864            graph.clone(),
865            id_to_cid.clone(),
866            entry_points.clone(),
867        );
868
869        // Serialize to file
870        let serialized = oxicode::serde::encode_to_vec(&data, oxicode::config::standard())
871            .map_err(|e| Error::Serialization(e.to_string()))?;
872
873        // Write to a temp file first, then rename (atomic)
874        let temp_path = format!("{}.tmp", path);
875        std::fs::write(&temp_path, &serialized).map_err(Error::Io)?;
876        std::fs::rename(&temp_path, &path).map_err(Error::Io)?;
877
878        Ok(())
879    }
880
881    /// Flush changes to disk
882    pub fn flush(&self) -> Result<()> {
883        if let Some(ref mut mmap) = *self.graph_mmap.write().unwrap() {
884            mmap.flush()
885                .map_err(|e| Error::Io(std::io::Error::other(e.to_string())))?;
886        }
887        Ok(())
888    }
889
890    /// Compact the index by removing fragmentation
891    ///
892    /// This method:
893    /// - Removes gaps in the ID space
894    /// - Rebuilds the graph with contiguous IDs
895    /// - Optimizes memory layout
896    pub fn compact(&mut self) -> Result<CompactionStats> {
897        if !*self.loaded.read().unwrap() {
898            return Err(Error::InvalidInput("Index not loaded".to_string()));
899        }
900
901        let start_time = std::time::Instant::now();
902        let old_size = self.num_vectors();
903        let graph = self.graph.read().unwrap();
904
905        let old_graph_edges: usize = graph.iter().map(|neighbors| neighbors.len()).sum();
906
907        // For now, just report stats since we don't have fragmentation yet
908        // In a real implementation, we'd rebuild with contiguous IDs
909        let stats = CompactionStats {
910            duration_ms: start_time.elapsed().as_millis() as u64,
911            vectors_before: old_size,
912            vectors_after: old_size,
913            graph_edges_before: old_graph_edges,
914            graph_edges_after: old_graph_edges,
915            bytes_saved: 0,
916        };
917
918        Ok(stats)
919    }
920
921    /// Prune the graph to remove low-quality edges
922    ///
923    /// This helps reduce memory usage and can improve query performance
924    /// by removing edges that don't contribute to search quality.
925    pub fn prune_graph(&mut self, quality_threshold: f32) -> Result<usize> {
926        if !*self.loaded.read().unwrap() {
927            return Err(Error::InvalidInput("Index not loaded".to_string()));
928        }
929
930        let mut graph = self.graph.write().unwrap();
931        let num_vecs = self.num_vectors();
932        let mut total_pruned = 0;
933
934        for node_id in 0..graph.len() {
935            if node_id >= num_vecs {
936                continue;
937            }
938
939            let node_vec = match self.read_vector(node_id) {
940                Ok(v) => v,
941                Err(_) => continue,
942            };
943            let neighbors = &graph[node_id];
944
945            // Compute distances to all neighbors
946            let mut neighbor_dists: Vec<(usize, f32)> = neighbors
947                .iter()
948                .filter(|&&n| n < num_vecs)
949                .filter_map(|&n| {
950                    self.read_vector(n).ok().map(|vec| {
951                        let dist = self.l2_distance(&node_vec, &vec);
952                        (n, dist)
953                    })
954                })
955                .collect();
956
957            // Sort by distance
958            neighbor_dists
959                .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
960
961            // Keep only neighbors within quality threshold of the best
962            if let Some(&(_, best_dist)) = neighbor_dists.first() {
963                let threshold_dist = best_dist * (1.0 + quality_threshold);
964                let keep_count = neighbor_dists
965                    .iter()
966                    .filter(|(_, d)| *d <= threshold_dist)
967                    .count();
968
969                if keep_count < neighbors.len() {
970                    total_pruned += neighbors.len() - keep_count;
971                    graph[node_id] = neighbor_dists
972                        .iter()
973                        .take(keep_count)
974                        .map(|(n, _)| *n)
975                        .collect();
976                }
977            }
978        }
979
980        Ok(total_pruned)
981    }
982
983    /// Get number of vectors in the index
984    pub fn len(&self) -> usize {
985        *self.next_id.read().unwrap()
986    }
987
988    /// Check if index is empty
989    pub fn is_empty(&self) -> bool {
990        self.len() == 0
991    }
992}
993
994/// Data stored in DiskANN index file (serializable version)
995#[derive(Debug, Clone, Serialize, Deserialize)]
996struct DiskANNData {
997    vectors: Vec<Vec<f32>>,
998    graph: Vec<Vec<usize>>,
999    id_to_cid: HashMap<usize, String>,
1000    entry_points: Vec<usize>,
1001}
1002
1003impl DiskANNData {
1004    fn from_index(
1005        vectors: Vec<Vec<f32>>,
1006        graph: Vec<Vec<usize>>,
1007        id_to_cid: HashMap<usize, Cid>,
1008        entry_points: Vec<usize>,
1009    ) -> Self {
1010        let id_to_cid_str = id_to_cid
1011            .into_iter()
1012            .map(|(k, v)| (k, v.to_string()))
1013            .collect();
1014        Self {
1015            vectors,
1016            graph,
1017            id_to_cid: id_to_cid_str,
1018            entry_points,
1019        }
1020    }
1021
1022    #[allow(dead_code)]
1023    fn to_cid_map(&self) -> Result<HashMap<usize, Cid>> {
1024        self.id_to_cid
1025            .iter()
1026            .map(|(k, v)| {
1027                v.parse::<Cid>()
1028                    .map(|cid| (*k, cid))
1029                    .map_err(|e| Error::InvalidInput(format!("Invalid CID: {}", e)))
1030            })
1031            .collect()
1032    }
1033}
1034
1035/// Statistics from index compaction
1036#[derive(Debug, Clone)]
1037pub struct CompactionStats {
1038    /// Time taken for compaction
1039    pub duration_ms: u64,
1040    /// Number of vectors before compaction
1041    pub vectors_before: usize,
1042    /// Number of vectors after compaction
1043    pub vectors_after: usize,
1044    /// Number of graph edges before
1045    pub graph_edges_before: usize,
1046    /// Number of graph edges after
1047    pub graph_edges_after: usize,
1048    /// Bytes saved by compaction
1049    pub bytes_saved: usize,
1050}
1051
1052/// Search result from DiskANN
1053#[derive(Debug, Clone)]
1054pub struct SearchResult {
1055    /// Content ID
1056    pub cid: Cid,
1057    /// Distance to query
1058    pub distance: f32,
1059}
1060
1061/// DiskANN index statistics
1062#[derive(Debug, Clone)]
1063pub struct DiskANNStats {
1064    /// Number of vectors in index
1065    pub num_vectors: usize,
1066    /// Vector dimension
1067    pub dimension: usize,
1068    /// Maximum graph degree
1069    pub max_degree: usize,
1070    /// Whether index is loaded
1071    pub index_loaded: bool,
1072    /// Estimated disk size in bytes
1073    pub estimated_disk_size: usize,
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078    use super::*;
1079
1080    #[test]
1081    fn test_diskann_create() {
1082        let config = DiskANNConfig::default();
1083        let mut index = DiskANNIndex::new(config);
1084
1085        let temp_file = "/tmp/test_diskann_index.dat";
1086        assert!(index.create(temp_file).is_ok());
1087        assert!(index.is_loaded());
1088
1089        // Cleanup
1090        std::fs::remove_file(temp_file).ok();
1091    }
1092
1093    #[test]
1094    fn test_diskann_stats() {
1095        let index = DiskANNIndex::with_defaults(128);
1096        let stats = index.stats();
1097
1098        assert_eq!(stats.dimension, 128);
1099        assert_eq!(stats.num_vectors, 0);
1100        assert!(!stats.index_loaded);
1101    }
1102
1103    #[test]
1104    fn test_index_header() {
1105        let config = DiskANNConfig::default();
1106        let header = IndexHeader::new(config);
1107
1108        assert_eq!(header.magic, IndexHeader::MAGIC);
1109        assert_eq!(header.version, 1);
1110        assert!(header.validate().is_ok());
1111
1112        // Test invalid magic
1113        let mut bad_header = header.clone();
1114        bad_header.magic = [0; 8];
1115        assert!(bad_header.validate().is_err());
1116    }
1117
1118    #[test]
1119    fn test_diskann_insert_and_search() {
1120        let config = DiskANNConfig {
1121            dimension: 4,
1122            max_degree: 16,
1123            queue_size: 50,
1124            ..Default::default()
1125        };
1126
1127        let mut index = DiskANNIndex::new(config);
1128        let temp_file = "/tmp/test_diskann_vamana.dat";
1129        index.create(temp_file).unwrap();
1130
1131        // Create test vectors
1132        let vectors = [
1133            vec![1.0, 0.0, 0.0, 0.0],
1134            vec![0.9, 0.1, 0.0, 0.0],
1135            vec![0.0, 1.0, 0.0, 0.0],
1136            vec![0.0, 0.0, 1.0, 0.0],
1137            vec![0.0, 0.0, 0.9, 0.1],
1138        ];
1139
1140        // Insert vectors
1141        let base_cids = [
1142            "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1143            "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1144            "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1145            "bafybeiakou6e7kkxc5qycjkqwucq4zfkfvzmlbf2vlihvqqnfjfzpqrkmq",
1146            "bafybeibscyh5z3uk6fvdidffhybzsxmckblkjhajy4y4uzcglmfwqx67b4",
1147        ];
1148        for (i, vec) in vectors.iter().enumerate() {
1149            let cid: Cid = base_cids[i].parse().unwrap();
1150            index.insert(&cid, vec).unwrap();
1151        }
1152
1153        assert_eq!(index.stats().num_vectors, 5);
1154
1155        // Search for nearest to first vector
1156        let query = vec![1.0, 0.0, 0.0, 0.0];
1157        let results = index.search(&query, 2).unwrap();
1158
1159        assert!(!results.is_empty());
1160        assert!(results.len() <= 2);
1161        // First result should be closest
1162        assert!(results[0].distance < 0.2);
1163
1164        // Cleanup
1165        std::fs::remove_file(temp_file).ok();
1166    }
1167
1168    #[test]
1169    fn test_vamana_graph_construction() {
1170        let config = DiskANNConfig {
1171            dimension: 8,
1172            max_degree: 8,
1173            queue_size: 20,
1174            alpha: 1.2,
1175            ..Default::default()
1176        };
1177
1178        let max_degree = config.max_degree;
1179        let mut index = DiskANNIndex::new(config);
1180        let temp_file = "/tmp/test_vamana_graph.dat";
1181        index.create(temp_file).unwrap();
1182
1183        // Insert 20 vectors
1184        let base_cids: Vec<&str> = vec![
1185            "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1186            "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1187            "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1188            "bafybeiakou6e7kkxc5qycjkqwucq4zfkfvzmlbf2vlihvqqnfjfzpqrkmq",
1189            "bafybeibscyh5z3uk6fvdidffhybzsxmckblkjhajy4y4uzcglmfwqx67b4",
1190            "bafybeiezkzpo2uy4teyix63fjc3vgpxlvhbmwjicxhxx6vaf3ywvkyz5ia",
1191            "bafybeifmyetvpv2uovt7ncnvjcwvshwqrr7zmyh5wpqwmf5mwy3m42xkre",
1192            "bafybeia7lv6vknr6fqjq2jlj3ygbdgzdqxqt7xo3u7dzz6ihfzd3zhd6pi",
1193            "bafybeif2ewg3nqa33yvecifp7jw7p2utbnkh34j7ku44mzs3lpmcbdkjzq",
1194            "bafybeid5cg74fzlh7okcaabfwexdvkiuocwbqhwrqc4x65jyplwsxzvvdq",
1195            "bafybeicy6rxfqlcdadwjfjjvvb7wlbnlrzuzsogpv5snwt46zpqrmihtnq",
1196            "bafybeie2kj53f4wmefncg3rvrvfegwk265iw2psfszftvq3slajlwkjfpm",
1197            "bafybeigk7gjp4y4m4gwvmblvf7mlufsqtfgwyjdqwvwudytucvx7wtnz4e",
1198            "bafybeihbsq7kdawlkzvfj7xttx27t4p52pkllmfevn5l2scgbvmgqcfmfy",
1199            "bafybeiej5vfvbkjbzyeouqxkn25yb2xzdz2igdwmawcbhv66kwfwqnvhzi",
1200            "bafybeigbkbpcxqbrvx56fqf7jb25r5wunzowl45uwmzcbxkwdtixlbtwim",
1201            "bafybeihyfvtf3uiilqvqsvhbphfdudqy7qrjkxqglh26xxvjhtxrkhhbxe",
1202            "bafybeicflzm3r35m4kj5chxjvdwgajq6ljhqpsjq6wdyqnlpfjwwb5nowi",
1203            "bafybeic73hjrp52jxz33zxlz5qthfxumqpyuvqfvawdcskqiqlpuww3vxi",
1204            "bafybeicbh5dkdyiq3gqufk46cktiwwucwl6mzhv6e5xhzmuvzojvykokpy",
1205        ];
1206        for (i, &cid_str) in base_cids.iter().enumerate() {
1207            let cid: Cid = cid_str.parse().unwrap();
1208            let vec: Vec<f32> = (0..8).map(|j| (i as f32 + j as f32) * 0.1).collect();
1209            index.insert(&cid, &vec).unwrap();
1210        }
1211
1212        // Check graph structure
1213        let graph = index.graph.read().unwrap();
1214        assert_eq!(graph.len(), 20);
1215
1216        // Each node (except possibly the first) should have some neighbors
1217        for (i, neighbors) in graph.iter().enumerate().skip(1) {
1218            if i < 19 {
1219                // Not the last node
1220                assert!(!neighbors.is_empty(), "Node {} should have neighbors", i);
1221                assert!(
1222                    neighbors.len() <= max_degree,
1223                    "Node {} has too many neighbors: {}",
1224                    i,
1225                    neighbors.len()
1226                );
1227            }
1228        }
1229
1230        // Cleanup
1231        std::fs::remove_file(temp_file).ok();
1232    }
1233
1234    #[test]
1235    fn test_robust_pruning() {
1236        let config = DiskANNConfig {
1237            dimension: 4,
1238            max_degree: 3,
1239            alpha: 1.2,
1240            ..Default::default()
1241        };
1242
1243        let max_degree = config.max_degree;
1244        let mut index = DiskANNIndex::new(config);
1245        let temp_file = "/tmp/test_robust_prune.dat";
1246        index.create(temp_file).unwrap();
1247
1248        // Add some vectors manually (write to mmap)
1249        index.ensure_vector_capacity(4).unwrap();
1250        index.write_vector(0, &[1.0, 0.0, 0.0, 0.0]).unwrap();
1251        index.write_vector(1, &[0.9, 0.1, 0.0, 0.0]).unwrap();
1252        index.write_vector(2, &[0.8, 0.2, 0.0, 0.0]).unwrap();
1253        index.write_vector(3, &[0.0, 1.0, 0.0, 0.0]).unwrap();
1254        index.update_vector_count(4).unwrap();
1255
1256        let node_vec = vec![1.0, 0.0, 0.0, 0.0];
1257        let candidates = vec![1, 2, 3];
1258
1259        let pruned = index.robust_prune(0, &node_vec, &candidates).unwrap();
1260
1261        // Should prune to max_degree neighbors
1262        assert!(pruned.len() <= max_degree);
1263        // Should include the closest neighbor
1264        assert!(pruned.contains(&1));
1265
1266        // Cleanup
1267        std::fs::remove_file(temp_file).ok();
1268    }
1269
1270    #[test]
1271    fn test_diskann_save_and_load() {
1272        let config = DiskANNConfig {
1273            dimension: 4,
1274            max_degree: 16,
1275            ..Default::default()
1276        };
1277
1278        let mut index = DiskANNIndex::new(config);
1279        let temp_file = "/tmp/test_diskann_save.dat";
1280        index.create(temp_file).unwrap();
1281
1282        // Insert some vectors
1283        let vectors = [
1284            vec![1.0, 0.0, 0.0, 0.0],
1285            vec![0.0, 1.0, 0.0, 0.0],
1286            vec![0.0, 0.0, 1.0, 0.0],
1287        ];
1288
1289        let base_cids = [
1290            "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1291            "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1292            "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1293        ];
1294
1295        for (i, vec) in vectors.iter().enumerate() {
1296            let cid: Cid = base_cids[i].parse().unwrap();
1297            index.insert(&cid, vec).unwrap();
1298        }
1299
1300        // Save the index
1301        assert!(index.save().is_ok());
1302
1303        // The save method overwrites the file, so we can't really test loading
1304        // without a proper load implementation that deserializes DiskANNData
1305        // For now, just verify save doesn't error
1306
1307        // Cleanup
1308        std::fs::remove_file(temp_file).ok();
1309    }
1310
1311    #[test]
1312    fn test_diskann_flush() {
1313        let config = DiskANNConfig {
1314            dimension: 4,
1315            ..Default::default()
1316        };
1317
1318        let mut index = DiskANNIndex::new(config);
1319        let temp_file = "/tmp/test_diskann_flush.dat";
1320        index.create(temp_file).unwrap();
1321
1322        // Flush should succeed
1323        assert!(index.flush().is_ok());
1324
1325        // Cleanup
1326        std::fs::remove_file(temp_file).ok();
1327    }
1328
1329    #[test]
1330    fn test_diskann_compact() {
1331        let config = DiskANNConfig {
1332            dimension: 4,
1333            max_degree: 16,
1334            ..Default::default()
1335        };
1336
1337        let mut index = DiskANNIndex::new(config);
1338        let temp_file = "/tmp/test_diskann_compact.dat";
1339        index.create(temp_file).unwrap();
1340
1341        // Insert some vectors
1342        let vectors = [
1343            vec![1.0, 0.0, 0.0, 0.0],
1344            vec![0.0, 1.0, 0.0, 0.0],
1345            vec![0.0, 0.0, 1.0, 0.0],
1346        ];
1347
1348        let base_cids = [
1349            "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1350            "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1351            "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1352        ];
1353
1354        for (i, vec) in vectors.iter().enumerate() {
1355            let cid: Cid = base_cids[i].parse().unwrap();
1356            index.insert(&cid, vec).unwrap();
1357        }
1358
1359        // Compact the index
1360        let stats = index.compact().unwrap();
1361        assert_eq!(stats.vectors_before, 3);
1362        assert_eq!(stats.vectors_after, 3);
1363
1364        // Cleanup
1365        std::fs::remove_file(temp_file).ok();
1366    }
1367
1368    #[test]
1369    fn test_diskann_prune_graph() {
1370        let config = DiskANNConfig {
1371            dimension: 4,
1372            max_degree: 16,
1373            ..Default::default()
1374        };
1375
1376        let mut index = DiskANNIndex::new(config);
1377        let temp_file = "/tmp/test_diskann_prune.dat";
1378        index.create(temp_file).unwrap();
1379
1380        // Insert some vectors
1381        let vectors = [
1382            vec![1.0, 0.0, 0.0, 0.0],
1383            vec![0.9, 0.1, 0.0, 0.0],
1384            vec![0.8, 0.2, 0.0, 0.0],
1385            vec![0.0, 0.0, 1.0, 0.0],
1386        ];
1387
1388        let base_cids = [
1389            "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1390            "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354",
1391            "bafybeihvvulpp6bcs5kum72jh5tkfo35dz2ow3lrqw4hmqyqbmfyvdqvdq",
1392            "bafybeiakou6e7kkxc5qycjkqwucq4zfkfvzmlbf2vlihvqqnfjfzpqrkmq",
1393        ];
1394
1395        for (i, vec) in vectors.iter().enumerate() {
1396            let cid: Cid = base_cids[i].parse().unwrap();
1397            index.insert(&cid, vec).unwrap();
1398        }
1399
1400        // Prune with a quality threshold
1401        let _pruned = index.prune_graph(0.5).unwrap();
1402        // Should prune some edges (pruned is usize, always >= 0)
1403
1404        // Cleanup
1405        std::fs::remove_file(temp_file).ok();
1406    }
1407
1408    #[test]
1409    fn test_diskann_len_and_is_empty() {
1410        let config = DiskANNConfig {
1411            dimension: 4,
1412            ..Default::default()
1413        };
1414
1415        let mut index = DiskANNIndex::new(config);
1416        let temp_file = "/tmp/test_diskann_len.dat";
1417        index.create(temp_file).unwrap();
1418
1419        assert_eq!(index.len(), 0);
1420        assert!(index.is_empty());
1421
1422        // Insert a vector
1423        let cid: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
1424            .parse()
1425            .unwrap();
1426        let vec = vec![1.0, 0.0, 0.0, 0.0];
1427        index.insert(&cid, &vec).unwrap();
1428
1429        assert_eq!(index.len(), 1);
1430        assert!(!index.is_empty());
1431
1432        // Cleanup
1433        std::fs::remove_file(temp_file).ok();
1434    }
1435}