use crate::primitives::Vector;
use rand::Rng;
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct HNSWIndex {
m: usize,
max_m0: usize,
ef_construction: usize,
ml: f64,
nodes: Vec<Node>,
item_to_node: HashMap<String, usize>,
entry_point: Option<usize>,
rng: rand::rngs::ThreadRng,
}
#[derive(Debug, Clone)]
struct Node {
item_id: String,
vector: Vector<f64>,
connections: Vec<Vec<usize>>,
}
impl HNSWIndex {
#[must_use]
pub fn new(m: usize, ef_construction: usize, _seed: f64) -> Self {
Self {
m,
max_m0: 2 * m,
ef_construction,
ml: 1.0 / (2.0_f64).ln(), nodes: Vec::new(),
item_to_node: HashMap::new(),
entry_point: None,
rng: rand::rng(),
}
}
pub fn add(&mut self, item_id: impl Into<String>, vector: Vector<f64>) {
let item_id = item_id.into();
let layer = self.random_layer();
let node_idx = self.nodes.len();
let connections = vec![Vec::new(); layer + 1];
let node = Node {
item_id: item_id.clone(),
vector,
connections,
};
self.nodes.push(node);
self.item_to_node.insert(item_id, node_idx);
if self.entry_point.is_none() {
self.entry_point = Some(node_idx);
return;
}
self.insert_node(node_idx, layer);
}
#[must_use]
pub fn search(&self, query: &Vector<f64>, k: usize) -> Vec<(String, f64)> {
if self.nodes.is_empty() || self.entry_point.is_none() {
return Vec::new();
}
let ep = self.entry_point.expect("Entry point exists");
let top_layer = self.nodes[ep].connections.len().saturating_sub(1);
let mut curr = ep;
for lc in (1..=top_layer).rev() {
curr = self
.search_layer(query, curr, 1, lc)
.into_iter()
.next()
.unwrap_or(curr);
}
let candidates = self.search_layer(query, curr, k.max(self.ef_construction), 0);
let mut results: Vec<(String, f64)> = candidates
.into_iter()
.map(|idx| {
let node = &self.nodes[idx];
let dist = Self::distance(query, &node.vector);
(node.item_id.clone(), dist)
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.into_iter().take(k).collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.nodes.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
#[must_use]
pub fn m(&self) -> usize {
self.m
}
#[must_use]
pub fn ef_construction(&self) -> usize {
self.ef_construction
}
fn random_layer(&mut self) -> usize {
let r: f64 = self.rng.random_range(0.0..1.0);
(-r.ln() * self.ml).floor() as usize
}
fn insert_node(&mut self, node_idx: usize, layer: usize) {
let ep = self.entry_point.expect("Entry point exists");
let top_layer = self.nodes[ep].connections.len().saturating_sub(1);
let mut curr = ep;
for lc in (layer + 1..=top_layer).rev() {
curr = self
.search_layer_node(node_idx, curr, 1, lc)
.into_iter()
.next()
.unwrap_or(curr);
}
for lc in (0..=layer).rev() {
let candidates = self.search_layer_node(node_idx, curr, self.ef_construction, lc);
let m = if lc == 0 { self.max_m0 } else { self.m };
let neighbors: Vec<usize> = candidates.into_iter().take(m).collect();
for &neighbor in &neighbors {
self.nodes[node_idx].connections[lc].push(neighbor);
if lc < self.nodes[neighbor].connections.len() {
self.nodes[neighbor].connections[lc].push(node_idx);
self.prune_connections(neighbor, lc, m);
}
}
if let Some(&first) = neighbors.first() {
curr = first;
}
}
if layer > top_layer {
self.entry_point = Some(node_idx);
}
}
fn search_layer(
&self,
query: &Vector<f64>,
entry: usize,
ef: usize,
layer: usize,
) -> Vec<usize> {
let mut visited = HashSet::new();
let mut candidates = Vec::new();
let mut best = Vec::new();
let entry_dist = Self::distance(query, &self.nodes[entry].vector);
candidates.push((entry, entry_dist));
best.push((entry, entry_dist));
visited.insert(entry);
while let Some((curr, _)) = candidates.pop() {
let worst_best_dist = best
.iter()
.map(|(_, d)| *d)
.fold(f64::NEG_INFINITY, f64::max);
let curr_dist = Self::distance(query, &self.nodes[curr].vector);
if curr_dist > worst_best_dist && best.len() >= ef {
break;
}
if layer < self.nodes[curr].connections.len() {
for &neighbor in &self.nodes[curr].connections[layer] {
if visited.insert(neighbor) {
let neighbor_dist = Self::distance(query, &self.nodes[neighbor].vector);
if neighbor_dist < worst_best_dist || best.len() < ef {
candidates.push((neighbor, neighbor_dist));
best.push((neighbor, neighbor_dist));
candidates.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
best.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
if best.len() > ef {
best.truncate(ef);
}
}
}
}
}
}
best.into_iter().map(|(idx, _)| idx).collect()
}
fn search_layer_node(
&self,
node_idx: usize,
entry: usize,
ef: usize,
layer: usize,
) -> Vec<usize> {
self.search_layer(&self.nodes[node_idx].vector, entry, ef, layer)
}
fn prune_connections(&mut self, node_idx: usize, layer: usize, max_m: usize) {
if self.nodes[node_idx].connections[layer].len() <= max_m {
return;
}
let node_vec = self.nodes[node_idx].vector.clone();
let mut neighbors: Vec<(usize, f64)> = self.nodes[node_idx].connections[layer]
.iter()
.map(|&neighbor| {
let dist = Self::distance(&node_vec, &self.nodes[neighbor].vector);
(neighbor, dist)
})
.collect();
neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
self.nodes[node_idx].connections[layer] = neighbors
.into_iter()
.take(max_m)
.map(|(idx, _)| idx)
.collect();
}
fn distance(a: &Vector<f64>, b: &Vector<f64>) -> f64 {
if a.len() != b.len() {
return f64::INFINITY;
}
let dot: f64 = a
.as_slice()
.iter()
.zip(b.as_slice().iter())
.map(|(x, y)| x * y)
.sum();
let norm_a: f64 = a.as_slice().iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_b: f64 = b.as_slice().iter().map(|x| x * x).sum::<f64>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return f64::INFINITY;
}
let cos_sim = dot / (norm_a * norm_b);
1.0 - cos_sim.clamp(-1.0, 1.0)
}
}
#[cfg(test)]
#[path = "hnsw_tests.rs"]
mod tests;