use common::types::{DistanceMetric, VectorId};
use parking_lot::RwLock;
use rand::Rng;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use crate::distance::calculate_distance;
#[inline]
fn similarity_to_distance(similarity: f32, metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Cosine => 1.0 - similarity,
DistanceMetric::Euclidean => -similarity,
DistanceMetric::DotProduct => -similarity,
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct HnswConfig {
pub m: usize,
pub m_max0: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub level_multiplier: f64,
pub distance_metric: DistanceMetric,
}
impl Default for HnswConfig {
fn default() -> Self {
let m = 16;
Self {
m,
m_max0: m * 2,
ef_construction: 200,
ef_search: 50,
level_multiplier: 1.0 / (m as f64).ln(),
distance_metric: DistanceMetric::Cosine,
}
}
}
impl HnswConfig {
pub fn new(m: usize, ef_construction: usize, ef_search: usize) -> Self {
Self {
m,
m_max0: m * 2,
ef_construction,
ef_search,
level_multiplier: 1.0 / (m as f64).ln(),
distance_metric: DistanceMetric::Cosine,
}
}
pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
self.distance_metric = metric;
self
}
}
#[derive(Debug)]
struct HnswNode {
id: VectorId,
vector: Vec<f32>,
connections: Vec<Vec<usize>>,
max_layer: usize,
}
#[derive(Debug, Clone)]
struct Candidate {
node_idx: usize,
distance: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
other
.distance
.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Debug, Clone)]
struct FurthestCandidate {
node_idx: usize,
distance: f32,
}
impl PartialEq for FurthestCandidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for FurthestCandidate {}
impl PartialOrd for FurthestCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for FurthestCandidate {
fn cmp(&self, other: &Self) -> Ordering {
self.distance
.partial_cmp(&other.distance)
.unwrap_or(Ordering::Equal)
}
}
pub struct HnswIndex {
config: HnswConfig,
nodes: RwLock<Vec<HnswNode>>,
entry_point: RwLock<Option<usize>>,
max_level: RwLock<usize>,
id_to_idx: RwLock<HashMap<VectorId, usize>>,
dimension: RwLock<Option<usize>>,
}
impl HnswIndex {
pub fn new() -> Self {
Self::with_config(HnswConfig::default())
}
pub fn with_config(config: HnswConfig) -> Self {
Self {
config,
nodes: RwLock::new(Vec::new()),
entry_point: RwLock::new(None),
max_level: RwLock::new(0),
id_to_idx: RwLock::new(HashMap::new()),
dimension: RwLock::new(None),
}
}
fn random_level(&self) -> usize {
let mut rng = rand::thread_rng();
let uniform: f64 = rng.gen();
(-uniform.ln() * self.config.level_multiplier).floor() as usize
}
fn distance(&self, query: &[f32], node_idx: usize, nodes: &[HnswNode]) -> f32 {
similarity_to_distance(
calculate_distance(query, &nodes[node_idx].vector, self.config.distance_metric),
self.config.distance_metric,
)
}
fn search_layer(
&self,
query: &[f32],
entry_points: Vec<usize>,
ef: usize,
layer: usize,
nodes: &[HnswNode],
) -> Vec<Candidate> {
let mut visited: HashSet<usize> = HashSet::new();
let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
let mut results: BinaryHeap<FurthestCandidate> = BinaryHeap::new();
for &ep in &entry_points {
visited.insert(ep);
let dist = self.distance(query, ep, nodes);
candidates.push(Candidate {
node_idx: ep,
distance: dist,
});
results.push(FurthestCandidate {
node_idx: ep,
distance: dist,
});
}
while let Some(candidate) = candidates.pop() {
let furthest_dist = results.peek().map(|r| r.distance).unwrap_or(f32::MAX);
if candidate.distance > furthest_dist && results.len() >= ef {
break;
}
let node = &nodes[candidate.node_idx];
if layer < node.connections.len() {
for &neighbor_idx in &node.connections[layer] {
if visited.insert(neighbor_idx) {
let dist = self.distance(query, neighbor_idx, nodes);
let should_add = results.len() < ef
|| dist < results.peek().map(|r| r.distance).unwrap_or(f32::MAX);
if should_add {
candidates.push(Candidate {
node_idx: neighbor_idx,
distance: dist,
});
results.push(FurthestCandidate {
node_idx: neighbor_idx,
distance: dist,
});
while results.len() > ef {
results.pop();
}
}
}
}
}
}
let mut final_results: Vec<Candidate> = results
.into_iter()
.map(|fc| Candidate {
node_idx: fc.node_idx,
distance: fc.distance,
})
.collect();
final_results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
});
final_results
}
fn select_neighbors_simple(&self, candidates: &[Candidate], m: usize) -> Vec<usize> {
candidates.iter().take(m).map(|c| c.node_idx).collect()
}
fn select_neighbors_heuristic(
&self,
query: &[f32],
candidates: &[Candidate],
m: usize,
nodes: &[HnswNode],
extend_candidates: bool,
) -> Vec<usize> {
let mut working_candidates = candidates.to_vec();
if extend_candidates {
let mut extended: HashSet<usize> =
working_candidates.iter().map(|c| c.node_idx).collect();
for candidate in candidates.iter().take(m) {
let node = &nodes[candidate.node_idx];
for layer_connections in &node.connections {
for &neighbor in layer_connections {
if extended.insert(neighbor) {
let dist = self.distance(query, neighbor, nodes);
working_candidates.push(Candidate {
node_idx: neighbor,
distance: dist,
});
}
}
}
}
working_candidates.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
});
}
let mut selected: Vec<usize> = Vec::with_capacity(m);
for candidate in &working_candidates {
if selected.len() >= m {
break;
}
let mut is_good = true;
for &sel_idx in &selected {
let dist_to_selected = calculate_distance(
&nodes[candidate.node_idx].vector,
&nodes[sel_idx].vector,
self.config.distance_metric,
);
if dist_to_selected < candidate.distance {
is_good = false;
break;
}
}
if is_good {
selected.push(candidate.node_idx);
}
}
if selected.len() < m {
for candidate in &working_candidates {
if selected.len() >= m {
break;
}
if !selected.contains(&candidate.node_idx) {
selected.push(candidate.node_idx);
}
}
}
selected
}
fn add_connection(&self, from_idx: usize, to_idx: usize, layer: usize, nodes: &mut [HnswNode]) {
let m_max = if layer == 0 {
self.config.m_max0
} else {
self.config.m
};
if layer < nodes[from_idx].connections.len()
&& !nodes[from_idx].connections[layer].contains(&to_idx)
{
nodes[from_idx].connections[layer].push(to_idx);
if nodes[from_idx].connections[layer].len() > m_max {
let conn_indices: Vec<usize> = nodes[from_idx].connections[layer].clone();
let mut sorted_candidates: Vec<Candidate> = conn_indices
.iter()
.map(|&idx| Candidate {
node_idx: idx,
distance: self.distance(&nodes[from_idx].vector, idx, nodes),
})
.collect();
sorted_candidates.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
});
nodes[from_idx].connections[layer] =
self.select_neighbors_simple(&sorted_candidates, m_max);
}
}
if layer < nodes[to_idx].connections.len()
&& !nodes[to_idx].connections[layer].contains(&from_idx)
{
nodes[to_idx].connections[layer].push(from_idx);
if nodes[to_idx].connections[layer].len() > m_max {
let conn_indices: Vec<usize> = nodes[to_idx].connections[layer].clone();
let mut sorted_candidates: Vec<Candidate> = conn_indices
.iter()
.map(|&idx| Candidate {
node_idx: idx,
distance: self.distance(&nodes[to_idx].vector, idx, nodes),
})
.collect();
sorted_candidates.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
});
nodes[to_idx].connections[layer] =
self.select_neighbors_simple(&sorted_candidates, m_max);
}
}
}
pub fn insert(&self, id: VectorId, vector: Vec<f32>) {
let vector_dim = vector.len();
{
let mut dim = self.dimension.write();
if let Some(d) = *dim {
if d != vector_dim {
tracing::error!("Dimension mismatch: expected {}, got {}", d, vector_dim);
return;
}
} else {
*dim = Some(vector_dim);
}
}
let new_level = self.random_level();
let new_node = HnswNode {
id: id.clone(),
vector: vector.clone(),
connections: (0..=new_level).map(|_| Vec::new()).collect(),
max_layer: new_level,
};
let mut nodes = self.nodes.write();
let new_idx = nodes.len();
nodes.push(new_node);
self.id_to_idx.write().insert(id, new_idx);
let entry = *self.entry_point.read();
let entry_idx = match entry {
None => {
*self.entry_point.write() = Some(new_idx);
*self.max_level.write() = new_level;
return;
}
Some(idx) => idx,
};
let current_max_level = *self.max_level.read();
let mut current_entry = vec![entry_idx];
for layer in (new_level + 1..=current_max_level).rev() {
let nearest = self.search_layer(&vector, current_entry.clone(), 1, layer, &nodes);
if !nearest.is_empty() {
current_entry = vec![nearest[0].node_idx];
}
}
for layer in (0..=new_level.min(current_max_level)).rev() {
let candidates = self.search_layer(
&vector,
current_entry.clone(),
self.config.ef_construction,
layer,
&nodes,
);
let m = if layer == 0 {
self.config.m_max0
} else {
self.config.m
};
let neighbors = self.select_neighbors_heuristic(&vector, &candidates, m, &nodes, false);
for &neighbor_idx in &neighbors {
self.add_connection(new_idx, neighbor_idx, layer, &mut nodes);
}
if !candidates.is_empty() {
current_entry = candidates.iter().take(1).map(|c| c.node_idx).collect();
}
}
if new_level > current_max_level {
*self.entry_point.write() = Some(new_idx);
*self.max_level.write() = new_level;
}
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(VectorId, f32)> {
self.search_with_ef(query, k, self.config.ef_search)
}
pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(VectorId, f32)> {
let nodes = self.nodes.read();
if nodes.is_empty() {
return Vec::new();
}
let entry = *self.entry_point.read();
let entry_idx = match entry {
None => return Vec::new(),
Some(idx) => idx,
};
let max_level = *self.max_level.read();
let mut current_entry = vec![entry_idx];
for layer in (1..=max_level).rev() {
let nearest = self.search_layer(query, current_entry.clone(), 1, layer, &nodes);
if !nearest.is_empty() {
current_entry = vec![nearest[0].node_idx];
}
}
let candidates = self.search_layer(query, current_entry, ef.max(k), 0, &nodes);
candidates
.into_iter()
.take(k)
.map(|c| (nodes[c.node_idx].id.clone(), c.distance))
.collect()
}
pub fn delete(&self, id: &VectorId) -> bool {
let idx = {
let id_map = self.id_to_idx.read();
match id_map.get(id) {
Some(&idx) => idx,
None => return false,
}
};
let mut nodes = self.nodes.write();
let mut id_map = self.id_to_idx.write();
for layer in 0..nodes[idx].connections.len() {
let neighbors: Vec<usize> = nodes[idx].connections[layer].clone();
for neighbor_idx in neighbors {
if neighbor_idx < nodes.len() && layer < nodes[neighbor_idx].connections.len() {
nodes[neighbor_idx].connections[layer].retain(|&n| n != idx);
}
}
}
nodes[idx].connections.clear();
nodes[idx].vector.clear();
id_map.remove(id);
let entry = *self.entry_point.read();
if entry == Some(idx) {
let new_entry = nodes
.iter()
.enumerate()
.filter(|(_, n)| !n.vector.is_empty())
.max_by_key(|(_, n)| n.max_layer)
.map(|(i, _)| i);
*self.entry_point.write() = new_entry;
}
true
}
pub fn len(&self) -> usize {
self.id_to_idx.read().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn stats(&self) -> HnswStats {
let nodes = self.nodes.read();
let max_level = *self.max_level.read();
let mut level_counts = vec![0usize; max_level + 1];
let mut total_connections = 0usize;
for node in nodes.iter() {
if !node.vector.is_empty() {
for (layer, connections) in node.connections.iter().enumerate() {
if layer <= max_level {
level_counts[layer] += 1;
total_connections += connections.len();
}
}
}
}
HnswStats {
num_vectors: self.len(),
max_level,
level_counts,
total_connections,
avg_connections: if !self.is_empty() {
total_connections as f64 / self.len() as f64
} else {
0.0
},
}
}
pub fn config(&self) -> &HnswConfig {
&self.config
}
pub fn dimension(&self) -> Option<usize> {
*self.dimension.read()
}
pub fn entry_point(&self) -> Option<usize> {
*self.entry_point.read()
}
pub fn max_level(&self) -> usize {
*self.max_level.read()
}
pub(crate) fn nodes_read(&self) -> Vec<NodeSnapshot> {
self.nodes
.read()
.iter()
.map(|node| NodeSnapshot {
id: node.id.clone(),
vector: node.vector.clone(),
connections: node.connections.clone(),
max_layer: node.max_layer,
})
.collect()
}
pub fn from_snapshot(snapshot: crate::persistence::HnswFullSnapshot) -> Result<Self, String> {
use std::collections::HashMap;
let mut nodes = Vec::with_capacity(snapshot.nodes.len());
let mut id_to_idx = HashMap::with_capacity(snapshot.nodes.len());
for (idx, snode) in snapshot.nodes.into_iter().enumerate() {
id_to_idx.insert(snode.id.clone(), idx);
nodes.push(HnswNode {
id: snode.id,
vector: snode.vector,
connections: snode.connections,
max_layer: snode.max_layer,
});
}
let dimension = if nodes.is_empty() {
None
} else {
Some(snapshot.dimension)
};
Ok(Self {
config: snapshot.config,
nodes: RwLock::new(nodes),
entry_point: RwLock::new(snapshot.entry_point),
max_level: RwLock::new(snapshot.max_level),
id_to_idx: RwLock::new(id_to_idx),
dimension: RwLock::new(dimension),
})
}
}
#[derive(Debug, Clone)]
pub(crate) struct NodeSnapshot {
pub id: String,
pub vector: Vec<f32>,
pub connections: Vec<Vec<usize>>,
pub max_layer: usize,
}
impl Default for HnswIndex {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct HnswStats {
pub num_vectors: usize,
pub max_level: usize,
pub level_counts: Vec<usize>,
pub total_connections: usize,
pub avg_connections: f64,
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vector(dim: usize) -> Vec<f32> {
let mut rng = rand::thread_rng();
(0..dim).map(|_| rng.gen::<f32>()).collect()
}
fn normalize(v: &mut Vec<f32>) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[test]
fn test_hnsw_basic_operations() {
let index = HnswIndex::new();
for i in 0..100 {
let mut vec = random_vector(128);
normalize(&mut vec);
index.insert(format!("vec_{}", i), vec);
}
assert_eq!(index.len(), 100);
assert!(!index.is_empty());
let mut query = random_vector(128);
normalize(&mut query);
let results = index.search(&query, 10);
assert_eq!(results.len(), 10);
for i in 1..results.len() {
assert!(results[i - 1].1 <= results[i].1);
}
}
#[test]
fn test_hnsw_delete() {
let index = HnswIndex::new();
for i in 0..10 {
let mut vec = random_vector(64);
normalize(&mut vec);
index.insert(format!("vec_{}", i), vec);
}
assert_eq!(index.len(), 10);
assert!(index.delete(&"vec_5".to_string()));
assert_eq!(index.len(), 9);
assert!(!index.delete(&"vec_999".to_string()));
}
#[test]
fn test_hnsw_recall() {
let dim = 128;
let n_vectors = 1000;
let index = HnswIndex::with_config(HnswConfig::new(16, 200, 100));
let mut vectors: Vec<(VectorId, Vec<f32>)> = Vec::new();
for i in 0..n_vectors {
let mut vec = random_vector(dim);
normalize(&mut vec);
let id: VectorId = format!("vec_{}", i);
vectors.push((id.clone(), vec.clone()));
index.insert(id, vec);
}
let n_queries = 10;
let k = 10;
let mut total_recall = 0.0;
for _ in 0..n_queries {
let mut query = random_vector(dim);
normalize(&mut query);
let hnsw_results: HashSet<String> = index
.search(&query, k)
.into_iter()
.map(|(id, _)| id)
.collect();
let mut exact: Vec<(String, f32)> = vectors
.iter()
.map(|(id, vec)| {
let sim = calculate_distance(&query, vec, DistanceMetric::Cosine);
(
id.clone(),
similarity_to_distance(sim, DistanceMetric::Cosine),
)
})
.collect();
exact.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let exact_results: HashSet<String> =
exact.into_iter().take(k).map(|(id, _)| id).collect();
let overlap = hnsw_results.intersection(&exact_results).count();
total_recall += overlap as f64 / k as f64;
}
let avg_recall = total_recall / n_queries as f64;
println!("Average recall@{}: {:.2}%", k, avg_recall * 100.0);
assert!(
avg_recall >= 0.80,
"Recall too low: {:.2}%",
avg_recall * 100.0
);
}
#[test]
fn test_hnsw_stats() {
let index = HnswIndex::new();
for i in 0..50 {
let mut vec = random_vector(64);
normalize(&mut vec);
index.insert(format!("vec_{}", i), vec);
}
let stats = index.stats();
assert_eq!(stats.num_vectors, 50);
let _ = stats.max_level;
assert!(stats.avg_connections > 0.0);
println!("HNSW Stats: {:?}", stats);
}
#[test]
fn test_hnsw_custom_ef() {
let index = HnswIndex::new();
for i in 0..100 {
let mut vec = random_vector(64);
normalize(&mut vec);
index.insert(format!("vec_{}", i), vec);
}
let mut query = random_vector(64);
normalize(&mut query);
let results_low_ef = index.search_with_ef(&query, 10, 10);
let results_high_ef = index.search_with_ef(&query, 10, 200);
assert_eq!(results_low_ef.len(), 10);
assert_eq!(results_high_ef.len(), 10);
}
#[test]
fn test_hnsw_empty_search() {
let index = HnswIndex::new();
let query = random_vector(64);
let results = index.search(&query, 10);
assert!(results.is_empty());
}
#[test]
fn test_hnsw_single_vector() {
let index = HnswIndex::new();
let mut vec = random_vector(64);
normalize(&mut vec);
index.insert("single".to_string(), vec.clone());
let results = index.search(&vec, 5);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "single".to_string());
assert!(
results[0].1.abs() < 0.1,
"Distance to self was {}",
results[0].1
);
}
}