use alloc::collections::BinaryHeap;
use alloc::vec::Vec;
use core::cmp::Ordering;
use hashbrown::HashMap;
use super::distance::{compute_distance, cosine_distance};
use super::types::{DistanceMetric, HNSW_MAX_NEIGHBORS, HnswNode, VectorError, VectorSearchResult};
#[derive(Clone)]
struct MaxCandidate {
id: u64,
distance: f32,
}
impl PartialEq for MaxCandidate {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for MaxCandidate {}
impl PartialOrd for MaxCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MaxCandidate {
fn cmp(&self, other: &Self) -> Ordering {
self.distance
.partial_cmp(&other.distance)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Clone)]
struct MinCandidate {
id: u64,
distance: f32,
}
impl PartialEq for MinCandidate {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for MinCandidate {}
impl PartialOrd for MinCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MinCandidate {
fn cmp(&self, other: &Self) -> Ordering {
other
.distance
.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Debug, Clone, Default)]
struct HnswLayer {
neighbors: HashMap<u64, Vec<u64>>,
}
impl HnswLayer {
fn new() -> Self {
Self {
neighbors: HashMap::new(),
}
}
fn add_node(&mut self, id: u64) {
self.neighbors.entry(id).or_default();
}
fn contains(&self, id: u64) -> bool {
self.neighbors.contains_key(&id)
}
fn get_neighbors(&self, id: u64) -> &[u64] {
self.neighbors.get(&id).map(|v| v.as_slice()).unwrap_or(&[])
}
fn set_neighbors(&mut self, id: u64, neighbors: Vec<u64>) {
self.neighbors.insert(id, neighbors);
}
fn connect(&mut self, a: u64, b: u64, max_neighbors: usize) {
if let Some(neighbors) = self.neighbors.get_mut(&a) {
if !neighbors.contains(&b) && neighbors.len() < max_neighbors {
neighbors.push(b);
}
}
if let Some(neighbors) = self.neighbors.get_mut(&b) {
if !neighbors.contains(&a) && neighbors.len() < max_neighbors {
neighbors.push(a);
}
}
}
fn remove_node(&mut self, id: u64) {
if let Some(neighbors) = self.neighbors.remove(&id) {
for neighbor_id in neighbors {
if let Some(neighbor_neighbors) = self.neighbors.get_mut(&neighbor_id) {
neighbor_neighbors.retain(|&x| x != id);
}
}
}
}
fn len(&self) -> usize {
self.neighbors.len()
}
fn is_empty(&self) -> bool {
self.neighbors.is_empty()
}
}
#[derive(Clone)]
pub struct HnswIndex {
layers: Vec<HnswLayer>,
entry_point: Option<u64>,
max_layer: usize,
m: usize,
m_max0: usize,
ef_construction: usize,
ef_search: usize,
ml: f32,
metric: DistanceMetric,
vectors: HashMap<u64, Vec<f32>>,
dimensions: usize,
rng_state: u64,
}
impl HnswIndex {
pub fn new(m: usize, ef_construction: usize) -> Self {
let m = m.clamp(2, HNSW_MAX_NEIGHBORS);
Self {
layers: Vec::new(),
entry_point: None,
max_layer: 0,
m,
m_max0: m * 2,
ef_construction,
ef_search: 50,
ml: 1.0 / libm::logf(m as f32),
metric: DistanceMetric::Cosine,
vectors: HashMap::new(),
dimensions: 0,
rng_state: 12345678, }
}
pub fn with_params(
m: usize,
ef_construction: usize,
ef_search: usize,
metric: DistanceMetric,
) -> Self {
let mut index = Self::new(m, ef_construction);
index.ef_search = ef_search;
index.metric = metric;
index
}
pub fn set_ef_search(&mut self, ef: usize) {
self.ef_search = ef.max(1);
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn max_layer(&self) -> usize {
self.max_layer
}
fn seed_rng(&mut self, id: u64) {
self.rng_state = self
.rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(id);
}
fn random_f32(&mut self) -> f32 {
self.rng_state = self
.rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1);
(self.rng_state >> 33) as f32 / (1u64 << 31) as f32
}
fn random_level(&mut self) -> usize {
let r = self.random_f32();
if r <= 0.0 {
return 0;
}
let level = (-libm::logf(r) * self.ml) as usize;
level.min(32) }
fn distance(&self, id_a: u64, id_b: u64) -> f32 {
let a = self.vectors.get(&id_a);
let b = self.vectors.get(&id_b);
match (a, b) {
(Some(va), Some(vb)) => compute_distance(va, vb, self.metric),
_ => f32::MAX,
}
}
fn distance_to_query(&self, query: &[f32], id: u64) -> f32 {
match self.vectors.get(&id) {
Some(v) => compute_distance(query, v, self.metric),
None => f32::MAX,
}
}
fn ensure_layers(&mut self, level: usize) {
while self.layers.len() <= level {
self.layers.push(HnswLayer::new());
}
}
fn search_layer_single(&self, query: &[f32], entry_point: u64, layer: usize) -> u64 {
let mut current = entry_point;
let mut current_dist = self.distance_to_query(query, current);
loop {
let mut changed = false;
let neighbors = self.layers[layer].get_neighbors(current);
for &neighbor in neighbors {
let dist = self.distance_to_query(query, neighbor);
if dist < current_dist {
current = neighbor;
current_dist = dist;
changed = true;
}
}
if !changed {
break;
}
}
current
}
fn search_layer(
&self,
query: &[f32],
entry_points: &[u64],
ef: usize,
layer: usize,
) -> Vec<(u64, f32)> {
let mut visited = hashbrown::HashSet::new();
let mut candidates: BinaryHeap<MinCandidate> = BinaryHeap::new();
let mut results: BinaryHeap<MaxCandidate> = BinaryHeap::new();
for &ep in entry_points {
if visited.insert(ep) {
let dist = self.distance_to_query(query, ep);
candidates.push(MinCandidate {
id: ep,
distance: dist,
});
results.push(MaxCandidate {
id: ep,
distance: dist,
});
}
}
while let Some(MinCandidate {
id: current,
distance: current_dist,
}) = candidates.pop()
{
let furthest_dist = results.peek().map(|r| r.distance).unwrap_or(f32::MAX);
if current_dist > furthest_dist {
break;
}
let neighbors = self.layers[layer].get_neighbors(current);
for &neighbor in neighbors {
if visited.insert(neighbor) {
let dist = self.distance_to_query(query, neighbor);
if dist < furthest_dist || results.len() < ef {
candidates.push(MinCandidate {
id: neighbor,
distance: dist,
});
results.push(MaxCandidate {
id: neighbor,
distance: dist,
});
while results.len() > ef {
results.pop();
}
}
}
}
}
let mut result_vec: Vec<_> = results.into_iter().map(|c| (c.id, c.distance)).collect();
result_vec.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
result_vec
}
fn select_neighbors_simple(&self, candidates: &[(u64, f32)], m: usize) -> Vec<u64> {
candidates.iter().take(m).map(|(id, _)| *id).collect()
}
fn select_neighbors_heuristic(
&self,
query_id: u64,
candidates: &[(u64, f32)],
m: usize,
layer: usize,
extend_candidates: bool,
) -> Vec<u64> {
if candidates.len() <= m {
return candidates.iter().map(|(id, _)| *id).collect();
}
let mut working_candidates = candidates.to_vec();
if extend_candidates {
let mut extended = hashbrown::HashSet::new();
for (id, _) in &working_candidates {
extended.insert(*id);
}
for (id, _) in candidates {
for &neighbor in self.layers[layer].get_neighbors(*id) {
if neighbor != query_id && extended.insert(neighbor) {
let dist = self.distance(query_id, neighbor);
working_candidates.push((neighbor, dist));
}
}
}
working_candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
}
let mut selected = Vec::with_capacity(m);
let mut selected_set = hashbrown::HashSet::new();
for (id, dist) in &working_candidates {
if selected.len() >= m {
break;
}
let mut good = true;
for &sel_id in &selected {
let dist_to_selected = self.distance(*id, sel_id);
if dist_to_selected < *dist {
good = false;
break;
}
}
if good && selected_set.insert(*id) {
selected.push(*id);
}
}
if selected.len() < m {
for (id, _) in &working_candidates {
if selected.len() >= m {
break;
}
if selected_set.insert(*id) {
selected.push(*id);
}
}
}
selected
}
pub fn insert(&mut self, id: u64, embedding: &[f32]) -> Result<(), VectorError> {
if self.dimensions == 0 {
self.dimensions = embedding.len();
} else if embedding.len() != self.dimensions {
return Err(VectorError::DimensionMismatch {
expected: self.dimensions,
actual: embedding.len(),
});
}
self.vectors.insert(id, embedding.to_vec());
self.seed_rng(id);
let level = self.random_level();
self.ensure_layers(level);
if self.entry_point.is_none() {
self.entry_point = Some(id);
self.max_layer = level;
for l in 0..=level {
self.layers[l].add_node(id);
}
return Ok(());
}
let entry_point = self.entry_point.unwrap();
let mut current = entry_point;
for l in (level + 1..=self.max_layer).rev() {
current = self.search_layer_single(embedding, current, l);
}
for l in (0..=level.min(self.max_layer)).rev() {
self.layers[l].add_node(id);
let candidates = self.search_layer(embedding, &[current], self.ef_construction, l);
let m = if l == 0 { self.m_max0 } else { self.m };
let neighbors = self.select_neighbors_heuristic(id, &candidates, m, l, true);
self.layers[l].set_neighbors(id, neighbors.clone());
for &neighbor in &neighbors {
let mut neighbor_neighbors: Vec<u64> =
self.layers[l].get_neighbors(neighbor).to_vec();
if !neighbor_neighbors.contains(&id) {
neighbor_neighbors.push(id);
if neighbor_neighbors.len() > m {
let mut with_dist: Vec<_> = neighbor_neighbors
.iter()
.map(|&n| (n, self.distance(neighbor, n)))
.collect();
with_dist.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
neighbor_neighbors =
with_dist.into_iter().take(m).map(|(n, _)| n).collect();
}
self.layers[l].set_neighbors(neighbor, neighbor_neighbors);
}
}
if !candidates.is_empty() {
current = candidates[0].0;
}
}
if level > self.max_layer {
self.entry_point = Some(id);
self.max_layer = level;
}
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<VectorSearchResult> {
if self.is_empty() || query.len() != self.dimensions {
return Vec::new();
}
let entry_point = match self.entry_point {
Some(ep) => ep,
None => return Vec::new(),
};
let mut current = entry_point;
for l in (1..=self.max_layer).rev() {
current = self.search_layer_single(query, current, l);
}
let ef = self.ef_search.max(k);
let candidates = self.search_layer(query, &[current], ef, 0);
candidates
.into_iter()
.take(k)
.map(|(id, distance)| {
let score = match self.metric {
DistanceMetric::Cosine => 1.0 - distance, DistanceMetric::DotProduct => -distance, _ => 1.0 / (1.0 + distance), };
VectorSearchResult::new(id, score, distance)
})
.collect()
}
pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<VectorSearchResult> {
if self.is_empty() || query.len() != self.dimensions {
return Vec::new();
}
let entry_point = match self.entry_point {
Some(ep) => ep,
None => return Vec::new(),
};
let mut current = entry_point;
for l in (1..=self.max_layer).rev() {
current = self.search_layer_single(query, current, l);
}
let candidates = self.search_layer(query, &[current], ef.max(k), 0);
candidates
.into_iter()
.take(k)
.map(|(id, distance)| {
let score = match self.metric {
DistanceMetric::Cosine => 1.0 - distance,
DistanceMetric::DotProduct => -distance,
_ => 1.0 / (1.0 + distance),
};
VectorSearchResult::new(id, score, distance)
})
.collect()
}
pub fn delete(&mut self, id: u64) -> Result<(), VectorError> {
if !self.vectors.contains_key(&id) {
return Err(VectorError::ObjectNotFound(id));
}
for layer in &mut self.layers {
layer.remove_node(id);
}
self.vectors.remove(&id);
if self.entry_point == Some(id) {
self.entry_point = None;
self.max_layer = 0;
for (l, layer) in self.layers.iter().enumerate().rev() {
if !layer.is_empty() {
if let Some(&first) = layer.neighbors.keys().next() {
self.entry_point = Some(first);
self.max_layer = l;
break;
}
}
}
}
Ok(())
}
pub fn get_vector(&self, id: u64) -> Option<&[f32]> {
self.vectors.get(&id).map(|v| v.as_slice())
}
pub fn contains(&self, id: u64) -> bool {
self.vectors.contains_key(&id)
}
pub fn get_ids(&self) -> Vec<u64> {
self.vectors.keys().copied().collect()
}
pub fn stats(&self) -> HnswStats {
let mut layer_sizes = Vec::new();
let mut total_edges = 0;
for layer in &self.layers {
layer_sizes.push(layer.len());
for neighbors in layer.neighbors.values() {
total_edges += neighbors.len();
}
}
HnswStats {
vector_count: self.vectors.len(),
dimensions: self.dimensions,
layer_count: self.layers.len(),
layer_sizes,
total_edges: total_edges / 2, m: self.m,
ef_construction: self.ef_construction,
ef_search: self.ef_search,
metric: self.metric,
entry_point: self.entry_point,
}
}
pub fn metric(&self) -> DistanceMetric {
self.metric
}
pub fn entry_point(&self) -> Option<u64> {
self.entry_point
}
}
#[derive(Debug, Clone)]
pub struct HnswStats {
pub vector_count: usize,
pub dimensions: usize,
pub layer_count: usize,
pub layer_sizes: Vec<usize>,
pub total_edges: usize,
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub metric: DistanceMetric,
pub entry_point: Option<u64>,
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
let mut rng = seed;
(0..dim)
.map(|_| {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
((rng >> 33) as f32 / (1u64 << 31) as f32) * 2.0 - 1.0
})
.collect()
}
fn normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[test]
fn test_hnsw_empty() {
let index = HnswIndex::new(16, 200);
assert!(index.is_empty());
assert_eq!(index.len(), 0);
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 10);
assert!(results.is_empty());
}
#[test]
fn test_hnsw_single_insert() {
let mut index = HnswIndex::new(16, 200);
let embedding = vec![1.0, 0.0, 0.0];
index.insert(1, &embedding).unwrap();
assert_eq!(index.len(), 1);
assert!(!index.is_empty());
assert!(index.contains(1));
assert!(!index.contains(2));
}
#[test]
fn test_hnsw_search_exact() {
let mut index = HnswIndex::new(16, 200);
index.insert(1, &[1.0, 0.0, 0.0]).unwrap();
index.insert(2, &[0.0, 1.0, 0.0]).unwrap();
index.insert(3, &[0.0, 0.0, 1.0]).unwrap();
let results = index.search(&[1.0, 0.0, 0.0], 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].object_id, 1);
assert!(results[0].distance < 0.01);
}
#[test]
fn test_hnsw_search_nearest() {
let mut index = HnswIndex::new(16, 200);
index.insert(1, &[1.0, 0.0, 0.0]).unwrap();
index.insert(2, &[0.9, 0.1, 0.0]).unwrap(); index.insert(3, &[0.0, 1.0, 0.0]).unwrap();
index.insert(4, &[0.0, 0.0, 1.0]).unwrap();
let results = index.search(&[0.95, 0.05, 0.0], 2);
assert!(results.len() >= 2);
assert!(results[0].object_id == 1 || results[0].object_id == 2);
}
#[test]
fn test_hnsw_delete() {
let mut index = HnswIndex::new(16, 200);
index.insert(1, &[1.0, 0.0, 0.0]).unwrap();
index.insert(2, &[0.0, 1.0, 0.0]).unwrap();
index.insert(3, &[0.0, 0.0, 1.0]).unwrap();
assert_eq!(index.len(), 3);
index.delete(2).unwrap();
assert_eq!(index.len(), 2);
assert!(!index.contains(2));
let results = index.search(&[0.0, 1.0, 0.0], 10);
for r in &results {
assert_ne!(r.object_id, 2);
}
}
#[test]
fn test_hnsw_dimension_mismatch() {
let mut index = HnswIndex::new(16, 200);
index.insert(1, &[1.0, 0.0, 0.0]).unwrap();
let result = index.insert(2, &[1.0, 0.0]); assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
}
#[test]
fn test_hnsw_larger_dataset() {
let mut index = HnswIndex::new(16, 200);
let dim = 64;
let n = 100;
for i in 0..n {
let mut v = random_vector(dim, i as u64);
normalize(&mut v);
index.insert(i as u64, &v).unwrap();
}
assert_eq!(index.len(), n);
let query = random_vector(dim, 999);
let results = index.search(&query, 10);
assert!(results.len() <= 10);
assert!(!results.is_empty());
let stats = index.stats();
assert_eq!(stats.vector_count, n);
assert_eq!(stats.dimensions, dim);
}
#[test]
fn test_hnsw_recall() {
let mut index = HnswIndex::new(16, 200);
index.set_ef_search(100);
let dim = 32;
let n = 200;
let mut vectors: Vec<Vec<f32>> = Vec::new();
for i in 0..n {
let mut v = random_vector(dim, i as u64);
normalize(&mut v);
vectors.push(v.clone());
index.insert(i as u64, &v).unwrap();
}
let mut total_recall = 0.0;
let num_queries = 10;
let k = 10;
for q in 0..num_queries {
let query = &vectors[q * 10];
let approx_results = index.search(query, k);
let approx_ids: hashbrown::HashSet<_> =
approx_results.iter().map(|r| r.object_id).collect();
let mut exact: Vec<_> = (0..n as u64)
.map(|i| {
let dist = cosine_distance(query, &vectors[i as usize]);
(i, dist)
})
.collect();
exact.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let mut hits = 0;
for (id, _) in exact.iter().take(k) {
if approx_ids.contains(id) {
hits += 1;
}
}
total_recall += hits as f32 / k as f32;
}
let avg_recall = total_recall / num_queries as f32;
assert!(
avg_recall > 0.8,
"Average recall {} is too low (expected > 0.8)",
avg_recall
);
}
#[test]
fn test_hnsw_get_vector() {
let mut index = HnswIndex::new(16, 200);
let embedding = vec![1.0, 2.0, 3.0];
index.insert(42, &embedding).unwrap();
let retrieved = index.get_vector(42).unwrap();
assert_eq!(retrieved, &[1.0, 2.0, 3.0]);
assert!(index.get_vector(999).is_none());
}
}