use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistanceMetric {
L2,
Cosine,
Dot,
}
impl DistanceMetric {
pub fn compute(self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "vector dim mismatch in HNSW distance");
match self {
DistanceMetric::L2 => {
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum.sqrt()
}
DistanceMetric::Cosine => {
let mut dot = 0.0f32;
let mut na = 0.0f32;
let mut nb = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
let denom = (na * nb).sqrt();
if denom == 0.0 {
f32::INFINITY
} else {
1.0 - dot / denom
}
}
DistanceMetric::Dot => {
let mut dot = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
}
-dot
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Node {
pub layers: Vec<Vec<i64>>,
}
impl Node {
pub fn max_layer(&self) -> usize {
self.layers.len() - 1
}
}
#[derive(Debug, Clone, Copy)]
pub struct HnswParams {
pub m: usize,
pub m_max0: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub m_l: f32,
}
impl Default for HnswParams {
fn default() -> Self {
let m = 16;
Self {
m,
m_max0: 2 * m,
ef_construction: 200,
ef_search: 50,
m_l: 1.0 / (m as f32).ln(),
}
}
}
#[derive(Debug, Clone)]
pub struct HnswIndex {
pub params: HnswParams,
pub distance: DistanceMetric,
pub entry_point: Option<i64>,
pub top_layer: usize,
pub nodes: HashMap<i64, Node>,
rng_state: u64,
}
impl HnswIndex {
pub fn new(distance: DistanceMetric, seed: u64) -> Self {
let seed = if seed == 0 { 0x9E3779B97F4A7C15 } else { seed };
Self {
params: HnswParams::default(),
distance,
entry_point: None,
top_layer: 0,
nodes: HashMap::new(),
rng_state: seed,
}
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn serialize_nodes(&self) -> Vec<(i64, Vec<Vec<i64>>)> {
let mut out: Vec<(i64, Vec<Vec<i64>>)> = self
.nodes
.iter()
.map(|(id, n)| (*id, n.layers.clone()))
.collect();
out.sort_by_key(|(id, _)| *id);
out
}
pub fn from_persisted_nodes<I>(distance: DistanceMetric, seed: u64, nodes: I) -> Self
where
I: IntoIterator<Item = (i64, Vec<Vec<i64>>)>,
{
let mut idx = Self::new(distance, seed);
let mut top_layer = 0usize;
let mut entry_point: Option<i64> = None;
for (id, layers) in nodes {
let max_layer = layers.len().saturating_sub(1);
if max_layer > top_layer || entry_point.is_none() {
top_layer = max_layer;
entry_point = Some(id);
}
idx.nodes.insert(id, Node { layers });
}
idx.top_layer = top_layer;
idx.entry_point = entry_point;
idx
}
pub fn insert<F>(&mut self, node_id: i64, vec: &[f32], get_vec: F)
where
F: Fn(i64) -> Vec<f32>,
{
if self.nodes.contains_key(&node_id) {
return;
}
if self.is_empty() {
self.nodes.insert(
node_id,
Node {
layers: vec![Vec::new()],
},
);
self.entry_point = Some(node_id);
self.top_layer = 0;
return;
}
let target_layer = self.pick_layer();
let new_node = Node {
layers: vec![Vec::new(); target_layer + 1],
};
self.nodes.insert(node_id, new_node);
let mut entry = self.entry_point.expect("non-empty index has entry point");
for layer in (target_layer + 1..=self.top_layer).rev() {
let nearest = self.search_layer(vec, &[entry], 1, layer, &get_vec);
if let Some((_, id)) = nearest.into_iter().next() {
entry = id;
}
}
let mut entries = vec![entry];
for layer in (0..=target_layer).rev() {
let candidates =
self.search_layer(vec, &entries, self.params.ef_construction, layer, &get_vec);
let m_max = if layer == 0 {
self.params.m_max0
} else {
self.params.m
};
let neighbors: Vec<i64> = candidates
.iter()
.take(self.params.m)
.map(|(_, id)| *id)
.collect();
self.nodes.get_mut(&node_id).expect("just inserted").layers[layer] = neighbors.clone();
for &nb in &neighbors {
let nb_layers = &mut self.nodes.get_mut(&nb).expect("neighbor must exist").layers;
if layer >= nb_layers.len() {
continue;
}
nb_layers[layer].push(node_id);
if nb_layers[layer].len() > m_max {
let nb_vec = get_vec(nb);
let mut by_dist: Vec<(f32, i64)> = nb_layers[layer]
.iter()
.map(|id| (self.distance.compute(&nb_vec, &get_vec(*id)), *id))
.collect();
by_dist
.sort_by(|(da, _), (db, _)| da.partial_cmp(db).unwrap_or(Ordering::Equal));
by_dist.truncate(m_max);
nb_layers[layer] = by_dist.into_iter().map(|(_, id)| id).collect();
}
}
entries = candidates.into_iter().map(|(_, id)| id).collect();
}
if target_layer > self.top_layer {
self.top_layer = target_layer;
self.entry_point = Some(node_id);
}
}
pub fn search<F>(&self, query: &[f32], k: usize, get_vec: F) -> Vec<i64>
where
F: Fn(i64) -> Vec<f32>,
{
if self.is_empty() || k == 0 {
return Vec::new();
}
let mut entry = self.entry_point.expect("non-empty index has entry point");
for layer in (1..=self.top_layer).rev() {
let nearest = self.search_layer(query, &[entry], 1, layer, &get_vec);
if let Some((_, id)) = nearest.into_iter().next() {
entry = id;
}
}
let ef = self.params.ef_search.max(k);
let candidates = self.search_layer(query, &[entry], ef, 0, &get_vec);
candidates.into_iter().take(k).map(|(_, id)| id).collect()
}
fn search_layer<F>(
&self,
query: &[f32],
entries: &[i64],
ef: usize,
layer: usize,
get_vec: &F,
) -> Vec<(f32, i64)>
where
F: Fn(i64) -> Vec<f32>,
{
let mut visited: HashSet<i64> = HashSet::with_capacity(ef * 2);
let mut candidates: BinaryHeap<MinHeapItem> = BinaryHeap::with_capacity(ef * 2);
let mut results: BinaryHeap<MaxHeapItem> = BinaryHeap::with_capacity(ef);
for &id in entries {
if !visited.insert(id) {
continue;
}
let d = self.distance.compute(query, &get_vec(id));
candidates.push(MinHeapItem { dist: d, id });
results.push(MaxHeapItem { dist: d, id });
}
while let Some(MinHeapItem {
dist: c_dist,
id: c_id,
}) = candidates.pop()
{
if let Some(worst) = results.peek() {
if results.len() >= ef && c_dist > worst.dist {
break;
}
}
let neighbors = self
.nodes
.get(&c_id)
.and_then(|n| n.layers.get(layer))
.cloned()
.unwrap_or_default();
for nb in neighbors {
if !visited.insert(nb) {
continue;
}
let d = self.distance.compute(query, &get_vec(nb));
let admit = if results.len() < ef {
true
} else {
d < results.peek().unwrap().dist
};
if admit {
candidates.push(MinHeapItem { dist: d, id: nb });
results.push(MaxHeapItem { dist: d, id: nb });
if results.len() > ef {
results.pop();
}
}
}
}
let mut out: Vec<(f32, i64)> = Vec::with_capacity(results.len());
while let Some(item) = results.pop() {
out.push((item.dist, item.id));
}
out.reverse();
out
}
fn pick_layer(&mut self) -> usize {
let u = self.next_uniform().max(1e-6); let layer = (-u.ln() * self.params.m_l).floor() as usize;
layer.min(self.top_layer + 1)
}
fn next_uniform(&mut self) -> f32 {
let mut x = self.rng_state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.rng_state = x;
((x >> 40) as u32) as f32 / (1u32 << 24) as f32
}
}
#[derive(Debug, Clone, Copy)]
struct MinHeapItem {
dist: f32,
id: i64,
}
impl PartialEq for MinHeapItem {
fn eq(&self, other: &Self) -> bool {
self.dist == other.dist && self.id == other.id
}
}
impl Eq for MinHeapItem {}
impl PartialOrd for MinHeapItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MinHeapItem {
fn cmp(&self, other: &Self) -> Ordering {
other
.dist
.partial_cmp(&self.dist)
.unwrap_or(Ordering::Equal)
.then(other.id.cmp(&self.id))
}
}
#[derive(Debug, Clone, Copy)]
struct MaxHeapItem {
dist: f32,
id: i64,
}
impl PartialEq for MaxHeapItem {
fn eq(&self, other: &Self) -> bool {
self.dist == other.dist && self.id == other.id
}
}
impl Eq for MaxHeapItem {}
impl PartialOrd for MaxHeapItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MaxHeapItem {
fn cmp(&self, other: &Self) -> Ordering {
self.dist
.partial_cmp(&other.dist)
.unwrap_or(Ordering::Equal)
.then(self.id.cmp(&other.id))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vec(state: &mut u64, dim: usize) -> Vec<f32> {
(0..dim)
.map(|_| {
let mut x = *state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
*state = x;
((x >> 40) as u32) as f32 / (1u32 << 24) as f32
})
.collect()
}
fn brute_force_topk(
vectors: &[Vec<f32>],
query: &[f32],
k: usize,
metric: DistanceMetric,
) -> Vec<i64> {
let mut by_dist: Vec<(f32, i64)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (metric.compute(query, v), i as i64))
.collect();
by_dist.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap_or(Ordering::Equal));
by_dist.into_iter().take(k).map(|(_, id)| id).collect()
}
fn recall_at_k(hnsw_result: &[i64], baseline: &[i64]) -> f32 {
let baseline_set: HashSet<i64> = baseline.iter().copied().collect();
let hits = hnsw_result
.iter()
.filter(|id| baseline_set.contains(id))
.count();
hits as f32 / baseline.len() as f32
}
#[test]
fn empty_index_returns_empty_search() {
let idx = HnswIndex::new(DistanceMetric::L2, 42);
let vectors: Vec<Vec<f32>> = vec![];
let result = idx.search(&[0.0; 4], 5, |id| vectors[id as usize].clone());
assert!(result.is_empty());
}
#[test]
fn single_node_returns_only_itself() {
let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
let v0 = vec![1.0, 2.0, 3.0];
let vectors = vec![v0.clone()];
idx.insert(0, &v0, |id| vectors[id as usize].clone());
let result = idx.search(&[0.0; 3], 5, |id| vectors[id as usize].clone());
assert_eq!(result, vec![0]);
}
#[test]
fn duplicate_insert_is_noop() {
let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
let v0 = vec![1.0, 2.0];
let vectors = vec![v0.clone()];
idx.insert(0, &v0, |id| vectors[id as usize].clone());
idx.insert(0, &v0, |id| vectors[id as usize].clone());
assert_eq!(idx.len(), 1);
}
#[test]
fn k_zero_returns_empty() {
let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
for (i, v) in vectors.iter().enumerate() {
idx.insert(i as i64, v, |id| vectors[id as usize].clone());
}
let result = idx.search(&[0.5, 0.5], 0, |id| vectors[id as usize].clone());
assert!(result.is_empty());
}
#[test]
fn small_graph_finds_exact_nearest() {
let vectors: Vec<Vec<f32>> = vec![
vec![0.0, 0.0],
vec![10.0, 0.0],
vec![0.0, 10.0],
vec![10.0, 10.0],
vec![5.0, 5.0],
];
let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
for (i, v) in vectors.iter().enumerate() {
idx.insert(i as i64, v, |id| vectors[id as usize].clone());
}
let result = idx.search(&[1.0, 1.0], 1, |id| vectors[id as usize].clone());
assert_eq!(result, vec![0]);
let result = idx.search(&[5.5, 5.5], 3, |id| vectors[id as usize].clone());
assert_eq!(result.len(), 3);
assert_eq!(result[0], 4, "closest to (5.5,5.5) should be id=4");
}
#[test]
fn recall_at_10_is_high_on_random_vectors_l2() {
let mut state: u64 = 0xDEADBEEF;
let dim = 8;
let n = 1000;
let queries = 20;
let k = 10;
let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
for (i, v) in vectors.iter().enumerate() {
idx.insert(i as i64, v, |id| vectors[id as usize].clone());
}
let mut total_recall = 0.0f32;
for _ in 0..queries {
let q = random_vec(&mut state, dim);
let hnsw_top = idx.search(&q, k, |id| vectors[id as usize].clone());
let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::L2);
total_recall += recall_at_k(&hnsw_top, &baseline);
}
let avg_recall = total_recall / queries as f32;
assert!(
avg_recall >= 0.95,
"recall@{k} dropped below 0.95: avg={avg_recall:.3}"
);
}
#[test]
fn recall_at_10_is_high_on_random_vectors_cosine() {
let mut state: u64 = 0xC0FFEE;
let dim = 16;
let n = 500;
let queries = 20;
let k = 10;
let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
let mut idx = HnswIndex::new(DistanceMetric::Cosine, 42);
for (i, v) in vectors.iter().enumerate() {
idx.insert(i as i64, v, |id| vectors[id as usize].clone());
}
let mut total_recall = 0.0f32;
for _ in 0..queries {
let q = random_vec(&mut state, dim);
let hnsw_top = idx.search(&q, k, |id| vectors[id as usize].clone());
let baseline = brute_force_topk(&vectors, &q, k, DistanceMetric::Cosine);
total_recall += recall_at_k(&hnsw_top, &baseline);
}
let avg_recall = total_recall / queries as f32;
assert!(
avg_recall >= 0.95,
"cosine recall@{k} dropped below 0.95: avg={avg_recall:.3}"
);
}
#[test]
fn entry_point_promotes_when_higher_layer_node_inserted() {
let mut state: u64 = 0xABCDEF;
let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
let dim = 4;
let mut vectors: Vec<Vec<f32>> = Vec::new();
for i in 0..50 {
vectors.push(random_vec(&mut state, dim));
let v = vectors[i].clone();
idx.insert(i as i64, &v, |id| vectors[id as usize].clone());
let entry = idx.entry_point.expect("non-empty");
let entry_max = idx.nodes[&entry].max_layer();
assert_eq!(
entry_max, idx.top_layer,
"entry-point invariant broken at step {i}: entry {entry} has max_layer {entry_max}, top_layer is {}",
idx.top_layer
);
}
}
#[test]
fn neighbor_lists_respect_m_max() {
let mut state: u64 = 0x123456;
let mut idx = HnswIndex::new(DistanceMetric::L2, 42);
let dim = 4;
let mut vectors: Vec<Vec<f32>> = Vec::new();
for i in 0..200 {
vectors.push(random_vec(&mut state, dim));
let v = vectors[i].clone();
idx.insert(i as i64, &v, |id| vectors[id as usize].clone());
}
for (id, node) in &idx.nodes {
for (layer, neighbors) in node.layers.iter().enumerate() {
let cap = if layer == 0 {
idx.params.m_max0
} else {
idx.params.m
};
assert!(
neighbors.len() <= cap,
"node {id} layer {layer} has {} > cap {cap}",
neighbors.len()
);
}
}
}
#[test]
fn deterministic_with_fixed_seed() {
let mut state: u64 = 0x999;
let dim = 4;
let n = 50;
let vectors: Vec<Vec<f32>> = (0..n).map(|_| random_vec(&mut state, dim)).collect();
let mut idx_a = HnswIndex::new(DistanceMetric::L2, 42);
let mut idx_b = HnswIndex::new(DistanceMetric::L2, 42);
for (i, v) in vectors.iter().enumerate() {
idx_a.insert(i as i64, v, |id| vectors[id as usize].clone());
idx_b.insert(i as i64, v, |id| vectors[id as usize].clone());
}
assert_eq!(idx_a.top_layer, idx_b.top_layer);
assert_eq!(idx_a.entry_point, idx_b.entry_point);
assert_eq!(idx_a.nodes.len(), idx_b.nodes.len());
for (id, node_a) in &idx_a.nodes {
let node_b = idx_b.nodes.get(id).expect("missing id");
assert_eq!(node_a.max_layer(), node_b.max_layer(), "id={id}");
}
}
}