use crate::error::{IndexError, IndexResult};
use crate::metric::Metric;
use crate::PointId;
use alloc::collections::BinaryHeap;
use alloc::vec::Vec;
use core::cmp::Ordering;
use rand::Rng;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy)]
pub struct HnswConfig {
pub m: usize,
pub m_max0: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub level_lambda: f32,
}
impl Default for HnswConfig {
fn default() -> Self {
let m = 16;
Self {
m,
m_max0: 2 * m,
ef_construction: 200,
ef_search: 50,
level_lambda: 1.0 / (m as f32).ln(),
}
}
}
impl HnswConfig {
fn validate(&self) -> IndexResult<()> {
if self.m == 0 {
return Err(IndexError::InvalidConfig("m must be > 0"));
}
if self.ef_construction < self.m {
return Err(IndexError::InvalidConfig("ef_construction must be >= m"));
}
if self.ef_search == 0 {
return Err(IndexError::InvalidConfig("ef_search must be > 0"));
}
if !self.level_lambda.is_finite() || self.level_lambda <= 0.0 {
return Err(IndexError::InvalidConfig(
"level_lambda must be finite and positive",
));
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Neighbor {
pub id: PointId,
pub distance: f32,
}
impl Eq for Neighbor {}
impl Ord for Neighbor {
fn cmp(&self, other: &Self) -> Ordering {
self.distance
.partial_cmp(&other.distance)
.unwrap_or(Ordering::Equal)
.then_with(|| self.id.cmp(&other.id))
}
}
impl PartialOrd for Neighbor {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct MaxHeapEntry(Neighbor);
impl Eq for MaxHeapEntry {}
impl Ord for MaxHeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.0.cmp(&other.0)
}
}
impl PartialOrd for MaxHeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct MinHeapEntry(Neighbor);
impl Eq for MinHeapEntry {}
impl Ord for MinHeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
other.0.cmp(&self.0)
}
}
impl PartialOrd for MinHeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
struct Node<P> {
point: P,
neighbors: Vec<Vec<PointId>>,
}
pub struct HnswIndex<P, M>
where
M: Metric<Point = P>,
{
config: HnswConfig,
metric: M,
nodes: HashMap<PointId, Node<P>>,
entry_point: Option<(PointId, usize)>, dim: Option<usize>,
}
impl<P, M> HnswIndex<P, M>
where
M: Metric<Point = P>,
{
pub fn new(config: HnswConfig, metric: M) -> IndexResult<Self> {
config.validate()?;
Ok(Self {
config,
metric,
nodes: HashMap::new(),
entry_point: None,
dim: None,
})
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn get(&self, id: PointId) -> Option<&P> {
self.nodes.get(&id).map(|n| &n.point)
}
pub fn insert(&mut self, id: PointId, point: P) -> IndexResult<()> {
if self.nodes.contains_key(&id) {
return Err(IndexError::DuplicateId(id));
}
let point_dim = self.metric.dim(&point);
match self.dim {
None => self.dim = Some(point_dim),
Some(d) if d != point_dim => {
return Err(IndexError::DimensionMismatch {
expected: d,
actual: point_dim,
});
}
_ => {}
}
let level = self.random_level(&mut rand::thread_rng());
let neighbors = (0..=level)
.map(|lvl| {
let cap = if lvl == 0 {
self.config.m_max0
} else {
self.config.m
};
Vec::with_capacity(cap)
})
.collect();
let node = Node { point, neighbors };
self.nodes.insert(id, node);
let Some((entry_id, entry_level)) = self.entry_point else {
self.entry_point = Some((id, level));
return Ok(());
};
let mut nearest = entry_id;
for lvl in ((level + 1)..=entry_level).rev() {
nearest = self.greedy_search_one_level(id, nearest, lvl);
}
for lvl in (0..=level.min(entry_level)).rev() {
let candidates = self.search_layer(id, &[nearest], lvl, self.config.ef_construction);
let m_at_level = if lvl == 0 {
self.config.m_max0
} else {
self.config.m
};
let selected = self.select_neighbors_heuristic(candidates, m_at_level, true);
let new_neighbors_at_level: Vec<PointId> = selected.iter().map(|n| n.id).collect();
self.nodes.get_mut(&id).unwrap().neighbors[lvl] = new_neighbors_at_level.clone();
for neighbor in &new_neighbors_at_level {
self.add_back_edge(*neighbor, id, lvl);
}
if let Some(closest) = selected.first() {
nearest = closest.id;
}
}
if level > entry_level {
self.entry_point = Some((id, level));
}
Ok(())
}
pub fn search(&self, query: &P, k: usize) -> Vec<Neighbor> {
self.search_with_ef(query, k, self.config.ef_search)
}
pub fn search_with_ef(&self, query: &P, k: usize, ef: usize) -> Vec<Neighbor> {
let Some((entry_id, entry_level)) = self.entry_point else {
return Vec::new();
};
let ef = ef.max(k);
let mut nearest_id = entry_id;
for lvl in (1..=entry_level).rev() {
nearest_id = self.greedy_search_one_level_query(query, nearest_id, lvl);
}
let mut found = self.search_layer_query(query, &[nearest_id], 0, ef);
found.sort();
found.truncate(k);
found
}
fn random_level<R: Rng>(&self, rng: &mut R) -> usize {
let r: f32 = rng.gen_range(f32::MIN_POSITIVE..1.0);
(-r.ln() * self.config.level_lambda).floor() as usize
}
fn greedy_search_one_level(&self, query_id: PointId, entry: PointId, level: usize) -> PointId {
let query = &self.nodes[&query_id].point;
self.greedy_search_one_level_query(query, entry, level)
}
fn greedy_search_one_level_query(&self, query: &P, entry: PointId, level: usize) -> PointId {
let mut current = entry;
let mut current_dist = self.metric.distance(query, &self.nodes[&entry].point);
loop {
let mut improved = false;
let neighbors_at_level = self.nodes[¤t]
.neighbors
.get(level)
.map(Vec::as_slice)
.unwrap_or(&[]);
for &nbr in neighbors_at_level {
let d = self.metric.distance(query, &self.nodes[&nbr].point);
if d < current_dist {
current_dist = d;
current = nbr;
improved = true;
}
}
if !improved {
return current;
}
}
}
fn search_layer(
&self,
query_id: PointId,
entry_points: &[PointId],
level: usize,
ef: usize,
) -> Vec<Neighbor> {
let query = &self.nodes[&query_id].point;
self.search_layer_query_with_exclude(query, entry_points, level, ef, Some(query_id))
}
fn search_layer_query(
&self,
query: &P,
entry_points: &[PointId],
level: usize,
ef: usize,
) -> Vec<Neighbor> {
self.search_layer_query_with_exclude(query, entry_points, level, ef, None)
}
fn search_layer_query_with_exclude(
&self,
query: &P,
entry_points: &[PointId],
level: usize,
ef: usize,
exclude: Option<PointId>,
) -> Vec<Neighbor> {
let mut visited: HashSet<PointId> = HashSet::with_capacity(ef * 2);
let mut frontier: BinaryHeap<MinHeapEntry> = BinaryHeap::new(); let mut results: BinaryHeap<MaxHeapEntry> = BinaryHeap::new();
for &ep in entry_points {
if Some(ep) == exclude {
continue;
}
if !visited.insert(ep) {
continue;
}
let d = self.metric.distance(query, &self.nodes[&ep].point);
let n = Neighbor {
id: ep,
distance: d,
};
frontier.push(MinHeapEntry(n));
results.push(MaxHeapEntry(n));
}
while let Some(MinHeapEntry(closest)) = frontier.pop() {
if results.len() >= ef {
if let Some(MaxHeapEntry(worst)) = results.peek() {
if closest.distance > worst.distance {
break;
}
}
}
let neighbors_at_level = self.nodes[&closest.id]
.neighbors
.get(level)
.map(Vec::as_slice)
.unwrap_or(&[]);
for &nbr in neighbors_at_level {
if Some(nbr) == exclude {
continue;
}
if !visited.insert(nbr) {
continue;
}
let d = self.metric.distance(query, &self.nodes[&nbr].point);
let cand = Neighbor {
id: nbr,
distance: d,
};
let should_push = match results.peek() {
Some(MaxHeapEntry(worst)) => d < worst.distance || results.len() < ef,
None => true,
};
if should_push {
frontier.push(MinHeapEntry(cand));
results.push(MaxHeapEntry(cand));
if results.len() > ef {
results.pop();
}
}
}
}
results.into_iter().map(|MaxHeapEntry(n)| n).collect()
}
fn select_neighbors_heuristic(
&self,
mut candidates: Vec<Neighbor>,
m: usize,
keep_pruned: bool,
) -> Vec<Neighbor> {
candidates.sort();
let mut selected: Vec<Neighbor> = Vec::with_capacity(m);
let mut discarded: Vec<Neighbor> = Vec::new();
for cand in candidates {
if selected.len() >= m {
break;
}
let dominated = selected.iter().any(|r| {
self.metric
.distance(&self.nodes[&cand.id].point, &self.nodes[&r.id].point)
<= cand.distance
});
if dominated {
discarded.push(cand);
} else {
selected.push(cand);
}
}
if keep_pruned {
for d in discarded {
if selected.len() >= m {
break;
}
selected.push(d);
}
}
selected
}
fn add_back_edge(&mut self, from: PointId, to: PointId, level: usize) {
let m_at_level = if level == 0 {
self.config.m_max0
} else {
self.config.m
};
let mut current_list: Vec<PointId> = {
let node = self
.nodes
.get_mut(&from)
.expect("from id exists in nodes map");
if node.neighbors.len() <= level {
node.neighbors.resize_with(level + 1, Vec::new);
}
if node.neighbors[level].contains(&to) {
return;
}
core::mem::take(&mut node.neighbors[level])
};
current_list.push(to);
if current_list.len() <= m_at_level {
self.nodes
.get_mut(&from)
.expect("from still present")
.neighbors[level] = current_list;
return;
}
let scored: Vec<Neighbor> = current_list
.iter()
.map(|&cid| {
let d = self
.metric
.distance(&self.nodes[&from].point, &self.nodes[&cid].point);
Neighbor {
id: cid,
distance: d,
}
})
.collect();
let kept_ids: Vec<PointId> = self
.select_neighbors_heuristic(scored, m_at_level, true)
.into_iter()
.map(|n| n.id)
.collect();
self.nodes
.get_mut(&from)
.expect("from still present")
.neighbors[level] = kept_ids;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metric::L2;
fn make_index() -> HnswIndex<Vec<f32>, L2> {
HnswIndex::new(HnswConfig::default(), L2).expect("default config valid")
}
#[test]
fn empty_index_search_returns_empty() {
let idx = make_index();
assert!(idx.search(&vec![1.0, 2.0, 3.0], 5).is_empty());
}
#[test]
fn single_point_returns_itself() {
let mut idx = make_index();
idx.insert(42, vec![1.0, 2.0, 3.0]).unwrap();
let results = idx.search(&vec![1.0, 2.0, 3.0], 5);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 42);
assert_eq!(results[0].distance, 0.0);
}
#[test]
fn duplicate_id_rejected() {
let mut idx = make_index();
idx.insert(7, vec![0.0, 0.0]).unwrap();
let err = idx.insert(7, vec![1.0, 1.0]).unwrap_err();
assert!(matches!(err, IndexError::DuplicateId(7)));
}
#[test]
fn dim_mismatch_rejected() {
let mut idx = make_index();
idx.insert(0, vec![0.0_f32; 64]).unwrap();
let err = idx.insert(1, vec![0.0_f32; 32]).unwrap_err();
assert!(
matches!(
err,
IndexError::DimensionMismatch {
expected: 64,
actual: 32
}
),
"expected DimensionMismatch, got {err:?}"
);
}
#[test]
fn nearest_neighbor_is_correct_on_grid() {
let mut idx = make_index();
let mut id = 0;
for x in 0..5 {
for y in 0..5 {
idx.insert(id, vec![x as f32, y as f32]).unwrap();
id += 1;
}
}
let res = idx.search(&vec![2.1, 2.1], 1);
assert_eq!(res.len(), 1);
assert_eq!(res[0].id, 12, "nearest to (2.1, 2.1) should be (2,2)");
}
#[test]
fn k_nearest_neighbors_sorted_by_distance() {
let mut idx = make_index();
for i in 0..20 {
idx.insert(i, vec![i as f32, 0.0]).unwrap();
}
let res = idx.search(&vec![10.0, 0.0], 5);
assert_eq!(res.len(), 5);
for w in res.windows(2) {
assert!(w[0].distance <= w[1].distance);
}
assert_eq!(res[0].id, 10);
}
#[test]
fn recall_against_brute_force_random_data() {
use rand::{rngs::StdRng, SeedableRng};
use rand_distr::{Distribution, StandardNormal};
let mut rng = StdRng::seed_from_u64(42);
let n = 500;
let dim = 16;
let points: Vec<Vec<f32>> = (0..n)
.map(|_| (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect())
.collect();
let mut idx = make_index();
for (i, p) in points.iter().enumerate() {
idx.insert(i as u64, p.clone()).unwrap();
}
let metric = L2;
let k = 10;
let n_queries = 10;
let mut total_recall = 0.0;
for _ in 0..n_queries {
let query: Vec<f32> = (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect();
let hnsw_ids: HashSet<PointId> =
idx.search(&query, k).into_iter().map(|n| n.id).collect();
let mut bf: Vec<Neighbor> = points
.iter()
.enumerate()
.map(|(i, p)| Neighbor {
id: i as u64,
distance: metric.distance(&query, p),
})
.collect();
bf.sort();
let bf_ids: HashSet<PointId> = bf.into_iter().take(k).map(|n| n.id).collect();
let intersection = hnsw_ids.intersection(&bf_ids).count();
total_recall += intersection as f32 / k as f32;
}
let avg_recall = total_recall / n_queries as f32;
assert!(
avg_recall >= 0.95,
"recall {avg_recall:.3} below threshold; check HNSW correctness"
);
}
#[test]
#[ignore = "slow: ~400 s in debug; run with `cargo test -- --ignored`"]
fn recall_at_realistic_scale() {
use rand::{rngs::StdRng, SeedableRng};
use rand_distr::{Distribution, StandardNormal};
let n = 5000;
let dim = 64;
let k = 10;
let n_queries = 20;
let seeds: [u64; 5] = [1, 2, 3, 4, 5];
let metric = L2;
let mut total_recall = 0.0f32;
for seed in seeds {
let mut rng = StdRng::seed_from_u64(seed);
let points: Vec<Vec<f32>> = (0..n)
.map(|_| (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect())
.collect();
let mut idx = make_index();
for (i, p) in points.iter().enumerate() {
idx.insert(i as u64, p.clone()).unwrap();
}
for _ in 0..n_queries {
let query: Vec<f32> = (0..dim).map(|_| StandardNormal.sample(&mut rng)).collect();
let hnsw_ids: HashSet<PointId> =
idx.search(&query, k).into_iter().map(|n| n.id).collect();
let mut bf: Vec<Neighbor> = points
.iter()
.enumerate()
.map(|(i, p)| Neighbor {
id: i as u64,
distance: metric.distance(&query, p),
})
.collect();
bf.sort();
let bf_ids: HashSet<PointId> = bf.into_iter().take(k).map(|n| n.id).collect();
total_recall += hnsw_ids.intersection(&bf_ids).count() as f32 / k as f32;
}
}
let mean_recall = total_recall / (seeds.len() * n_queries) as f32;
println!("recall_at_realistic_scale: mean_recall = {mean_recall:.4}");
assert!(
mean_recall >= 0.90,
"mean recall {mean_recall:.3} below 0.90; HNSW graph quality degraded"
);
assert!(
mean_recall < 0.999,
"mean recall {mean_recall:.3} implausibly perfect; test is no longer exercising ANN"
);
}
}