diskann_rs/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<f32>`)
2//!
3//! A minimal, on-disk, DiskANN-like library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any `Distance<f32>`** from `anndists` (L2, Cosine, Hamming, Dot, …)
8//! - Supports **incremental updates** (add/delete vectors without full rebuild)
9//!
10//! ## Example
11//! ```no_run
12//! use anndists::dist::{DistL2, DistCosine};
13//! use diskann_rs::{DiskANN, DiskAnnParams};
14//!
15//! // Build a new index from vectors, using L2 and default params
16//! let vectors = vec![vec![0.0; 128]; 1000];
17//! let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2{}, "index.db").unwrap();
18//!
19//! // Or with custom params
20//! let index2 = DiskANN::<DistCosine>::build_index_with_params(
21//!     &vectors,
22//!     DistCosine{},
23//!     "index_cos.db",
24//!     DiskAnnParams { max_degree: 48, ..Default::default() },
25//! ).unwrap();
26//!
27//! // Search the index
28//! let query = vec![0.0; 128];
29//! let neighbors = index.search(&query, 10, 64);
30//!
31//! // Open later (provide the same distance type)
32//! let _reopened = DiskANN::<DistL2>::open_index_default_metric("index.db").unwrap();
33//! ```
34//!
35//! ## Incremental Updates
36//! ```no_run
37//! use anndists::dist::DistL2;
38//! use diskann_rs::IncrementalDiskANN;
39//!
40//! // Build initial index
41//! let vectors = vec![vec![0.0; 128]; 1000];
42//! let mut index = IncrementalDiskANN::<DistL2>::build_default(&vectors, "index.db").unwrap();
43//!
44//! // Add vectors without rebuilding
45//! let new_ids = index.add_vectors(&[vec![1.0; 128]]).unwrap();
46//!
47//! // Delete vectors (lazy tombstoning)
48//! index.delete_vectors(&[0, 1, 2]).unwrap();
49//!
50//! // Compact when needed
51//! if index.should_compact() {
52//!     index.compact("index_v2.db").unwrap();
53//! }
54//! ```
55//!
56//! ## File Layout
57//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
58//! [ vectors (num * dim * f32) ][ adjacency (num * max_degree * u32) ]
59//!
60//! `vectors_offset` is a fixed 1 MiB gap by default.
61
62mod incremental;
63mod filtered;
64pub mod simd;
65pub mod pq;
66
67pub use incremental::{
68    IncrementalDiskANN, IncrementalConfig, IncrementalStats,
69    is_delta_id, delta_local_idx,
70};
71
72pub use filtered::{FilteredDiskANN, Filter};
73
74pub use simd::{SimdL2, SimdDot, SimdCosine, simd_info};
75
76pub use pq::{ProductQuantizer, PQConfig, PQStats};
77
78use anndists::prelude::Distance;
79use bytemuck;
80use memmap2::Mmap;
81use rand::prelude::*;
82use rayon::prelude::*;
83use serde::{Deserialize, Serialize};
84use std::cmp::{Ordering, Reverse};
85use std::collections::{BinaryHeap, HashSet};
86use std::fs::OpenOptions;
87use std::io::{Read, Seek, SeekFrom, Write};
88use thiserror::Error;
89
90/// Padding sentinel for adjacency slots (avoid colliding with node 0).
91const PAD_U32: u32 = u32::MAX;
92
93/// Defaults for in-memory DiskANN builds
94pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
95pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
96pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
97
98/// Optional bag of knobs if you want to override just a few.
99#[derive(Clone, Copy, Debug)]
100pub struct DiskAnnParams {
101    pub max_degree: usize,
102    pub build_beam_width: usize,
103    pub alpha: f32,
104}
105impl Default for DiskAnnParams {
106    fn default() -> Self {
107        Self {
108            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
109            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
110            alpha: DISKANN_DEFAULT_ALPHA,
111        }
112    }
113}
114
115/// Custom error type for DiskAnnRS operations
116#[derive(Debug, Error)]
117pub enum DiskAnnError {
118    /// Represents I/O errors during file operations
119    #[error("I/O error: {0}")]
120    Io(#[from] std::io::Error),
121
122    /// Represents serialization/deserialization errors
123    #[error("Serialization error: {0}")]
124    Bincode(#[from] bincode::Error),
125
126    /// Represents index-specific errors
127    #[error("Index error: {0}")]
128    IndexError(String),
129}
130
131/// Internal metadata structure stored in the index file
132#[derive(Serialize, Deserialize, Debug)]
133struct Metadata {
134    dim: usize,
135    num_vectors: usize,
136    max_degree: usize,
137    medoid_id: u32,
138    vectors_offset: u64,
139    adjacency_offset: u64,
140    distance_name: String,
141}
142
143/// Candidate for search/frontier queues
144#[derive(Clone, Copy)]
145struct Candidate {
146    dist: f32,
147    id: u32,
148}
149impl PartialEq for Candidate {
150    fn eq(&self, other: &Self) -> bool {
151        self.dist == other.dist && self.id == other.id
152    }
153}
154impl Eq for Candidate {}
155impl PartialOrd for Candidate {
156    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
157        // Natural order by distance: smaller is "less".
158        self.dist.partial_cmp(&other.dist)
159    }
160}
161impl Ord for Candidate {
162    fn cmp(&self, other: &Self) -> Ordering {
163        self.partial_cmp(other).unwrap_or(Ordering::Equal)
164    }
165}
166
167/// Main struct representing a DiskANN index (generic over distance)
168pub struct DiskANN<D>
169where
170    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
171{
172    /// Dimensionality of vectors in the index
173    pub dim: usize,
174    /// Number of vectors in the index
175    pub num_vectors: usize,
176    /// Maximum number of edges per node
177    pub max_degree: usize,
178    /// Informational: type name of the distance (from metadata)
179    pub distance_name: String,
180
181    /// ID of the medoid (used as entry point)
182    pub(crate) medoid_id: u32,
183    // Offsets
184    pub(crate) vectors_offset: u64,
185    pub(crate) adjacency_offset: u64,
186
187    /// Memory-mapped file
188    pub(crate) mmap: Mmap,
189
190    /// The distance strategy
191    pub(crate) dist: D,
192}
193
194// constructors
195
196impl<D> DiskANN<D>
197where
198    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
199{
200    /// Build with default parameters: (M=32, L=256, alpha=1.2).
201    pub fn build_index_default(
202        vectors: &[Vec<f32>],
203        dist: D,
204        file_path: &str,
205    ) -> Result<Self, DiskAnnError> {
206        Self::build_index(
207            vectors,
208            DISKANN_DEFAULT_MAX_DEGREE,
209            DISKANN_DEFAULT_BUILD_BEAM,
210            DISKANN_DEFAULT_ALPHA,
211            dist,
212            file_path,
213        )
214    }
215
216    /// Build with a `DiskAnnParams` bundle.
217    pub fn build_index_with_params(
218        vectors: &[Vec<f32>],
219        dist: D,
220        file_path: &str,
221        p: DiskAnnParams,
222    ) -> Result<Self, DiskAnnError> {
223        Self::build_index(
224            vectors,
225            p.max_degree,
226            p.build_beam_width,
227            p.alpha,
228            dist,
229            file_path,
230        )
231    }
232}
233
234/// Extra sugar when your distance type implements `Default` (most unit-struct metrics do).
235impl<D> DiskANN<D>
236where
237    D: Distance<f32> + Default + Send + Sync + Copy + Clone + 'static,
238{
239    /// Build with default params **and** `D::default()` metric.
240    pub fn build_index_default_metric(
241        vectors: &[Vec<f32>],
242        file_path: &str,
243    ) -> Result<Self, DiskAnnError> {
244        Self::build_index_default(vectors, D::default(), file_path)
245    }
246
247    /// Open an index using `D::default()` as the distance (matches what you built with).
248    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
249        Self::open_index_with(path, D::default())
250    }
251}
252
253impl<D> DiskANN<D>
254where
255    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
256{
257    /// Builds a new index from provided vectors
258    ///
259    /// # Arguments
260    /// * `vectors` - The vectors to index (slice of `Vec<f32>`)
261    /// * `max_degree` - Maximum edges per node (M ~ 24-64)
262    /// * `build_beam_width` - Construction L (e.g., 128-400)
263    /// * `alpha` - Pruning parameter (1.2–2.0)
264    /// * `dist` - Any `anndists::Distance<f32>` (e.g., `DistL2`)
265    /// * `file_path` - Path of index file
266    pub fn build_index(
267        vectors: &[Vec<f32>],
268        max_degree: usize,
269        build_beam_width: usize,
270        alpha: f32,
271        dist: D,
272        file_path: &str,
273    ) -> Result<Self, DiskAnnError> {
274        if vectors.is_empty() {
275            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
276        }
277
278        let num_vectors = vectors.len();
279        let dim = vectors[0].len();
280        for (i, v) in vectors.iter().enumerate() {
281            if v.len() != dim {
282                return Err(DiskAnnError::IndexError(format!(
283                    "Vector {} has dimension {} but expected {}",
284                    i,
285                    v.len(),
286                    dim
287                )));
288            }
289        }
290
291        let mut file = OpenOptions::new()
292            .create(true)
293            .write(true)
294            .read(true)
295            .truncate(true)
296            .open(file_path)?;
297
298        // Reserve space for metadata (we'll write it after data)
299        let vectors_offset = 1024 * 1024;
300        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
301
302        // Write vectors contiguous (sequential I/O is fastest)
303        file.seek(SeekFrom::Start(vectors_offset))?;
304        for vector in vectors {
305            let bytes = bytemuck::cast_slice(vector);
306            file.write_all(bytes)?;
307        }
308
309        // Compute medoid using provided distance (parallelized distance eval)
310        let medoid_id = calculate_medoid(vectors, dist);
311
312        // Build Vamana-like graph (stronger refinement, parallel inner loops)
313        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
314        let graph = build_vamana_graph(
315            vectors,
316            max_degree,
317            build_beam_width,
318            alpha,
319            dist,
320            medoid_id as u32,
321        );
322
323        // Write adjacency lists (fixed max_degree, pad with PAD_U32)
324        file.seek(SeekFrom::Start(adjacency_offset))?;
325        for neighbors in &graph {
326            let mut padded = neighbors.clone();
327            padded.resize(max_degree, PAD_U32);
328            let bytes = bytemuck::cast_slice(&padded);
329            file.write_all(bytes)?;
330        }
331
332        // Write metadata
333        let metadata = Metadata {
334            dim,
335            num_vectors,
336            max_degree,
337            medoid_id: medoid_id as u32,
338            vectors_offset: vectors_offset as u64,
339            adjacency_offset,
340            distance_name: std::any::type_name::<D>().to_string(),
341        };
342
343        let md_bytes = bincode::serialize(&metadata)?;
344        file.seek(SeekFrom::Start(0))?;
345        let md_len = md_bytes.len() as u64;
346        file.write_all(&md_len.to_le_bytes())?;
347        file.write_all(&md_bytes)?;
348        file.sync_all()?;
349
350        // Memory map the file
351        let mmap = unsafe { memmap2::Mmap::map(&file)? };
352
353        Ok(Self {
354            dim,
355            num_vectors,
356            max_degree,
357            distance_name: metadata.distance_name,
358            medoid_id: metadata.medoid_id,
359            vectors_offset: metadata.vectors_offset,
360            adjacency_offset: metadata.adjacency_offset,
361            mmap,
362            dist,
363        })
364    }
365
366    /// Opens an existing index file, supplying the distance strategy explicitly.
367    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
368        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
369
370        // Read metadata length
371        let mut buf8 = [0u8; 8];
372        file.seek(SeekFrom::Start(0))?;
373        file.read_exact(&mut buf8)?;
374        let md_len = u64::from_le_bytes(buf8);
375
376        // Read metadata
377        let mut md_bytes = vec![0u8; md_len as usize];
378        file.read_exact(&mut md_bytes)?;
379        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
380
381        let mmap = unsafe { memmap2::Mmap::map(&file)? };
382
383        // Optional sanity/logging: warn if type differs from recorded name
384        let expected = std::any::type_name::<D>();
385        if metadata.distance_name != expected {
386            eprintln!(
387                "Warning: index recorded distance `{}` but you opened with `{}`",
388                metadata.distance_name, expected
389            );
390        }
391
392        Ok(Self {
393            dim: metadata.dim,
394            num_vectors: metadata.num_vectors,
395            max_degree: metadata.max_degree,
396            distance_name: metadata.distance_name,
397            medoid_id: metadata.medoid_id,
398            vectors_offset: metadata.vectors_offset,
399            adjacency_offset: metadata.adjacency_offset,
400            mmap,
401            dist,
402        })
403    }
404
405    /// Searches the index for nearest neighbors using a best-first beam search.
406    /// Termination rule: continue while the best frontier can still improve the worst in working set.
407    /// Like `search` but also returns the distance for each neighbor.
408    /// Distances are exactly the ones computed during the beam search.
409    pub fn search_with_dists(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
410        assert_eq!(
411            query.len(),
412            self.dim,
413            "Query dim {} != index dim {}",
414            query.len(),
415            self.dim
416        );
417
418        #[derive(Clone, Copy)]
419        struct Candidate {
420            dist: f32,
421            id: u32,
422        }
423        impl PartialEq for Candidate {
424            fn eq(&self, o: &Self) -> bool {
425                self.dist == o.dist && self.id == o.id
426            }
427        }
428        impl Eq for Candidate {}
429        impl PartialOrd for Candidate {
430            fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
431                self.dist.partial_cmp(&o.dist)
432            }
433        }
434        impl Ord for Candidate {
435            fn cmp(&self, o: &Self) -> Ordering {
436                self.partial_cmp(o).unwrap_or(Ordering::Equal)
437            }
438        }
439
440        let mut visited = HashSet::new();
441        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // best-first by dist
442        let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // working set, max-heap by dist
443
444        // seed from medoid
445        let start_dist = self.distance_to(query, self.medoid_id as usize);
446        let start = Candidate {
447            dist: start_dist,
448            id: self.medoid_id,
449        };
450        frontier.push(Reverse(start));
451        w.push(start);
452        visited.insert(self.medoid_id);
453
454        // expand while best frontier can still improve worst in working set
455        while let Some(Reverse(best)) = frontier.peek().copied() {
456            if w.len() >= beam_width {
457                if let Some(worst) = w.peek() {
458                    if best.dist >= worst.dist {
459                        break;
460                    }
461                }
462            }
463            let Reverse(current) = frontier.pop().unwrap();
464
465            for &nb in self.get_neighbors(current.id) {
466                if nb == PAD_U32 {
467                    continue;
468                }
469                if !visited.insert(nb) {
470                    continue;
471                }
472
473                let d = self.distance_to(query, nb as usize);
474                let cand = Candidate { dist: d, id: nb };
475
476                if w.len() < beam_width {
477                    w.push(cand);
478                    frontier.push(Reverse(cand));
479                } else if d < w.peek().unwrap().dist {
480                    w.pop();
481                    w.push(cand);
482                    frontier.push(Reverse(cand));
483                }
484            }
485        }
486
487        // top-k by distance, keep distances
488        let mut results: Vec<_> = w.into_vec();
489        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
490        results.truncate(k);
491        results.into_iter().map(|c| (c.id, c.dist)).collect()
492    }
493    /// search but only return neighbor ids
494    pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
495        self.search_with_dists(query, k, beam_width)
496            .into_iter()
497            .map(|(id, _dist)| id)
498            .collect()
499    }
500
501    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
502    fn get_neighbors(&self, node_id: u32) -> &[u32] {
503        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
504        let start = offset as usize;
505        let end = start + (self.max_degree * 4);
506        let bytes = &self.mmap[start..end];
507        bytemuck::cast_slice(bytes)
508    }
509
510    /// Computes distance between `query` and vector `idx`
511    fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
512        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
513        let start = offset as usize;
514        let end = start + (self.dim * 4);
515        let bytes = &self.mmap[start..end];
516        let vector: &[f32] = bytemuck::cast_slice(bytes);
517        self.dist.eval(query, vector)
518    }
519
520    /// Gets a vector from the index (useful for tests)
521    pub fn get_vector(&self, idx: usize) -> Vec<f32> {
522        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
523        let start = offset as usize;
524        let end = start + (self.dim * 4);
525        let bytes = &self.mmap[start..end];
526        let vector: &[f32] = bytemuck::cast_slice(bytes);
527        vector.to_vec()
528    }
529}
530
531/// Calculates the medoid (vector closest to the centroid) using distance `D`
532/// Parallelizes the per-vector distance evaluations.
533fn calculate_medoid<D: Distance<f32> + Copy + Sync>(vectors: &[Vec<f32>], dist: D) -> usize {
534    let dim = vectors[0].len();
535    let mut centroid = vec![0.0f32; dim];
536
537    for v in vectors {
538        for (i, &val) in v.iter().enumerate() {
539            centroid[i] += val;
540        }
541    }
542    for val in &mut centroid {
543        *val /= vectors.len() as f32;
544    }
545
546    let (best_idx, _best_dist) = vectors
547        .par_iter()
548        .enumerate()
549        .map(|(idx, v)| (idx, dist.eval(&centroid, v)))
550        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
551
552    best_idx
553}
554
555/// Builds a strengthened Vamana-like graph using multi-pass refinement.
556/// - Multi-seed candidate gathering (medoid + random seeds)
557/// - Union with current adjacency before α-prune
558/// - 2 refinement passes with symmetrization after each pass
559fn build_vamana_graph<D: Distance<f32> + Copy + Sync>(
560    vectors: &[Vec<f32>],
561    max_degree: usize,
562    build_beam_width: usize,
563    alpha: f32,
564    dist: D,
565    medoid_id: u32,
566) -> Vec<Vec<u32>> {
567    let n = vectors.len();
568    let mut graph = vec![Vec::<u32>::new(); n];
569
570    // Light random bootstrap to avoid disconnected starts
571    {
572        let mut rng = thread_rng();
573        for i in 0..n {
574            let mut s = HashSet::new();
575            let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
576            while s.len() < target {
577                let nb = rng.gen_range(0..n);
578                if nb != i {
579                    s.insert(nb as u32);
580                }
581            }
582            graph[i] = s.into_iter().collect();
583        }
584    }
585
586    // Refinement passes
587    const PASSES: usize = 2;
588    const EXTRA_SEEDS: usize = 2;
589
590    let mut rng = thread_rng();
591    for _pass in 0..PASSES {
592        // Shuffle visit order each pass
593        let mut order: Vec<usize> = (0..n).collect();
594        order.shuffle(&mut rng);
595
596        // Snapshot read of graph for parallel candidate building
597        let snapshot = &graph;
598
599        // Build new neighbor proposals in parallel
600        let new_graph: Vec<Vec<u32>> = order
601            .par_iter()
602            .map(|&u| {
603                let mut candidates: Vec<(u32, f32)> =
604                    Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
605
606                // Include current adjacency with distances
607                for &nb in &snapshot[u] {
608                    let d = dist.eval(&vectors[u], &vectors[nb as usize]);
609                    candidates.push((nb, d));
610                }
611
612                // Seeds: always medoid + some random starts
613                let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
614                seeds.push(medoid_id as usize);
615                let mut trng = thread_rng();
616                for _ in 0..EXTRA_SEEDS {
617                    seeds.push(trng.gen_range(0..n));
618                }
619
620                // Gather candidates from greedy searches
621                for start in seeds {
622                    let mut part = greedy_search(
623                        &vectors[u],
624                        vectors,
625                        snapshot,
626                        start,
627                        build_beam_width,
628                        dist,
629                    );
630                    candidates.append(&mut part);
631                }
632
633                // Deduplicate by id keeping best distance
634                candidates.sort_by(|a, b| a.0.cmp(&b.0));
635                candidates.dedup_by(|a, b| {
636                    if a.0 == b.0 {
637                        if a.1 < b.1 {
638                            *b = *a;
639                        }
640                        true
641                    } else {
642                        false
643                    }
644                });
645
646                // α-prune around u
647                prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
648            })
649            .collect();
650
651        // Symmetrize: union incoming + outgoing, then α-prune again (parallel)
652        // Build inverse map: node-id -> position in `order`
653        let mut pos_of = vec![0usize; n];
654        for (pos, &u) in order.iter().enumerate() {
655            pos_of[u] = pos;
656        }
657
658        // Build incoming as CSR
659        let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
660
661        // Union + prune in parallel
662        graph = (0..n)
663            .into_par_iter()
664            .map(|u| {
665                let ng = &new_graph[pos_of[u]]; // outgoing from this pass
666                let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; // incoming to u
667
668                // pool = union(outgoing ∪ incoming) with tiny, cache-friendly ops
669                let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
670                pool_ids.extend_from_slice(ng);
671                pool_ids.extend_from_slice(inc);
672                pool_ids.sort_unstable();
673                pool_ids.dedup();
674
675                // compute distances once, then α-prune
676                let pool: Vec<(u32, f32)> = pool_ids
677                    .into_iter()
678                    .filter(|&id| id as usize != u)
679                    .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
680                    .collect();
681
682                prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
683            })
684            .collect();
685    }
686
687    // Final cleanup (ensure <= max_degree everywhere)
688    graph
689        .into_par_iter()
690        .enumerate()
691        .map(|(u, neigh)| {
692            if neigh.len() <= max_degree {
693                return neigh;
694            }
695            let pool: Vec<(u32, f32)> = neigh
696                .iter()
697                .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
698                .collect();
699            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
700        })
701        .collect()
702}
703
704/// Greedy search used during construction (read-only on `graph`)
705/// Same termination rule as query-time search.
706fn greedy_search<D: Distance<f32> + Copy>(
707    query: &[f32],
708    vectors: &[Vec<f32>],
709    graph: &[Vec<u32>],
710    start_id: usize,
711    beam_width: usize,
712    dist: D,
713) -> Vec<(u32, f32)> {
714    let mut visited = HashSet::new();
715    let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // min-heap by dist
716    let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // max-heap by dist
717
718    let start_dist = dist.eval(query, &vectors[start_id]);
719    let start = Candidate {
720        dist: start_dist,
721        id: start_id as u32,
722    };
723    frontier.push(Reverse(start));
724    w.push(start);
725    visited.insert(start_id as u32);
726
727    while let Some(Reverse(best)) = frontier.peek().copied() {
728        if w.len() >= beam_width {
729            if let Some(worst) = w.peek() {
730                if best.dist >= worst.dist {
731                    break;
732                }
733            }
734        }
735        let Reverse(cur) = frontier.pop().unwrap();
736
737        for &nb in &graph[cur.id as usize] {
738            if !visited.insert(nb) {
739                continue;
740            }
741            let d = dist.eval(query, &vectors[nb as usize]);
742            let cand = Candidate { dist: d, id: nb };
743
744            if w.len() < beam_width {
745                w.push(cand);
746                frontier.push(Reverse(cand));
747            } else if d < w.peek().unwrap().dist {
748                w.pop();
749                w.push(cand);
750                frontier.push(Reverse(cand));
751            }
752        }
753    }
754
755    let mut v = w.into_vec();
756    v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
757    v.into_iter().map(|c| (c.id, c.dist)).collect()
758}
759
760/// α-pruning from DiskANN/Vamana
761fn prune_neighbors<D: Distance<f32> + Copy>(
762    node_id: usize,
763    candidates: &[(u32, f32)],
764    vectors: &[Vec<f32>],
765    max_degree: usize,
766    alpha: f32,
767    dist: D,
768) -> Vec<u32> {
769    if candidates.is_empty() {
770        return Vec::new();
771    }
772
773    let mut sorted = candidates.to_vec();
774    sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
775
776    let mut pruned = Vec::<u32>::new();
777
778    for &(cand_id, cand_dist) in &sorted {
779        if cand_id as usize == node_id {
780            continue;
781        }
782        let mut ok = true;
783        for &sel in &pruned {
784            let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
785            if d < alpha * cand_dist {
786                ok = false;
787                break;
788            }
789        }
790        if ok {
791            pruned.push(cand_id);
792            if pruned.len() >= max_degree {
793                break;
794            }
795        }
796    }
797
798    // fill with closest if still not full
799    for &(cand_id, _) in &sorted {
800        if cand_id as usize == node_id {
801            continue;
802        }
803        if !pruned.contains(&cand_id) {
804            pruned.push(cand_id);
805            if pruned.len() >= max_degree {
806                break;
807            }
808        }
809    }
810
811    pruned
812}
813
814fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
815    // 1) count in-degree per node
816    let mut indeg = vec![0usize; n];
817    for (pos, _u) in order.iter().enumerate() {
818        for &v in &new_graph[pos] {
819            indeg[v as usize] += 1;
820        }
821    }
822    // 2) prefix sums → offsets
823    let mut off = vec![0usize; n + 1];
824    for i in 0..n {
825        off[i + 1] = off[i] + indeg[i];
826    }
827    // 3) fill flat incoming list
828    let mut cur = off.clone();
829    let mut incoming_flat = vec![0u32; off[n]];
830    for (pos, &u) in order.iter().enumerate() {
831        for &v in &new_graph[pos] {
832            let idx = cur[v as usize];
833            incoming_flat[idx] = u as u32;
834            cur[v as usize] += 1;
835        }
836    }
837    (incoming_flat, off)
838}
839
840#[cfg(test)]
841mod tests {
842    use super::*;
843    use anndists::dist::{DistCosine, DistL2};
844    use rand::Rng;
845    use std::fs;
846
847    fn euclid(a: &[f32], b: &[f32]) -> f32 {
848        a.iter()
849            .zip(b)
850            .map(|(x, y)| (x - y) * (x - y))
851            .sum::<f32>()
852            .sqrt()
853    }
854
855    #[test]
856    fn test_small_index_l2() {
857        let path = "test_small_l2.db";
858        let _ = fs::remove_file(path);
859
860        let vectors = vec![
861            vec![0.0, 0.0],
862            vec![1.0, 0.0],
863            vec![0.0, 1.0],
864            vec![1.0, 1.0],
865            vec![0.5, 0.5],
866        ];
867
868        let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
869
870        let q = vec![0.1, 0.1];
871        let nns = index.search(&q, 3, 8);
872        assert_eq!(nns.len(), 3);
873
874        // Verify the first neighbor is quite close in L2
875        let v = index.get_vector(nns[0] as usize);
876        assert!(euclid(&q, &v) < 1.0);
877
878        let _ = fs::remove_file(path);
879    }
880
881    #[test]
882    fn test_cosine() {
883        let path = "test_cosine.db";
884        let _ = fs::remove_file(path);
885
886        let vectors = vec![
887            vec![1.0, 0.0, 0.0],
888            vec![0.0, 1.0, 0.0],
889            vec![0.0, 0.0, 1.0],
890            vec![1.0, 1.0, 0.0],
891            vec![1.0, 0.0, 1.0],
892        ];
893
894        let index =
895            DiskANN::<DistCosine>::build_index_default(&vectors, DistCosine {}, path).unwrap();
896
897        let q = vec![2.0, 0.0, 0.0]; // parallel to [1,0,0]
898        let nns = index.search(&q, 2, 8);
899        assert_eq!(nns.len(), 2);
900
901        // Top neighbor should have high cosine similarity (close direction)
902        let v = index.get_vector(nns[0] as usize);
903        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
904        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
905        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
906        let cos = dot / (n1 * n2);
907        assert!(cos > 0.7);
908
909        let _ = fs::remove_file(path);
910    }
911
912    #[test]
913    fn test_persistence_and_open() {
914        let path = "test_persist.db";
915        let _ = fs::remove_file(path);
916
917        let vectors = vec![
918            vec![0.0, 0.0],
919            vec![1.0, 0.0],
920            vec![0.0, 1.0],
921            vec![1.0, 1.0],
922        ];
923
924        {
925            let _idx = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
926        }
927
928        let idx2 = DiskANN::<DistL2>::open_index_default_metric(path).unwrap();
929        assert_eq!(idx2.num_vectors, 4);
930        assert_eq!(idx2.dim, 2);
931
932        let q = vec![0.9, 0.9];
933        let res = idx2.search(&q, 2, 8);
934        // [1,1] should be best
935        assert_eq!(res[0], 3);
936
937        let _ = fs::remove_file(path);
938    }
939
940    #[test]
941    fn test_grid_connectivity() {
942        let path = "test_grid.db";
943        let _ = fs::remove_file(path);
944
945        // 5x5 grid
946        let mut vectors = Vec::new();
947        for i in 0..5 {
948            for j in 0..5 {
949                vectors.push(vec![i as f32, j as f32]);
950            }
951        }
952
953        let index = DiskANN::<DistL2>::build_index_with_params(
954            &vectors,
955            DistL2 {},
956            path,
957            DiskAnnParams {
958                max_degree: 4,
959                build_beam_width: 64,
960                alpha: 1.5,
961            },
962        )
963        .unwrap();
964
965        for target in 0..vectors.len() {
966            let q = &vectors[target];
967            let nns = index.search(q, 10, 32);
968            if !nns.contains(&(target as u32)) {
969                let v = index.get_vector(nns[0] as usize);
970                assert!(euclid(q, &v) < 2.0);
971            }
972            for &nb in nns.iter().take(5) {
973                let v = index.get_vector(nb as usize);
974                assert!(euclid(q, &v) < 5.0);
975            }
976        }
977
978        let _ = fs::remove_file(path);
979    }
980
981    #[test]
982    fn test_medium_random() {
983        let path = "test_medium.db";
984        let _ = fs::remove_file(path);
985
986        let n = 200usize;
987        let d = 32usize;
988        let mut rng = rand::thread_rng();
989        let vectors: Vec<Vec<f32>> = (0..n)
990            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
991            .collect();
992
993        let index = DiskANN::<DistL2>::build_index_with_params(
994            &vectors,
995            DistL2 {},
996            path,
997            DiskAnnParams {
998                max_degree: 32,
999                build_beam_width: 128,
1000                alpha: 1.2,
1001            },
1002        )
1003        .unwrap();
1004
1005        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1006        let res = index.search(&q, 10, 64);
1007        assert_eq!(res.len(), 10);
1008
1009        // Ensure distances are nondecreasing
1010        let dists: Vec<f32> = res
1011            .iter()
1012            .map(|&id| {
1013                let v = index.get_vector(id as usize);
1014                euclid(&q, &v)
1015            })
1016            .collect();
1017        let mut sorted = dists.clone();
1018        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1019        assert_eq!(dists, sorted);
1020
1021        let _ = fs::remove_file(path);
1022    }
1023}