Skip to main content

diskann_rs/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<f32>`)
2//!
3//! A minimal, on-disk, DiskANN-like library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any `Distance<f32>`** from `anndists` (L2, Cosine, Hamming, Dot, …)
8//! - Supports **incremental updates** (add/delete vectors without full rebuild)
9//!
10//! ## Example
11//! ```no_run
12//! use anndists::dist::{DistL2, DistCosine};
13//! use diskann_rs::{DiskANN, DiskAnnParams};
14//!
15//! // Build a new index from vectors, using L2 and default params
16//! let vectors = vec![vec![0.0; 128]; 1000];
17//! let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2{}, "index.db").unwrap();
18//!
19//! // Or with custom params
20//! let index2 = DiskANN::<DistCosine>::build_index_with_params(
21//!     &vectors,
22//!     DistCosine{},
23//!     "index_cos.db",
24//!     DiskAnnParams { max_degree: 48, ..Default::default() },
25//! ).unwrap();
26//!
27//! // Search the index
28//! let query = vec![0.0; 128];
29//! let neighbors = index.search(&query, 10, 64);
30//!
31//! // Open later (provide the same distance type)
32//! let _reopened = DiskANN::<DistL2>::open_index_default_metric("index.db").unwrap();
33//! ```
34//!
35//! ## Incremental Updates
36//! ```no_run
37//! use anndists::dist::DistL2;
38//! use diskann_rs::IncrementalDiskANN;
39//!
40//! // Build initial index
41//! let vectors = vec![vec![0.0; 128]; 1000];
42//! let mut index = IncrementalDiskANN::<DistL2>::build_default(&vectors, "index.db").unwrap();
43//!
44//! // Add vectors without rebuilding
45//! let new_ids = index.add_vectors(&[vec![1.0; 128]]).unwrap();
46//!
47//! // Delete vectors (lazy tombstoning)
48//! index.delete_vectors(&[0, 1, 2]).unwrap();
49//!
50//! // Compact when needed
51//! if index.should_compact() {
52//!     index.compact("index_v2.db").unwrap();
53//! }
54//! ```
55//!
56//! ## File Layout
57//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
58//! [ vectors (num * dim * f32) ][ adjacency (num * max_degree * u32) ]
59//!
60//! `vectors_offset` is a fixed 1 MiB gap by default.
61
62mod incremental;
63mod filtered;
64pub mod simd;
65pub mod pq;
66pub mod storage;
67pub mod sq;
68pub mod formats;
69
70pub use incremental::{
71    IncrementalDiskANN, IncrementalConfig, IncrementalStats,
72    is_delta_id, delta_local_idx,
73};
74
75pub use filtered::{FilteredDiskANN, Filter};
76
77pub use simd::{SimdL2, SimdDot, SimdCosine, simd_info};
78
79pub use pq::{ProductQuantizer, PQConfig, PQStats};
80
81pub use storage::Storage;
82
83pub use sq::{VectorQuantizer, F16Quantizer, Int8Quantizer};
84
85use anndists::prelude::Distance;
86use bytemuck;
87use rand::prelude::*;
88use rayon::prelude::*;
89use serde::{Deserialize, Serialize};
90use std::cmp::{Ordering, Reverse};
91use std::collections::{BinaryHeap, HashSet};
92use std::fs::OpenOptions;
93use std::io::{Read, Seek, SeekFrom, Write};
94use std::sync::Arc;
95use thiserror::Error;
96
97/// Padding sentinel for adjacency slots (avoid colliding with node 0).
98const PAD_U32: u32 = u32::MAX;
99
100/// Defaults for in-memory DiskANN builds
101pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
102pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
103pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
104
105/// Optional bag of knobs if you want to override just a few.
106#[derive(Clone, Copy, Debug)]
107pub struct DiskAnnParams {
108    pub max_degree: usize,
109    pub build_beam_width: usize,
110    pub alpha: f32,
111}
112impl Default for DiskAnnParams {
113    fn default() -> Self {
114        Self {
115            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
116            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
117            alpha: DISKANN_DEFAULT_ALPHA,
118        }
119    }
120}
121
122/// Custom error type for DiskAnnRS operations
123#[derive(Debug, Error)]
124pub enum DiskAnnError {
125    /// Represents I/O errors during file operations
126    #[error("I/O error: {0}")]
127    Io(#[from] std::io::Error),
128
129    /// Represents serialization/deserialization errors
130    #[error("Serialization error: {0}")]
131    Bincode(#[from] bincode::Error),
132
133    /// Represents index-specific errors
134    #[error("Index error: {0}")]
135    IndexError(String),
136}
137
138/// Internal metadata structure stored in the index file
139#[derive(Serialize, Deserialize, Debug)]
140struct Metadata {
141    dim: usize,
142    num_vectors: usize,
143    max_degree: usize,
144    medoid_id: u32,
145    vectors_offset: u64,
146    adjacency_offset: u64,
147    distance_name: String,
148}
149
150/// Candidate for search/frontier queues
151#[derive(Clone, Copy)]
152struct Candidate {
153    dist: f32,
154    id: u32,
155}
156impl PartialEq for Candidate {
157    fn eq(&self, other: &Self) -> bool {
158        self.dist == other.dist && self.id == other.id
159    }
160}
161impl Eq for Candidate {}
162impl PartialOrd for Candidate {
163    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
164        // Natural order by distance: smaller is "less".
165        self.dist.partial_cmp(&other.dist)
166    }
167}
168impl Ord for Candidate {
169    fn cmp(&self, other: &Self) -> Ordering {
170        self.partial_cmp(other).unwrap_or(Ordering::Equal)
171    }
172}
173
174/// Main struct representing a DiskANN index (generic over distance)
175pub struct DiskANN<D>
176where
177    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
178{
179    /// Dimensionality of vectors in the index
180    pub dim: usize,
181    /// Number of vectors in the index
182    pub num_vectors: usize,
183    /// Maximum number of edges per node
184    pub max_degree: usize,
185    /// Informational: type name of the distance (from metadata)
186    pub distance_name: String,
187
188    /// ID of the medoid (used as entry point)
189    pub(crate) medoid_id: u32,
190    // Offsets
191    pub(crate) vectors_offset: u64,
192    pub(crate) adjacency_offset: u64,
193
194    /// Backing storage (mmap, owned bytes, or shared bytes)
195    pub(crate) storage: Storage,
196
197    /// The distance strategy
198    pub(crate) dist: D,
199}
200
201// constructors
202
203impl<D> DiskANN<D>
204where
205    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
206{
207    /// Build with default parameters: (M=32, L=256, alpha=1.2).
208    pub fn build_index_default(
209        vectors: &[Vec<f32>],
210        dist: D,
211        file_path: &str,
212    ) -> Result<Self, DiskAnnError> {
213        Self::build_index(
214            vectors,
215            DISKANN_DEFAULT_MAX_DEGREE,
216            DISKANN_DEFAULT_BUILD_BEAM,
217            DISKANN_DEFAULT_ALPHA,
218            dist,
219            file_path,
220        )
221    }
222
223    /// Build with a `DiskAnnParams` bundle.
224    pub fn build_index_with_params(
225        vectors: &[Vec<f32>],
226        dist: D,
227        file_path: &str,
228        p: DiskAnnParams,
229    ) -> Result<Self, DiskAnnError> {
230        Self::build_index(
231            vectors,
232            p.max_degree,
233            p.build_beam_width,
234            p.alpha,
235            dist,
236            file_path,
237        )
238    }
239}
240
241/// Extra sugar when your distance type implements `Default` (most unit-struct metrics do).
242impl<D> DiskANN<D>
243where
244    D: Distance<f32> + Default + Send + Sync + Copy + Clone + 'static,
245{
246    /// Build with default params **and** `D::default()` metric.
247    pub fn build_index_default_metric(
248        vectors: &[Vec<f32>],
249        file_path: &str,
250    ) -> Result<Self, DiskAnnError> {
251        Self::build_index_default(vectors, D::default(), file_path)
252    }
253
254    /// Open an index using `D::default()` as the distance (matches what you built with).
255    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
256        Self::open_index_with(path, D::default())
257    }
258}
259
260impl<D> DiskANN<D>
261where
262    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
263{
264    /// Builds a new index from provided vectors
265    ///
266    /// # Arguments
267    /// * `vectors` - The vectors to index (slice of `Vec<f32>`)
268    /// * `max_degree` - Maximum edges per node (M ~ 24-64)
269    /// * `build_beam_width` - Construction L (e.g., 128-400)
270    /// * `alpha` - Pruning parameter (1.2–2.0)
271    /// * `dist` - Any `anndists::Distance<f32>` (e.g., `DistL2`)
272    /// * `file_path` - Path of index file
273    pub fn build_index(
274        vectors: &[Vec<f32>],
275        max_degree: usize,
276        build_beam_width: usize,
277        alpha: f32,
278        dist: D,
279        file_path: &str,
280    ) -> Result<Self, DiskAnnError> {
281        if vectors.is_empty() {
282            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
283        }
284
285        let num_vectors = vectors.len();
286        let dim = vectors[0].len();
287        for (i, v) in vectors.iter().enumerate() {
288            if v.len() != dim {
289                return Err(DiskAnnError::IndexError(format!(
290                    "Vector {} has dimension {} but expected {}",
291                    i,
292                    v.len(),
293                    dim
294                )));
295            }
296        }
297
298        let mut file = OpenOptions::new()
299            .create(true)
300            .write(true)
301            .read(true)
302            .truncate(true)
303            .open(file_path)?;
304
305        // Reserve space for metadata (we'll write it after data)
306        let vectors_offset = 1024 * 1024;
307        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
308
309        // Write vectors contiguous (sequential I/O is fastest)
310        file.seek(SeekFrom::Start(vectors_offset))?;
311        for vector in vectors {
312            let bytes = bytemuck::cast_slice(vector);
313            file.write_all(bytes)?;
314        }
315
316        // Compute medoid using provided distance (parallelized distance eval)
317        let medoid_id = calculate_medoid(vectors, dist);
318
319        // Build Vamana-like graph (stronger refinement, parallel inner loops)
320        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
321        let graph = build_vamana_graph(
322            vectors,
323            max_degree,
324            build_beam_width,
325            alpha,
326            dist,
327            medoid_id as u32,
328        );
329
330        // Write adjacency lists (fixed max_degree, pad with PAD_U32)
331        file.seek(SeekFrom::Start(adjacency_offset))?;
332        for neighbors in &graph {
333            let mut padded = neighbors.clone();
334            padded.resize(max_degree, PAD_U32);
335            let bytes = bytemuck::cast_slice(&padded);
336            file.write_all(bytes)?;
337        }
338
339        // Write metadata
340        let metadata = Metadata {
341            dim,
342            num_vectors,
343            max_degree,
344            medoid_id: medoid_id as u32,
345            vectors_offset: vectors_offset as u64,
346            adjacency_offset,
347            distance_name: std::any::type_name::<D>().to_string(),
348        };
349
350        let md_bytes = bincode::serialize(&metadata)?;
351        file.seek(SeekFrom::Start(0))?;
352        let md_len = md_bytes.len() as u64;
353        file.write_all(&md_len.to_le_bytes())?;
354        file.write_all(&md_bytes)?;
355        file.sync_all()?;
356
357        // Memory map the file
358        let mmap = unsafe { memmap2::Mmap::map(&file)? };
359
360        Ok(Self {
361            dim,
362            num_vectors,
363            max_degree,
364            distance_name: metadata.distance_name,
365            medoid_id: metadata.medoid_id,
366            vectors_offset: metadata.vectors_offset,
367            adjacency_offset: metadata.adjacency_offset,
368            storage: Storage::Mmap(mmap),
369            dist,
370        })
371    }
372
373    /// Opens an existing index file, supplying the distance strategy explicitly.
374    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
375        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
376
377        // Read metadata length
378        let mut buf8 = [0u8; 8];
379        file.seek(SeekFrom::Start(0))?;
380        file.read_exact(&mut buf8)?;
381        let md_len = u64::from_le_bytes(buf8);
382
383        // Read metadata
384        let mut md_bytes = vec![0u8; md_len as usize];
385        file.read_exact(&mut md_bytes)?;
386        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
387
388        let mmap = unsafe { memmap2::Mmap::map(&file)? };
389
390        // Optional sanity/logging: warn if type differs from recorded name
391        let expected = std::any::type_name::<D>();
392        if metadata.distance_name != expected {
393            eprintln!(
394                "Warning: index recorded distance `{}` but you opened with `{}`",
395                metadata.distance_name, expected
396            );
397        }
398
399        Ok(Self {
400            dim: metadata.dim,
401            num_vectors: metadata.num_vectors,
402            max_degree: metadata.max_degree,
403            distance_name: metadata.distance_name,
404            medoid_id: metadata.medoid_id,
405            vectors_offset: metadata.vectors_offset,
406            adjacency_offset: metadata.adjacency_offset,
407            storage: Storage::Mmap(mmap),
408            dist,
409        })
410    }
411
412    /// Load an index from an owned byte buffer (no file needed).
413    pub fn from_bytes(bytes: Vec<u8>, dist: D) -> Result<Self, DiskAnnError> {
414        let metadata = Self::parse_metadata(&bytes)?;
415
416        let expected = std::any::type_name::<D>();
417        if metadata.distance_name != expected {
418            eprintln!(
419                "Warning: index recorded distance `{}` but you opened with `{}`",
420                metadata.distance_name, expected
421            );
422        }
423
424        Ok(Self {
425            dim: metadata.dim,
426            num_vectors: metadata.num_vectors,
427            max_degree: metadata.max_degree,
428            distance_name: metadata.distance_name,
429            medoid_id: metadata.medoid_id,
430            vectors_offset: metadata.vectors_offset,
431            adjacency_offset: metadata.adjacency_offset,
432            storage: Storage::Owned(bytes),
433            dist,
434        })
435    }
436
437    /// Load an index from a shared byte buffer (cheap clone, multi-reader).
438    pub fn from_shared_bytes(bytes: Arc<[u8]>, dist: D) -> Result<Self, DiskAnnError> {
439        let metadata = Self::parse_metadata(&bytes)?;
440
441        let expected = std::any::type_name::<D>();
442        if metadata.distance_name != expected {
443            eprintln!(
444                "Warning: index recorded distance `{}` but you opened with `{}`",
445                metadata.distance_name, expected
446            );
447        }
448
449        Ok(Self {
450            dim: metadata.dim,
451            num_vectors: metadata.num_vectors,
452            max_degree: metadata.max_degree,
453            distance_name: metadata.distance_name,
454            medoid_id: metadata.medoid_id,
455            vectors_offset: metadata.vectors_offset,
456            adjacency_offset: metadata.adjacency_offset,
457            storage: Storage::Shared(bytes),
458            dist,
459        })
460    }
461
462    /// Serialize the index to a byte vector.
463    pub fn to_bytes(&self) -> Vec<u8> {
464        self.storage.to_vec()
465    }
466
467    /// Parse metadata from raw bytes (shared helper for from_bytes / from_shared_bytes).
468    fn parse_metadata(bytes: &[u8]) -> Result<Metadata, DiskAnnError> {
469        if bytes.len() < 8 {
470            return Err(DiskAnnError::IndexError("Buffer too small for metadata length".into()));
471        }
472        let md_len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
473        if bytes.len() < 8 + md_len {
474            return Err(DiskAnnError::IndexError("Buffer too small for metadata".into()));
475        }
476        let metadata: Metadata = bincode::deserialize(&bytes[8..8 + md_len])?;
477        Ok(metadata)
478    }
479
480    /// Searches the index for nearest neighbors using a best-first beam search.
481    /// Termination rule: continue while the best frontier can still improve the worst in working set.
482    /// Like `search` but also returns the distance for each neighbor.
483    /// Distances are exactly the ones computed during the beam search.
484    pub fn search_with_dists(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
485        assert_eq!(
486            query.len(),
487            self.dim,
488            "Query dim {} != index dim {}",
489            query.len(),
490            self.dim
491        );
492
493        #[derive(Clone, Copy)]
494        struct Candidate {
495            dist: f32,
496            id: u32,
497        }
498        impl PartialEq for Candidate {
499            fn eq(&self, o: &Self) -> bool {
500                self.dist == o.dist && self.id == o.id
501            }
502        }
503        impl Eq for Candidate {}
504        impl PartialOrd for Candidate {
505            fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
506                self.dist.partial_cmp(&o.dist)
507            }
508        }
509        impl Ord for Candidate {
510            fn cmp(&self, o: &Self) -> Ordering {
511                self.partial_cmp(o).unwrap_or(Ordering::Equal)
512            }
513        }
514
515        let mut visited = HashSet::new();
516        let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // best-first by dist
517        let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // working set, max-heap by dist
518
519        // seed from medoid
520        let start_dist = self.distance_to(query, self.medoid_id as usize);
521        let start = Candidate {
522            dist: start_dist,
523            id: self.medoid_id,
524        };
525        frontier.push(Reverse(start));
526        w.push(start);
527        visited.insert(self.medoid_id);
528
529        // expand while best frontier can still improve worst in working set
530        while let Some(Reverse(best)) = frontier.peek().copied() {
531            if w.len() >= beam_width {
532                if let Some(worst) = w.peek() {
533                    if best.dist >= worst.dist {
534                        break;
535                    }
536                }
537            }
538            let Reverse(current) = frontier.pop().unwrap();
539
540            for &nb in self.get_neighbors(current.id) {
541                if nb == PAD_U32 {
542                    continue;
543                }
544                if !visited.insert(nb) {
545                    continue;
546                }
547
548                let d = self.distance_to(query, nb as usize);
549                let cand = Candidate { dist: d, id: nb };
550
551                if w.len() < beam_width {
552                    w.push(cand);
553                    frontier.push(Reverse(cand));
554                } else if d < w.peek().unwrap().dist {
555                    w.pop();
556                    w.push(cand);
557                    frontier.push(Reverse(cand));
558                }
559            }
560        }
561
562        // top-k by distance, keep distances
563        let mut results: Vec<_> = w.into_vec();
564        results.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
565        results.truncate(k);
566        results.into_iter().map(|c| (c.id, c.dist)).collect()
567    }
568    /// search but only return neighbor ids
569    pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
570        self.search_with_dists(query, k, beam_width)
571            .into_iter()
572            .map(|(id, _dist)| id)
573            .collect()
574    }
575
576    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
577    fn get_neighbors(&self, node_id: u32) -> &[u32] {
578        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
579        let start = offset as usize;
580        let end = start + (self.max_degree * 4);
581        let bytes = &self.storage[start..end];
582        bytemuck::cast_slice(bytes)
583    }
584
585    /// Computes distance between `query` and vector `idx`
586    fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
587        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
588        let start = offset as usize;
589        let end = start + (self.dim * 4);
590        let bytes = &self.storage[start..end];
591        let vector: &[f32] = bytemuck::cast_slice(bytes);
592        self.dist.eval(query, vector)
593    }
594
595    /// Gets a vector from the index (useful for tests)
596    pub fn get_vector(&self, idx: usize) -> Vec<f32> {
597        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
598        let start = offset as usize;
599        let end = start + (self.dim * 4);
600        let bytes = &self.storage[start..end];
601        let vector: &[f32] = bytemuck::cast_slice(bytes);
602        vector.to_vec()
603    }
604}
605
606/// Calculates the medoid (vector closest to the centroid) using distance `D`
607/// Parallelizes the per-vector distance evaluations.
608fn calculate_medoid<D: Distance<f32> + Copy + Sync>(vectors: &[Vec<f32>], dist: D) -> usize {
609    let dim = vectors[0].len();
610    let mut centroid = vec![0.0f32; dim];
611
612    for v in vectors {
613        for (i, &val) in v.iter().enumerate() {
614            centroid[i] += val;
615        }
616    }
617    for val in &mut centroid {
618        *val /= vectors.len() as f32;
619    }
620
621    let (best_idx, _best_dist) = vectors
622        .par_iter()
623        .enumerate()
624        .map(|(idx, v)| (idx, dist.eval(&centroid, v)))
625        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
626
627    best_idx
628}
629
630/// Builds a strengthened Vamana-like graph using multi-pass refinement.
631/// - Multi-seed candidate gathering (medoid + random seeds)
632/// - Union with current adjacency before α-prune
633/// - 2 refinement passes with symmetrization after each pass
634fn build_vamana_graph<D: Distance<f32> + Copy + Sync>(
635    vectors: &[Vec<f32>],
636    max_degree: usize,
637    build_beam_width: usize,
638    alpha: f32,
639    dist: D,
640    medoid_id: u32,
641) -> Vec<Vec<u32>> {
642    let n = vectors.len();
643    let mut graph = vec![Vec::<u32>::new(); n];
644
645    // Light random bootstrap to avoid disconnected starts
646    {
647        let mut rng = thread_rng();
648        for i in 0..n {
649            let mut s = HashSet::new();
650            let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
651            while s.len() < target {
652                let nb = rng.gen_range(0..n);
653                if nb != i {
654                    s.insert(nb as u32);
655                }
656            }
657            graph[i] = s.into_iter().collect();
658        }
659    }
660
661    // Refinement passes
662    const PASSES: usize = 2;
663    const EXTRA_SEEDS: usize = 2;
664
665    let mut rng = thread_rng();
666    for _pass in 0..PASSES {
667        // Shuffle visit order each pass
668        let mut order: Vec<usize> = (0..n).collect();
669        order.shuffle(&mut rng);
670
671        // Snapshot read of graph for parallel candidate building
672        let snapshot = &graph;
673
674        // Build new neighbor proposals in parallel
675        let new_graph: Vec<Vec<u32>> = order
676            .par_iter()
677            .map(|&u| {
678                let mut candidates: Vec<(u32, f32)> =
679                    Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
680
681                // Include current adjacency with distances
682                for &nb in &snapshot[u] {
683                    let d = dist.eval(&vectors[u], &vectors[nb as usize]);
684                    candidates.push((nb, d));
685                }
686
687                // Seeds: always medoid + some random starts
688                let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
689                seeds.push(medoid_id as usize);
690                let mut trng = thread_rng();
691                for _ in 0..EXTRA_SEEDS {
692                    seeds.push(trng.gen_range(0..n));
693                }
694
695                // Gather candidates from greedy searches
696                for start in seeds {
697                    let mut part = greedy_search(
698                        &vectors[u],
699                        vectors,
700                        snapshot,
701                        start,
702                        build_beam_width,
703                        dist,
704                    );
705                    candidates.append(&mut part);
706                }
707
708                // Deduplicate by id keeping best distance
709                candidates.sort_by(|a, b| a.0.cmp(&b.0));
710                candidates.dedup_by(|a, b| {
711                    if a.0 == b.0 {
712                        if a.1 < b.1 {
713                            *b = *a;
714                        }
715                        true
716                    } else {
717                        false
718                    }
719                });
720
721                // α-prune around u
722                prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
723            })
724            .collect();
725
726        // Symmetrize: union incoming + outgoing, then α-prune again (parallel)
727        // Build inverse map: node-id -> position in `order`
728        let mut pos_of = vec![0usize; n];
729        for (pos, &u) in order.iter().enumerate() {
730            pos_of[u] = pos;
731        }
732
733        // Build incoming as CSR
734        let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
735
736        // Union + prune in parallel
737        graph = (0..n)
738            .into_par_iter()
739            .map(|u| {
740                let ng = &new_graph[pos_of[u]]; // outgoing from this pass
741                let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; // incoming to u
742
743                // pool = union(outgoing ∪ incoming) with tiny, cache-friendly ops
744                let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
745                pool_ids.extend_from_slice(ng);
746                pool_ids.extend_from_slice(inc);
747                pool_ids.sort_unstable();
748                pool_ids.dedup();
749
750                // compute distances once, then α-prune
751                let pool: Vec<(u32, f32)> = pool_ids
752                    .into_iter()
753                    .filter(|&id| id as usize != u)
754                    .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
755                    .collect();
756
757                prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
758            })
759            .collect();
760    }
761
762    // Final cleanup (ensure <= max_degree everywhere)
763    graph
764        .into_par_iter()
765        .enumerate()
766        .map(|(u, neigh)| {
767            if neigh.len() <= max_degree {
768                return neigh;
769            }
770            let pool: Vec<(u32, f32)> = neigh
771                .iter()
772                .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
773                .collect();
774            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
775        })
776        .collect()
777}
778
779/// Greedy search used during construction (read-only on `graph`)
780/// Same termination rule as query-time search.
781fn greedy_search<D: Distance<f32> + Copy>(
782    query: &[f32],
783    vectors: &[Vec<f32>],
784    graph: &[Vec<u32>],
785    start_id: usize,
786    beam_width: usize,
787    dist: D,
788) -> Vec<(u32, f32)> {
789    let mut visited = HashSet::new();
790    let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // min-heap by dist
791    let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // max-heap by dist
792
793    let start_dist = dist.eval(query, &vectors[start_id]);
794    let start = Candidate {
795        dist: start_dist,
796        id: start_id as u32,
797    };
798    frontier.push(Reverse(start));
799    w.push(start);
800    visited.insert(start_id as u32);
801
802    while let Some(Reverse(best)) = frontier.peek().copied() {
803        if w.len() >= beam_width {
804            if let Some(worst) = w.peek() {
805                if best.dist >= worst.dist {
806                    break;
807                }
808            }
809        }
810        let Reverse(cur) = frontier.pop().unwrap();
811
812        for &nb in &graph[cur.id as usize] {
813            if !visited.insert(nb) {
814                continue;
815            }
816            let d = dist.eval(query, &vectors[nb as usize]);
817            let cand = Candidate { dist: d, id: nb };
818
819            if w.len() < beam_width {
820                w.push(cand);
821                frontier.push(Reverse(cand));
822            } else if d < w.peek().unwrap().dist {
823                w.pop();
824                w.push(cand);
825                frontier.push(Reverse(cand));
826            }
827        }
828    }
829
830    let mut v = w.into_vec();
831    v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
832    v.into_iter().map(|c| (c.id, c.dist)).collect()
833}
834
835/// α-pruning from DiskANN/Vamana
836fn prune_neighbors<D: Distance<f32> + Copy>(
837    node_id: usize,
838    candidates: &[(u32, f32)],
839    vectors: &[Vec<f32>],
840    max_degree: usize,
841    alpha: f32,
842    dist: D,
843) -> Vec<u32> {
844    if candidates.is_empty() {
845        return Vec::new();
846    }
847
848    let mut sorted = candidates.to_vec();
849    sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
850
851    let mut pruned = Vec::<u32>::new();
852
853    for &(cand_id, cand_dist) in &sorted {
854        if cand_id as usize == node_id {
855            continue;
856        }
857        let mut ok = true;
858        for &sel in &pruned {
859            let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
860            if d < alpha * cand_dist {
861                ok = false;
862                break;
863            }
864        }
865        if ok {
866            pruned.push(cand_id);
867            if pruned.len() >= max_degree {
868                break;
869            }
870        }
871    }
872
873    // fill with closest if still not full
874    for &(cand_id, _) in &sorted {
875        if cand_id as usize == node_id {
876            continue;
877        }
878        if !pruned.contains(&cand_id) {
879            pruned.push(cand_id);
880            if pruned.len() >= max_degree {
881                break;
882            }
883        }
884    }
885
886    pruned
887}
888
889fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
890    // 1) count in-degree per node
891    let mut indeg = vec![0usize; n];
892    for (pos, _u) in order.iter().enumerate() {
893        for &v in &new_graph[pos] {
894            indeg[v as usize] += 1;
895        }
896    }
897    // 2) prefix sums → offsets
898    let mut off = vec![0usize; n + 1];
899    for i in 0..n {
900        off[i + 1] = off[i] + indeg[i];
901    }
902    // 3) fill flat incoming list
903    let mut cur = off.clone();
904    let mut incoming_flat = vec![0u32; off[n]];
905    for (pos, &u) in order.iter().enumerate() {
906        for &v in &new_graph[pos] {
907            let idx = cur[v as usize];
908            incoming_flat[idx] = u as u32;
909            cur[v as usize] += 1;
910        }
911    }
912    (incoming_flat, off)
913}
914
915#[cfg(test)]
916mod tests {
917    use super::*;
918    use anndists::dist::{DistCosine, DistL2};
919    use rand::Rng;
920    use std::fs;
921
922    fn euclid(a: &[f32], b: &[f32]) -> f32 {
923        a.iter()
924            .zip(b)
925            .map(|(x, y)| (x - y) * (x - y))
926            .sum::<f32>()
927            .sqrt()
928    }
929
930    #[test]
931    fn test_small_index_l2() {
932        let path = "test_small_l2.db";
933        let _ = fs::remove_file(path);
934
935        let vectors = vec![
936            vec![0.0, 0.0],
937            vec![1.0, 0.0],
938            vec![0.0, 1.0],
939            vec![1.0, 1.0],
940            vec![0.5, 0.5],
941        ];
942
943        let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
944
945        let q = vec![0.1, 0.1];
946        let nns = index.search(&q, 3, 8);
947        assert_eq!(nns.len(), 3);
948
949        // Verify the first neighbor is quite close in L2
950        let v = index.get_vector(nns[0] as usize);
951        assert!(euclid(&q, &v) < 1.0);
952
953        let _ = fs::remove_file(path);
954    }
955
956    #[test]
957    fn test_cosine() {
958        let path = "test_cosine.db";
959        let _ = fs::remove_file(path);
960
961        let vectors = vec![
962            vec![1.0, 0.0, 0.0],
963            vec![0.0, 1.0, 0.0],
964            vec![0.0, 0.0, 1.0],
965            vec![1.0, 1.0, 0.0],
966            vec![1.0, 0.0, 1.0],
967        ];
968
969        let index =
970            DiskANN::<DistCosine>::build_index_default(&vectors, DistCosine {}, path).unwrap();
971
972        let q = vec![2.0, 0.0, 0.0]; // parallel to [1,0,0]
973        let nns = index.search(&q, 2, 8);
974        assert_eq!(nns.len(), 2);
975
976        // Top neighbor should have high cosine similarity (close direction)
977        let v = index.get_vector(nns[0] as usize);
978        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
979        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
980        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
981        let cos = dot / (n1 * n2);
982        assert!(cos > 0.7);
983
984        let _ = fs::remove_file(path);
985    }
986
987    #[test]
988    fn test_persistence_and_open() {
989        let path = "test_persist.db";
990        let _ = fs::remove_file(path);
991
992        let vectors = vec![
993            vec![0.0, 0.0],
994            vec![1.0, 0.0],
995            vec![0.0, 1.0],
996            vec![1.0, 1.0],
997        ];
998
999        {
1000            let _idx = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1001        }
1002
1003        let idx2 = DiskANN::<DistL2>::open_index_default_metric(path).unwrap();
1004        assert_eq!(idx2.num_vectors, 4);
1005        assert_eq!(idx2.dim, 2);
1006
1007        let q = vec![0.9, 0.9];
1008        let res = idx2.search(&q, 2, 8);
1009        // [1,1] should be best
1010        assert_eq!(res[0], 3);
1011
1012        let _ = fs::remove_file(path);
1013    }
1014
1015    #[test]
1016    fn test_grid_connectivity() {
1017        let path = "test_grid.db";
1018        let _ = fs::remove_file(path);
1019
1020        // 5x5 grid
1021        let mut vectors = Vec::new();
1022        for i in 0..5 {
1023            for j in 0..5 {
1024                vectors.push(vec![i as f32, j as f32]);
1025            }
1026        }
1027
1028        let index = DiskANN::<DistL2>::build_index_with_params(
1029            &vectors,
1030            DistL2 {},
1031            path,
1032            DiskAnnParams {
1033                max_degree: 4,
1034                build_beam_width: 64,
1035                alpha: 1.5,
1036            },
1037        )
1038        .unwrap();
1039
1040        for target in 0..vectors.len() {
1041            let q = &vectors[target];
1042            let nns = index.search(q, 10, 32);
1043            if !nns.contains(&(target as u32)) {
1044                let v = index.get_vector(nns[0] as usize);
1045                assert!(euclid(q, &v) < 2.0);
1046            }
1047            for &nb in nns.iter().take(5) {
1048                let v = index.get_vector(nb as usize);
1049                assert!(euclid(q, &v) < 5.0);
1050            }
1051        }
1052
1053        let _ = fs::remove_file(path);
1054    }
1055
1056    #[test]
1057    fn test_medium_random() {
1058        let path = "test_medium.db";
1059        let _ = fs::remove_file(path);
1060
1061        let n = 200usize;
1062        let d = 32usize;
1063        let mut rng = rand::thread_rng();
1064        let vectors: Vec<Vec<f32>> = (0..n)
1065            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1066            .collect();
1067
1068        let index = DiskANN::<DistL2>::build_index_with_params(
1069            &vectors,
1070            DistL2 {},
1071            path,
1072            DiskAnnParams {
1073                max_degree: 32,
1074                build_beam_width: 128,
1075                alpha: 1.2,
1076            },
1077        )
1078        .unwrap();
1079
1080        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1081        let res = index.search(&q, 10, 64);
1082        assert_eq!(res.len(), 10);
1083
1084        // Ensure distances are nondecreasing
1085        let dists: Vec<f32> = res
1086            .iter()
1087            .map(|&id| {
1088                let v = index.get_vector(id as usize);
1089                euclid(&q, &v)
1090            })
1091            .collect();
1092        let mut sorted = dists.clone();
1093        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1094        assert_eq!(dists, sorted);
1095
1096        let _ = fs::remove_file(path);
1097    }
1098
1099    #[test]
1100    fn test_to_bytes_from_bytes_round_trip() {
1101        let path = "test_bytes_rt.db";
1102        let _ = fs::remove_file(path);
1103
1104        let vectors = vec![
1105            vec![0.0, 0.0],
1106            vec![1.0, 0.0],
1107            vec![0.0, 1.0],
1108            vec![1.0, 1.0],
1109            vec![0.5, 0.5],
1110        ];
1111
1112        let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1113        let bytes = index.to_bytes();
1114
1115        let index2 = DiskANN::<DistL2>::from_bytes(bytes, DistL2 {}).unwrap();
1116        assert_eq!(index2.num_vectors, 5);
1117        assert_eq!(index2.dim, 2);
1118
1119        let q = vec![0.9, 0.9];
1120        let res1 = index.search(&q, 3, 8);
1121        let res2 = index2.search(&q, 3, 8);
1122        assert_eq!(res1, res2);
1123
1124        let _ = fs::remove_file(path);
1125    }
1126
1127    #[test]
1128    fn test_from_shared_bytes() {
1129        let path = "test_shared_bytes.db";
1130        let _ = fs::remove_file(path);
1131
1132        let vectors = vec![
1133            vec![0.0, 0.0],
1134            vec![1.0, 0.0],
1135            vec![0.0, 1.0],
1136            vec![1.0, 1.0],
1137        ];
1138
1139        let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1140        let bytes = index.to_bytes();
1141        let shared: std::sync::Arc<[u8]> = bytes.into();
1142
1143        let index2 = DiskANN::<DistL2>::from_shared_bytes(shared, DistL2 {}).unwrap();
1144        assert_eq!(index2.num_vectors, 4);
1145        assert_eq!(index2.dim, 2);
1146
1147        let q = vec![0.9, 0.9];
1148        let res = index2.search(&q, 2, 8);
1149        assert_eq!(res[0], 3); // [1,1]
1150
1151        let _ = fs::remove_file(path);
1152    }
1153}