use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use serde::{Deserialize, Serialize};
use crate::distance::Distance;
pub type NodeId = u64;
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct HnswParams {
pub m: usize,
pub m0: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub seed: u64,
}
impl Default for HnswParams {
fn default() -> Self {
Self {
m: 16,
m0: 32,
ef_construction: 200,
ef_search: 50,
seed: 0xDEAD_BEEF_CAFE_F00D,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct HnswNode {
id: NodeId,
vector: Vec<f32>,
levels: Vec<Vec<usize>>,
deleted: bool,
}
impl HnswNode {
fn level(&self) -> usize {
self.levels.len().saturating_sub(1)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct HnswIndex {
params: HnswParams,
distance: Distance,
nodes: Vec<HnswNode>,
id_to_idx: HashMap<NodeId, usize>,
entry: Option<usize>,
ml: f64,
rng_state: u64,
dim: u16,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum IndexError {
#[error("dimension mismatch: index has {expected}, got {got}")]
DimensionMismatch {
expected: u16,
got: u16,
},
#[error("id {0} already present in the index")]
Duplicate(NodeId),
#[error("empty vector")]
Empty,
}
#[derive(Clone, Debug, PartialEq)]
pub struct SearchResult {
pub id: NodeId,
pub score: f32,
}
#[derive(Clone, Copy, Debug)]
struct Candidate {
idx: usize,
score: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Clone, Copy, Debug)]
struct MaxCandidate {
idx: usize,
score: f32,
}
impl PartialEq for MaxCandidate {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl Eq for MaxCandidate {}
impl PartialOrd for MaxCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MaxCandidate {
fn cmp(&self, other: &Self) -> Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(Ordering::Equal)
}
}
impl HnswIndex {
#[must_use]
pub fn new(distance: Distance, params: HnswParams) -> Self {
let ml = if params.m > 1 {
1.0 / f64::from(u32::try_from(params.m).unwrap_or(u32::MAX)).ln()
} else {
1.0
};
Self {
params,
distance,
nodes: Vec::new(),
id_to_idx: HashMap::new(),
entry: None,
ml,
rng_state: params.seed,
dim: 0,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.nodes.iter().filter(|n| !n.deleted).count()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn dim(&self) -> u16 {
self.dim
}
#[must_use]
pub fn distance(&self) -> Distance {
self.distance
}
pub fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
if vector.is_empty() {
return Err(IndexError::Empty);
}
let got = u16::try_from(vector.len()).unwrap_or(u16::MAX);
if self.nodes.is_empty() {
self.dim = got;
} else if self.dim != got {
return Err(IndexError::DimensionMismatch {
expected: self.dim,
got,
});
}
if self.id_to_idx.contains_key(&id) {
return Err(IndexError::Duplicate(id));
}
let level = self.random_level();
let mut levels: Vec<Vec<usize>> = Vec::with_capacity(level + 1);
for _ in 0..=level {
levels.push(Vec::new());
}
let new_idx = self.nodes.len();
self.nodes.push(HnswNode {
id,
vector,
levels,
deleted: false,
});
self.id_to_idx.insert(id, new_idx);
let Some(entry) = self.entry else {
self.entry = Some(new_idx);
return Ok(());
};
let entry_level = self.nodes[entry].level();
let mut current = entry;
if entry_level > level {
for lc in (level + 1..=entry_level).rev() {
current = self.greedy_search_layer(current, new_idx, lc);
}
}
let start_layer = level.min(entry_level);
let mut entry_points = vec![current];
for lc in (0..=start_layer).rev() {
let neighbours = self.search_layer(
new_idx,
&entry_points,
lc,
self.params.ef_construction,
true,
);
let m = if lc == 0 {
self.params.m0
} else {
self.params.m
};
let selected = Self::select_neighbours(&neighbours, m);
for &nb in &selected {
self.nodes[new_idx].levels[lc].push(nb);
self.nodes[nb].levels[lc].push(new_idx);
let cap = if lc == 0 {
self.params.m0
} else {
self.params.m
};
if self.nodes[nb].levels[lc].len() > cap {
self.shrink_connections(nb, lc, cap);
}
}
entry_points = selected;
if entry_points.is_empty() {
entry_points = vec![current];
}
}
if level > entry_level {
self.entry = Some(new_idx);
}
Ok(())
}
pub fn delete(&mut self, id: NodeId) -> bool {
let Some(&idx) = self.id_to_idx.get(&id) else {
return false;
};
if self.nodes[idx].deleted {
return false;
}
self.nodes[idx].deleted = true;
true
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef: Option<usize>,
) -> Result<Vec<SearchResult>, IndexError> {
if query.is_empty() {
return Ok(Vec::new());
}
if self.nodes.is_empty() {
return Ok(Vec::new());
}
let got = u16::try_from(query.len()).unwrap_or(u16::MAX);
if self.dim != got {
return Err(IndexError::DimensionMismatch {
expected: self.dim,
got,
});
}
let mut entry = self.entry.unwrap_or(0);
let entry_level = self.nodes[entry].level();
let ef = ef.unwrap_or(self.params.ef_search).max(k);
let query_owned = query.to_vec();
for lc in (1..=entry_level).rev() {
entry = self.greedy_search_layer_against(&query_owned, entry, lc);
}
let candidates = self.search_layer_against(&query_owned, &[entry], 0, ef, true);
let mut sorted: Vec<MaxCandidate> = candidates;
sorted.sort_by(|a, b| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(sorted
.into_iter()
.filter(|c| !self.nodes[c.idx].deleted)
.take(k)
.map(|c| SearchResult {
id: self.nodes[c.idx].id,
score: c.score,
})
.collect())
}
#[must_use]
pub fn contains(&self, id: NodeId) -> bool {
self.id_to_idx
.get(&id)
.is_some_and(|&idx| !self.nodes[idx].deleted)
}
fn random_level(&mut self) -> usize {
let r = self.rand_unit();
let r = r.max(f64::MIN_POSITIVE);
let level = (-r.ln() * self.ml).floor();
let max_level = 16_f64;
let clamped = level.clamp(0.0, max_level);
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "clamped to [0, 16]"
)]
let lvl = clamped as usize;
lvl
}
fn rand_unit(&mut self) -> f64 {
let mut x = self.rng_state;
x ^= x >> 12;
x ^= x << 25;
x ^= x >> 27;
self.rng_state = x;
let r = x.wrapping_mul(0x2545_F491_4F6C_DD1D);
let bits = (r >> 11) & ((1u64 << 53) - 1);
#[allow(
clippy::cast_precision_loss,
reason = "bits is in [0, 2^53), exactly representable as f64"
)]
let f = (bits as f64) / ((1_u64 << 53) as f64);
f
}
fn greedy_search_layer(&self, entry: usize, query_idx: usize, lc: usize) -> usize {
let q = self.nodes[query_idx].vector.clone();
self.greedy_search_layer_against(&q, entry, lc)
}
fn greedy_search_layer_against(&self, query: &[f32], entry: usize, lc: usize) -> usize {
let mut current = entry;
let mut current_score = self.distance.score(query, &self.nodes[current].vector);
loop {
let mut improved = false;
if lc < self.nodes[current].levels.len() {
let neighbours: Vec<usize> = self.nodes[current].levels[lc].clone();
for nb in neighbours {
let s = self.distance.score(query, &self.nodes[nb].vector);
if s < current_score {
current_score = s;
current = nb;
improved = true;
}
}
}
if !improved {
break;
}
}
current
}
fn search_layer(
&self,
query_idx: usize,
entry_points: &[usize],
lc: usize,
ef: usize,
include_deleted: bool,
) -> Vec<MaxCandidate> {
let q = self.nodes[query_idx].vector.clone();
self.search_layer_against(&q, entry_points, lc, ef, include_deleted)
}
fn search_layer_against(
&self,
query: &[f32],
entry_points: &[usize],
lc: usize,
ef: usize,
include_deleted: bool,
) -> Vec<MaxCandidate> {
let mut visited: HashSet<usize> = HashSet::new();
let mut frontier: BinaryHeap<Candidate> = BinaryHeap::new();
let mut top: BinaryHeap<MaxCandidate> = BinaryHeap::new();
for &ep in entry_points {
if visited.insert(ep) {
let s = self.distance.score(query, &self.nodes[ep].vector);
frontier.push(Candidate { idx: ep, score: s });
if include_deleted || !self.nodes[ep].deleted {
top.push(MaxCandidate { idx: ep, score: s });
}
}
}
while let Some(c) = frontier.pop() {
if top.len() >= ef {
if let Some(worst) = top.peek() {
if c.score > worst.score {
break;
}
}
}
if lc < self.nodes[c.idx].levels.len() {
let neighbours: Vec<usize> = self.nodes[c.idx].levels[lc].clone();
for nb in neighbours {
if !visited.insert(nb) {
continue;
}
let s = self.distance.score(query, &self.nodes[nb].vector);
let admit = match top.peek() {
Some(worst) => s < worst.score || top.len() < ef,
None => true,
};
if admit {
frontier.push(Candidate { idx: nb, score: s });
if include_deleted || !self.nodes[nb].deleted {
top.push(MaxCandidate { idx: nb, score: s });
if top.len() > ef {
top.pop();
}
}
}
}
}
}
top.into_vec()
}
fn select_neighbours(candidates: &[MaxCandidate], m: usize) -> Vec<usize> {
let mut sorted: Vec<MaxCandidate> = candidates.to_vec();
sorted.sort_by(|a, b| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted.into_iter().take(m).map(|c| c.idx).collect()
}
fn shrink_connections(&mut self, idx: usize, lc: usize, cap: usize) {
let q = self.nodes[idx].vector.clone();
let neighbours = std::mem::take(&mut self.nodes[idx].levels[lc]);
let mut scored: Vec<(usize, f32)> = neighbours
.into_iter()
.map(|nb| {
let s = self.distance.score(&q, &self.nodes[nb].vector);
(nb, s)
})
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(cap);
self.nodes[idx].levels[lc] = scored.into_iter().map(|(nb, _)| nb).collect();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distance::Distance;
fn unit(seed: u64, dim: usize) -> Vec<f32> {
let mut x = seed;
let mut v: Vec<f32> = Vec::with_capacity(dim);
for _ in 0..dim {
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
let bits = (x >> 11) & ((1_u64 << 53) - 1);
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
reason = "test fixture; PRNG output narrowed to f32"
)]
let r = ((bits as f64) / ((1_u64 << 53) as f64)) * 2.0 - 1.0;
#[allow(
clippy::cast_possible_truncation,
reason = "test fixture; f64 -> f32 narrowing is intentional"
)]
let rf = r as f32;
v.push(rf);
}
v
}
#[test]
fn insert_and_search_small() {
let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
let target = unit(42, 8);
idx.insert(0, target.clone()).unwrap();
for i in 1..50_u64 {
idx.insert(i, unit(i.wrapping_mul(1_000_003) + 1, 8))
.unwrap();
}
let res = idx.search(&target, 3, None).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].id, 0);
}
#[test]
fn delete_excludes_from_search() {
let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
for i in 0..30_u64 {
idx.insert(i, unit(i + 1, 8)).unwrap();
}
let q = unit(1, 8);
let before = idx.search(&q, 5, None).unwrap();
let target = before[0].id;
assert!(idx.delete(target));
let after = idx.search(&q, 5, None).unwrap();
assert!(after.iter().all(|r| r.id != target));
}
#[test]
fn dimension_mismatch_rejected() {
let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
idx.insert(0, vec![0.1, 0.2, 0.3]).unwrap();
assert!(matches!(
idx.insert(1, vec![0.1, 0.2]),
Err(IndexError::DimensionMismatch { .. })
));
}
#[test]
fn duplicate_id_rejected() {
let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
idx.insert(7, vec![0.1, 0.2]).unwrap();
assert!(matches!(
idx.insert(7, vec![0.3, 0.4]),
Err(IndexError::Duplicate(7))
));
}
#[test]
fn empty_index_search_is_empty() {
let idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
let res = idx.search(&[0.1, 0.2], 5, None).unwrap();
assert!(res.is_empty());
}
}