mod incremental;
mod filtered;
pub mod simd;
pub mod pq;
pub mod storage;
pub mod sq;
pub mod formats;
mod quantized;
pub use quantized::{QuantizedDiskANN, QuantizedConfig};
pub use incremental::{
IncrementalDiskANN, IncrementalConfig, IncrementalStats,
IncrementalQuantizedConfig, QuantizerKind,
is_delta_id, delta_local_idx,
};
pub use filtered::{FilteredDiskANN, Filter};
pub use simd::{SimdL2, SimdDot, SimdCosine, simd_info};
pub use pq::{ProductQuantizer, PQConfig, PQStats};
pub use storage::Storage;
pub use sq::{VectorQuantizer, F16Quantizer, Int8Quantizer};
use anndists::prelude::Distance;
use bytemuck;
use rand::prelude::*;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::cmp::{Ordering, Reverse};
use std::collections::{BinaryHeap, HashSet};
use std::fs::OpenOptions;
use std::io::{Read, Seek, SeekFrom, Write};
use std::sync::Arc;
use thiserror::Error;
pub(crate) const PAD_U32: u32 = u32::MAX;
const CORE_MAGIC: u32 = 0x44414E4E;
const CORE_FORMAT_VERSION: u32 = 1;
pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
#[derive(Clone, Copy, Debug)]
pub struct DiskAnnParams {
pub max_degree: usize,
pub build_beam_width: usize,
pub alpha: f32,
}
impl Default for DiskAnnParams {
fn default() -> Self {
Self {
max_degree: DISKANN_DEFAULT_MAX_DEGREE,
build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
alpha: DISKANN_DEFAULT_ALPHA,
}
}
}
#[derive(Debug, Error)]
pub enum DiskAnnError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Bincode(#[from] bincode::Error),
#[error("Index error: {0}")]
IndexError(String),
}
#[derive(Serialize, Deserialize, Debug)]
struct Metadata {
dim: usize,
num_vectors: usize,
max_degree: usize,
medoid_id: u32,
vectors_offset: u64,
adjacency_offset: u64,
distance_name: String,
}
#[derive(Clone, Copy)]
pub(crate) struct Candidate {
pub dist: f32,
pub id: u32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.dist == other.dist && self.id == other.id
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.dist.partial_cmp(&other.dist)
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap_or(Ordering::Equal)
}
}
#[allow(dead_code)]
pub(crate) trait GraphIndex: Send + Sync {
fn num_vectors(&self) -> usize;
fn dim(&self) -> usize;
fn entry_point(&self) -> u32;
fn distance_to(&self, query: &[f32], id: u32) -> f32;
fn get_neighbors(&self, id: u32) -> Vec<u32>; fn get_vector(&self, id: u32) -> Vec<f32>;
fn is_live(&self, _id: u32) -> bool {
true
}
}
impl<D> GraphIndex for DiskANN<D>
where
D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
{
fn num_vectors(&self) -> usize {
self.num_vectors
}
fn dim(&self) -> usize {
self.dim
}
fn entry_point(&self) -> u32 {
self.medoid_id
}
fn distance_to(&self, query: &[f32], id: u32) -> f32 {
DiskANN::distance_to(self, query, id as usize)
}
fn get_neighbors(&self, id: u32) -> Vec<u32> {
DiskANN::get_neighbors(self, id)
.iter()
.copied()
.filter(|&nb| nb != PAD_U32)
.collect()
}
fn get_vector(&self, id: u32) -> Vec<f32> {
DiskANN::get_vector(self, id as usize)
}
}
pub(crate) struct BeamSearchConfig {
pub expanded_beam: Option<usize>,
pub max_iterations: Option<usize>,
pub early_term_factor: Option<f32>,
}
impl Default for BeamSearchConfig {
fn default() -> Self {
Self {
expanded_beam: None,
max_iterations: None,
early_term_factor: None,
}
}
}
pub(crate) fn beam_search(
start_ids: &[u32],
beam_width: usize,
k: usize,
distance_fn: impl Fn(u32) -> f32,
neighbors_fn: impl Fn(u32) -> Vec<u32>,
filter_fn: impl Fn(u32) -> bool,
config: BeamSearchConfig,
) -> Vec<(u32, f32)> {
let working_beam = config.expanded_beam.unwrap_or(beam_width);
let is_filtered = config.expanded_beam.is_some();
let mut visited = HashSet::new();
let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
let mut results: Vec<(u32, f32)> = if is_filtered {
Vec::with_capacity(k)
} else {
Vec::new() };
for &sid in start_ids {
if !visited.insert(sid) {
continue;
}
let d = distance_fn(sid);
let cand = Candidate { dist: d, id: sid };
frontier.push(Reverse(cand));
w.push(cand);
if is_filtered && filter_fn(sid) {
results.push((sid, d));
}
}
let mut iterations = 0;
let max_iterations = config.max_iterations.unwrap_or(usize::MAX);
let early_term_factor = config.early_term_factor.unwrap_or(f32::MAX);
while let Some(Reverse(best)) = frontier.peek().copied() {
iterations += 1;
if iterations > max_iterations {
break;
}
if is_filtered && results.len() >= k {
if let Some((_, worst_dist)) = results.last() {
if best.dist > *worst_dist * early_term_factor {
break;
}
}
}
if w.len() >= working_beam {
if let Some(worst) = w.peek() {
if best.dist >= worst.dist {
break;
}
}
}
let Reverse(current) = frontier.pop().unwrap();
for nb in neighbors_fn(current.id) {
if !visited.insert(nb) {
continue;
}
let d = distance_fn(nb);
let cand = Candidate { dist: d, id: nb };
if w.len() < working_beam {
w.push(cand);
frontier.push(Reverse(cand));
} else if d < w.peek().unwrap().dist {
w.pop();
w.push(cand);
frontier.push(Reverse(cand));
}
if is_filtered && filter_fn(nb) {
let pos = results
.iter()
.position(|(_, dist)| d < *dist)
.unwrap_or(results.len());
if pos < k {
results.insert(pos, (nb, d));
if results.len() > k {
results.pop();
}
}
}
}
}
if is_filtered {
results
} else {
let mut candidates: Vec<_> = w.into_vec();
candidates.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
candidates.truncate(k);
candidates.into_iter().map(|c| (c.id, c.dist)).collect()
}
}
pub struct DiskANN<D>
where
D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
{
pub dim: usize,
pub num_vectors: usize,
pub max_degree: usize,
pub distance_name: String,
pub(crate) medoid_id: u32,
pub(crate) vectors_offset: u64,
pub(crate) adjacency_offset: u64,
pub(crate) storage: Storage,
pub(crate) dist: D,
}
impl<D> DiskANN<D>
where
D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
{
pub fn build_index_default(
vectors: &[Vec<f32>],
dist: D,
file_path: &str,
) -> Result<Self, DiskAnnError> {
Self::build_index(
vectors,
DISKANN_DEFAULT_MAX_DEGREE,
DISKANN_DEFAULT_BUILD_BEAM,
DISKANN_DEFAULT_ALPHA,
dist,
file_path,
)
}
pub fn build_index_with_params(
vectors: &[Vec<f32>],
dist: D,
file_path: &str,
p: DiskAnnParams,
) -> Result<Self, DiskAnnError> {
Self::build_index(
vectors,
p.max_degree,
p.build_beam_width,
p.alpha,
dist,
file_path,
)
}
}
impl<D> DiskANN<D>
where
D: Distance<f32> + Default + Send + Sync + Copy + Clone + 'static,
{
pub fn build_index_default_metric(
vectors: &[Vec<f32>],
file_path: &str,
) -> Result<Self, DiskAnnError> {
Self::build_index_default(vectors, D::default(), file_path)
}
pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
Self::open_index_with(path, D::default())
}
}
impl<D> DiskANN<D>
where
D: Distance<f32> + Send + Sync + Copy + Clone + 'static,
{
pub fn build_index(
vectors: &[Vec<f32>],
max_degree: usize,
build_beam_width: usize,
alpha: f32,
dist: D,
file_path: &str,
) -> Result<Self, DiskAnnError> {
if vectors.is_empty() {
return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
}
let num_vectors = vectors.len();
let dim = vectors[0].len();
for (i, v) in vectors.iter().enumerate() {
if v.len() != dim {
return Err(DiskAnnError::IndexError(format!(
"Vector {} has dimension {} but expected {}",
i,
v.len(),
dim
)));
}
}
let mut file = OpenOptions::new()
.create(true)
.write(true)
.read(true)
.truncate(true)
.open(file_path)?;
let vectors_offset = 1024 * 1024;
let total_vector_bytes = (num_vectors as u64) * (dim as u64) * 4;
file.seek(SeekFrom::Start(vectors_offset))?;
for vector in vectors {
let bytes = bytemuck::cast_slice(vector);
file.write_all(bytes)?;
}
let medoid_id = calculate_medoid(vectors, dist);
let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
let graph = build_vamana_graph(
vectors,
max_degree,
build_beam_width,
alpha,
dist,
medoid_id as u32,
);
file.seek(SeekFrom::Start(adjacency_offset))?;
for neighbors in &graph {
let mut padded = neighbors.clone();
padded.resize(max_degree, PAD_U32);
let bytes = bytemuck::cast_slice(&padded);
file.write_all(bytes)?;
}
let metadata = Metadata {
dim,
num_vectors,
max_degree,
medoid_id: medoid_id as u32,
vectors_offset: vectors_offset as u64,
adjacency_offset,
distance_name: std::any::type_name::<D>().to_string(),
};
let md_bytes = bincode::serialize(&metadata)?;
file.seek(SeekFrom::Start(0))?;
file.write_all(&CORE_MAGIC.to_le_bytes())?;
file.write_all(&CORE_FORMAT_VERSION.to_le_bytes())?;
let md_len = md_bytes.len() as u64;
file.write_all(&md_len.to_le_bytes())?;
file.write_all(&md_bytes)?;
file.sync_all()?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
Ok(Self {
dim,
num_vectors,
max_degree,
distance_name: metadata.distance_name,
medoid_id: metadata.medoid_id,
vectors_offset: metadata.vectors_offset,
adjacency_offset: metadata.adjacency_offset,
storage: Storage::Mmap(mmap),
dist,
})
}
pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
let mut file = OpenOptions::new().read(true).write(false).open(path)?;
let mut buf4 = [0u8; 4];
file.seek(SeekFrom::Start(0))?;
file.read_exact(&mut buf4)?;
let first_u32 = u32::from_le_bytes(buf4);
let md_offset = if first_u32 == CORE_MAGIC {
let mut ver_buf = [0u8; 4];
file.read_exact(&mut ver_buf)?;
let version = u32::from_le_bytes(ver_buf);
if version != CORE_FORMAT_VERSION {
return Err(DiskAnnError::IndexError(format!(
"Unsupported core format version: {}", version
)));
}
8u64 } else {
file.seek(SeekFrom::Start(0))?;
0u64
};
let mut buf8 = [0u8; 8];
file.seek(SeekFrom::Start(md_offset))?;
file.read_exact(&mut buf8)?;
let md_len = u64::from_le_bytes(buf8);
let file_size = file.seek(SeekFrom::End(0))?;
if md_len > 1024 * 1024 || md_offset + 8 + md_len > file_size {
return Err(DiskAnnError::IndexError(format!(
"Invalid metadata length {} (file size {})",
md_len, file_size
)));
}
file.seek(SeekFrom::Start(md_offset + 8))?;
let mut md_bytes = vec![0u8; md_len as usize];
file.read_exact(&mut md_bytes)?;
let metadata: Metadata = bincode::deserialize(&md_bytes)?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
let expected = std::any::type_name::<D>();
if metadata.distance_name != expected {
eprintln!(
"Warning: index recorded distance `{}` but you opened with `{}`",
metadata.distance_name, expected
);
}
Ok(Self {
dim: metadata.dim,
num_vectors: metadata.num_vectors,
max_degree: metadata.max_degree,
distance_name: metadata.distance_name,
medoid_id: metadata.medoid_id,
vectors_offset: metadata.vectors_offset,
adjacency_offset: metadata.adjacency_offset,
storage: Storage::Mmap(mmap),
dist,
})
}
pub fn from_bytes(bytes: Vec<u8>, dist: D) -> Result<Self, DiskAnnError> {
let metadata = Self::parse_metadata(&bytes)?;
let expected = std::any::type_name::<D>();
if metadata.distance_name != expected {
eprintln!(
"Warning: index recorded distance `{}` but you opened with `{}`",
metadata.distance_name, expected
);
}
Ok(Self {
dim: metadata.dim,
num_vectors: metadata.num_vectors,
max_degree: metadata.max_degree,
distance_name: metadata.distance_name,
medoid_id: metadata.medoid_id,
vectors_offset: metadata.vectors_offset,
adjacency_offset: metadata.adjacency_offset,
storage: Storage::Owned(bytes),
dist,
})
}
pub fn from_shared_bytes(bytes: Arc<[u8]>, dist: D) -> Result<Self, DiskAnnError> {
let metadata = Self::parse_metadata(&bytes)?;
let expected = std::any::type_name::<D>();
if metadata.distance_name != expected {
eprintln!(
"Warning: index recorded distance `{}` but you opened with `{}`",
metadata.distance_name, expected
);
}
Ok(Self {
dim: metadata.dim,
num_vectors: metadata.num_vectors,
max_degree: metadata.max_degree,
distance_name: metadata.distance_name,
medoid_id: metadata.medoid_id,
vectors_offset: metadata.vectors_offset,
adjacency_offset: metadata.adjacency_offset,
storage: Storage::Shared(bytes),
dist,
})
}
pub fn to_bytes(&self) -> Vec<u8> {
self.storage.to_vec()
}
fn parse_metadata(bytes: &[u8]) -> Result<Metadata, DiskAnnError> {
if bytes.len() < 8 {
return Err(DiskAnnError::IndexError("Buffer too small for metadata length".into()));
}
let first_u32 = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
let md_offset = if first_u32 == CORE_MAGIC {
if bytes.len() < 16 {
return Err(DiskAnnError::IndexError("Buffer too small for header".into()));
}
let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
if version != CORE_FORMAT_VERSION {
return Err(DiskAnnError::IndexError(format!(
"Unsupported core format version: {}", version
)));
}
8
} else {
0
};
if bytes.len() < md_offset + 8 {
return Err(DiskAnnError::IndexError("Buffer too small for metadata length".into()));
}
let md_len = u64::from_le_bytes(bytes[md_offset..md_offset + 8].try_into().unwrap()) as usize;
if bytes.len() < md_offset + 8 + md_len {
return Err(DiskAnnError::IndexError("Buffer too small for metadata".into()));
}
let metadata: Metadata = bincode::deserialize(&bytes[md_offset + 8..md_offset + 8 + md_len])?;
Ok(metadata)
}
pub fn search_with_dists(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
assert_eq!(
query.len(),
self.dim,
"Query dim {} != index dim {}",
query.len(),
self.dim
);
beam_search(
&[self.medoid_id],
beam_width,
k,
|id| self.distance_to(query, id as usize),
|id| self.get_neighbors(id).iter().copied().filter(|&nb| nb != PAD_U32).collect(),
|_| true,
BeamSearchConfig::default(),
)
}
pub fn search(&self, query: &[f32], k: usize, beam_width: usize) -> Vec<u32> {
self.search_with_dists(query, k, beam_width)
.into_iter()
.map(|(id, _dist)| id)
.collect()
}
pub(crate) fn get_neighbors(&self, node_id: u32) -> &[u32] {
let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
let start = offset as usize;
let end = start + (self.max_degree * 4);
let bytes = &self.storage[start..end];
bytemuck::cast_slice(bytes)
}
pub(crate) fn distance_to(&self, query: &[f32], idx: usize) -> f32 {
let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
let start = offset as usize;
let end = start + (self.dim * 4);
let bytes = &self.storage[start..end];
let vector: &[f32] = bytemuck::cast_slice(bytes);
self.dist.eval(query, vector)
}
pub fn get_vector(&self, idx: usize) -> Vec<f32> {
let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * 4);
let start = offset as usize;
let end = start + (self.dim * 4);
let bytes = &self.storage[start..end];
let vector: &[f32] = bytemuck::cast_slice(bytes);
vector.to_vec()
}
}
fn calculate_medoid<D: Distance<f32> + Copy + Sync>(vectors: &[Vec<f32>], dist: D) -> usize {
let dim = vectors[0].len();
let mut centroid = vec![0.0f32; dim];
for v in vectors {
for (i, &val) in v.iter().enumerate() {
centroid[i] += val;
}
}
for val in &mut centroid {
*val /= vectors.len() as f32;
}
let (best_idx, _best_dist) = vectors
.par_iter()
.enumerate()
.map(|(idx, v)| (idx, dist.eval(¢roid, v)))
.reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
best_idx
}
fn build_vamana_graph<D: Distance<f32> + Copy + Sync>(
vectors: &[Vec<f32>],
max_degree: usize,
build_beam_width: usize,
alpha: f32,
dist: D,
medoid_id: u32,
) -> Vec<Vec<u32>> {
let n = vectors.len();
let mut graph = vec![Vec::<u32>::new(); n];
{
let mut rng = thread_rng();
for i in 0..n {
let mut s = HashSet::new();
let target = (max_degree / 2).max(2).min(n.saturating_sub(1));
while s.len() < target {
let nb = rng.gen_range(0..n);
if nb != i {
s.insert(nb as u32);
}
}
graph[i] = s.into_iter().collect();
}
}
const PASSES: usize = 2;
const EXTRA_SEEDS: usize = 2;
let mut rng = thread_rng();
for _pass in 0..PASSES {
let mut order: Vec<usize> = (0..n).collect();
order.shuffle(&mut rng);
let snapshot = &graph;
let new_graph: Vec<Vec<u32>> = order
.par_iter()
.map(|&u| {
let mut candidates: Vec<(u32, f32)> =
Vec::with_capacity(build_beam_width * (2 + EXTRA_SEEDS));
for &nb in &snapshot[u] {
let d = dist.eval(&vectors[u], &vectors[nb as usize]);
candidates.push((nb, d));
}
let mut seeds = Vec::with_capacity(1 + EXTRA_SEEDS);
seeds.push(medoid_id as usize);
let mut trng = thread_rng();
for _ in 0..EXTRA_SEEDS {
seeds.push(trng.gen_range(0..n));
}
for start in seeds {
let mut part = greedy_search(
&vectors[u],
vectors,
snapshot,
start,
build_beam_width,
dist,
);
candidates.append(&mut part);
}
candidates.sort_by(|a, b| a.0.cmp(&b.0));
candidates.dedup_by(|a, b| {
if a.0 == b.0 {
if a.1 < b.1 {
*b = *a;
}
true
} else {
false
}
});
prune_neighbors(u, &candidates, vectors, max_degree, alpha, dist)
})
.collect();
let mut pos_of = vec![0usize; n];
for (pos, &u) in order.iter().enumerate() {
pos_of[u] = pos;
}
let (incoming_flat, incoming_off) = build_incoming_csr(&order, &new_graph, n);
graph = (0..n)
.into_par_iter()
.map(|u| {
let ng = &new_graph[pos_of[u]]; let inc = &incoming_flat[incoming_off[u]..incoming_off[u + 1]];
let mut pool_ids: Vec<u32> = Vec::with_capacity(ng.len() + inc.len());
pool_ids.extend_from_slice(ng);
pool_ids.extend_from_slice(inc);
pool_ids.sort_unstable();
pool_ids.dedup();
let pool: Vec<(u32, f32)> = pool_ids
.into_iter()
.filter(|&id| id as usize != u)
.map(|id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
.collect();
prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
})
.collect();
}
graph
.into_par_iter()
.enumerate()
.map(|(u, neigh)| {
if neigh.len() <= max_degree {
return neigh;
}
let pool: Vec<(u32, f32)> = neigh
.iter()
.map(|&id| (id, dist.eval(&vectors[u], &vectors[id as usize])))
.collect();
prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
})
.collect()
}
fn greedy_search<D: Distance<f32> + Copy>(
query: &[f32],
vectors: &[Vec<f32>],
graph: &[Vec<u32>],
start_id: usize,
beam_width: usize,
dist: D,
) -> Vec<(u32, f32)> {
let mut visited = HashSet::new();
let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new(); let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
let start_dist = dist.eval(query, &vectors[start_id]);
let start = Candidate {
dist: start_dist,
id: start_id as u32,
};
frontier.push(Reverse(start));
w.push(start);
visited.insert(start_id as u32);
while let Some(Reverse(best)) = frontier.peek().copied() {
if w.len() >= beam_width {
if let Some(worst) = w.peek() {
if best.dist >= worst.dist {
break;
}
}
}
let Reverse(cur) = frontier.pop().unwrap();
for &nb in &graph[cur.id as usize] {
if !visited.insert(nb) {
continue;
}
let d = dist.eval(query, &vectors[nb as usize]);
let cand = Candidate { dist: d, id: nb };
if w.len() < beam_width {
w.push(cand);
frontier.push(Reverse(cand));
} else if d < w.peek().unwrap().dist {
w.pop();
w.push(cand);
frontier.push(Reverse(cand));
}
}
}
let mut v = w.into_vec();
v.sort_by(|a, b| a.dist.partial_cmp(&b.dist).unwrap());
v.into_iter().map(|c| (c.id, c.dist)).collect()
}
fn prune_neighbors<D: Distance<f32> + Copy>(
node_id: usize,
candidates: &[(u32, f32)],
vectors: &[Vec<f32>],
max_degree: usize,
alpha: f32,
dist: D,
) -> Vec<u32> {
if candidates.is_empty() {
return Vec::new();
}
let mut sorted = candidates.to_vec();
sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let mut pruned = Vec::<u32>::new();
for &(cand_id, cand_dist) in &sorted {
if cand_id as usize == node_id {
continue;
}
let mut ok = true;
for &sel in &pruned {
let d = dist.eval(&vectors[cand_id as usize], &vectors[sel as usize]);
if d < alpha * cand_dist {
ok = false;
break;
}
}
if ok {
pruned.push(cand_id);
if pruned.len() >= max_degree {
break;
}
}
}
for &(cand_id, _) in &sorted {
if pruned.len() >= max_degree {
break;
}
if cand_id as usize == node_id {
continue;
}
if !pruned.contains(&cand_id) {
pruned.push(cand_id);
}
}
pruned
}
fn build_incoming_csr(order: &[usize], new_graph: &[Vec<u32>], n: usize) -> (Vec<u32>, Vec<usize>) {
let mut indeg = vec![0usize; n];
for (pos, _u) in order.iter().enumerate() {
for &v in &new_graph[pos] {
indeg[v as usize] += 1;
}
}
let mut off = vec![0usize; n + 1];
for i in 0..n {
off[i + 1] = off[i] + indeg[i];
}
let mut cur = off.clone();
let mut incoming_flat = vec![0u32; off[n]];
for (pos, &u) in order.iter().enumerate() {
for &v in &new_graph[pos] {
let idx = cur[v as usize];
incoming_flat[idx] = u as u32;
cur[v as usize] += 1;
}
}
(incoming_flat, off)
}
#[cfg(test)]
mod tests {
use super::*;
use anndists::dist::{DistCosine, DistL2};
use rand::Rng;
use std::fs;
fn euclid(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
#[test]
fn test_small_index_l2() {
let path = "test_small_l2.db";
let _ = fs::remove_file(path);
let vectors = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![0.5, 0.5],
];
let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
let q = vec![0.1, 0.1];
let nns = index.search(&q, 3, 8);
assert_eq!(nns.len(), 3);
let v = index.get_vector(nns[0] as usize);
assert!(euclid(&q, &v) < 1.0);
let _ = fs::remove_file(path);
}
#[test]
fn test_cosine() {
let path = "test_cosine.db";
let _ = fs::remove_file(path);
let vectors = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
vec![1.0, 1.0, 0.0],
vec![1.0, 0.0, 1.0],
];
let index =
DiskANN::<DistCosine>::build_index_default(&vectors, DistCosine {}, path).unwrap();
let q = vec![2.0, 0.0, 0.0]; let nns = index.search(&q, 2, 8);
assert_eq!(nns.len(), 2);
let v = index.get_vector(nns[0] as usize);
let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
let cos = dot / (n1 * n2);
assert!(cos > 0.7);
let _ = fs::remove_file(path);
}
#[test]
fn test_persistence_and_open() {
let path = "test_persist.db";
let _ = fs::remove_file(path);
let vectors = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
{
let _idx = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
}
let idx2 = DiskANN::<DistL2>::open_index_default_metric(path).unwrap();
assert_eq!(idx2.num_vectors, 4);
assert_eq!(idx2.dim, 2);
let q = vec![0.9, 0.9];
let res = idx2.search(&q, 2, 8);
assert_eq!(res[0], 3);
let _ = fs::remove_file(path);
}
#[test]
fn test_grid_connectivity() {
let path = "test_grid.db";
let _ = fs::remove_file(path);
let mut vectors = Vec::new();
for i in 0..5 {
for j in 0..5 {
vectors.push(vec![i as f32, j as f32]);
}
}
let index = DiskANN::<DistL2>::build_index_with_params(
&vectors,
DistL2 {},
path,
DiskAnnParams {
max_degree: 4,
build_beam_width: 64,
alpha: 1.5,
},
)
.unwrap();
for target in 0..vectors.len() {
let q = &vectors[target];
let nns = index.search(q, 10, 32);
if !nns.contains(&(target as u32)) {
let v = index.get_vector(nns[0] as usize);
assert!(euclid(q, &v) < 2.0);
}
for &nb in nns.iter().take(5) {
let v = index.get_vector(nb as usize);
assert!(euclid(q, &v) < 5.0);
}
}
let _ = fs::remove_file(path);
}
#[test]
fn test_medium_random() {
let path = "test_medium.db";
let _ = fs::remove_file(path);
let n = 200usize;
let d = 32usize;
let mut rng = rand::thread_rng();
let vectors: Vec<Vec<f32>> = (0..n)
.map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
.collect();
let index = DiskANN::<DistL2>::build_index_with_params(
&vectors,
DistL2 {},
path,
DiskAnnParams {
max_degree: 32,
build_beam_width: 128,
alpha: 1.2,
},
)
.unwrap();
let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
let res = index.search(&q, 10, 64);
assert_eq!(res.len(), 10);
let dists: Vec<f32> = res
.iter()
.map(|&id| {
let v = index.get_vector(id as usize);
euclid(&q, &v)
})
.collect();
let mut sorted = dists.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert_eq!(dists, sorted);
let _ = fs::remove_file(path);
}
#[test]
fn test_to_bytes_from_bytes_round_trip() {
let path = "test_bytes_rt.db";
let _ = fs::remove_file(path);
let vectors = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![0.5, 0.5],
];
let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
let bytes = index.to_bytes();
let index2 = DiskANN::<DistL2>::from_bytes(bytes, DistL2 {}).unwrap();
assert_eq!(index2.num_vectors, 5);
assert_eq!(index2.dim, 2);
let q = vec![0.9, 0.9];
let res1 = index.search(&q, 3, 8);
let res2 = index2.search(&q, 3, 8);
assert_eq!(res1, res2);
let _ = fs::remove_file(path);
}
#[test]
fn test_from_shared_bytes() {
let path = "test_shared_bytes.db";
let _ = fs::remove_file(path);
let vectors = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
let bytes = index.to_bytes();
let shared: std::sync::Arc<[u8]> = bytes.into();
let index2 = DiskANN::<DistL2>::from_shared_bytes(shared, DistL2 {}).unwrap();
assert_eq!(index2.num_vectors, 4);
assert_eq!(index2.dim, 2);
let q = vec![0.9, 0.9];
let res = index2.search(&q, 2, 8);
assert_eq!(res[0], 3);
let _ = fs::remove_file(path);
}
#[test]
fn test_candidate_ordering() {
use std::cmp::Reverse;
use std::collections::BinaryHeap;
let a = Candidate { dist: 1.0, id: 0 };
let b = Candidate { dist: 2.0, id: 1 };
let c = Candidate { dist: 0.5, id: 2 };
assert!(a < b);
assert!(c < a);
let mut min_heap: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
min_heap.push(Reverse(a));
min_heap.push(Reverse(b));
min_heap.push(Reverse(c));
assert_eq!(min_heap.pop().unwrap().0.id, 2); assert_eq!(min_heap.pop().unwrap().0.id, 0); assert_eq!(min_heap.pop().unwrap().0.id, 1);
let mut max_heap: BinaryHeap<Candidate> = BinaryHeap::new();
max_heap.push(a);
max_heap.push(b);
max_heap.push(c);
assert_eq!(max_heap.peek().unwrap().id, 1); }
#[test]
fn test_beam_search_small_graph() {
let positions: Vec<[f32; 2]> = vec![
[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [0.0, 2.0], [2.0, 1.0], ];
let neighbors: Vec<Vec<u32>> = vec![
vec![1, 3], vec![0, 2], vec![1, 4], vec![0, 4], vec![2, 3], ];
let query = [2.1f32, 0.9];
let results = beam_search(
&[0], 5,
3,
|id| {
let p = &positions[id as usize];
((query[0] - p[0]).powi(2) + (query[1] - p[1]).powi(2)).sqrt()
},
|id| neighbors[id as usize].clone(),
|_| true,
BeamSearchConfig::default(),
);
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, 4);
assert_eq!(results[1].0, 2);
assert!(results[0].1 <= results[1].1);
assert!(results[1].1 <= results[2].1);
}
#[test]
fn test_beam_search_with_filter() {
let positions: Vec<[f32; 2]> = vec![
[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [0.0, 2.0], [2.0, 1.0],
];
let neighbors: Vec<Vec<u32>> = vec![
vec![1, 3], vec![0, 2], vec![1, 4], vec![0, 4], vec![2, 3],
];
let query = [2.1f32, 0.9];
let results = beam_search(
&[0],
5,
3,
|id| {
let p = &positions[id as usize];
((query[0] - p[0]).powi(2) + (query[1] - p[1]).powi(2)).sqrt()
},
|id| neighbors[id as usize].clone(),
|id| id % 2 == 1, BeamSearchConfig {
expanded_beam: Some(10),
max_iterations: Some(20),
early_term_factor: Some(1.5),
},
);
for (id, _) in &results {
assert!(id % 2 == 1, "Expected only odd IDs, got {}", id);
}
let ids: HashSet<u32> = results.iter().map(|(id, _)| *id).collect();
assert!(ids.contains(&1));
assert!(ids.contains(&3));
}
#[test]
fn test_prune_neighbors_alpha() {
let vectors = vec![
vec![0.0, 0.0], vec![1.0, 0.0], vec![1.2, 0.0], vec![0.0, 2.0], ];
let candidates: Vec<(u32, f32)> = vec![
(1, DistL2 {}.eval(&vectors[0], &vectors[1])),
(2, DistL2 {}.eval(&vectors[0], &vectors[2])),
(3, DistL2 {}.eval(&vectors[0], &vectors[3])),
];
let pruned = prune_neighbors(0, &candidates, &vectors, 3, 1.0, DistL2 {});
assert!(pruned.contains(&1));
assert!(pruned.contains(&3));
}
#[test]
fn test_prune_neighbors_max_degree() {
let vectors = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![2.0, 0.0],
vec![0.0, 2.0],
];
let candidates: Vec<(u32, f32)> = (1..6)
.map(|i| (i as u32, DistL2 {}.eval(&vectors[0], &vectors[i])))
.collect();
let pruned = prune_neighbors(0, &candidates, &vectors, 2, 1.2, DistL2 {});
assert_eq!(pruned.len(), 2);
assert!(!pruned.is_empty());
let pruned = prune_neighbors(0, &candidates, &vectors, 5, 1.2, DistL2 {});
assert_eq!(pruned.len(), 5);
let pruned = prune_neighbors(0, &candidates, &vectors, 1, 1.2, DistL2 {});
assert_eq!(pruned.len(), 1);
}
#[test]
fn test_core_magic_number_in_bytes() {
let path = "test_magic.db";
let _ = fs::remove_file(path);
let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let index = DiskANN::<DistL2>::build_index_default(&vectors, DistL2 {}, path).unwrap();
let bytes = index.to_bytes();
let magic = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
assert_eq!(magic, CORE_MAGIC, "Expected magic 0x{:08X}, got 0x{:08X}", CORE_MAGIC, magic);
let version = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
assert_eq!(version, CORE_FORMAT_VERSION);
let _ = fs::remove_file(path);
}
}