Skip to main content

diskann_rs/
lib.rs

1//! # DiskAnn (generic over `anndists::Distance<f32>`)
2//!
3//! A minimal, on-disk, DiskANN-like library that:
4//! - Builds a Vamana-style graph (greedy + α-pruning) in memory
5//! - Writes vectors + fixed-degree adjacency to a single file
6//! - Memory-maps the file for low-overhead reads
7//! - Is **generic over any `Distance<f32>`** from `anndists` (L2, Cosine, Hamming, Dot, …)
8//! - Supports **incremental updates** (add/delete vectors without full rebuild)
9//!
10//! ## Example
11//! ```no_run
12//! use anndists::dist::{DistL2, DistCosine};
13//! use diskann_rs::{DiskANN, DiskAnnParams};
14//!
15//! // Build a new index from vectors, using L2 and default params
16//! let vectors = vec![vec![0.0; 128]; 1000];
17//! let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2{}, "index.db").unwrap();
18//!
19//! // Or with custom params
20//! let index2 = DiskANN::<DistCosine>::build_index_with_params(
21//!     &vectors,
22//!     DistCosine{},
23//!     "index_cos.db",
24//!     DiskAnnParams { max_degree: 48, ..Default::default() },
25//! ).unwrap();
26//!
27//! // Search the index
28//! let query = vec![0.0; 128];
29//! let neighbors = index.search(&query, 10, 64);
30//!
31//! // Open later (provide the same distance type)
32//! let _reopened = DiskANN::<DistL2>::open_index_default_metric("index.db").unwrap();
33//! ```
34//!
35//! ## Incremental Updates
36//! ```no_run
37//! use anndists::dist::DistL2;
38//! use diskann_rs::IncrementalDiskANN;
39//!
40//! // Build initial index
41//! let vectors = vec![vec![0.0; 128]; 1000];
42//! let mut index = IncrementalDiskANN::<DistL2>::build_default(&vectors, "index.db").unwrap();
43//!
44//! // Add vectors without rebuilding
45//! let new_ids = index.add_vectors(&[vec![1.0; 128]]).unwrap();
46//!
47//! // Delete vectors (lazy tombstoning)
48//! index.delete_vectors(&[0, 1, 2]).unwrap();
49//!
50//! // Compact when needed
51//! if index.should_compact() {
52//!     index.compact("index_v2.db").unwrap();
53//! }
54//! ```
55//!
56//! ## File Layout
57//! [ metadata_len:u64 ][ metadata (bincode) ][ padding up to vectors_offset ]
58//! [ vectors (num * dim * f32) ][ adjacency (num * max_degree * u32) ]
59//!
60//! `vectors_offset` is a fixed 1 MiB gap by default.
61
62mod incremental;
63mod filtered;
64pub mod simd;
65pub mod pq;
66pub mod storage;
67pub mod sq;
68pub mod formats;
69mod quantized;
70
71pub use quantized::{QuantizedDiskANN, QuantizedConfig};
72
73pub use incremental::{
74    IncrementalDiskANN, IncrementalConfig, IncrementalStats,
75    IncrementalQuantizedConfig, QuantizerKind,
76    is_delta_id, delta_local_idx,
77};
78
79pub use filtered::{FilteredDiskANN, Filter};
80
81pub use simd::{SimdL2, SimdDot, SimdCosine, simd_info};
82
83pub use pq::{ProductQuantizer, PQConfig, PQStats};
84
85pub use storage::Storage;
86
87pub use sq::{VectorQuantizer, F16Quantizer, Int8Quantizer};
88
89use anndists::prelude::Distance;
90use bytemuck;
91use rand::prelude::*;
92use rayon::prelude::*;
93use serde::{Deserialize, Serialize};
94use std::cmp::{Ordering, Reverse};
95use std::collections::{BinaryHeap, HashSet};
96use std::fs::OpenOptions;
97use std::io::{Read, Seek, SeekFrom, Write};
98use std::sync::Arc;
99use thiserror::Error;
100
101/// Padding sentinel for adjacency slots (avoid colliding with node 0).
102pub(crate) const PAD_U32: u32 = u32::MAX;
103
104/// Magic number for the core index format: "DANN"
105const CORE_MAGIC: u32 = 0x44414E4E;
106/// Current core index format version
107const CORE_FORMAT_VERSION: u32 = 1;
108
109/// Defaults for in-memory DiskANN builds
110pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
111pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
112pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
113
114/// Optional bag of knobs if you want to override just a few.
115#[derive(Clone, Copy, Debug)]
116pub struct DiskAnnParams {
117    pub max_degree: usize,
118    pub build_beam_width: usize,
119    pub alpha: f32,
120}
121impl Default for DiskAnnParams {
122    fn default() -> Self {
123        Self {
124            max_degree: DISKANN_DEFAULT_MAX_DEGREE,
125            build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
126            alpha: DISKANN_DEFAULT_ALPHA,
127        }
128    }
129}
130
131/// Custom error type for DiskAnnRS operations
132#[derive(Debug, Error)]
133pub enum DiskAnnError {
134    /// Represents I/O errors during file operations
135    #[error("I/O error: {0}")]
136    Io(#[from] std::io::Error),
137
138    /// Represents serialization/deserialization errors
139    #[error("Serialization error: {0}")]
140    Bincode(#[from] bincode::Error),
141
142    /// Represents index-specific errors
143    #[error("Index error: {0}")]
144    IndexError(String),
145}
146
147/// Internal metadata structure stored in the index file
148#[derive(Serialize, Deserialize, Debug)]
149struct Metadata {
150    dim: usize,
151    num_vectors: usize,
152    max_degree: usize,
153    medoid_id: u32,
154    vectors_offset: u64,
155    adjacency_offset: u64,
156    distance_name: String,
157}
158
159/// Candidate for search/frontier queues
160#[derive(Clone, Copy)]
161pub(crate) struct Candidate {
162    pub dist: f32,
163    pub id: u32,
164}
165impl PartialEq for Candidate {
166    fn eq(&self, other: &Self) -> bool {
167        self.dist == other.dist && self.id == other.id
168    }
169}
170impl Eq for Candidate {}
171impl PartialOrd for Candidate {
172    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
173        // Natural order by distance: smaller is "less".
174        self.dist.partial_cmp(&other.dist)
175    }
176}
177impl Ord for Candidate {
178    fn cmp(&self, other: &Self) -> Ordering {
179        self.partial_cmp(other).unwrap_or(Ordering::Equal)
180    }
181}
182
183/// Internal abstraction for a searchable graph index with u32 IDs.
184#[allow(dead_code)]
185pub(crate) trait GraphIndex: Send + Sync {
186    fn num_vectors(&self) -> usize;
187    fn dim(&self) -> usize;
188    fn entry_point(&self) -> u32;
189    fn distance_to(&self, query: &[f32], id: u32) -> f32;
190    fn get_neighbors(&self, id: u32) -> Vec<u32>; // PAD_U32 already filtered
191    fn get_vector(&self, id: u32) -> Vec<f32>;
192    fn is_live(&self, _id: u32) -> bool {
193        true
194    }
195}
196
197impl<D> GraphIndex for DiskANN<D>
198where
199    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
200{
201    fn num_vectors(&self) -> usize {
202        self.num_vectors
203    }
204    fn dim(&self) -> usize {
205        self.dim
206    }
207    fn entry_point(&self) -> u32 {
208        self.medoid_id
209    }
210    fn distance_to(&self, query: &[f32], id: u32) -> f32 {
211        DiskANN::distance_to(self, query, id as usize)
212    }
213    fn get_neighbors(&self, id: u32) -> Vec<u32> {
214        DiskANN::get_neighbors(self, id)
215            .iter()
216            .copied()
217            .filter(|&nb| nb != PAD_U32)
218            .collect()
219    }
220    fn get_vector(&self, id: u32) -> Vec<f32> {
221        DiskANN::get_vector(self, id as usize)
222    }
223}
224
225/// Configuration for the unified beam search.
226pub(crate) struct BeamSearchConfig {
227    /// If set, use an expanded working set of this size (for filtered search).
228    /// Candidates in the expanded set participate in graph exploration but
229    /// only candidates passing the filter are added to the results.
230    pub expanded_beam: Option<usize>,
231    /// Maximum iterations before forced termination (for filtered search).
232    pub max_iterations: Option<usize>,
233    /// Early termination factor: stop when best frontier > worst_result * factor
234    /// (for filtered search).
235    pub early_term_factor: Option<f32>,
236}
237
238impl Default for BeamSearchConfig {
239    fn default() -> Self {
240        Self {
241            expanded_beam: None,
242            max_iterations: None,
243            early_term_factor: None,
244        }
245    }
246}
247
248/// Unified beam search used by all search variants (base, quantized, filtered).
249///
250/// - `start_ids`: entry point nodes (typically medoid, or multiple seeds)
251/// - `beam_width`: working set size (number of closest candidates maintained)
252/// - `k`: number of results to return
253/// - `distance_fn`: computes distance from query to node id
254/// - `neighbors_fn`: returns neighbor ids for a node (filtered, no PAD_U32)
255/// - `filter_fn`: returns true if a candidate should be included in results
256/// - `config`: optional expanded beam / iteration limits for filtered search
257pub(crate) fn beam_search(
258    start_ids: &[u32],
259    beam_width: usize,
260    k: usize,
261    distance_fn: impl Fn(u32) -> f32,
262    neighbors_fn: impl Fn(u32) -> Vec<u32>,
263    filter_fn: impl Fn(u32) -> bool,
264    config: BeamSearchConfig,
265) -> Vec<(u32, f32)> {
266    let working_beam = config.expanded_beam.unwrap_or(beam_width);
267    let is_filtered = config.expanded_beam.is_some();
268
269    let mut visited = HashSet::new();
270    let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
271    let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // working set (max-heap by dist)
272
273    // For filtered search, maintain a separate sorted results vec
274    let mut results: Vec<(u32, f32)> = if is_filtered {
275        Vec::with_capacity(k)
276    } else {
277        Vec::new() // unused in non-filtered mode
278    };
279
280    // Seed from all start nodes
281    for &sid in start_ids {
282        if !visited.insert(sid) {
283            continue;
284        }
285        let d = distance_fn(sid);
286        let cand = Candidate { dist: d, id: sid };
287        frontier.push(Reverse(cand));
288        w.push(cand);
289        if is_filtered && filter_fn(sid) {
290            results.push((sid, d));
291        }
292    }
293
294    let mut iterations = 0;
295    let max_iterations = config.max_iterations.unwrap_or(usize::MAX);
296    let early_term_factor = config.early_term_factor.unwrap_or(f32::MAX);
297
298    while let Some(Reverse(best)) = frontier.peek().copied() {
299        iterations += 1;
300        if iterations > max_iterations {
301            break;
302        }
303
304        // Filtered early termination: stop when best frontier can't improve worst result
305        if is_filtered && results.len() >= k {
306            if let Some((_, worst_dist)) = results.last() {
307                if best.dist > *worst_dist * early_term_factor {
308                    break;
309                }
310            }
311        }
312
313        // Standard beam termination
314        if w.len() >= working_beam {
315            if let Some(worst) = w.peek() {
316                if best.dist >= worst.dist {
317                    break;
318                }
319            }
320        }
321
322        let Reverse(current) = frontier.pop().unwrap();
323
324        for nb in neighbors_fn(current.id) {
325            if !visited.insert(nb) {
326                continue;
327            }
328
329            let d = distance_fn(nb);
330            let cand = Candidate { dist: d, id: nb };
331
332            // Always add to working set for graph exploration
333            if w.len() < working_beam {
334                w.push(cand);
335                frontier.push(Reverse(cand));
336            } else if d < w.peek().unwrap().dist {
337                w.pop();
338                w.push(cand);
339                frontier.push(Reverse(cand));
340            }
341
342            // For filtered search, maintain separate results
343            if is_filtered && filter_fn(nb) {
344                let pos = results
345                    .iter()
346                    .position(|(_, dist)| d < *dist)
347                    .unwrap_or(results.len());
348                if pos < k {
349                    results.insert(pos, (nb, d));
350                    if results.len() > k {
351                        results.pop();
352                    }
353                }
354            }
355        }
356    }
357
358    if is_filtered {
359        results
360    } else {
361        // Non-filtered: extract top-k from working set
362        let mut candidates: Vec<_> = w.into_vec();
363        candidates.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
364        candidates.truncate(k);
365        candidates.into_iter().map(|c| (c.id, c.dist)).collect()
366    }
367}
368
369/// Main struct representing a DiskANN index (generic over distance)
370pub struct DiskANN<D>
371where
372    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
373{
374    /// Dimensionality of vectors in the index
375    pub dim: usize,
376    /// Number of vectors in the index
377    pub num_vectors: usize,
378    /// Maximum number of edges per node
379    pub max_degree: usize,
380    /// Informational: type name of the distance (from metadata)
381    pub distance_name: String,
382
383    /// ID of the medoid (used as entry point)
384    pub(crate) medoid_id: u32,
385    // Offsets
386    pub(crate) vectors_offset: u64,
387    pub(crate) adjacency_offset: u64,
388
389    /// Backing storage (mmap, owned bytes, or shared bytes)
390    pub(crate) storage: Storage,
391
392    /// The distance strategy
393    pub(crate) dist: D,
394}
395
396// constructors
397
398impl<D> DiskANN<D>
399where
400    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
401{
402    /// Build with default parameters: (M=32, L=256, alpha=1.2).
403    pub fn build_index_default(
404        vectors: &[Vec<f32>],
405        dist: D,
406        file_path: &str,
407    ) -> Result<Self, DiskAnnError> {
408        Self::build_index(
409            vectors,
410            DISKANN_DEFAULT_MAX_DEGREE,
411            DISKANN_DEFAULT_BUILD_BEAM,
412            DISKANN_DEFAULT_ALPHA,
413            dist,
414            file_path,
415        )
416    }
417
418    /// Build with a `DiskAnnParams` bundle.
419    pub fn build_index_with_params(
420        vectors: &[Vec<f32>],
421        dist: D,
422        file_path: &str,
423        p: DiskAnnParams,
424    ) -> Result<Self, DiskAnnError> {
425        Self::build_index(
426            vectors,
427            p.max_degree,
428            p.build_beam_width,
429            p.alpha,
430            dist,
431            file_path,
432        )
433    }
434}
435
436/// Extra sugar when your distance type implements `Default` (most unit-struct metrics do).
437impl<D> DiskANN<D>
438where
439    D: Distance<f32> + Default + Send + Sync + Copy + Clone + 'static,
440{
441    /// Build with default params **and** `D::default()` metric.
442    pub fn build_index_default_metric(
443        vectors: &[Vec<f32>],
444        file_path: &str,
445    ) -> Result<Self, DiskAnnError> {
446        Self::build_index_default(vectors, D::default(), file_path)
447    }
448
449    /// Open an index using `D::default()` as the distance (matches what you built with).
450    pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
451        Self::open_index_with(path, D::default())
452    }
453}
454
455impl<D> DiskANN<D>
456where
457    D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
458{
459    /// Builds a new index from provided vectors
460    ///
461    /// # Arguments
462    /// * `vectors` - The vectors to index (slice of `Vec<f32>`)
463    /// * `max_degree` - Maximum edges per node (M ~ 24-64)
464    /// * `build_beam_width` - Construction L (e.g., 128-400)
465    /// * `alpha` - Pruning parameter (1.2–2.0)
466    /// * `dist` - Any `anndists::Distance<f32>` (e.g., `DistL2`)
467    /// * `file_path` - Path of index file
468    pub fn build_index(
469        vectors: &[Vec<f32>],
470        max_degree: usize,
471        build_beam_width: usize,
472        alpha: f32,
473        dist: D,
474        file_path: &str,
475    ) -> Result<Self, DiskAnnError> {
476        if vectors.is_empty() {
477            return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
478        }
479
480        let num_vectors = vectors.len();
481        let dim = vectors[0].len();
482        for (i, v) in vectors.iter().enumerate() {
483            if v.len() != dim {
484                return Err(DiskAnnError::IndexError(format!(
485                    "Vector {} has dimension {} but expected {}",
486                    i,
487                    v.len(),
488                    dim
489                )));
490            }
491        }
492
493        let mut file = OpenOptions::new()
494            .create(true)
495            .write(true)
496            .read(true)
497            .truncate(true)
498            .open(file_path)?;
499
500        // Reserve space for metadata (we'll write it after data)
501        let vectors_offset = 1024 * 1024;
502        let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
503
504        // Write vectors contiguous (sequential I/O is fastest)
505        file.seek(SeekFrom::Start(vectors_offset))?;
506        for vector in vectors {
507            let bytes = bytemuck::cast_slice(vector);
508            file.write_all(bytes)?;
509        }
510
511        // Compute medoid using provided distance (parallelized distance eval)
512        let medoid_id = calculate_medoid(vectors, dist);
513
514        // Build Vamana-like graph (stronger refinement, parallel inner loops)
515        let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
516        let graph = build_vamana_graph(
517            vectors,
518            max_degree,
519            build_beam_width,
520            alpha,
521            dist,
522            medoid_id as u32,
523        );
524
525        // Write adjacency lists (fixed max_degree, pad with PAD_U32)
526        file.seek(SeekFrom::Start(adjacency_offset))?;
527        for neighbors in &graph {
528            let mut padded = neighbors.clone();
529            padded.resize(max_degree, PAD_U32);
530            let bytes = bytemuck::cast_slice(&padded);
531            file.write_all(bytes)?;
532        }
533
534        // Write metadata
535        let metadata = Metadata {
536            dim,
537            num_vectors,
538            max_degree,
539            medoid_id: medoid_id as u32,
540            vectors_offset: vectors_offset as u64,
541            adjacency_offset,
542            distance_name: std::any::type_name::<D>().to_string(),
543        };
544
545        let md_bytes = bincode::serialize(&metadata)?;
546        file.seek(SeekFrom::Start(0))?;
547        file.write_all(&CORE_MAGIC.to_le_bytes())?;
548        file.write_all(&CORE_FORMAT_VERSION.to_le_bytes())?;
549        let md_len = md_bytes.len() as u64;
550        file.write_all(&md_len.to_le_bytes())?;
551        file.write_all(&md_bytes)?;
552        file.sync_all()?;
553
554        // Memory map the file
555        let mmap = unsafe { memmap2::Mmap::map(&file)? };
556
557        Ok(Self {
558            dim,
559            num_vectors,
560            max_degree,
561            distance_name: metadata.distance_name,
562            medoid_id: metadata.medoid_id,
563            vectors_offset: metadata.vectors_offset,
564            adjacency_offset: metadata.adjacency_offset,
565            storage: Storage::Mmap(mmap),
566            dist,
567        })
568    }
569
570    /// Opens an existing index file, supplying the distance strategy explicitly.
571    pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
572        let mut file = OpenOptions::new().read(true).write(false).open(path)?;
573
574        // Read first 4 bytes to detect format (magic or old-style md_len)
575        let mut buf4 = [0u8; 4];
576        file.seek(SeekFrom::Start(0))?;
577        file.read_exact(&mut buf4)?;
578        let first_u32 = u32::from_le_bytes(buf4);
579
580        let md_offset = if first_u32 == CORE_MAGIC {
581            // New format: [magic:u32][version:u32][md_len:u64][metadata...]
582            let mut ver_buf = [0u8; 4];
583            file.read_exact(&mut ver_buf)?;
584            let version = u32::from_le_bytes(ver_buf);
585            if version != CORE_FORMAT_VERSION {
586                return Err(DiskAnnError::IndexError(format!(
587                    "Unsupported core format version: {}", version
588                )));
589            }
590            8u64 // magic + version = 8 bytes, then md_len starts
591        } else {
592            // Old format: [md_len:u64][metadata...]
593            file.seek(SeekFrom::Start(0))?;
594            0u64
595        };
596
597        // Read metadata length
598        let mut buf8 = [0u8; 8];
599        file.seek(SeekFrom::Start(md_offset))?;
600        file.read_exact(&mut buf8)?;
601        let md_len = u64::from_le_bytes(buf8);
602
603        // Sanity check: metadata length must be reasonable (< 1 MiB and < file size)
604        let file_size = file.seek(SeekFrom::End(0))?;
605        if md_len > 1024 * 1024 || md_offset + 8 + md_len > file_size {
606            return Err(DiskAnnError::IndexError(format!(
607                "Invalid metadata length {} (file size {})",
608                md_len, file_size
609            )));
610        }
611        file.seek(SeekFrom::Start(md_offset + 8))?;
612
613        // Read metadata
614        let mut md_bytes = vec![0u8; md_len as usize];
615        file.read_exact(&mut md_bytes)?;
616        let metadata: Metadata = bincode::deserialize(&md_bytes)?;
617
618        let mmap = unsafe { memmap2::Mmap::map(&file)? };
619
620        // Optional sanity/logging: warn if type differs from recorded name
621        let expected = std::any::type_name::<D>();
622        if metadata.distance_name != expected {
623            eprintln!(
624                "Warning: index recorded distance `{}` but you opened with `{}`",
625                metadata.distance_name, expected
626            );
627        }
628
629        Ok(Self {
630            dim: metadata.dim,
631            num_vectors: metadata.num_vectors,
632            max_degree: metadata.max_degree,
633            distance_name: metadata.distance_name,
634            medoid_id: metadata.medoid_id,
635            vectors_offset: metadata.vectors_offset,
636            adjacency_offset: metadata.adjacency_offset,
637            storage: Storage::Mmap(mmap),
638            dist,
639        })
640    }
641
642    /// Load an index from an owned byte buffer (no file needed).
643    pub fn from_bytes(bytes: Vec<u8>, dist: D) -> Result<Self, DiskAnnError> {
644        let metadata = Self::parse_metadata(&bytes)?;
645
646        let expected = std::any::type_name::<D>();
647        if metadata.distance_name != expected {
648            eprintln!(
649                "Warning: index recorded distance `{}` but you opened with `{}`",
650                metadata.distance_name, expected
651            );
652        }
653
654        Ok(Self {
655            dim: metadata.dim,
656            num_vectors: metadata.num_vectors,
657            max_degree: metadata.max_degree,
658            distance_name: metadata.distance_name,
659            medoid_id: metadata.medoid_id,
660            vectors_offset: metadata.vectors_offset,
661            adjacency_offset: metadata.adjacency_offset,
662            storage: Storage::Owned(bytes),
663            dist,
664        })
665    }
666
667    /// Load an index from a shared byte buffer (cheap clone, multi-reader).
668    pub fn from_shared_bytes(bytes: Arc<[u8]>, dist: D) -> Result<Self, DiskAnnError> {
669        let metadata = Self::parse_metadata(&bytes)?;
670
671        let expected = std::any::type_name::<D>();
672        if metadata.distance_name != expected {
673            eprintln!(
674                "Warning: index recorded distance `{}` but you opened with `{}`",
675                metadata.distance_name, expected
676            );
677        }
678
679        Ok(Self {
680            dim: metadata.dim,
681            num_vectors: metadata.num_vectors,
682            max_degree: metadata.max_degree,
683            distance_name: metadata.distance_name,
684            medoid_id: metadata.medoid_id,
685            vectors_offset: metadata.vectors_offset,
686            adjacency_offset: metadata.adjacency_offset,
687            storage: Storage::Shared(bytes),
688            dist,
689        })
690    }
691
692    /// Serialize the index to a byte vector.
693    pub fn to_bytes(&self) -> Vec<u8> {
694        self.storage.to_vec()
695    }
696
697    /// Parse metadata from raw bytes (shared helper for from_bytes / from_shared_bytes).
698    /// Handles both new format (with magic/version) and old format (raw md_len).
699    fn parse_metadata(bytes: &[u8]) -> Result<Metadata, DiskAnnError> {
700        if bytes.len() < 8 {
701            return Err(DiskAnnError::IndexError("Buffer too small for metadata length".into()));
702        }
703
704        // Detect format: check first 4 bytes for magic
705        let first_u32 = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
706        let md_offset = if first_u32 == CORE_MAGIC {
707            // New format: skip magic(4) + version(4)
708            if bytes.len() < 16 {
709                return Err(DiskAnnError::IndexError("Buffer too small for header".into()));
710            }
711            let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
712            if version != CORE_FORMAT_VERSION {
713                return Err(DiskAnnError::IndexError(format!(
714                    "Unsupported core format version: {}", version
715                )));
716            }
717            8
718        } else {
719            0
720        };
721
722        if bytes.len() < md_offset + 8 {
723            return Err(DiskAnnError::IndexError("Buffer too small for metadata length".into()));
724        }
725        let md_len = u64::from_le_bytes(bytes[md_offset..md_offset + 8].try_into().unwrap()) as usize;
726        if bytes.len() < md_offset + 8 + md_len {
727            return Err(DiskAnnError::IndexError("Buffer too small for metadata".into()));
728        }
729        let metadata: Metadata = bincode::deserialize(&bytes[md_offset + 8..md_offset + 8 + md_len])?;
730        Ok(metadata)
731    }
732
733    /// Searches the index for nearest neighbors using a best-first beam search.
734    /// Like `search` but also returns the distance for each neighbor.
735    pub fn search_with_dists(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
736        assert_eq!(
737            query.len(),
738            self.dim,
739            "Query dim {} != index dim {}",
740            query.len(),
741            self.dim
742        );
743
744        beam_search(
745            &[self.medoid_id],
746            beam_width,
747            k,
748            |id| self.distance_to(query, id as usize),
749            |id| self.get_neighbors(id).iter().copied().filter(|&nb| nb != PAD_U32).collect(),
750            |_| true,
751            BeamSearchConfig::default(),
752        )
753    }
754    /// search but only return neighbor ids
755    pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
756        self.search_with_dists(query, k, beam_width)
757            .into_iter()
758            .map(|(id, _dist)| id)
759            .collect()
760    }
761
762    /// Gets the neighbors of a node from the (fixed-degree) adjacency region
763    pub(crate) fn get_neighbors(&self, node_id: u32) -> &[u32] {
764        let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
765        let start = offset as usize;
766        let end = start + (self.max_degree * 4);
767        let bytes = &self.storage[start..end];
768        bytemuck::cast_slice(bytes)
769    }
770
771    /// Computes distance between `query` and vector `idx`
772    pub(crate) fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
773        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
774        let start = offset as usize;
775        let end = start + (self.dim * 4);
776        let bytes = &self.storage[start..end];
777        let vector: &[f32] = bytemuck::cast_slice(bytes);
778        self.dist.eval(query, vector)
779    }
780
781    /// Gets a vector from the index (useful for tests)
782    pub fn get_vector(&self, idx: usize) -> Vec<f32> {
783        let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
784        let start = offset as usize;
785        let end = start + (self.dim * 4);
786        let bytes = &self.storage[start..end];
787        let vector: &[f32] = bytemuck::cast_slice(bytes);
788        vector.to_vec()
789    }
790}
791
792/// Calculates the medoid (vector closest to the centroid) using distance `D`
793/// Parallelizes the per-vector distance evaluations.
794fn calculate_medoid<D: Distance<f32> + Copy + Sync>(vectors: &[Vec<f32>], dist: D) -> usize {
795    let dim = vectors[0].len();
796    let mut centroid = vec![0.0f32; dim];
797
798    for v in vectors {
799        for (i, &val) in v.iter().enumerate() {
800            centroid[i] += val;
801        }
802    }
803    for val in &mut centroid {
804        *val /= vectors.len() as f32;
805    }
806
807    let (best_idx, _best_dist) = vectors
808        .par_iter()
809        .enumerate()
810        .map(|(idx, v)| (idx, dist.eval(&centroid, v)))
811        .reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
812
813    best_idx
814}
815
816/// Builds a strengthened Vamana-like graph using multi-pass refinement.
817/// - Multi-seed candidate gathering (medoid + random seeds)
818/// - Union with current adjacency before α-prune
819/// - 2 refinement passes with symmetrization after each pass
820fn build_vamana_graph<D: Distance<f32> + Copy + Sync>(
821    vectors: &[Vec<f32>],
822    max_degree: usize,
823    build_beam_width: usize,
824    alpha: f32,
825    dist: D,
826    medoid_id: u32,
827) -> Vec<Vec<u32>> {
828    let n = vectors.len();
829    let mut graph = vec![Vec::<u32>::new(); n];
830
831    // Light random bootstrap to avoid disconnected starts
832    {
833        let mut rng = thread_rng();
834        for i in 0..n {
835            let mut s = HashSet::new();
836            let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
837            while s.len() < target {
838                let nb = rng.gen_range(0..n);
839                if nb != i {
840                    s.insert(nb as u32);
841                }
842            }
843            graph[i] = s.into_iter().collect();
844        }
845    }
846
847    // Refinement passes
848    const PASSES: usize = 2;
849    const EXTRA_SEEDS: usize = 2;
850
851    let mut rng = thread_rng();
852    for _pass in 0..PASSES {
853        // Shuffle visit order each pass
854        let mut order: Vec<usize> = (0..n).collect();
855        order.shuffle(&mut rng);
856
857        // Snapshot read of graph for parallel candidate building
858        let snapshot = &graph;
859
860        // Build new neighbor proposals in parallel
861        let new_graph: Vec<Vec<u32>> = order
862            .par_iter()
863            .map(|&u| {
864                let mut candidates: Vec<(u32, f32)> =
865                    Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
866
867                // Include current adjacency with distances
868                for &nb in &snapshot[u] {
869                    let d = dist.eval(&vectors[u], &vectors[nb as usize]);
870                    candidates.push((nb, d));
871                }
872
873                // Seeds: always medoid + some random starts
874                let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
875                seeds.push(medoid_id as usize);
876                let mut trng = thread_rng();
877                for _ in 0..EXTRA_SEEDS {
878                    seeds.push(trng.gen_range(0..n));
879                }
880
881                // Gather candidates from greedy searches
882                for start in seeds {
883                    let mut part = greedy_search(
884                        &vectors[u],
885                        vectors,
886                        snapshot,
887                        start,
888                        build_beam_width,
889                        dist,
890                    );
891                    candidates.append(&mut part);
892                }
893
894                // Deduplicate by id keeping best distance
895                candidates.sort_by(|a, b| a.0.cmp(&b.0));
896                candidates.dedup_by(|a, b| {
897                    if a.0 == b.0 {
898                        if a.1 < b.1 {
899                            *b = *a;
900                        }
901                        true
902                    } else {
903                        false
904                    }
905                });
906
907                // α-prune around u
908                prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
909            })
910            .collect();
911
912        // Symmetrize: union incoming + outgoing, then α-prune again (parallel)
913        // Build inverse map: node-id -> position in `order`
914        let mut pos_of = vec![0usize; n];
915        for (pos, &u) in order.iter().enumerate() {
916            pos_of[u] = pos;
917        }
918
919        // Build incoming as CSR
920        let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
921
922        // Union + prune in parallel
923        graph = (0..n)
924            .into_par_iter()
925            .map(|u| {
926                let ng = &new_graph[pos_of[u]]; // outgoing from this pass
927                let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]]; // incoming to u
928
929                // pool = union(outgoing ∪ incoming) with tiny, cache-friendly ops
930                let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
931                pool_ids.extend_from_slice(ng);
932                pool_ids.extend_from_slice(inc);
933                pool_ids.sort_unstable();
934                pool_ids.dedup();
935
936                // compute distances once, then α-prune
937                let pool: Vec<(u32, f32)> = pool_ids
938                    .into_iter()
939                    .filter(|&id| id as usize != u)
940                    .map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
941                    .collect();
942
943                prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
944            })
945            .collect();
946    }
947
948    // Final cleanup (ensure <= max_degree everywhere)
949    graph
950        .into_par_iter()
951        .enumerate()
952        .map(|(u, neigh)| {
953            if neigh.len() <= max_degree {
954                return neigh;
955            }
956            let pool: Vec<(u32, f32)> = neigh
957                .iter()
958                .map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
959                .collect();
960            prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
961        })
962        .collect()
963}
964
965/// Greedy search used during construction (read-only on `graph`)
966/// Same termination rule as query-time search.
967fn greedy_search<D: Distance<f32> + Copy>(
968    query: &[f32],
969    vectors: &[Vec<f32>],
970    graph: &[Vec<u32>],
971    start_id: usize,
972    beam_width: usize,
973    dist: D,
974) -> Vec<(u32, f32)> {
975    let mut visited = HashSet::new();
976    let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); // min-heap by dist
977    let mut w: BinaryHeap<Candidate> = BinaryHeap::new(); // max-heap by dist
978
979    let start_dist = dist.eval(query, &vectors[start_id]);
980    let start = Candidate {
981        dist: start_dist,
982        id: start_id as u32,
983    };
984    frontier.push(Reverse(start));
985    w.push(start);
986    visited.insert(start_id as u32);
987
988    while let Some(Reverse(best)) = frontier.peek().copied() {
989        if w.len() >= beam_width {
990            if let Some(worst) = w.peek() {
991                if best.dist >= worst.dist {
992                    break;
993                }
994            }
995        }
996        let Reverse(cur) = frontier.pop().unwrap();
997
998        for &nb in &graph[cur.id as usize] {
999            if !visited.insert(nb) {
1000                continue;
1001            }
1002            let d = dist.eval(query, &vectors[nb as usize]);
1003            let cand = Candidate { dist: d, id: nb };
1004
1005            if w.len() < beam_width {
1006                w.push(cand);
1007                frontier.push(Reverse(cand));
1008            } else if d < w.peek().unwrap().dist {
1009                w.pop();
1010                w.push(cand);
1011                frontier.push(Reverse(cand));
1012            }
1013        }
1014    }
1015
1016    let mut v = w.into_vec();
1017    v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
1018    v.into_iter().map(|c| (c.id, c.dist)).collect()
1019}
1020
1021/// α-pruning from DiskANN/Vamana
1022fn prune_neighbors<D: Distance<f32> + Copy>(
1023    node_id: usize,
1024    candidates: &[(u32, f32)],
1025    vectors: &[Vec<f32>],
1026    max_degree: usize,
1027    alpha: f32,
1028    dist: D,
1029) -> Vec<u32> {
1030    if candidates.is_empty() {
1031        return Vec::new();
1032    }
1033
1034    let mut sorted = candidates.to_vec();
1035    sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1036
1037    let mut pruned = Vec::<u32>::new();
1038
1039    for &(cand_id, cand_dist) in &sorted {
1040        if cand_id as usize == node_id {
1041            continue;
1042        }
1043        let mut ok = true;
1044        for &sel in &pruned {
1045            let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
1046            if d < alpha * cand_dist {
1047                ok = false;
1048                break;
1049            }
1050        }
1051        if ok {
1052            pruned.push(cand_id);
1053            if pruned.len() >= max_degree {
1054                break;
1055            }
1056        }
1057    }
1058
1059    // fill with closest if still not full
1060    for &(cand_id, _) in &sorted {
1061        if pruned.len() >= max_degree {
1062            break;
1063        }
1064        if cand_id as usize == node_id {
1065            continue;
1066        }
1067        if !pruned.contains(&cand_id) {
1068            pruned.push(cand_id);
1069        }
1070    }
1071
1072    pruned
1073}
1074
1075fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
1076    // 1) count in-degree per node
1077    let mut indeg = vec![0usize; n];
1078    for (pos, _u) in order.iter().enumerate() {
1079        for &v in &new_graph[pos] {
1080            indeg[v as usize] += 1;
1081        }
1082    }
1083    // 2) prefix sums → offsets
1084    let mut off = vec![0usize; n + 1];
1085    for i in 0..n {
1086        off[i + 1] = off[i] + indeg[i];
1087    }
1088    // 3) fill flat incoming list
1089    let mut cur = off.clone();
1090    let mut incoming_flat = vec![0u32; off[n]];
1091    for (pos, &u) in order.iter().enumerate() {
1092        for &v in &new_graph[pos] {
1093            let idx = cur[v as usize];
1094            incoming_flat[idx] = u as u32;
1095            cur[v as usize] += 1;
1096        }
1097    }
1098    (incoming_flat, off)
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103    use super::*;
1104    use anndists::dist::{DistCosine, DistL2};
1105    use rand::Rng;
1106    use std::fs;
1107
1108    fn euclid(a: &[f32], b: &[f32]) -> f32 {
1109        a.iter()
1110            .zip(b)
1111            .map(|(x, y)| (x - y) * (x - y))
1112            .sum::<f32>()
1113            .sqrt()
1114    }
1115
1116    #[test]
1117    fn test_small_index_l2() {
1118        let path = "test_small_l2.db";
1119        let _ = fs::remove_file(path);
1120
1121        let vectors = vec![
1122            vec![0.0, 0.0],
1123            vec![1.0, 0.0],
1124            vec![0.0, 1.0],
1125            vec![1.0, 1.0],
1126            vec![0.5, 0.5],
1127        ];
1128
1129        let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1130
1131        let q = vec![0.1, 0.1];
1132        let nns = index.search(&q, 3, 8);
1133        assert_eq!(nns.len(), 3);
1134
1135        // Verify the first neighbor is quite close in L2
1136        let v = index.get_vector(nns[0] as usize);
1137        assert!(euclid(&q, &v) < 1.0);
1138
1139        let _ = fs::remove_file(path);
1140    }
1141
1142    #[test]
1143    fn test_cosine() {
1144        let path = "test_cosine.db";
1145        let _ = fs::remove_file(path);
1146
1147        let vectors = vec![
1148            vec![1.0, 0.0, 0.0],
1149            vec![0.0, 1.0, 0.0],
1150            vec![0.0, 0.0, 1.0],
1151            vec![1.0, 1.0, 0.0],
1152            vec![1.0, 0.0, 1.0],
1153        ];
1154
1155        let index =
1156            DiskANN::<DistCosine>::build_index_default(&vectors, DistCosine {}, path).unwrap();
1157
1158        let q = vec![2.0, 0.0, 0.0]; // parallel to [1,0,0]
1159        let nns = index.search(&q, 2, 8);
1160        assert_eq!(nns.len(), 2);
1161
1162        // Top neighbor should have high cosine similarity (close direction)
1163        let v = index.get_vector(nns[0] as usize);
1164        let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
1165        let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1166        let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
1167        let cos = dot / (n1 * n2);
1168        assert!(cos > 0.7);
1169
1170        let _ = fs::remove_file(path);
1171    }
1172
1173    #[test]
1174    fn test_persistence_and_open() {
1175        let path = "test_persist.db";
1176        let _ = fs::remove_file(path);
1177
1178        let vectors = vec![
1179            vec![0.0, 0.0],
1180            vec![1.0, 0.0],
1181            vec![0.0, 1.0],
1182            vec![1.0, 1.0],
1183        ];
1184
1185        {
1186            let _idx = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1187        }
1188
1189        let idx2 = DiskANN::<DistL2>::open_index_default_metric(path).unwrap();
1190        assert_eq!(idx2.num_vectors, 4);
1191        assert_eq!(idx2.dim, 2);
1192
1193        let q = vec![0.9, 0.9];
1194        let res = idx2.search(&q, 2, 8);
1195        // [1,1] should be best
1196        assert_eq!(res[0], 3);
1197
1198        let _ = fs::remove_file(path);
1199    }
1200
1201    #[test]
1202    fn test_grid_connectivity() {
1203        let path = "test_grid.db";
1204        let _ = fs::remove_file(path);
1205
1206        // 5x5 grid
1207        let mut vectors = Vec::new();
1208        for i in 0..5 {
1209            for j in 0..5 {
1210                vectors.push(vec![i as f32, j as f32]);
1211            }
1212        }
1213
1214        let index = DiskANN::<DistL2>::build_index_with_params(
1215            &vectors,
1216            DistL2 {},
1217            path,
1218            DiskAnnParams {
1219                max_degree: 4,
1220                build_beam_width: 64,
1221                alpha: 1.5,
1222            },
1223        )
1224        .unwrap();
1225
1226        for target in 0..vectors.len() {
1227            let q = &vectors[target];
1228            let nns = index.search(q, 10, 32);
1229            if !nns.contains(&(target as u32)) {
1230                let v = index.get_vector(nns[0] as usize);
1231                assert!(euclid(q, &v) < 2.0);
1232            }
1233            for &nb in nns.iter().take(5) {
1234                let v = index.get_vector(nb as usize);
1235                assert!(euclid(q, &v) < 5.0);
1236            }
1237        }
1238
1239        let _ = fs::remove_file(path);
1240    }
1241
1242    #[test]
1243    fn test_medium_random() {
1244        let path = "test_medium.db";
1245        let _ = fs::remove_file(path);
1246
1247        let n = 200usize;
1248        let d = 32usize;
1249        let mut rng = rand::thread_rng();
1250        let vectors: Vec<Vec<f32>> = (0..n)
1251            .map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
1252            .collect();
1253
1254        let index = DiskANN::<DistL2>::build_index_with_params(
1255            &vectors,
1256            DistL2 {},
1257            path,
1258            DiskAnnParams {
1259                max_degree: 32,
1260                build_beam_width: 128,
1261                alpha: 1.2,
1262            },
1263        )
1264        .unwrap();
1265
1266        let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
1267        let res = index.search(&q, 10, 64);
1268        assert_eq!(res.len(), 10);
1269
1270        // Ensure distances are nondecreasing
1271        let dists: Vec<f32> = res
1272            .iter()
1273            .map(|&id| {
1274                let v = index.get_vector(id as usize);
1275                euclid(&q, &v)
1276            })
1277            .collect();
1278        let mut sorted = dists.clone();
1279        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1280        assert_eq!(dists, sorted);
1281
1282        let _ = fs::remove_file(path);
1283    }
1284
1285    #[test]
1286    fn test_to_bytes_from_bytes_round_trip() {
1287        let path = "test_bytes_rt.db";
1288        let _ = fs::remove_file(path);
1289
1290        let vectors = vec![
1291            vec![0.0, 0.0],
1292            vec![1.0, 0.0],
1293            vec![0.0, 1.0],
1294            vec![1.0, 1.0],
1295            vec![0.5, 0.5],
1296        ];
1297
1298        let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1299        let bytes = index.to_bytes();
1300
1301        let index2 = DiskANN::<DistL2>::from_bytes(bytes, DistL2 {}).unwrap();
1302        assert_eq!(index2.num_vectors, 5);
1303        assert_eq!(index2.dim, 2);
1304
1305        let q = vec![0.9, 0.9];
1306        let res1 = index.search(&q, 3, 8);
1307        let res2 = index2.search(&q, 3, 8);
1308        assert_eq!(res1, res2);
1309
1310        let _ = fs::remove_file(path);
1311    }
1312
1313    #[test]
1314    fn test_from_shared_bytes() {
1315        let path = "test_shared_bytes.db";
1316        let _ = fs::remove_file(path);
1317
1318        let vectors = vec![
1319            vec![0.0, 0.0],
1320            vec![1.0, 0.0],
1321            vec![0.0, 1.0],
1322            vec![1.0, 1.0],
1323        ];
1324
1325        let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1326        let bytes = index.to_bytes();
1327        let shared: std::sync::Arc<[u8]> = bytes.into();
1328
1329        let index2 = DiskANN::<DistL2>::from_shared_bytes(shared, DistL2 {}).unwrap();
1330        assert_eq!(index2.num_vectors, 4);
1331        assert_eq!(index2.dim, 2);
1332
1333        let q = vec![0.9, 0.9];
1334        let res = index2.search(&q, 2, 8);
1335        assert_eq!(res[0], 3); // [1,1]
1336
1337        let _ = fs::remove_file(path);
1338    }
1339
1340    // ================================================================
1341    // Unit tests for graph algorithms (no file I/O)
1342    // ================================================================
1343
1344    #[test]
1345    fn test_candidate_ordering() {
1346        use std::cmp::Reverse;
1347        use std::collections::BinaryHeap;
1348
1349        let a = Candidate { dist: 1.0, id: 0 };
1350        let b = Candidate { dist: 2.0, id: 1 };
1351        let c = Candidate { dist: 0.5, id: 2 };
1352
1353        // Natural ordering: smaller dist is "less"
1354        assert!(a < b);
1355        assert!(c < a);
1356
1357        // Min-heap via Reverse
1358        let mut min_heap: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
1359        min_heap.push(Reverse(a));
1360        min_heap.push(Reverse(b));
1361        min_heap.push(Reverse(c));
1362        assert_eq!(min_heap.pop().unwrap().0.id, 2); // dist 0.5
1363        assert_eq!(min_heap.pop().unwrap().0.id, 0); // dist 1.0
1364        assert_eq!(min_heap.pop().unwrap().0.id, 1); // dist 2.0
1365
1366        // Max-heap (natural order)
1367        let mut max_heap: BinaryHeap<Candidate> = BinaryHeap::new();
1368        max_heap.push(a);
1369        max_heap.push(b);
1370        max_heap.push(c);
1371        assert_eq!(max_heap.peek().unwrap().id, 1); // dist 2.0 at top
1372    }
1373
1374    #[test]
1375    fn test_beam_search_small_graph() {
1376        // Hand-crafted 5-node graph:
1377        //   0 --1.0-- 1 --1.0-- 2
1378        //   |                   |
1379        //  2.0                 1.0
1380        //   |                   |
1381        //   3 ------1.5------- 4
1382        //
1383        // Node positions: 0=(0,0), 1=(1,0), 2=(2,0), 3=(0,2), 4=(2,1)
1384        let positions: Vec<[f32; 2]> = vec![
1385            [0.0, 0.0], // 0
1386            [1.0, 0.0], // 1
1387            [2.0, 0.0], // 2
1388            [0.0, 2.0], // 3
1389            [2.0, 1.0], // 4
1390        ];
1391
1392        let neighbors: Vec<Vec<u32>> = vec![
1393            vec![1, 3],    // 0 -> 1, 3
1394            vec![0, 2],    // 1 -> 0, 2
1395            vec![1, 4],    // 2 -> 1, 4
1396            vec![0, 4],    // 3 -> 0, 4
1397            vec![2, 3],    // 4 -> 2, 3
1398        ];
1399
1400        // Query near node 4: (2.1, 0.9)
1401        let query = [2.1f32, 0.9];
1402
1403        let results = beam_search(
1404            &[0], // start from node 0
1405            5,
1406            3,
1407            |id| {
1408                let p = &positions[id as usize];
1409                ((query[0] - p[0]).powi(2) + (query[1] - p[1]).powi(2)).sqrt()
1410            },
1411            |id| neighbors[id as usize].clone(),
1412            |_| true,
1413            BeamSearchConfig::default(),
1414        );
1415
1416        assert_eq!(results.len(), 3);
1417        // Node 4 (2,1) should be closest to query (2.1, 0.9)
1418        assert_eq!(results[0].0, 4);
1419        // Node 2 (2,0) should be second closest
1420        assert_eq!(results[1].0, 2);
1421        // Distances should be sorted
1422        assert!(results[0].1 <= results[1].1);
1423        assert!(results[1].1 <= results[2].1);
1424    }
1425
1426    #[test]
1427    fn test_beam_search_with_filter() {
1428        // Same 5-node graph as above
1429        let positions: Vec<[f32; 2]> = vec![
1430            [0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [0.0, 2.0], [2.0, 1.0],
1431        ];
1432        let neighbors: Vec<Vec<u32>> = vec![
1433            vec![1, 3], vec![0, 2], vec![1, 4], vec![0, 4], vec![2, 3],
1434        ];
1435
1436        // Query near node 4, but filter out nodes 4 and 2 (even IDs only allowed: 0, 2, 4... but let's filter for odd IDs)
1437        let query = [2.1f32, 0.9];
1438
1439        let results = beam_search(
1440            &[0],
1441            5,
1442            3,
1443            |id| {
1444                let p = &positions[id as usize];
1445                ((query[0] - p[0]).powi(2) + (query[1] - p[1]).powi(2)).sqrt()
1446            },
1447            |id| neighbors[id as usize].clone(),
1448            |id| id % 2 == 1, // only odd IDs: 1 and 3
1449            BeamSearchConfig {
1450                expanded_beam: Some(10),
1451                max_iterations: Some(20),
1452                early_term_factor: Some(1.5),
1453            },
1454        );
1455
1456        // Should only contain odd IDs
1457        for (id, _) in &results {
1458            assert!(id % 2 == 1, "Expected only odd IDs, got {}", id);
1459        }
1460        // Should find at least nodes 1 and 3
1461        let ids: HashSet<u32> = results.iter().map(|(id, _)| *id).collect();
1462        assert!(ids.contains(&1));
1463        assert!(ids.contains(&3));
1464    }
1465
1466    #[test]
1467    fn test_prune_neighbors_alpha() {
1468        // 3 candidates around node 0:
1469        //   node 1 at distance 1.0
1470        //   node 2 at distance 1.5 but close to node 1 (should be pruned with high alpha)
1471        //   node 3 at distance 2.0 but far from both (should survive)
1472        let vectors = vec![
1473            vec![0.0, 0.0], // node 0 (center)
1474            vec![1.0, 0.0], // node 1
1475            vec![1.2, 0.0], // node 2 (close to node 1)
1476            vec![0.0, 2.0], // node 3 (far from node 1)
1477        ];
1478
1479        let candidates: Vec<(u32, f32)> = vec![
1480            (1, DistL2 {}.eval(&vectors[0], &vectors[1])),
1481            (2, DistL2 {}.eval(&vectors[0], &vectors[2])),
1482            (3, DistL2 {}.eval(&vectors[0], &vectors[3])),
1483        ];
1484
1485        // With alpha=1.0 (strict pruning), node 2 should be pruned because
1486        // dist(1,2) < alpha * dist(0,2), meaning node 1 is a better representative
1487        let pruned = prune_neighbors(0, &candidates, &vectors, 3, 1.0, DistL2 {});
1488
1489        // Node 1 should always be included (closest)
1490        assert!(pruned.contains(&1));
1491        // Node 3 should be included (it's in a different direction)
1492        assert!(pruned.contains(&3));
1493        // With strict alpha=1.0 and max_degree=3, node 2 might still be added in the fill phase
1494        // but the alpha-pruning step itself should prefer diverse directions
1495    }
1496
1497    #[test]
1498    fn test_prune_neighbors_max_degree() {
1499        let vectors = vec![
1500            vec![0.0, 0.0],
1501            vec![1.0, 0.0],
1502            vec![0.0, 1.0],
1503            vec![1.0, 1.0],
1504            vec![2.0, 0.0],
1505            vec![0.0, 2.0],
1506        ];
1507
1508        let candidates: Vec<(u32, f32)> = (1..6)
1509            .map(|i| (i as u32, DistL2 {}.eval(&vectors[0], &vectors[i])))
1510            .collect();
1511
1512        // max_degree=2: should return at most 2 neighbors
1513        let pruned = prune_neighbors(0, &candidates, &vectors, 2, 1.2, DistL2 {});
1514        assert_eq!(pruned.len(), 2);
1515        assert!(!pruned.is_empty());
1516
1517        // max_degree=5: should return all 5
1518        let pruned = prune_neighbors(0, &candidates, &vectors, 5, 1.2, DistL2 {});
1519        assert_eq!(pruned.len(), 5);
1520
1521        // max_degree=1: should return exactly 1 (the closest)
1522        let pruned = prune_neighbors(0, &candidates, &vectors, 1, 1.2, DistL2 {});
1523        assert_eq!(pruned.len(), 1);
1524    }
1525
1526    #[test]
1527    fn test_core_magic_number_in_bytes() {
1528        let path = "test_magic.db";
1529        let _ = fs::remove_file(path);
1530
1531        let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1532        let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
1533        let bytes = index.to_bytes();
1534
1535        // First 4 bytes should be CORE_MAGIC
1536        let magic = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
1537        assert_eq!(magic, CORE_MAGIC, "Expected magic 0x{:08X}, got 0x{:08X}", CORE_MAGIC, magic);
1538
1539        // Next 4 bytes should be version
1540        let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
1541        assert_eq!(version, CORE_FORMAT_VERSION);
1542
1543        let _ = fs::remove_file(path);
1544    }
1545}