use std::collections::HashSet;
use std::path::Path;
use std::sync::Arc;
use rand::seq::SliceRandom;
use rand::Rng;
use smallvec::SmallVec;
use crate::RetrieveError;
pub struct DiskANNIndex {
dimension: usize,
params: DiskANNParams,
built: bool,
vectors: Vec<f32>,
num_vectors: usize,
doc_ids: Vec<u32>,
adj: Vec<SmallVec<[u32; 32]>>,
start_node: u32,
}
impl DiskANNIndex {
#[inline]
pub fn dimension(&self) -> usize {
self.dimension
}
#[inline]
pub fn num_vectors(&self) -> usize {
self.num_vectors
}
#[inline]
pub fn ef_search(&self) -> usize {
self.params.ef_search
}
#[inline]
pub fn size_bytes(&self) -> usize {
self.vectors.len() * std::mem::size_of::<f32>()
+ self
.adj
.iter()
.map(|n| n.len() * std::mem::size_of::<u32>())
.sum::<usize>()
}
pub fn save(&self, output_dir: &Path) -> Result<(), RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"cannot save unbuilt index".into(),
));
}
if !output_dir.exists() {
std::fs::create_dir_all(output_dir)?;
}
let vectors_path = output_dir.join("vectors.bin");
let mut vectors_file = std::fs::File::create(&vectors_path)?;
let vectors_bytes = unsafe {
std::slice::from_raw_parts(
self.vectors.as_ptr() as *const u8,
self.vectors.len() * std::mem::size_of::<f32>(),
)
};
use std::io::Write;
vectors_file.write_all(vectors_bytes)?;
let graph_path = output_dir.join("graph.index");
let mut graph_writer = super::disk_io::DiskGraphWriter::new(
&graph_path,
self.num_vectors,
self.params.m,
self.start_node,
)
.map_err(|e| {
RetrieveError::Io(Arc::new(std::io::Error::other(format!(
"failed to create graph writer: {}",
e
))))
})?;
for neighbors in &self.adj {
graph_writer.write_adjacency(neighbors).map_err(|e| {
RetrieveError::Io(Arc::new(std::io::Error::other(format!(
"failed to write adjacency: {}",
e
))))
})?;
}
graph_writer.flush().map_err(|e| {
RetrieveError::Io(Arc::new(std::io::Error::other(format!(
"failed to flush graph: {}",
e
))))
})?;
let doc_ids_path = output_dir.join("doc_ids.bin");
let mut doc_ids_file = std::fs::File::create(&doc_ids_path)?;
let doc_ids_bytes = unsafe {
std::slice::from_raw_parts(
self.doc_ids.as_ptr() as *const u8,
self.doc_ids.len() * std::mem::size_of::<u32>(),
)
};
doc_ids_file.write_all(doc_ids_bytes)?;
let metadata_path = output_dir.join("metadata.json");
let metadata = serde_json::json!({
"dimension": self.dimension,
"num_vectors": self.num_vectors,
"start_node": self.start_node,
"params": {
"m": self.params.m,
"ef_construction": self.params.ef_construction,
"alpha": self.params.alpha,
"ef_search": self.params.ef_search
}
});
let metadata_file = std::fs::File::create(&metadata_path)?;
serde_json::to_writer_pretty(metadata_file, &metadata)
.map_err(|e| RetrieveError::Serialization(e.to_string()))?;
Ok(())
}
}
pub struct DiskANNSearcher {
dimension: usize,
start_node: u32,
params: DiskANNParams,
graph_reader: super::disk_io::DiskGraphReader,
vectors_file: std::fs::File,
doc_ids: Vec<u32>,
read_buf: Vec<u8>,
vec_buf: Vec<f32>,
}
impl DiskANNSearcher {
pub fn load(index_dir: &Path) -> Result<Self, RetrieveError> {
let metadata_path = index_dir.join("metadata.json");
let metadata_file = std::fs::File::open(&metadata_path)?;
let metadata: serde_json::Value = serde_json::from_reader(metadata_file)
.map_err(|e| RetrieveError::Serialization(e.to_string()))?;
let dimension = metadata["dimension"]
.as_u64()
.ok_or(RetrieveError::FormatError("Missing dimension".to_string()))?
as usize;
let num_vectors = metadata["num_vectors"]
.as_u64()
.ok_or(RetrieveError::FormatError(
"Missing num_vectors".to_string(),
))? as usize;
let start_node = metadata["start_node"]
.as_u64()
.ok_or(RetrieveError::FormatError("Missing start_node".to_string()))?
as u32;
let params_val = &metadata["params"];
let params = DiskANNParams {
m: params_val["m"].as_u64().unwrap_or(32) as usize,
ef_construction: params_val["ef_construction"].as_u64().unwrap_or(100) as usize,
alpha: params_val["alpha"].as_f64().unwrap_or(1.2) as f32,
ef_search: params_val["ef_search"].as_u64().unwrap_or(100) as usize,
seed: None,
};
let graph_path = index_dir.join("graph.index");
let graph_reader = super::disk_io::DiskGraphReader::open(&graph_path).map_err(|e| {
RetrieveError::Io(Arc::new(std::io::Error::other(format!(
"failed to open graph: {}",
e
))))
})?;
let vectors_path = index_dir.join("vectors.bin");
let vectors_file = std::fs::File::open(&vectors_path)?;
let doc_ids_path = index_dir.join("doc_ids.bin");
let doc_ids = if doc_ids_path.exists() {
let bytes = std::fs::read(&doc_ids_path)?;
if bytes.len() != num_vectors * 4 {
return Err(RetrieveError::FormatError(format!(
"doc_ids.bin size mismatch: expected {} bytes, got {}",
num_vectors * 4,
bytes.len()
)));
}
bytes
.chunks_exact(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
} else {
(0..num_vectors as u32).collect()
};
Ok(Self {
read_buf: vec![0u8; dimension * 4],
vec_buf: vec![0.0f32; dimension],
dimension,
start_node,
params,
graph_reader,
doc_ids,
vectors_file,
})
}
pub fn search(
&mut self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
let ef = ef_search.max(k).max(self.params.ef_search);
let mut visited = HashSet::new();
let mut retset: Vec<Candidate> = Vec::with_capacity(ef + 1);
let start_dist = {
let v = self.read_vector(self.start_node)?;
crate::simd::l2_distance_squared(query, v)
};
retset.push(Candidate {
id: self.start_node,
dist: start_dist,
});
visited.insert(self.start_node);
let mut current_idx = 0;
retset.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
while current_idx < retset.len() {
let current = retset[current_idx];
current_idx += 1;
let neighbors = self.graph_reader.get_neighbors(current.id)?;
for neighbor in neighbors {
if visited.contains(&neighbor) {
continue;
}
visited.insert(neighbor);
let dist = {
let v = self.read_vector(neighbor)?;
crate::simd::l2_distance_squared(query, v)
};
retset.push(Candidate { id: neighbor, dist });
}
retset.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
if retset.len() > ef {
retset.truncate(ef);
}
}
Ok(retset
.into_iter()
.take(k)
.filter_map(|c| {
let doc_id = self.doc_ids.get(c.id as usize).copied()?;
Some((doc_id, c.dist))
})
.collect())
}
fn read_vector(&mut self, idx: u32) -> Result<&[f32], RetrieveError> {
use std::io::{Read, Seek, SeekFrom};
let offset = idx as u64 * self.dimension as u64 * 4;
self.vectors_file.seek(SeekFrom::Start(offset))?;
self.vectors_file.read_exact(&mut self.read_buf)?;
for i in 0..self.dimension {
let start = i * 4;
self.vec_buf[i] = f32::from_le_bytes([
self.read_buf[start],
self.read_buf[start + 1],
self.read_buf[start + 2],
self.read_buf[start + 3],
]);
}
Ok(&self.vec_buf)
}
}
#[derive(Clone, Debug)]
pub struct DiskANNParams {
pub m: usize,
pub ef_construction: usize,
pub alpha: f32,
pub ef_search: usize,
pub seed: Option<u64>,
}
impl Default for DiskANNParams {
fn default() -> Self {
Self {
m: 32,
ef_construction: 100,
alpha: 1.2,
ef_search: 100,
seed: None,
}
}
}
#[derive(Clone, Copy, PartialEq)]
struct Candidate {
id: u32,
dist: f32,
}
impl Eq for Candidate {}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.dist.total_cmp(&other.dist)
}
}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl DiskANNIndex {
pub fn new(dimension: usize, params: DiskANNParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be greater than 0".to_string(),
));
}
Ok(Self {
dimension,
params,
built: false,
vectors: Vec::new(),
num_vectors: 0,
doc_ids: Vec::new(),
adj: Vec::new(),
start_node: 0,
})
}
pub fn add(&mut self, doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
self.add_slice(doc_id, &vector)
}
pub fn add_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"cannot add vectors after index is built".into(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
self.vectors.extend_from_slice(vector);
self.doc_ids.push(doc_id);
self.num_vectors += 1;
self.adj.push(SmallVec::new());
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
self.initialize_random_graph();
self.start_node = self.compute_medoid();
self.vamana_pass(1.0)?;
self.vamana_pass(self.params.alpha)?;
self.built = true;
self.reorder_for_locality();
Ok(())
}
#[cfg(feature = "parallel")]
pub fn build_parallel(&mut self, batch_size: usize) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
self.initialize_random_graph();
self.start_node = self.compute_medoid();
self.vamana_pass_parallel(1.0, batch_size)?;
self.vamana_pass_parallel(self.params.alpha, batch_size)?;
self.built = true;
self.reorder_for_locality();
Ok(())
}
#[cfg(feature = "parallel")]
fn vamana_pass_parallel(&mut self, alpha: f32, batch_size: usize) -> Result<(), RetrieveError> {
use rayon::prelude::*;
let mut nodes: Vec<u32> = (0..self.num_vectors as u32).collect();
{
use rand::SeedableRng;
let mut rng: Box<dyn rand::RngCore> = match self.params.seed {
Some(s) => Box::new(rand::rngs::StdRng::seed_from_u64(s.wrapping_add(1))),
None => Box::new(rand::rng()),
};
nodes.shuffle(&mut *rng);
}
let m = self.params.m;
let ef_c = self.params.ef_construction;
let start_node = self.start_node;
let batch_sz = batch_size.max(1);
for batch_start in (0..nodes.len()).step_by(batch_sz) {
let batch_end = (batch_start + batch_sz).min(nodes.len());
let batch = &nodes[batch_start..batch_end];
let results: Vec<(u32, Vec<u32>)> = batch
.par_iter()
.map(|&i| {
let query_vec = self.get_vector(i);
let (visited, _) = self.greedy_search(query_vec, ef_c, start_node);
let new_neighbors = self.robust_prune(i, &visited, alpha, m);
(i, new_neighbors)
})
.collect();
for (i, new_neighbors) in &results {
let i = *i;
self.adj[i as usize] = new_neighbors.iter().copied().collect();
for &j in new_neighbors {
if !self.adj[j as usize].contains(&i) {
self.adj[j as usize].push(i);
}
}
}
let overweight: Vec<u32> = (0..self.num_vectors as u32)
.filter(|&id| self.adj[id as usize].len() > m)
.collect();
if !overweight.is_empty() {
let pruned: Vec<(u32, Vec<u32>)> = overweight
.par_iter()
.map(|&id| {
let candidates: Vec<u32> = self.adj[id as usize].to_vec();
let pruned = self.robust_prune(id, &candidates, alpha, m);
(id, pruned)
})
.collect();
for (id, new_adj) in pruned {
self.adj[id as usize] = new_adj.into_iter().collect();
}
}
}
Ok(())
}
fn reorder_for_locality(&mut self) {
if self.num_vectors <= 1 {
return;
}
let n = self.num_vectors;
let dim = self.dimension;
let ep = self.start_node as usize;
let mut new_order: Vec<u32> = Vec::with_capacity(n);
let mut visited = vec![false; n];
let mut queue = std::collections::VecDeque::with_capacity(n);
queue.push_back(ep);
visited[ep] = true;
while let Some(node) = queue.pop_front() {
new_order.push(node as u32);
for &nb in &self.adj[node] {
let nb = nb as usize;
if nb < n && !visited[nb] {
visited[nb] = true;
queue.push_back(nb);
}
}
}
for (i, &v) in visited.iter().enumerate() {
if !v {
new_order.push(i as u32);
}
}
let mut old_to_new = vec![0u32; n];
for (new_idx, &old_idx) in new_order.iter().enumerate() {
old_to_new[old_idx as usize] = new_idx as u32;
}
let mut new_vectors = vec![0.0f32; self.vectors.len()];
for (new_idx, &old_idx) in new_order.iter().enumerate() {
let src = old_idx as usize * dim;
let dst = new_idx * dim;
new_vectors[dst..dst + dim].copy_from_slice(&self.vectors[src..src + dim]);
}
self.vectors = new_vectors;
let new_doc_ids: Vec<u32> = new_order
.iter()
.map(|&old| self.doc_ids[old as usize])
.collect();
self.doc_ids = new_doc_ids;
let mut new_adj: Vec<SmallVec<[u32; 32]>> = vec![SmallVec::new(); n];
for (old_idx, nbs) in self.adj.iter().enumerate() {
if old_idx < n {
let new_idx = old_to_new[old_idx] as usize;
new_adj[new_idx] = nbs.iter().map(|&nb| old_to_new[nb as usize]).collect();
}
}
self.adj = new_adj;
self.start_node = old_to_new[ep];
}
fn initialize_random_graph(&mut self) {
use rand::SeedableRng;
let mut rng: Box<dyn rand::RngCore> = match self.params.seed {
Some(s) => Box::new(rand::rngs::StdRng::seed_from_u64(s)),
None => Box::new(rand::rng()),
};
let r = self.params.m;
for i in 0..self.num_vectors {
let mut neighbors: HashSet<u32> = HashSet::with_capacity(r);
while neighbors.len() < r && neighbors.len() < self.num_vectors - 1 {
let n = rng.random_range(0..self.num_vectors) as u32;
if n != i as u32 {
neighbors.insert(n);
}
}
self.adj[i] = neighbors.into_iter().collect();
}
}
fn compute_medoid(&self) -> u32 {
let n = self.num_vectors;
let dim = self.dimension;
let mut centroid = vec![0.0f32; dim];
for i in 0..n {
let v = self.get_vector(i as u32);
for (c, &x) in centroid.iter_mut().zip(v.iter()) {
*c += x;
}
}
let inv_n = 1.0 / n as f32;
for c in centroid.iter_mut() {
*c *= inv_n;
}
let mut best_id = 0u32;
let mut best_dist = f32::INFINITY;
for i in 0..n {
let d = self.dist(¢roid, self.get_vector(i as u32));
if d < best_dist {
best_dist = d;
best_id = i as u32;
}
}
best_id
}
fn vamana_pass(&mut self, alpha: f32) -> Result<(), RetrieveError> {
let mut nodes: Vec<u32> = (0..self.num_vectors as u32).collect();
{
use rand::SeedableRng;
let mut rng: Box<dyn rand::RngCore> = match self.params.seed {
Some(s) => Box::new(rand::rngs::StdRng::seed_from_u64(s.wrapping_add(1))),
None => Box::new(rand::rng()),
};
nodes.shuffle(&mut *rng);
}
for &i in &nodes {
let query_vec = self.get_vector(i);
let (visited, _) =
self.greedy_search(query_vec, self.params.ef_construction, self.start_node);
let new_neighbors = self.robust_prune(i, &visited, alpha, self.params.m);
let neighbors_for_i = new_neighbors.clone();
self.adj[i as usize] = new_neighbors.into_iter().collect();
for j in neighbors_for_i {
if !self.adj[j as usize].contains(&i) {
let mut rev_candidates: Vec<u32> = self.adj[j as usize].to_vec();
rev_candidates.push(i);
let pruned = self.robust_prune(j, &rev_candidates, alpha, self.params.m);
self.adj[j as usize] = pruned.into_iter().collect();
}
}
}
Ok(())
}
fn robust_prune(
&self,
node: u32,
candidates: &[u32],
alpha: f32,
max_degree: usize,
) -> Vec<u32> {
let node_vec = self.get_vector(node);
let candidate_set: HashSet<u32> = candidates.iter().copied().collect();
let mut candidates_with_dist: Vec<Candidate> = candidates
.iter()
.filter(|&&c| c != node)
.map(|&c| Candidate {
id: c,
dist: self.dist(node_vec, self.get_vector(c)),
})
.collect();
for &neighbor in &self.adj[node as usize] {
if !candidate_set.contains(&neighbor) {
candidates_with_dist.push(Candidate {
id: neighbor,
dist: self.dist(node_vec, self.get_vector(neighbor)),
});
}
}
candidates_with_dist.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
let mut new_neighbors: Vec<u32> = Vec::with_capacity(max_degree);
candidates_with_dist.dedup_by(|a, b| a.id == b.id);
for cand in candidates_with_dist {
if new_neighbors.len() >= max_degree {
break;
}
let mut prune = false;
let cand_vec = self.get_vector(cand.id);
for &existing_neighbor in &new_neighbors {
let dist_existing_cand = self.dist(self.get_vector(existing_neighbor), cand_vec);
if alpha * dist_existing_cand <= cand.dist {
prune = true;
break;
}
}
if !prune {
new_neighbors.push(cand.id);
}
}
new_neighbors
}
fn greedy_search(
&self,
query: &[f32],
l_size: usize,
start_node: u32,
) -> (Vec<u32>, Vec<Candidate>) {
use std::cmp::Reverse;
use std::collections::BinaryHeap;
let dist_fn = self.dist_fn();
let num_vectors = self.num_vectors;
thread_local! {
static VISITED: std::cell::RefCell<(Vec<u8>, u8)> =
const { std::cell::RefCell::new((Vec::new(), 1)) };
}
VISITED.with(|cell| {
let (marks, gen) = &mut *cell.borrow_mut();
if marks.len() < num_vectors {
marks.resize(num_vectors, 0);
}
if let Some(next) = gen.checked_add(1) {
*gen = next;
} else {
marks.fill(0);
*gen = 1;
}
let generation = *gen;
let mut visited_insert = |id: u32| -> bool {
let idx = id as usize;
if idx < marks.len() && marks[idx] != generation {
marks[idx] = generation;
true
} else {
idx >= marks.len()
}
};
let mut frontier: BinaryHeap<Reverse<Candidate>> =
BinaryHeap::with_capacity(l_size * 2);
let mut results: BinaryHeap<Candidate> = BinaryHeap::with_capacity(l_size + 1);
let start_dist = dist_fn(query, self.get_vector(start_node));
visited_insert(start_node);
frontier.push(Reverse(Candidate {
id: start_node,
dist: start_dist,
}));
results.push(Candidate {
id: start_node,
dist: start_dist,
});
while let Some(Reverse(current)) = frontier.pop() {
if results.len() >= l_size {
if let Some(worst) = results.peek() {
if current.dist >= worst.dist {
break;
}
}
}
let neighbors = &self.adj[current.id as usize];
for (i, &neighbor) in neighbors.iter().enumerate() {
if i + 1 < neighbors.len() {
let next_id = neighbors[i + 1] as usize;
if next_id < num_vectors {
let ptr = self.vectors.as_ptr().wrapping_add(next_id * self.dimension);
#[cfg(target_arch = "aarch64")]
unsafe {
std::arch::asm!(
"prfm pldl1keep, [{ptr}]",
ptr = in(reg) ptr,
options(nostack, preserves_flags)
);
}
#[cfg(target_arch = "x86_64")]
unsafe {
std::arch::x86_64::_mm_prefetch(
ptr as *const i8,
std::arch::x86_64::_MM_HINT_T0,
);
}
}
}
if !visited_insert(neighbor) {
continue;
}
let dist = dist_fn(query, self.get_vector(neighbor));
frontier.push(Reverse(Candidate { id: neighbor, dist }));
results.push(Candidate { id: neighbor, dist });
if results.len() > l_size {
results.pop();
}
}
}
let mut result_vec: Vec<Candidate> = results.into_vec();
result_vec.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
let ids: Vec<u32> = result_vec.iter().map(|c| c.id).collect();
(ids, result_vec)
})
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let ef = ef_search.max(k);
let (_, candidates) = self.greedy_search(query, ef, self.start_node);
let result = candidates
.into_iter()
.take(k)
.filter_map(|c| {
let doc_id = self.doc_ids.get(c.id as usize).copied()?;
Some((doc_id, c.dist))
})
.collect();
Ok(result)
}
#[inline]
fn get_vector(&self, idx: u32) -> &[f32] {
let start = idx as usize * self.dimension;
&self.vectors[start..start + self.dimension]
}
#[inline]
fn dist(&self, a: &[f32], b: &[f32]) -> f32 {
crate::simd::l2_distance_squared(a, b)
}
#[inline(always)]
fn dist_fn(&self) -> fn(&[f32], &[f32]) -> f32 {
crate::simd::l2_distance_squared
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::error::RetrieveError;
#[test]
fn test_create_index() {
let index = DiskANNIndex::new(4, DiskANNParams::default());
assert!(index.is_ok());
let index = index.unwrap();
assert_eq!(index.dimension(), 4);
assert_eq!(index.num_vectors(), 0);
}
#[test]
fn test_add_and_search() {
let params = DiskANNParams {
m: 4,
ef_construction: 20,
alpha: 1.2,
ef_search: 20,
seed: None,
..DiskANNParams::default()
};
let mut index = DiskANNIndex::new(4, params).unwrap();
for i in 0..10u32 {
let v = vec![i as f32, (i as f32) * 0.5, 1.0, 0.0];
index.add(i, v).unwrap();
}
index.build().unwrap();
let query = vec![0.0, 0.0, 1.0, 0.0];
let results = index.search(&query, 3, 20).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 3);
assert_eq!(results[0].0, 0);
}
#[test]
fn test_zero_dimension_error() {
let result = DiskANNIndex::new(0, DiskANNParams::default());
match result {
Err(RetrieveError::InvalidParameter(_)) => {}
Err(other) => panic!("Expected InvalidParameter, got {:?}", other),
Ok(_) => panic!("Expected error for dimension 0"),
}
}
#[test]
fn test_max_degree_enforced() {
let m = 4;
let params = DiskANNParams {
m,
ef_construction: 20,
alpha: 1.2,
ef_search: 20,
seed: None,
..DiskANNParams::default()
};
let mut index = DiskANNIndex::new(4, params).unwrap();
for i in 0..30u32 {
let v = vec![i as f32, (i as f32) * 0.3, 1.0, (i as f32) * 0.1];
index.add(i, v).unwrap();
}
index.build().unwrap();
for (node, neighbors) in index.adj.iter().enumerate() {
assert!(
neighbors.len() <= m,
"Node {} has {} neighbors, max is {}",
node,
neighbors.len(),
m
);
}
}
#[test]
fn test_self_query_in_results() {
let params = DiskANNParams {
m: 32,
ef_construction: 100,
alpha: 1.2,
ef_search: 100,
seed: None,
..DiskANNParams::default()
};
let dim = 16;
let n = 100u32;
let mut index = DiskANNIndex::new(dim, params).unwrap();
use std::hash::{Hash, Hasher};
for i in 0..n {
let raw: Vec<f32> = (0..dim)
.map(|j| {
let mut h = std::collections::hash_map::DefaultHasher::new();
(42u64, i, j).hash(&mut h);
(h.finish() as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
})
.collect();
let norm: f32 = raw.iter().map(|x| x * x).sum::<f32>().sqrt();
let v: Vec<f32> = raw.iter().map(|x| x / norm).collect();
index.add(i, v).unwrap();
}
index.build().unwrap();
for &i in &[0, 1, n / 2, n - 1] {
let raw: Vec<f32> = (0..dim)
.map(|j| {
let mut h = std::collections::hash_map::DefaultHasher::new();
(42u64, i, j).hash(&mut h);
(h.finish() as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
})
.collect();
let norm: f32 = raw.iter().map(|x| x * x).sum::<f32>().sqrt();
let v: Vec<f32> = raw.iter().map(|x| x / norm).collect();
let results = index.search(&v, 5, 100).unwrap();
let found = results.iter().any(|&(id, dist)| id == i && dist < 1e-4);
assert!(
found,
"Self-query doc_id={} not found in top-5: {:?}",
i, results
);
}
}
#[test]
fn test_neighbor_ids_in_bounds() {
let params = DiskANNParams {
m: 8,
ef_construction: 30,
alpha: 1.2,
ef_search: 30,
seed: None,
..DiskANNParams::default()
};
let mut index = DiskANNIndex::new(4, params).unwrap();
let n = 25u32;
for i in 0..n {
let v = vec![i as f32, (i as f32) * 0.4, 1.0, 0.0];
index.add(i, v).unwrap();
}
index.build().unwrap();
for (node, neighbors) in index.adj.iter().enumerate() {
for &nbr in neighbors {
assert!(nbr < n, "Node {} has out-of-bounds neighbor {}", node, nbr);
}
}
}
#[test]
fn test_medoid_is_not_always_zero() {
let params = DiskANNParams {
m: 4,
ef_construction: 20,
alpha: 1.2,
ef_search: 20,
seed: None,
..DiskANNParams::default()
};
let mut index = DiskANNIndex::new(2, params).unwrap();
index.add(0, vec![100.0, 100.0]).unwrap(); index.add(1, vec![1.0, 0.0]).unwrap();
index.add(2, vec![1.1, 0.0]).unwrap();
index.add(3, vec![0.9, 0.0]).unwrap();
index.add(4, vec![1.0, 0.1]).unwrap();
index.add(5, vec![1.0, -0.1]).unwrap();
index.build().unwrap();
let medoid_idx = index.start_node as usize;
let medoid_vec = &index.vectors[medoid_idx * 2..(medoid_idx + 1) * 2];
assert!(
medoid_vec[0] < 50.0,
"medoid should not be the outlier at [100,100]; got vec={:?}",
medoid_vec
);
}
#[test]
fn test_reverse_edges_improve_recall() {
let params = DiskANNParams {
m: 4,
ef_construction: 20,
alpha: 1.2,
ef_search: 20,
seed: None,
..DiskANNParams::default()
};
let n = 20u32;
let mut index = DiskANNIndex::new(4, params).unwrap();
for i in 0..n {
let v = vec![i as f32 * 0.1, 0.0, 0.0, 0.0];
index.add(i, v).unwrap();
}
index.build().unwrap();
let isolated: Vec<_> = index
.adj
.iter()
.enumerate()
.filter(|(_, nbrs)| nbrs.is_empty())
.map(|(i, _)| i)
.collect();
assert!(
isolated.len() <= 1, "too many isolated nodes ({}): reverse edges may be missing",
isolated.len()
);
let mut self_hits = 0;
for i in 0..n {
let q = vec![i as f32 * 0.1, 0.0, 0.0, 0.0];
let results = index.search(&q, 3, 20).unwrap();
if results.iter().any(|(id, _)| *id == i) {
self_hits += 1;
}
}
assert!(
self_hits >= (n * 8 / 10),
"self-recall too low ({}/{}): reverse edges may be missing",
self_hits,
n
);
}
}