use crate::distance::{DistanceMetric, distance};
pub use nodedb_types::hnsw::HnswParams;
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: u32,
pub distance: f32,
}
pub(crate) struct Node {
pub vector: Vec<f32>,
pub neighbors: Vec<Vec<u32>>,
pub deleted: bool,
}
pub struct HnswIndex {
pub(crate) params: HnswParams,
pub(crate) dim: usize,
pub(crate) nodes: Vec<Node>,
pub(crate) entry_point: Option<u32>,
pub(crate) max_layer: usize,
pub(crate) rng: Xorshift64,
}
pub(crate) struct Xorshift64(pub u64);
impl Xorshift64 {
pub fn new(seed: u64) -> Self {
Self(seed.max(1))
}
pub fn next_f64(&mut self) -> f64 {
self.0 ^= self.0 << 13;
self.0 ^= self.0 >> 7;
self.0 ^= self.0 << 17;
(self.0 as f64) / (u64::MAX as f64)
}
}
#[derive(Clone, Copy, PartialEq)]
pub(crate) struct Candidate {
pub dist: f32,
pub id: u32,
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.dist
.partial_cmp(&other.dist)
.unwrap_or(std::cmp::Ordering::Equal)
.then(self.id.cmp(&other.id))
}
}
impl HnswIndex {
pub fn new(dim: usize, params: HnswParams) -> Self {
Self {
dim,
nodes: Vec::new(),
entry_point: None,
max_layer: 0,
rng: Xorshift64::new(42),
params,
}
}
pub fn with_seed(dim: usize, params: HnswParams, seed: u64) -> Self {
Self {
dim,
nodes: Vec::new(),
entry_point: None,
max_layer: 0,
rng: Xorshift64::new(seed),
params,
}
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn live_count(&self) -> usize {
self.nodes.len() - self.tombstone_count()
}
pub fn tombstone_count(&self) -> usize {
self.nodes.iter().filter(|n| n.deleted).count()
}
pub fn is_empty(&self) -> bool {
self.live_count() == 0
}
pub fn delete(&mut self, id: u32) -> bool {
if let Some(node) = self.nodes.get_mut(id as usize) {
if node.deleted {
return false;
}
node.deleted = true;
true
} else {
false
}
}
pub fn is_deleted(&self, id: u32) -> bool {
self.nodes.get(id as usize).is_none_or(|n| n.deleted)
}
pub fn undelete(&mut self, id: u32) -> bool {
if let Some(node) = self.nodes.get_mut(id as usize)
&& node.deleted
{
node.deleted = false;
return true;
}
false
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
self.nodes.get(id as usize).map(|n| n.vector.as_slice())
}
pub fn params(&self) -> &HnswParams {
&self.params
}
pub fn entry_point(&self) -> Option<u32> {
self.entry_point
}
pub fn max_layer(&self) -> usize {
self.max_layer
}
pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct Snapshot {
dim: usize,
m: usize,
m0: usize,
ef_construction: usize,
metric: u8,
entry_point: Option<u32>,
max_layer: usize,
rng_state: u64,
nodes: Vec<NodeSnap>,
}
#[derive(Serialize, Deserialize)]
struct NodeSnap {
vector: Vec<f32>,
neighbors: Vec<Vec<u32>>,
deleted: bool,
}
let snapshot = Snapshot {
dim: self.dim,
m: self.params.m,
m0: self.params.m0,
ef_construction: self.params.ef_construction,
metric: self.params.metric as u8,
entry_point: self.entry_point,
max_layer: self.max_layer,
rng_state: self.rng.0,
nodes: self
.nodes
.iter()
.map(|n| NodeSnap {
vector: n.vector.clone(),
neighbors: n.neighbors.clone(),
deleted: n.deleted,
})
.collect(),
};
match rmp_serde::to_vec_named(&snapshot) {
Ok(bytes) => bytes,
Err(e) => {
tracing::error!(error = %e, "HNSW checkpoint serialization failed");
Vec::new()
}
}
}
pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct Snapshot {
dim: usize,
m: usize,
m0: usize,
ef_construction: usize,
metric: u8,
entry_point: Option<u32>,
max_layer: usize,
rng_state: u64,
nodes: Vec<NodeSnap>,
}
#[derive(Serialize, Deserialize)]
struct NodeSnap {
vector: Vec<f32>,
neighbors: Vec<Vec<u32>>,
deleted: bool,
}
let snap: Snapshot = rmp_serde::from_slice(bytes).ok()?;
let metric = match snap.metric {
0 => DistanceMetric::L2,
1 => DistanceMetric::Cosine,
2 => DistanceMetric::InnerProduct,
_ => DistanceMetric::Cosine,
};
let nodes: Vec<Node> = snap
.nodes
.into_iter()
.map(|n| Node {
vector: n.vector,
neighbors: n.neighbors,
deleted: n.deleted,
})
.collect();
Some(Self {
dim: snap.dim,
params: HnswParams {
m: snap.m,
m0: snap.m0,
ef_construction: snap.ef_construction,
metric,
},
nodes,
entry_point: snap.entry_point,
max_layer: snap.max_layer,
rng: Xorshift64::new(snap.rng_state),
})
}
pub(crate) fn random_layer(&mut self) -> usize {
let ml = 1.0 / (self.params.m as f64).ln();
let r = self.rng.next_f64().max(f64::MIN_POSITIVE);
(-r.ln() * ml).floor() as usize
}
pub(crate) fn dist_to_node(&self, query: &[f32], node_id: u32) -> f32 {
distance(
query,
&self.nodes[node_id as usize].vector,
self.params.metric,
)
}
pub(crate) fn max_neighbors(&self, layer: usize) -> usize {
if layer == 0 {
self.params.m0
} else {
self.params.m
}
}
pub fn compact(&mut self) -> usize {
let tombstone_count = self.tombstone_count();
if tombstone_count == 0 {
return 0;
}
let mut id_map: Vec<u32> = Vec::with_capacity(self.nodes.len());
let mut new_id = 0u32;
for node in &self.nodes {
if node.deleted {
id_map.push(u32::MAX);
} else {
id_map.push(new_id);
new_id += 1;
}
}
let mut new_nodes: Vec<Node> = Vec::with_capacity(new_id as usize);
for node in self.nodes.drain(..) {
if node.deleted {
continue;
}
let remapped_neighbors: Vec<Vec<u32>> = node
.neighbors
.into_iter()
.map(|layer_neighbors| {
layer_neighbors
.into_iter()
.filter_map(|old_nid| {
let new_nid = id_map[old_nid as usize];
if new_nid == u32::MAX {
None
} else {
Some(new_nid)
}
})
.collect()
})
.collect();
new_nodes.push(Node {
vector: node.vector,
neighbors: remapped_neighbors,
deleted: false,
});
}
self.entry_point = if let Some(old_ep) = self.entry_point {
let new_ep = id_map[old_ep as usize];
if new_ep == u32::MAX {
new_nodes
.iter()
.enumerate()
.max_by_key(|(_, n)| n.neighbors.len())
.map(|(i, _)| i as u32)
} else {
Some(new_ep)
}
} else {
None
};
self.max_layer = new_nodes
.iter()
.map(|n| n.neighbors.len().saturating_sub(1))
.max()
.unwrap_or(0);
self.nodes = new_nodes;
tombstone_count
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_empty_index() {
let idx = HnswIndex::new(3, HnswParams::default());
assert_eq!(idx.len(), 0);
assert!(idx.is_empty());
assert!(idx.entry_point().is_none());
}
#[test]
fn params_default() {
let p = HnswParams::default();
assert_eq!(p.m, 16);
assert_eq!(p.m0, 32);
assert_eq!(p.ef_construction, 200);
assert_eq!(p.metric, DistanceMetric::Cosine);
}
#[test]
fn candidate_ordering() {
let a = Candidate { dist: 0.1, id: 1 };
let b = Candidate { dist: 0.5, id: 2 };
assert!(a < b);
}
}