use anndists::prelude::Distance;
use memmap2::Mmap;
use rand::{prelude::*, thread_rng};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::cmp::{Ordering, Reverse};
use std::collections::{BinaryHeap, HashSet};
use std::fs::OpenOptions;
use std::io::{Read, Seek, SeekFrom, Write};
use std::marker::PhantomData;
use thiserror::Error;
const PAD_U32: u32 = u32::MAX;
pub const DISKANN_DEFAULT_MAX_DEGREE: usize = 64;
pub const DISKANN_DEFAULT_BUILD_BEAM: usize = 128;
pub const DISKANN_DEFAULT_ALPHA: f32 = 1.2;
pub const DISKANN_DEFAULT_PASSES: usize = 1;
pub const DISKANN_DEFAULT_EXTRA_SEEDS: usize = 1;
const GRAPH_SLACK_FACTOR: f32 = 1.3;
const MICRO_BATCH_CHUNK_SIZE: usize = 256;
#[derive(Clone, Copy, Debug)]
pub struct DiskAnnParams {
pub max_degree: usize,
pub build_beam_width: usize,
pub alpha: f32,
pub passes: usize,
pub extra_seeds: usize,
}
impl Default for DiskAnnParams {
fn default() -> Self {
Self {
max_degree: DISKANN_DEFAULT_MAX_DEGREE,
build_beam_width: DISKANN_DEFAULT_BUILD_BEAM,
alpha: DISKANN_DEFAULT_ALPHA,
passes: DISKANN_DEFAULT_PASSES,
extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
}
}
}
#[derive(Debug, Error)]
pub enum DiskAnnError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Bincode(#[from] bincode::Error),
#[error("Index error: {0}")]
IndexError(String),
}
#[derive(Serialize, Deserialize, Debug)]
struct Metadata {
dim: usize,
num_vectors: usize,
max_degree: usize,
medoid_id: u32,
vectors_offset: u64,
adjacency_offset: u64,
elem_size: u8,
distance_name: String,
}
#[derive(Clone, Copy, Debug)]
struct Candidate {
dist: f32,
id: u32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.dist.to_bits() == other.dist.to_bits()
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(
self.dist
.total_cmp(&other.dist)
.then_with(|| self.id.cmp(&other.id)),
)
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap_or(Ordering::Equal)
}
}
#[derive(Clone, Debug)]
struct FlatVectors<T> {
data: Vec<T>,
dim: usize,
n: usize,
}
impl<T: Copy> FlatVectors<T> {
fn from_vecs(vectors: &[Vec<T>]) -> Result<Self, DiskAnnError> {
if vectors.is_empty() {
return Err(DiskAnnError::IndexError("No vectors provided".to_string()));
}
let dim = vectors[0].len();
for (i, v) in vectors.iter().enumerate() {
if v.len() != dim {
return Err(DiskAnnError::IndexError(format!(
"Vector {} has dimension {} but expected {}",
i,
v.len(),
dim
)));
}
}
let n = vectors.len();
let mut data = Vec::with_capacity(n * dim);
for v in vectors {
data.extend_from_slice(v);
}
Ok(Self { data, dim, n })
}
#[inline]
fn row(&self, idx: usize) -> &[T] {
let start = idx * self.dim;
let end = start + self.dim;
&self.data[start..end]
}
}
#[derive(Default, Debug)]
struct OrderedBeam {
items: Vec<Candidate>,
}
impl OrderedBeam {
#[inline]
fn clear(&mut self) {
self.items.clear();
}
#[inline]
fn len(&self) -> usize {
self.items.len()
}
#[inline]
fn is_empty(&self) -> bool {
self.items.is_empty()
}
#[inline]
fn best(&self) -> Option<Candidate> {
self.items.last().copied()
}
#[inline]
fn worst(&self) -> Option<Candidate> {
self.items.first().copied()
}
#[inline]
fn pop_best(&mut self) -> Option<Candidate> {
self.items.pop()
}
#[inline]
fn reserve(&mut self, cap: usize) {
if self.items.capacity() < cap {
self.items.reserve(cap - self.items.capacity());
}
}
#[inline]
fn insert_unbounded(&mut self, cand: Candidate) {
let pos = self.items.partition_point(|x| {
x.dist > cand.dist || (x.dist.to_bits() == cand.dist.to_bits() && x.id > cand.id)
});
self.items.insert(pos, cand);
}
#[inline]
fn insert_capped(&mut self, cand: Candidate, cap: usize) {
if cap == 0 {
return;
}
if self.items.len() < cap {
self.insert_unbounded(cand);
return;
}
let worst = self.items[0];
if cand.dist >= worst.dist {
return;
}
self.insert_unbounded(cand);
if self.items.len() > cap {
self.items.remove(0);
}
}
}
#[derive(Debug)]
struct BuildScratch {
marks: Vec<u32>,
epoch: u32,
visited_ids: Vec<u32>,
visited_dists: Vec<f32>,
frontier: OrderedBeam,
work: OrderedBeam,
seeds: Vec<usize>,
candidates: Vec<(u32, f32)>,
}
impl BuildScratch {
fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
Self {
marks: vec![0u32; n],
epoch: 1,
visited_ids: Vec::with_capacity(beam_width * 4),
visited_dists: Vec::with_capacity(beam_width * 4),
frontier: {
let mut b = OrderedBeam::default();
b.reserve(beam_width * 2);
b
},
work: {
let mut b = OrderedBeam::default();
b.reserve(beam_width * 2);
b
},
seeds: Vec::with_capacity(1 + extra_seeds),
candidates: Vec::with_capacity(beam_width * (4 + extra_seeds) + max_degree * 2),
}
}
#[inline]
fn reset_search(&mut self) {
self.epoch = self.epoch.wrapping_add(1);
if self.epoch == 0 {
self.marks.fill(0);
self.epoch = 1;
}
self.visited_ids.clear();
self.visited_dists.clear();
self.frontier.clear();
self.work.clear();
}
#[inline]
fn is_marked(&self, idx: usize) -> bool {
self.marks[idx] == self.epoch
}
#[inline]
fn mark_with_dist(&mut self, idx: usize, dist: f32) {
self.marks[idx] = self.epoch;
self.visited_ids.push(idx as u32);
self.visited_dists.push(dist);
}
}
#[derive(Debug)]
struct IncrementalInsertScratch {
build: BuildScratch,
}
impl IncrementalInsertScratch {
fn new(n: usize, beam_width: usize, max_degree: usize, extra_seeds: usize) -> Self {
Self {
build: BuildScratch::new(n, beam_width, max_degree, extra_seeds),
}
}
}
pub struct DiskANN<T, D>
where
T: bytemuck::Pod + Copy + Send + Sync + 'static,
D: Distance<T> + Send + Sync + Copy + Clone + 'static,
{
pub dim: usize,
pub num_vectors: usize,
pub max_degree: usize,
pub distance_name: String,
medoid_id: u32,
vectors_offset: u64,
adjacency_offset: u64,
mmap: Mmap,
dist: D,
_phantom: PhantomData<T>,
}
impl<T, D> DiskANN<T, D>
where
T: bytemuck::Pod + Copy + Send + Sync + 'static,
D: Distance<T> + Send + Sync + Copy + Clone + 'static,
{
pub fn build_index_default(
vectors: &[Vec<T>],
dist: D,
file_path: &str,
) -> Result<Self, DiskAnnError> {
Self::build_index(
vectors,
DISKANN_DEFAULT_MAX_DEGREE,
DISKANN_DEFAULT_BUILD_BEAM,
DISKANN_DEFAULT_ALPHA,
DISKANN_DEFAULT_PASSES,
DISKANN_DEFAULT_EXTRA_SEEDS,
dist,
file_path,
)
}
pub fn build_index_with_params(
vectors: &[Vec<T>],
dist: D,
file_path: &str,
p: DiskAnnParams,
) -> Result<Self, DiskAnnError> {
Self::build_index(
vectors,
p.max_degree,
p.build_beam_width,
p.alpha,
p.passes,
p.extra_seeds,
dist,
file_path,
)
}
pub fn open_index_with(path: &str, dist: D) -> Result<Self, DiskAnnError> {
let mut file = OpenOptions::new().read(true).write(false).open(path)?;
let mut buf8 = [0u8; 8];
file.seek(SeekFrom::Start(0))?;
file.read_exact(&mut buf8)?;
let md_len = u64::from_le_bytes(buf8);
let mut md_bytes = vec![0u8; md_len as usize];
file.read_exact(&mut md_bytes)?;
let metadata: Metadata = bincode::deserialize(&md_bytes)?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
let want = std::mem::size_of::<T>() as u8;
if metadata.elem_size != want {
return Err(DiskAnnError::IndexError(format!(
"element size mismatch: file has {}B, T is {}B",
metadata.elem_size, want
)));
}
let expected = std::any::type_name::<D>();
if metadata.distance_name != expected {
eprintln!(
"Warning: index recorded distance `{}` but you opened with `{}`",
metadata.distance_name, expected
);
}
Ok(Self {
dim: metadata.dim,
num_vectors: metadata.num_vectors,
max_degree: metadata.max_degree,
distance_name: metadata.distance_name,
medoid_id: metadata.medoid_id,
vectors_offset: metadata.vectors_offset,
adjacency_offset: metadata.adjacency_offset,
mmap,
dist,
_phantom: PhantomData,
})
}
}
impl<T, D> DiskANN<T, D>
where
T: bytemuck::Pod + Copy + Send + Sync + 'static,
D: Distance<T> + Default + Send + Sync + Copy + Clone + 'static,
{
pub fn build_index_default_metric(
vectors: &[Vec<T>],
file_path: &str,
) -> Result<Self, DiskAnnError> {
Self::build_index_default(vectors, D::default(), file_path)
}
pub fn open_index_default_metric(path: &str) -> Result<Self, DiskAnnError> {
Self::open_index_with(path, D::default())
}
}
impl<T, D> DiskANN<T, D>
where
T: bytemuck::Pod + Copy + Send + Sync + 'static,
D: Distance<T> + Send + Sync + Copy + Clone + 'static,
{
pub fn build_index(
vectors: &[Vec<T>],
max_degree: usize,
build_beam_width: usize,
alpha: f32,
passes: usize,
extra_seeds: usize,
dist: D,
file_path: &str,
) -> Result<Self, DiskAnnError> {
let flat = FlatVectors::from_vecs(vectors)?;
let num_vectors = flat.n;
let dim = flat.dim;
let mut file = OpenOptions::new()
.create(true)
.write(true)
.read(true)
.truncate(true)
.open(file_path)?;
let vectors_offset = 1024 * 1024;
assert_eq!(
(vectors_offset as usize) % std::mem::align_of::<T>(),
0,
"vectors_offset must be aligned for T"
);
let elem_sz = std::mem::size_of::<T>() as u64;
let total_vector_bytes = (num_vectors as u64) * (dim as u64) * elem_sz;
file.seek(SeekFrom::Start(vectors_offset as u64))?;
file.write_all(bytemuck::cast_slice::<T, u8>(&flat.data))?;
let medoid_id = calculate_medoid(&flat, dist);
let adjacency_offset = vectors_offset as u64 + total_vector_bytes;
let graph = build_vamana_graph(
&flat,
max_degree,
build_beam_width,
alpha,
passes,
extra_seeds,
dist,
medoid_id as u32,
);
file.seek(SeekFrom::Start(adjacency_offset))?;
for neighbors in &graph {
let mut padded = neighbors.clone();
padded.resize(max_degree, PAD_U32);
let bytes = bytemuck::cast_slice::<u32, u8>(&padded);
file.write_all(bytes)?;
}
let metadata = Metadata {
dim,
num_vectors,
max_degree,
medoid_id: medoid_id as u32,
vectors_offset: vectors_offset as u64,
adjacency_offset,
elem_size: std::mem::size_of::<T>() as u8,
distance_name: std::any::type_name::<D>().to_string(),
};
let md_bytes = bincode::serialize(&metadata)?;
file.seek(SeekFrom::Start(0))?;
let md_len = md_bytes.len() as u64;
file.write_all(&md_len.to_le_bytes())?;
file.write_all(&md_bytes)?;
file.sync_all()?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
Ok(Self {
dim,
num_vectors,
max_degree,
distance_name: metadata.distance_name,
medoid_id: metadata.medoid_id,
vectors_offset: metadata.vectors_offset,
adjacency_offset: metadata.adjacency_offset,
mmap,
dist,
_phantom: PhantomData,
})
}
pub fn search_with_dists(&self, query: &[T], k: usize, beam_width: usize) -> Vec<(u32, f32)> {
assert_eq!(
query.len(),
self.dim,
"Query dim {} != index dim {}",
query.len(),
self.dim
);
let mut visited = HashSet::new();
let mut frontier: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
let mut w: BinaryHeap<Candidate> = BinaryHeap::new();
let start_dist = self.distance_to(query, self.medoid_id as usize);
let start = Candidate {
dist: start_dist,
id: self.medoid_id,
};
frontier.push(Reverse(start));
w.push(start);
visited.insert(self.medoid_id);
while let Some(Reverse(best)) = frontier.peek().copied() {
if w.len() >= beam_width {
if let Some(worst) = w.peek() {
if best.dist >= worst.dist {
break;
}
}
}
let Reverse(current) = frontier.pop().unwrap();
for &nb in self.get_neighbors(current.id) {
if nb == PAD_U32 {
continue;
}
if !visited.insert(nb) {
continue;
}
let d = self.distance_to(query, nb as usize);
let cand = Candidate { dist: d, id: nb };
if w.len() < beam_width {
w.push(cand);
frontier.push(Reverse(cand));
} else if d < w.peek().unwrap().dist {
w.pop();
w.push(cand);
frontier.push(Reverse(cand));
}
}
}
let mut results: Vec<_> = w.into_vec();
results.sort_by(|a, b| a.dist.total_cmp(&b.dist));
results.truncate(k);
results.into_iter().map(|c| (c.id, c.dist)).collect()
}
pub fn search(&self, query: &[T], k: usize, beam_width: usize) -> Vec<u32> {
self.search_with_dists(query, k, beam_width)
.into_iter()
.map(|(id, _dist)| id)
.collect()
}
fn get_neighbors(&self, node_id: u32) -> &[u32] {
let offset = self.adjacency_offset + (node_id as u64 * self.max_degree as u64 * 4);
let start = offset as usize;
let end = start + (self.max_degree * 4);
let bytes = &self.mmap[start..end];
bytemuck::cast_slice(bytes)
}
fn distance_to(&self, query: &[T], idx: usize) -> f32 {
let elem_sz = std::mem::size_of::<T>();
let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
let start = offset as usize;
let end = start + (self.dim * elem_sz);
let bytes = &self.mmap[start..end];
let vector: &[T] = bytemuck::cast_slice(bytes);
self.dist.eval(query, vector)
}
pub fn get_vector(&self, idx: usize) -> Vec<T> {
let elem_sz = std::mem::size_of::<T>();
let offset = self.vectors_offset + (idx as u64 * self.dim as u64 * elem_sz as u64);
let start = offset as usize;
let end = start + (self.dim * elem_sz);
let bytes = &self.mmap[start..end];
let vector: &[T] = bytemuck::cast_slice(bytes);
vector.to_vec()
}
}
fn calculate_medoid<T, D>(vectors: &FlatVectors<T>, dist: D) -> usize
where
T: bytemuck::Pod + Copy + Send + Sync,
D: Distance<T> + Copy + Sync,
{
let n = vectors.n;
let k = 8.min(n);
let mut rng = thread_rng();
let pivots: Vec<usize> = (0..k).map(|_| rng.gen_range(0..n)).collect();
let (best_idx, _best_score) = (0..n)
.into_par_iter()
.map(|i| {
let vi = vectors.row(i);
let score: f32 = pivots.iter().map(|&p| dist.eval(vi, vectors.row(p))).sum();
(i, score)
})
.reduce(|| (0usize, f32::MAX), |a, b| if a.1 <= b.1 { a } else { b });
best_idx
}
fn dedup_keep_best_by_id_in_place(cands: &mut Vec<(u32, f32)>) {
if cands.is_empty() {
return;
}
cands.sort_by(|a, b| {
a.0.cmp(&b.0)
.then_with(|| a.1.total_cmp(&b.1))
});
let mut write = 0usize;
for read in 0..cands.len() {
if write == 0 || cands[read].0 != cands[write - 1].0 {
cands[write] = cands[read];
write += 1;
}
}
cands.truncate(write);
}
fn merge_chunk_updates_into_graph_reuse<T, D>(
graph: &mut [Vec<u32>],
chunk_nodes: &[usize],
chunk_pruned: &[Vec<u32>],
vectors: &FlatVectors<T>,
max_degree: usize,
slack_limit: usize,
alpha: f32,
dist: D,
merge: &mut MergeScratch,
) where
T: bytemuck::Pod + Copy + Send + Sync,
D: Distance<T> + Copy + Sync,
{
merge.reset();
for &u in chunk_nodes {
merge.mark_affected(u);
}
let mut total_incoming = 0usize;
for (local_idx, &u) in chunk_nodes.iter().enumerate() {
for &dst in &chunk_pruned[local_idx] {
let dst_usize = dst as usize;
if dst_usize == u {
continue;
}
merge.mark_affected(dst_usize);
merge.incoming_counts[dst_usize] += 1;
total_incoming += 1;
}
}
merge.affected_nodes.sort_unstable();
let mut running = 0usize;
for &u in &merge.affected_nodes {
merge.incoming_offsets[u] = running;
running += merge.incoming_counts[u];
merge.incoming_offsets[u + 1] = running;
}
merge.incoming_flat.resize(total_incoming, PAD_U32);
for &u in &merge.affected_nodes {
merge.incoming_write[u] = merge.incoming_offsets[u];
}
for (local_idx, &u) in chunk_nodes.iter().enumerate() {
for &dst in &chunk_pruned[local_idx] {
let dst_usize = dst as usize;
if dst_usize == u {
continue;
}
let pos = merge.incoming_write[dst_usize];
merge.incoming_flat[pos] = u as u32;
merge.incoming_write[dst_usize] += 1;
}
}
for (local_idx, &u) in chunk_nodes.iter().enumerate() {
graph[u] = chunk_pruned[local_idx].clone();
}
let affected = merge.affected_nodes.clone();
let updated_pairs: Vec<(usize, Vec<u32>)> = affected
.into_par_iter()
.map(|u| {
let start = merge.incoming_offsets[u];
let end = merge.incoming_offsets[u + 1];
let mut ids: Vec<u32> = Vec::with_capacity(graph[u].len() + (end - start));
ids.extend_from_slice(&graph[u]);
if start < end {
ids.extend_from_slice(&merge.incoming_flat[start..end]);
}
ids.retain(|&id| id != PAD_U32 && id as usize != u);
ids.sort_unstable();
ids.dedup();
if ids.is_empty() {
return (u, Vec::new());
}
if ids.len() <= slack_limit {
return (u, ids);
}
let mut pool = Vec::<(u32, f32)>::with_capacity(ids.len());
for id in ids {
let d = dist.eval(vectors.row(u), vectors.row(id as usize));
pool.push((id, d));
}
let pruned = prune_neighbors(u, &pool, vectors, max_degree, alpha, dist);
(u, pruned)
})
.collect();
for (u, neigh) in updated_pairs {
graph[u] = neigh;
}
for &u in &merge.affected_nodes {
merge.incoming_counts[u] = 0;
merge.incoming_offsets[u + 1] = 0;
}
}
#[derive(Debug)]
struct MergeScratch {
incoming_counts: Vec<usize>,
incoming_offsets: Vec<usize>,
incoming_write: Vec<usize>,
incoming_flat: Vec<u32>,
affected_marks: Vec<u32>,
affected_epoch: u32,
affected_nodes: Vec<usize>,
}
impl MergeScratch {
fn new(n: usize) -> Self {
Self {
incoming_counts: vec![0usize; n],
incoming_offsets: vec![0usize; n + 1],
incoming_write: vec![0usize; n],
incoming_flat: Vec::new(),
affected_marks: vec![0u32; n],
affected_epoch: 1,
affected_nodes: Vec::new(),
}
}
#[inline]
fn reset(&mut self) {
self.affected_epoch = self.affected_epoch.wrapping_add(1);
if self.affected_epoch == 0 {
self.affected_marks.fill(0);
self.affected_epoch = 1;
}
self.affected_nodes.clear();
self.incoming_flat.clear();
}
#[inline]
fn mark_affected(&mut self, u: usize) {
if self.affected_marks[u] != self.affected_epoch {
self.affected_marks[u] = self.affected_epoch;
self.affected_nodes.push(u);
self.incoming_counts[u] = 0;
}
}
}
fn build_vamana_graph<T, D>(
vectors: &FlatVectors<T>,
max_degree: usize,
build_beam_width: usize,
alpha: f32,
passes: usize,
extra_seeds: usize,
dist: D,
medoid_id: u32,
) -> Vec<Vec<u32>>
where
T: bytemuck::Pod + Copy + Send + Sync,
D: Distance<T> + Copy + Sync,
{
let n = vectors.n;
let mut graph = vec![Vec::<u32>::new(); n];
{
let mut rng = thread_rng();
let target = max_degree.min(n.saturating_sub(1));
for i in 0..n {
let mut s = HashSet::with_capacity(target);
while s.len() < target {
let nb = rng.gen_range(0..n);
if nb != i {
s.insert(nb as u32);
}
}
graph[i] = s.into_iter().collect();
}
}
let passes = passes.max(1);
let mut rng = thread_rng();
let slack_limit = ((GRAPH_SLACK_FACTOR * max_degree as f32).ceil() as usize).max(max_degree);
let mut merge_scratch = MergeScratch::new(n);
for pass_idx in 0..passes {
let pass_alpha = if passes == 1 {
alpha
} else if pass_idx == 0 {
1.0
} else {
alpha
};
let mut order: Vec<usize> = (0..n).collect();
order.shuffle(&mut rng);
for chunk in order.chunks(MICRO_BATCH_CHUNK_SIZE) {
let snapshot = &graph;
let chunk_results: Vec<(usize, Vec<u32>)> = chunk
.par_iter()
.map_init(
|| IncrementalInsertScratch::new(n, build_beam_width, max_degree, extra_seeds),
|scratch, &u| {
let bs = &mut scratch.build;
bs.candidates.clear();
for &nb in &snapshot[u] {
let d = dist.eval(vectors.row(u), vectors.row(nb as usize));
bs.candidates.push((nb, d));
}
bs.seeds.clear();
bs.seeds.push(medoid_id as usize);
let mut local_rng = thread_rng();
while bs.seeds.len() < 1 + extra_seeds {
let s = local_rng.gen_range(0..n);
if !bs.seeds.contains(&s) {
bs.seeds.push(s);
}
}
let seeds_len = bs.seeds.len();
for si in 0..seeds_len {
let start = bs.seeds[si];
greedy_search_visited_collect(
vectors.row(u),
vectors,
snapshot,
start,
build_beam_width,
dist,
bs,
);
for i in 0..bs.visited_ids.len() {
bs.candidates.push((bs.visited_ids[i], bs.visited_dists[i]));
}
}
dedup_keep_best_by_id_in_place(&mut bs.candidates);
let pruned = prune_neighbors(
u,
&bs.candidates,
vectors,
max_degree,
pass_alpha,
dist,
);
(u, pruned)
},
)
.collect();
let mut chunk_nodes = Vec::<usize>::with_capacity(chunk_results.len());
let mut chunk_pruned = Vec::<Vec<u32>>::with_capacity(chunk_results.len());
for (u, pruned) in chunk_results {
chunk_nodes.push(u);
chunk_pruned.push(pruned);
}
merge_chunk_updates_into_graph_reuse(
&mut graph,
&chunk_nodes,
&chunk_pruned,
vectors,
max_degree,
slack_limit,
pass_alpha,
dist,
&mut merge_scratch,
);
}
}
graph
.into_par_iter()
.enumerate()
.map(|(u, neigh)| {
if neigh.len() <= max_degree {
return neigh;
}
let mut ids = neigh;
ids.sort_unstable();
ids.dedup();
let pool: Vec<(u32, f32)> = ids
.into_iter()
.filter(|&id| id as usize != u)
.map(|id| (id, dist.eval(vectors.row(u), vectors.row(id as usize))))
.collect();
prune_neighbors(u, &pool, vectors, max_degree, alpha, dist)
})
.collect()
}
fn greedy_search_visited_collect<T, D>(
query: &[T],
vectors: &FlatVectors<T>,
graph: &[Vec<u32>],
start_id: usize,
beam_width: usize,
dist: D,
scratch: &mut BuildScratch,
) where
T: bytemuck::Pod + Copy + Send + Sync,
D: Distance<T> + Copy,
{
scratch.reset_search();
let start_dist = dist.eval(query, vectors.row(start_id));
let start = Candidate {
dist: start_dist,
id: start_id as u32,
};
scratch.frontier.insert_unbounded(start);
scratch.work.insert_capped(start, beam_width);
scratch.mark_with_dist(start_id, start_dist);
while !scratch.frontier.is_empty() {
let best = scratch.frontier.best().unwrap();
if scratch.work.len() >= beam_width {
if let Some(worst) = scratch.work.worst() {
if best.dist >= worst.dist {
break;
}
}
}
let cur = scratch.frontier.pop_best().unwrap();
for &nb in &graph[cur.id as usize] {
let nb_usize = nb as usize;
if scratch.is_marked(nb_usize) {
continue;
}
let d = dist.eval(query, vectors.row(nb_usize));
scratch.mark_with_dist(nb_usize, d);
let cand = Candidate { dist: d, id: nb };
if scratch.work.len() < beam_width {
scratch.work.insert_unbounded(cand);
scratch.frontier.insert_unbounded(cand);
} else if let Some(worst) = scratch.work.worst() {
if d < worst.dist {
scratch.work.insert_capped(cand, beam_width);
scratch.frontier.insert_unbounded(cand);
}
}
}
}
}
fn prune_neighbors<T, D>(
node_id: usize,
candidates: &[(u32, f32)],
vectors: &FlatVectors<T>,
max_degree: usize,
alpha: f32,
dist: D,
) -> Vec<u32>
where
T: bytemuck::Pod + Copy + Send + Sync,
D: Distance<T> + Copy,
{
if candidates.is_empty() || max_degree == 0 {
return Vec::new();
}
let mut sorted = candidates.to_vec();
sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
let mut uniq = Vec::<(u32, f32)>::with_capacity(sorted.len());
let mut last_id: Option<u32> = None;
for &(cand_id, cand_dist) in &sorted {
if cand_id as usize == node_id {
continue;
}
if last_id == Some(cand_id) {
continue;
}
uniq.push((cand_id, cand_dist));
last_id = Some(cand_id);
}
if uniq.is_empty() {
return Vec::new();
}
let mut pruned = Vec::<u32>::with_capacity(max_degree);
for &(cand_id, cand_dist_to_node) in &uniq {
let mut occluded = false;
for &sel_id in &pruned {
let d_cand_sel = dist.eval(
vectors.row(cand_id as usize),
vectors.row(sel_id as usize),
);
if alpha * d_cand_sel <= cand_dist_to_node {
occluded = true;
break;
}
}
if !occluded {
pruned.push(cand_id);
if pruned.len() >= max_degree {
return pruned;
}
}
}
if pruned.len() < max_degree {
for &(cand_id, _) in &uniq {
if pruned.contains(&cand_id) {
continue;
}
pruned.push(cand_id);
if pruned.len() >= max_degree {
break;
}
}
}
pruned
}
#[cfg(test)]
mod tests {
use super::*;
use anndists::dist::{DistCosine, DistL2};
use rand::Rng;
use std::fs;
fn euclid(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
#[test]
fn test_small_index_l2() {
let path = "test_small_l2.db";
let _ = fs::remove_file(path);
let vectors = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![0.5, 0.5],
];
let index = DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
let q = vec![0.1, 0.1];
let nns = index.search(&q, 3, 8);
assert_eq!(nns.len(), 3);
let v = index.get_vector(nns[0] as usize);
assert!(euclid(&q, &v) < 1.0);
let _ = fs::remove_file(path);
}
#[test]
fn test_cosine() {
let path = "test_cosine.db";
let _ = fs::remove_file(path);
let vectors = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
vec![1.0, 1.0, 0.0],
vec![1.0, 0.0, 1.0],
];
let index =
DiskANN::<f32, DistCosine>::build_index_default(&vectors, DistCosine, path).unwrap();
let q = vec![2.0, 0.0, 0.0];
let nns = index.search(&q, 2, 8);
assert_eq!(nns.len(), 2);
let v = index.get_vector(nns[0] as usize);
let dot = v.iter().zip(&q).map(|(a, b)| a * b).sum::<f32>();
let n1 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let n2 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
let cos = dot / (n1 * n2);
assert!(cos > 0.7);
let _ = fs::remove_file(path);
}
#[test]
fn test_persistence_and_open() {
let path = "test_persist.db";
let _ = fs::remove_file(path);
let vectors = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
{
let _idx =
DiskANN::<f32, DistL2>::build_index_default(&vectors, DistL2, path).unwrap();
}
let idx2 = DiskANN::<f32, DistL2>::open_index_default_metric(path).unwrap();
assert_eq!(idx2.num_vectors, 4);
assert_eq!(idx2.dim, 2);
let q = vec![0.9, 0.9];
let res = idx2.search(&q, 2, 8);
assert_eq!(res[0], 3);
let _ = fs::remove_file(path);
}
#[test]
fn test_grid_connectivity() {
let path = "test_grid.db";
let _ = fs::remove_file(path);
let mut vectors = Vec::new();
for i in 0..5 {
for j in 0..5 {
vectors.push(vec![i as f32, j as f32]);
}
}
let index = DiskANN::<f32, DistL2>::build_index_with_params(
&vectors,
DistL2,
path,
DiskAnnParams {
max_degree: 4,
build_beam_width: 64,
alpha: 1.5,
passes: DISKANN_DEFAULT_PASSES,
extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
},
)
.unwrap();
for target in 0..vectors.len() {
let q = &vectors[target];
let nns = index.search(q, 10, 32);
if !nns.contains(&(target as u32)) {
let v = index.get_vector(nns[0] as usize);
assert!(euclid(q, &v) < 2.0);
}
for &nb in nns.iter().take(5) {
let v = index.get_vector(nb as usize);
assert!(euclid(q, &v) < 5.0);
}
}
let _ = fs::remove_file(path);
}
#[test]
fn test_medium_random() {
let path = "test_medium.db";
let _ = fs::remove_file(path);
let n = 200usize;
let d = 32usize;
let mut rng = rand::thread_rng();
let vectors: Vec<Vec<f32>> = (0..n)
.map(|_| (0..d).map(|_| rng.r#gen::<f32>()).collect())
.collect();
let index = DiskANN::<f32, DistL2>::build_index_with_params(
&vectors,
DistL2,
path,
DiskAnnParams {
max_degree: 32,
build_beam_width: 128,
alpha: 1.2,
passes: DISKANN_DEFAULT_PASSES,
extra_seeds: DISKANN_DEFAULT_EXTRA_SEEDS,
},
)
.unwrap();
let q: Vec<f32> = (0..d).map(|_| rng.r#gen::<f32>()).collect();
let res = index.search(&q, 10, 64);
assert_eq!(res.len(), 10);
let dists: Vec<f32> = res
.iter()
.map(|&id| {
let v = index.get_vector(id as usize);
euclid(&q, &v)
})
.collect();
let mut sorted = dists.clone();
sorted.sort_by(|a, b| a.total_cmp(b));
assert_eq!(dists, sorted);
let _ = fs::remove_file(path);
}
}