use std::cmp::{Ordering, Reverse};
use std::collections::{BinaryHeap, HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq)]
pub(super) struct OrderedF64(pub f64);
impl Eq for OrderedF64 {}
impl PartialOrd for OrderedF64 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedF64 {
fn cmp(&self, other: &Self) -> Ordering {
self.0.total_cmp(&other.0)
}
}
#[derive(Clone)]
struct HnswNode {
#[allow(dead_code)]
level: usize,
neighbors: Vec<Vec<u64>>,
}
#[derive(Clone)]
pub struct HnswGraph {
m: usize,
m0: usize,
ef_construction: usize,
ml: f64,
entry_point: Option<u64>,
max_level: usize,
nodes: HashMap<u64, HnswNode>,
vectors: HashMap<u64, Vec<f32>>,
}
impl HnswGraph {
pub fn new(m: usize, m0: usize, ef_construction: usize) -> Self {
let ml = 1.0 / (m as f64).ln();
Self {
m,
m0,
ef_construction,
ml,
entry_point: None,
max_level: 0,
nodes: HashMap::new(),
vectors: HashMap::new(),
}
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn contains(&self, node_id: u64) -> bool {
self.nodes.contains_key(&node_id)
}
pub fn infer_dim(&self) -> Option<usize> {
self.vectors.values().next().map(|v| v.len())
}
pub fn all_vectors(&self) -> impl Iterator<Item = (u64, &Vec<f32>)> {
self.vectors.iter().map(|(id, vec)| (*id, vec))
}
fn assign_level(&self, seed: u64) -> usize {
let state = seed
.wrapping_mul(6364136223846793005_u64)
.wrapping_add(1442695040888963407_u64);
let r = ((state >> 40) as f64) / ((1u64 << 24) as f64);
let r = r.clamp(f64::MIN_POSITIVE, 1.0);
let level = (-r.ln() * self.ml).floor() as usize;
level.min(16)
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f64 {
if a.len() != b.len() || a.is_empty() {
return 1.0;
}
let mut dot = 0.0_f64;
let mut norm_a = 0.0_f64;
let mut norm_b = 0.0_f64;
for (x, y) in a.iter().zip(b.iter()) {
let x = *x as f64;
let y = *y as f64;
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
if norm_a <= f64::EPSILON || norm_b <= f64::EPSILON {
return 1.0;
}
let sim = (dot / (norm_a.sqrt() * norm_b.sqrt())).clamp(-1.0, 1.0);
1.0 - sim
}
fn search_layer(
&self,
query: &[f32],
entry_points: &[(u64, f64)],
ef: usize,
layer: usize,
) -> BinaryHeap<(OrderedF64, u64)> {
let degree_hint = if layer == 0 { self.m0 } else { self.m };
let entry_capacity = entry_points.len().max(1);
let beam_capacity = ef.max(entry_capacity);
let visited_capacity =
entry_capacity.saturating_add(beam_capacity.saturating_mul(degree_hint.max(1)));
let mut visited: HashSet<u64> = HashSet::with_capacity(visited_capacity);
let mut candidates: BinaryHeap<Reverse<(OrderedF64, u64)>> =
BinaryHeap::with_capacity(beam_capacity);
let mut found: BinaryHeap<(OrderedF64, u64)> = BinaryHeap::with_capacity(beam_capacity);
let mut worst_found = f64::MAX;
for &(ep_id, ep_dist) in entry_points {
if visited.insert(ep_id) {
candidates.push(Reverse((OrderedF64(ep_dist), ep_id)));
worst_found =
Self::push_found_candidate(&mut found, ef, OrderedF64(ep_dist), ep_id);
}
}
while let Some(Reverse((OrderedF64(c_dist), c_id))) = candidates.pop() {
if c_dist > worst_found {
break;
}
let Some(node) = self.nodes.get(&c_id) else {
continue;
};
let Some(neighbors) = node.neighbors.get(layer) else {
continue;
};
for &nbr_id in neighbors {
if !visited.insert(nbr_id) {
continue;
}
let Some(nbr_vec) = self.vectors.get(&nbr_id) else {
continue;
};
let dist = Self::cosine_distance(query, nbr_vec);
if found.len() >= ef && dist >= worst_found {
continue;
}
candidates.push(Reverse((OrderedF64(dist), nbr_id)));
worst_found = Self::push_found_candidate(&mut found, ef, OrderedF64(dist), nbr_id);
}
}
found
}
fn push_found_candidate(
found: &mut BinaryHeap<(OrderedF64, u64)>,
ef: usize,
distance: OrderedF64,
node_id: u64,
) -> f64 {
found.push((distance, node_id));
if found.len() > ef {
found.pop();
}
found.peek().map(|(d, _)| d.0).unwrap_or(f64::MAX)
}
fn select_neighbors(candidates: &[(u64, f64)], m_limit: usize) -> Vec<u64> {
let mut sorted = candidates.to_vec();
sorted.sort_by(|a, b| a.1.total_cmp(&b.1));
sorted.truncate(m_limit);
sorted.into_iter().map(|(id, _)| id).collect()
}
pub fn insert(&mut self, node_id: u64, vector: Vec<f32>, rng: u64) {
if self.nodes.contains_key(&node_id) {
self.vectors.insert(node_id, vector);
return;
}
let level = self.assign_level(rng);
self.vectors.insert(node_id, vector.clone());
let mut node = HnswNode {
level,
neighbors: vec![Vec::new(); level + 1],
};
let Some(ep) = self.entry_point else {
self.nodes.insert(node_id, node);
self.entry_point = Some(node_id);
self.max_level = level;
return;
};
let ep_dist = Self::cosine_distance(
&vector,
self.vectors.get(&ep).expect("entry point has vector"),
);
let mut current_ep: Vec<(u64, f64)> = vec![(ep, ep_dist)];
let top_level = self.max_level;
for lc in (level + 1..=top_level).rev() {
let found = self.search_layer(&vector, ¤t_ep, 1, lc);
let best = found.into_sorted_vec().into_iter().next();
if let Some((OrderedF64(d), best_id)) = best {
current_ep = vec![(best_id, d)];
}
}
for lc in (0..=level.min(top_level)).rev() {
let found = self.search_layer(&vector, ¤t_ep, self.ef_construction, lc);
let mut candidates: Vec<(u64, f64)> = found
.into_iter()
.map(|(OrderedF64(d), id)| (id, d))
.collect();
candidates.sort_by(|a, b| a.1.total_cmp(&b.1));
current_ep = candidates.iter().map(|&(id, d)| (id, d)).collect();
let m_max = if lc == 0 { self.m0 } else { self.m };
let neighbors = Self::select_neighbors(&candidates, m_max);
node.neighbors[lc] = neighbors.clone();
for &nbr_id in &neighbors {
if let Some(nbr_node) = self.nodes.get_mut(&nbr_id) {
if lc < nbr_node.neighbors.len() && !nbr_node.neighbors[lc].contains(&node_id) {
nbr_node.neighbors[lc].push(node_id);
if nbr_node.neighbors[lc].len() > m_max {
if let Some(nbr_vec) = self.vectors.get(&nbr_id).cloned() {
let mut nbr_cands: Vec<(u64, f64)> = nbr_node.neighbors[lc]
.iter()
.filter_map(|&id| {
self.vectors
.get(&id)
.map(|v| (id, Self::cosine_distance(&nbr_vec, v)))
})
.collect();
nbr_cands.sort_by(|a, b| a.1.total_cmp(&b.1));
nbr_cands.truncate(m_max);
nbr_node.neighbors[lc] =
nbr_cands.into_iter().map(|(id, _)| id).collect();
}
}
}
}
}
}
self.nodes.insert(node_id, node);
if level > top_level {
self.entry_point = Some(node_id);
self.max_level = level;
}
}
pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(u64, f64)> {
if self.nodes.is_empty() || k == 0 {
return Vec::new();
}
let Some(ep) = self.entry_point else {
return Vec::new();
};
let ep_dist = Self::cosine_distance(
query,
self.vectors.get(&ep).expect("entry point has vector"),
);
let mut current_ep: Vec<(u64, f64)> = vec![(ep, ep_dist)];
for lc in (1..=self.max_level).rev() {
let found = self.search_layer(query, ¤t_ep, 1, lc);
let best = found.into_sorted_vec().into_iter().next();
if let Some((OrderedF64(d), best_id)) = best {
current_ep = vec![(best_id, d)];
}
}
let found = self.search_layer(query, ¤t_ep, ef.max(k), 0);
let mut results: Vec<(u64, f64)> = found
.into_iter()
.map(|(OrderedF64(dist), id)| (id, 1.0 - dist))
.collect();
results.sort_by(|a, b| b.1.total_cmp(&a.1));
results.truncate(k);
results
}
}
#[cfg(test)]
mod tests {
use super::*;
fn synthetic_vector(seed: u64, dim: usize) -> Vec<f32> {
let mut state = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let mut out = Vec::with_capacity(dim);
for _ in 0..dim {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let v = ((state >> 40) as f32) / ((1u64 << 24) as f32);
out.push((v * 2.0) - 1.0);
}
out
}
#[test]
fn hnsw_search_empty_returns_empty() {
let graph = HnswGraph::new(16, 32, 200);
let query = vec![1.0_f32, 0.0, 0.0, 0.0];
assert!(graph.search(&query, 5, 64).is_empty());
}
#[test]
fn hnsw_search_respects_k() {
let mut graph = HnswGraph::new(16, 32, 200);
for i in 0..50u64 {
graph.insert(i, synthetic_vector(i, 4), i);
}
let query = synthetic_vector(999, 4);
let results = graph.search(&query, 3, 64);
assert_eq!(results.len(), 3);
}
#[test]
fn hnsw_insert_and_search_returns_nearest() {
let dim = 4;
let mut graph = HnswGraph::new(16, 32, 200);
let mut vecs: Vec<(u64, Vec<f32>)> = Vec::new();
for i in 0..100u64 {
let v = synthetic_vector(i * 7 + 13, dim);
graph.insert(i, v.clone(), i);
vecs.push((i, v));
}
let mut centroid = vec![0.0_f32; dim];
for (_, v) in &vecs {
for (c, x) in centroid.iter_mut().zip(v.iter()) {
*c += x;
}
}
for c in centroid.iter_mut() {
*c /= vecs.len() as f32;
}
let true_nearest = vecs
.iter()
.map(|(id, v)| {
let d = OrderedF64(HnswGraph::cosine_distance(¢roid, v));
(d, *id)
})
.min()
.map(|(_, id)| id)
.unwrap();
let results = graph.search(¢roid, 5, 64);
assert!(!results.is_empty(), "search returned no results");
let returned_ids: Vec<u64> = results.iter().map(|(id, _)| *id).collect();
assert!(
returned_ids.contains(&true_nearest),
"true nearest {} not in top-5: {:?}",
true_nearest,
returned_ids
);
}
#[test]
fn hnsw_repeated_search_is_deterministic_and_unique() {
let mut graph = HnswGraph::new(16, 32, 200);
for i in 0..128u64 {
graph.insert(i, synthetic_vector(i * 11 + 5, 8), i);
}
let query = synthetic_vector(777, 8);
let first = graph.search(&query, 16, 64);
let second = graph.search(&query, 16, 64);
assert_eq!(first, second, "repeated HNSW search drifted");
let ids: std::collections::HashSet<u64> = first.iter().map(|(id, _)| *id).collect();
assert_eq!(ids.len(), first.len(), "HNSW search returned duplicates");
}
}