diskann_rs/
lib.rs

1//! # DiskAnnRS
2//!
3//! A DiskANN-like Rust library implementing approximate nearest neighbor search with
4//! single-file storage support. The library provides both Euclidean distance and
5//! Cosine similarity metrics, using the Vamana graph algorithm for efficient search.
6//!
7//! ## Features
8//!
9//! - Single-file storage format with memory-mapped access
10//! - Support for both Euclidean and Cosine distance metrics
11//! - Vamana graph construction with pruning
12//! - Efficient beam search with medoid entry points
13//! - Minimal memory footprint during search
14//!
15//! ## Example
16//!
17//! ```rust,no_run
18//! use diskann_rs::{DiskANN, DistanceMetric};
19//!
20//! // Build a new index from vectors
21//! let vectors = vec![vec![0.0; 128]; 1000]; // Your vectors
22//! let index = DiskANN::build_index(
23//!     &vectors,
24//!     32,      // maximum degree
25//!     128,     // build-time beam width  
26//!     0.5,     // alpha parameter for pruning
27//!     DistanceMetric::Euclidean,
28//!     "index.db"
29//! ).unwrap();
30//!
31//! // Search the index
32//! let query = vec![0.0; 128];  // your query vector
33//! let neighbors = index.search(&query, 10, 64);  // find top 10 with beam width 64
34//! ```
35
36use bytemuck;
37use memmap2::Mmap;
38use rand::prelude::*;
39use serde::{Deserialize, Serialize};
40use std::cmp::Ordering;
41use std::collections::{BinaryHeap, HashSet};
42use std::{
43    fs::OpenOptions,
44    io::{Seek, SeekFrom, Write},
45};
46use thiserror::Error;
47
48/// Custom error type for DiskAnnRS operations
49#[derive(Debug, Error)]
50pub enum DiskAnnError {
51    /// Represents I/O errors during file operations
52    #[error("I/O error: {0}")]
53    Io(#[from] std::io::Error),
54
55    /// Represents serialization/deserialization errors
56    #[error("Serialization error: {0}")]
57    Bincode(#[from] bincode::Error),
58
59    /// Represents index-specific errors
60    #[error("Index error: {0}")]
61    IndexError(String),
62}
63
64/// Supported distance metrics for vector comparison
65#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
66pub enum DistanceMetric {
67    /// Standard Euclidean distance
68    Euclidean,
69    /// Cosine similarity (converted to distance as 1 - similarity)
70    Cosine,
71}
72
73/// Internal metadata structure stored in the index file
74#[derive(Serialize, Deserialize, Debug)]
75struct Metadata {
76    dim: usize,
77    num_vectors: usize,
78    max_degree: usize,
79    distance_metric: DistanceMetric,
80    medoid_id: u32,
81    vectors_offset: u64,
82    adjacency_offset: u64,
83}
84
85/// Main struct representing a DiskANN index
86pub struct DiskANN {
87    /// Dimensionality of vectors in the index
88    pub dim: usize,
89    /// Number of vectors in the index
90    pub num_vectors: usize,
91    /// Maximum number of edges per node
92    pub max_degree: usize,
93    /// Distance metric used by this index
94    pub distance_metric: DistanceMetric,
95    /// ID of the medoid (used as entry point)
96    medoid_id: u32,
97    vectors_offset: u64,
98    adjacency_offset: u64,
99    mmap: Mmap,
100}
101
102/// Candidate struct for search operations
103#[derive(Clone)]
104struct Candidate {
105    dist: f32,
106    id: u32,
107}
108
109impl PartialEq for Candidate {
110    fn eq(&self, other: &Self) -> bool {
111        self.dist == other.dist && self.id == other.id
112    }
113}
114
115impl Eq for Candidate {}
116
117impl PartialOrd for Candidate {
118    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
119        // Min-heap: smaller distance is "greater" priority
120        other.dist.partial_cmp(&self.dist)
121    }
122}
123
124impl Ord for Candidate {
125    fn cmp(&self, other: &Self) -> Ordering {
126        self.partial_cmp(other).unwrap_or(Ordering::Equal)
127    }
128}
129
130impl DiskANN {
131    /// Builds a new index from provided vectors
132    ///
133    /// # Arguments
134    ///
135    /// * `vectors` - The vectors to index
136    /// * `max_degree` - Maximum number of edges per node (typically 32-64)
137    /// * `build_beam_width` - Beam width during construction (typically 128)
138    /// * `alpha` - Pruning parameter (typically 1.2-2.0)
139    /// * `distance_metric` - Distance metric to use
140    /// * `file_path` - Path where the index file will be created
141    ///
142    /// # Returns
143    ///
144    /// Returns `Result<DiskANN, DiskAnnError>`
145    pub fn build_index(
146        vectors: &[Vec<f32>],
147        max_degree: usize,
148        build_beam_width: usize,
149        alpha: f32,
150        distance_metric: DistanceMetric,
151        file_path: &str,
152    ) -> Result<Self, DiskAnnError> {
153        if vectors.is_empty() {
154            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
155        }
156
157        let num_vectors = vectors.len();
158        let dim = vectors[0].len();
159
160        // Validate all vectors have same dimension
161        for (i, v) in vectors.iter().enumerate() {
162            if v.len() != dim {
163                return Err(DiskAnnError::IndexError(format!(
164                    "Vector {} has dimension {} but expected {}",
165                    i,
166                    v.len(),
167                    dim
168                )));
169            }
170        }
171
172        println!(
173            "Building index for {} vectors of dimension {} with max_degree={}",
174            num_vectors, dim, max_degree
175        );
176
177        let mut file = OpenOptions::new()
178            .create(true)
179            .write(true)
180            .read(true)
181            .truncate(true)
182            .open(file_path)?;
183
184        // Reserve space for metadata (we'll write it at the end)
185        let vectors_offset = 1024 * 1024; // 1MB for metadata
186        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
187
188        // Write vectors to file
189        file.seek(SeekFrom::Start(vectors_offset))?;
190        for vector in vectors {
191            let bytes = bytemuck::cast_slice(vector);
192            file.write_all(bytes)?;
193        }
194
195        // Calculate medoid (centroid closest to mean of all vectors)
196        let medoid_id = calculate_medoid(vectors, distance_metric);
197        println!("Calculated medoid: {}", medoid_id);
198
199        // Build Vamana graph
200        let adjacency_offset = vectors_offset + total_vector_bytes;
201        let graph = build_vamana_graph(
202            vectors,
203            max_degree,
204            build_beam_width,
205            alpha,
206            distance_metric,
207            medoid_id as u32,
208        );
209
210        // Write adjacency lists
211        file.seek(SeekFrom::Start(adjacency_offset))?;
212        for neighbors in &graph {
213            // Pad with zeros if needed
214            let mut padded = neighbors.clone();
215            padded.resize(max_degree, 0);
216            let bytes = bytemuck::cast_slice(&padded);
217            file.write_all(bytes)?;
218        }
219
220        // Write metadata
221        let metadata = Metadata {
222            dim,
223            num_vectors,
224            max_degree,
225            distance_metric,
226            medoid_id: medoid_id as u32,
227            vectors_offset,
228            adjacency_offset,
229        };
230
231        let md_bytes = bincode::serialize(&metadata)?;
232        file.seek(SeekFrom::Start(0))?;
233        let md_len = md_bytes.len() as u64;
234        file.write_all(&md_len.to_le_bytes())?;
235        file.write_all(&md_bytes)?;
236        file.sync_all()?;
237
238        // Memory map the file
239        let mmap = unsafe { memmap2::Mmap::map(&file)? };
240
241        Ok(Self {
242            dim,
243            num_vectors,
244            max_degree,
245            distance_metric,
246            medoid_id: metadata.medoid_id,
247            vectors_offset,
248            adjacency_offset,
249            mmap,
250        })
251    }
252
253    /// Opens an existing index file
254    ///
255    /// # Arguments
256    ///
257    /// * `path` - Path to the index file
258    ///
259    /// # Returns
260    ///
261    /// Returns `Result<DiskANN, DiskAnnError>`
262    pub fn open_index(path: &str) -> Result<Self, DiskAnnError> {
263        let file = OpenOptions::new().read(true).write(false).open(path)?;
264        
265        // Read metadata length
266        let mut buf8 = [0u8; 8];
267        use std::os::unix::fs::FileExt;
268        file.read_exact_at(&mut buf8, 0)?;
269        let md_len = u64::from_le_bytes(buf8);
270        
271        // Read metadata
272        let mut md_bytes = vec![0u8; md_len as usize];
273        file.read_exact_at(&mut md_bytes, 8)?;
274        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
275
276        let mmap = unsafe { memmap2::Mmap::map(&file)? };
277
278        Ok(Self {
279            dim: metadata.dim,
280            num_vectors: metadata.num_vectors,
281            max_degree: metadata.max_degree,
282            distance_metric: metadata.distance_metric,
283            medoid_id: metadata.medoid_id,
284            vectors_offset: metadata.vectors_offset,
285            adjacency_offset: metadata.adjacency_offset,
286            mmap,
287        })
288    }
289
290    /// Searches the index for nearest neighbors using beam search
291    ///
292    /// # Arguments
293    ///
294    /// * `query` - Query vector
295    /// * `k` - Number of nearest neighbors to return
296    /// * `beam_width` - Beam width for the search (typically 32-128)
297    ///
298    /// # Returns
299    ///
300    /// Returns a vector of node IDs representing the nearest neighbors
301    pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
302        if query.len() != self.dim {
303            panic!(
304                "Query dimension {} does not match index dimension {}",
305                query.len(),
306                self.dim
307            );
308        }
309
310        // Initialize with medoid as entry point
311        let mut visited = HashSet::new();
312        let mut candidates = BinaryHeap::new();
313        let mut w = BinaryHeap::new(); // Working set
314
315        // Start from medoid
316        let start_dist = self.distance_to(query, self.medoid_id as usize);
317        candidates.push(Candidate {
318            dist: start_dist,
319            id: self.medoid_id,
320        });
321        w.push(Candidate {
322            dist: start_dist,
323            id: self.medoid_id,
324        });
325        visited.insert(self.medoid_id);
326
327        // Beam search
328        let mut best_dist = start_dist;
329        let mut iterations_without_improvement = 0;
330        const MAX_ITERATIONS_WITHOUT_IMPROVEMENT: usize = 5;
331
332        while let Some(current) = candidates.pop() {
333            // Early termination if no improvement
334            if current.dist > best_dist {
335                iterations_without_improvement += 1;
336                if iterations_without_improvement > MAX_ITERATIONS_WITHOUT_IMPROVEMENT {
337                    break;
338                }
339            } else {
340                best_dist = current.dist;
341                iterations_without_improvement = 0;
342            }
343
344            // Get neighbors of current node
345            let neighbors = self.get_neighbors(current.id);
346
347            for &neighbor_id in neighbors {
348                if neighbor_id == 0 || visited.contains(&neighbor_id) {
349                    continue;
350                }
351
352                visited.insert(neighbor_id);
353                let dist = self.distance_to(query, neighbor_id as usize);
354
355                // Update working set
356                w.push(Candidate {
357                    dist,
358                    id: neighbor_id,
359                });
360
361                // Prune working set to beam width
362                if w.len() > beam_width {
363                    // Keep only top beam_width candidates
364                    let mut temp = Vec::new();
365                    for _ in 0..beam_width {
366                        if let Some(c) = w.pop() {
367                            temp.push(c);
368                        }
369                    }
370                    w.clear();
371                    for c in temp {
372                        w.push(c);
373                    }
374                }
375
376                // Add to candidates if promising
377                if w.len() < beam_width || dist < w.peek().unwrap().dist {
378                    candidates.push(Candidate {
379                        dist,
380                        id: neighbor_id,
381                    });
382                }
383            }
384        }
385
386        // Extract top-k results
387        let mut results: Vec<_> = w.into_vec();
388        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
389        results.truncate(k);
390        results.into_iter().map(|c| c.id).collect()
391    }
392
393    /// Gets the neighbors of a node from the graph
394    fn get_neighbors(&self, node_id: u32) -> &[u32] {
395        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
396        let start = offset as usize;
397        let end = start + (self.max_degree * 4);
398        let bytes = &self.mmap[start..end];
399        bytemuck::cast_slice(bytes)
400    }
401
402    /// Computes distance between query and a vector in the index
403    fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
404        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
405        let start = offset as usize;
406        let end = start + (self.dim * 4);
407        let bytes = &self.mmap[start..end];
408        let vector: &[f32] = bytemuck::cast_slice(bytes);
409
410        match self.distance_metric {
411            DistanceMetric::Euclidean => euclidean_distance(query, vector),
412            DistanceMetric::Cosine => 1.0 - cosine_similarity(query, vector),
413        }
414    }
415
416    /// Gets a vector from the index (for testing)
417    pub fn get_vector(&self, idx: usize) -> Vec<f32> {
418        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
419        let start = offset as usize;
420        let end = start + (self.dim * 4);
421        let bytes = &self.mmap[start..end];
422        let vector: &[f32] = bytemuck::cast_slice(bytes);
423        vector.to_vec()
424    }
425}
426
427/// Calculates the medoid (vector closest to the centroid)
428fn calculate_medoid(vectors: &[Vec<f32>], distance_metric: DistanceMetric) -> usize {
429    let dim = vectors[0].len();
430    let mut centroid = vec![0.0; dim];
431
432    // Calculate centroid
433    for vector in vectors {
434        for (i, &val) in vector.iter().enumerate() {
435            centroid[i] += val;
436        }
437    }
438    for val in &mut centroid {
439        *val /= vectors.len() as f32;
440    }
441
442    // Find vector closest to centroid
443    let mut best_idx = 0;
444    let mut best_dist = f32::MAX;
445
446    for (idx, vector) in vectors.iter().enumerate() {
447        let dist = match distance_metric {
448            DistanceMetric::Euclidean => euclidean_distance(&centroid, vector),
449            DistanceMetric::Cosine => 1.0 - cosine_similarity(&centroid, vector),
450        };
451        if dist < best_dist {
452            best_dist = dist;
453            best_idx = idx;
454        }
455    }
456
457    best_idx
458}
459
460/// Builds the Vamana graph using greedy search and pruning
461fn build_vamana_graph(
462    vectors: &[Vec<f32>],
463    max_degree: usize,
464    beam_width: usize,
465    alpha: f32,
466    distance_metric: DistanceMetric,
467    medoid_id: u32,
468) -> Vec<Vec<u32>> {
469    let num_vectors = vectors.len();
470    let mut graph = vec![Vec::new(); num_vectors];
471
472    // Initialize with random graph
473    let mut rng = thread_rng();
474    for i in 0..num_vectors {
475        let mut neighbors = HashSet::new();
476        while neighbors.len() < max_degree.min(num_vectors - 1) {
477            let neighbor = rng.gen_range(0..num_vectors);
478            if neighbor != i {
479                neighbors.insert(neighbor as u32);
480            }
481        }
482        graph[i] = neighbors.into_iter().collect();
483    }
484
485    println!("Building Vamana graph with beam_width={}, alpha={}", beam_width, alpha);
486
487    // Iterative improvement
488    for iteration in 0..2 {
489        println!("Graph building iteration {}", iteration + 1);
490        
491        // Process nodes in random order
492        let mut node_order: Vec<usize> = (0..num_vectors).collect();
493        node_order.shuffle(&mut rng);
494
495        for &node_id in &node_order {
496            // Search for nearest neighbors using current graph
497            let neighbors = greedy_search(
498                &vectors[node_id],
499                vectors,
500                &graph,
501                medoid_id as usize,
502                beam_width,
503                distance_metric,
504            );
505
506            // Prune neighbors using α-pruning
507            let pruned = prune_neighbors(
508                node_id,
509                &neighbors,
510                vectors,
511                max_degree,
512                alpha,
513                distance_metric,
514            );
515
516            graph[node_id] = pruned;
517
518            // Make graph undirected by adding reverse edges
519            let current_neighbors = graph[node_id].clone();
520            for neighbor in current_neighbors {
521                if !graph[neighbor as usize].contains(&(node_id as u32)) {
522                    graph[neighbor as usize].push(node_id as u32);
523                    
524                    // Prune if degree exceeds max
525                    if graph[neighbor as usize].len() > max_degree {
526                        let neighbors_of_neighbor: Vec<_> = graph[neighbor as usize]
527                            .iter()
528                            .map(|&id| (id, {
529                                let dist = match distance_metric {
530                                    DistanceMetric::Euclidean => {
531                                        euclidean_distance(&vectors[neighbor as usize], &vectors[id as usize])
532                                    }
533                                    DistanceMetric::Cosine => {
534                                        1.0 - cosine_similarity(&vectors[neighbor as usize], &vectors[id as usize])
535                                    }
536                                };
537                                dist
538                            }))
539                            .collect();
540                        
541                        let pruned = prune_neighbors(
542                            neighbor as usize,
543                            &neighbors_of_neighbor,
544                            vectors,
545                            max_degree,
546                            alpha,
547                            distance_metric,
548                        );
549                        graph[neighbor as usize] = pruned;
550                    }
551                }
552            }
553        }
554    }
555
556    graph
557}
558
559/// Performs greedy search on the graph during construction
560fn greedy_search(
561    query: &[f32],
562    vectors: &[Vec<f32>],
563    graph: &[Vec<u32>],
564    start_id: usize,
565    beam_width: usize,
566    distance_metric: DistanceMetric,
567) -> Vec<(u32, f32)> {
568    let mut visited = HashSet::new();
569    let mut candidates = BinaryHeap::new();
570    let mut w = BinaryHeap::new();
571
572    // Start from medoid
573    let start_dist = match distance_metric {
574        DistanceMetric::Euclidean => euclidean_distance(query, &vectors[start_id]),
575        DistanceMetric::Cosine => 1.0 - cosine_similarity(query, &vectors[start_id]),
576    };
577
578    candidates.push(Candidate {
579        dist: start_dist,
580        id: start_id as u32,
581    });
582    w.push(Candidate {
583        dist: start_dist,
584        id: start_id as u32,
585    });
586    visited.insert(start_id as u32);
587
588    while let Some(current) = candidates.pop() {
589        for &neighbor_id in &graph[current.id as usize] {
590            if visited.contains(&neighbor_id) {
591                continue;
592            }
593
594            visited.insert(neighbor_id);
595            let dist = match distance_metric {
596                DistanceMetric::Euclidean => euclidean_distance(query, &vectors[neighbor_id as usize]),
597                DistanceMetric::Cosine => 1.0 - cosine_similarity(query, &vectors[neighbor_id as usize]),
598            };
599
600            w.push(Candidate { dist, id: neighbor_id });
601
602            if w.len() > beam_width {
603                let mut temp = Vec::new();
604                for _ in 0..beam_width {
605                    if let Some(c) = w.pop() {
606                        temp.push(c);
607                    }
608                }
609                w.clear();
610                for c in temp {
611                    w.push(c);
612                }
613            }
614
615            if w.len() < beam_width || dist < w.peek().unwrap().dist {
616                candidates.push(Candidate { dist, id: neighbor_id });
617            }
618        }
619    }
620
621    w.into_vec()
622        .into_iter()
623        .map(|c| (c.id, c.dist))
624        .collect()
625}
626
627/// Prunes neighbors using the α-pruning strategy
628fn prune_neighbors(
629    node_id: usize,
630    candidates: &[(u32, f32)],
631    vectors: &[Vec<f32>],
632    max_degree: usize,
633    alpha: f32,
634    distance_metric: DistanceMetric,
635) -> Vec<u32> {
636    if candidates.is_empty() {
637        return Vec::new();
638    }
639
640    let mut sorted_candidates = candidates.to_vec();
641    sorted_candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
642
643    let mut pruned = Vec::new();
644    
645    for &(candidate_id, candidate_dist) in &sorted_candidates {
646        if candidate_id as usize == node_id {
647            continue;
648        }
649
650        // Check if this candidate is diverse enough from already selected neighbors
651        let mut should_add = true;
652        for &selected_id in &pruned {
653            let dist_to_selected = match distance_metric {
654                DistanceMetric::Euclidean => {
655                    euclidean_distance(&vectors[candidate_id as usize], &vectors[selected_id as usize])
656                }
657                DistanceMetric::Cosine => {
658                    1.0 - cosine_similarity(&vectors[candidate_id as usize], &vectors[selected_id as usize])
659                }
660            };
661
662            if dist_to_selected < alpha * candidate_dist {
663                should_add = false;
664                break;
665            }
666        }
667
668        if should_add {
669            pruned.push(candidate_id);
670            if pruned.len() >= max_degree {
671                break;
672            }
673        }
674    }
675
676    // Fill remaining slots with closest candidates if needed
677    for &(candidate_id, _) in &sorted_candidates {
678        if !pruned.contains(&candidate_id) && candidate_id as usize != node_id {
679            pruned.push(candidate_id);
680            if pruned.len() >= max_degree {
681                break;
682            }
683        }
684    }
685
686    pruned
687}
688
689/// Computes Euclidean distance between two vectors
690fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
691    a.iter()
692        .zip(b.iter())
693        .map(|(x, y)| (x - y) * (x - y))
694        .sum::<f32>()
695        .sqrt()
696}
697
698/// Computes cosine similarity between two vectors
699fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
700    let mut dot = 0.0;
701    let mut norm_a = 0.0;
702    let mut norm_b = 0.0;
703    for (x, y) in a.iter().zip(b.iter()) {
704        dot += x * y;
705        norm_a += x * x;
706        norm_b += y * y;
707    }
708    if norm_a == 0.0 || norm_b == 0.0 {
709        return 0.0;
710    }
711    dot / (norm_a.sqrt() * norm_b.sqrt())
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717    use std::fs;
718
719    #[test]
720    fn test_small_index() {
721        let test_file = "test_small.db";
722        
723        // Clean up any existing test file
724        let _ = fs::remove_file(test_file);
725
726        // Create small test vectors
727        let vectors = vec![
728            vec![0.0, 0.0],
729            vec![1.0, 0.0],
730            vec![0.0, 1.0],
731            vec![1.0, 1.0],
732            vec![0.5, 0.5],
733        ];
734
735        // Build index
736        let index = DiskANN::build_index(
737            &vectors,
738            3,      // max_degree
739            4,      // beam_width
740            1.2,    // alpha
741            DistanceMetric::Euclidean,
742            test_file,
743        )
744        .unwrap();
745
746        // Test search
747        let query = vec![0.1, 0.1];
748        let neighbors = index.search(&query, 3, 4);
749        
750        // Should find [0, 0] among the top results (it's the closest)
751        // But since this is approximate search, just verify we get valid results
752        assert_eq!(neighbors.len(), 3);
753        
754        // Verify the first result is reasonably close
755        let first_vector = index.get_vector(neighbors[0] as usize);
756        let dist = euclidean_distance(&query, &first_vector);
757        assert!(dist < 1.0, "First neighbor should be close to query");
758        
759        // Clean up
760        let _ = fs::remove_file(test_file);
761    }
762
763    #[test]
764    fn test_memory_efficiency() {
765        let test_file = "test_memory.db";
766        let _ = fs::remove_file(test_file);
767
768        // Create larger test set
769        let num_vectors = 1000;
770        let dim = 128;
771        let mut vectors = Vec::new();
772        let mut rng = thread_rng();
773        
774        for _ in 0..num_vectors {
775            let v: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
776            vectors.push(v);
777        }
778
779        // Build index
780        let index = DiskANN::build_index(
781            &vectors,
782            32,     // max_degree
783            64,     // beam_width
784            1.2,    // alpha
785            DistanceMetric::Euclidean,
786            test_file,
787        )
788        .unwrap();
789
790        // Memory usage test: search should visit only a small fraction of nodes
791        let query: Vec<f32> = (0..dim).map(|_| rng.gen()).collect();
792        let k = 10;
793        let beam_width = 32;
794        
795        // This should complete quickly without loading all vectors
796        let start = std::time::Instant::now();
797        let neighbors = index.search(&query, k, beam_width);
798        let elapsed = start.elapsed();
799        
800        assert_eq!(neighbors.len(), k);
801        assert!(elapsed.as_millis() < 100, "Search took too long: {:?}", elapsed);
802        
803        // Verify results are reasonable
804        let distances: Vec<f32> = neighbors
805            .iter()
806            .map(|&id| index.distance_to(&query, id as usize))
807            .collect();
808        
809        // Distances should be sorted
810        let mut sorted = distances.clone();
811        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
812        assert_eq!(distances, sorted);
813
814        let _ = fs::remove_file(test_file);
815    }
816
817    #[test]
818    fn test_cosine_similarity() {
819        let test_file = "test_cosine.db";
820        let _ = fs::remove_file(test_file);
821
822        let vectors = vec![
823            vec![1.0, 0.0, 0.0],
824            vec![0.0, 1.0, 0.0],
825            vec![0.0, 0.0, 1.0],
826            vec![1.0, 1.0, 0.0],
827            vec![1.0, 0.0, 1.0],
828        ];
829
830        let index = DiskANN::build_index(
831            &vectors,
832            3,
833            4,
834            1.2,
835            DistanceMetric::Cosine,
836            test_file,
837        )
838        .unwrap();
839
840        // Query similar to first vector
841        let query = vec![2.0, 0.0, 0.0]; // Parallel to [1,0,0]
842        let neighbors = index.search(&query, 2, 4);
843        
844        // Should find vector 0 among results (parallel vectors have cosine similarity 1)
845        assert_eq!(neighbors.len(), 2);
846        
847        // The top result should have very high cosine similarity to query
848        let first_vector = index.get_vector(neighbors[0] as usize);
849        let similarity = cosine_similarity(&query, &first_vector);
850        assert!(similarity > 0.7, "First neighbor should have high cosine similarity");
851
852        let _ = fs::remove_file(test_file);
853    }
854
855    #[test]
856    fn test_persistence() {
857        let test_file = "test_persist.db";
858        let _ = fs::remove_file(test_file);
859
860        let vectors = vec![
861            vec![0.0, 0.0],
862            vec![1.0, 0.0],
863            vec![0.0, 1.0],
864            vec![1.0, 1.0],
865        ];
866
867        // Build and close index
868        {
869            let _index = DiskANN::build_index(
870                &vectors,
871                2,
872                4,
873                1.2,
874                DistanceMetric::Euclidean,
875                test_file,
876            )
877            .unwrap();
878        }
879
880        // Open existing index
881        let index = DiskANN::open_index(test_file).unwrap();
882        assert_eq!(index.num_vectors, 4);
883        assert_eq!(index.dim, 2);
884
885        // Search should work
886        let query = vec![0.9, 0.9];
887        let neighbors = index.search(&query, 2, 4);
888        assert_eq!(neighbors[0], 3); // [1,1] is closest to [0.9,0.9]
889
890        let _ = fs::remove_file(test_file);
891    }
892
893    #[test]
894    fn test_graph_connectivity() {
895        let test_file = "test_graph.db";
896        let _ = fs::remove_file(test_file);
897
898        // Create a grid of vectors
899        let mut vectors = Vec::new();
900        for i in 0..5 {
901            for j in 0..5 {
902                vectors.push(vec![i as f32, j as f32]);
903            }
904        }
905
906        let index = DiskANN::build_index(
907            &vectors,
908            4,      // max_degree
909            8,      // beam_width
910            1.5,    // alpha
911            DistanceMetric::Euclidean,
912            test_file,
913        )
914        .unwrap();
915
916        // Test that we can find reasonable neighbors for each vector
917        for target_idx in 0..vectors.len() {
918            let query = &vectors[target_idx];
919            // Use higher beam width for better recall
920            let neighbors = index.search(query, 10, 32);
921            
922            // The exact vector should be found with high beam width
923            // If not found, at least verify we get close neighbors
924            if !neighbors.contains(&(target_idx as u32)) {
925                // Check that we at least found very close neighbors
926                let first_vec = index.get_vector(neighbors[0] as usize);
927                let dist = euclidean_distance(query, &first_vec);
928                assert!(
929                    dist < 2.0,
930                    "Vector {} not found but nearest neighbor at distance {} is too far",
931                    target_idx, dist
932                );
933            }
934            
935            // Verify all results are reasonable (all should be close)
936            for &neighbor_id in neighbors.iter().take(5) {
937                let neighbor_vec = index.get_vector(neighbor_id as usize);
938                let dist = euclidean_distance(query, &neighbor_vec);
939                assert!(dist < 5.0, "Neighbor should be reasonably close");
940            }
941        }
942
943        let _ = fs::remove_file(test_file);
944    }
945}