use serde::{Deserialize, Serialize};
use super::super::distance::distance;
#[derive(Serialize, Deserialize)]
struct HnswSnapshot {
dim: usize,
params_m: usize,
params_m0: usize,
params_ef_construction: usize,
params_metric: u8,
entry_point: Option<u32>,
max_layer: usize,
rng_state: u64,
nodes: Vec<NodeSnapshot>,
}
#[derive(Serialize, Deserialize)]
struct NodeSnapshot {
vector: Vec<f32>,
neighbors: Vec<Vec<u32>>,
deleted: bool,
}
pub use nodedb_types::hnsw::HnswParams;
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: u32,
pub distance: f32,
}
pub(super) struct Node {
pub vector: Vec<f32>,
pub neighbors: Vec<Vec<u32>>,
pub deleted: bool,
}
pub struct HnswIndex {
pub(super) params: HnswParams,
pub(super) dim: usize,
pub(super) nodes: Vec<Node>,
pub(super) entry_point: Option<u32>,
pub(super) max_layer: usize,
pub(super) rng: Xorshift64,
}
pub(super) struct Xorshift64(u64);
impl Xorshift64 {
pub fn new(seed: u64) -> Self {
Self(seed.max(1))
}
pub fn next_f64(&mut self) -> f64 {
self.0 ^= self.0 << 13;
self.0 ^= self.0 >> 7;
self.0 ^= self.0 << 17;
(self.0 as f64) / (u64::MAX as f64)
}
}
#[derive(Clone, Copy, PartialEq)]
pub(super) struct Candidate {
pub dist: f32,
pub id: u32,
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.dist
.partial_cmp(&other.dist)
.unwrap_or(std::cmp::Ordering::Equal)
.then(self.id.cmp(&other.id))
}
}
impl HnswIndex {
pub fn new(dim: usize, params: HnswParams) -> Self {
Self {
dim,
nodes: Vec::new(),
entry_point: None,
max_layer: 0,
rng: Xorshift64::new(42),
params,
}
}
pub fn with_seed(dim: usize, params: HnswParams, seed: u64) -> Self {
Self {
dim,
nodes: Vec::new(),
entry_point: None,
max_layer: 0,
rng: Xorshift64::new(seed),
params,
}
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn live_count(&self) -> usize {
self.nodes.len() - self.tombstone_count()
}
pub fn tombstone_count(&self) -> usize {
self.nodes.iter().filter(|n| n.deleted).count()
}
pub fn tombstone_ratio(&self) -> f64 {
if self.nodes.is_empty() {
0.0
} else {
self.tombstone_count() as f64 / self.nodes.len() as f64
}
}
pub fn is_empty(&self) -> bool {
self.live_count() == 0
}
pub fn delete(&mut self, id: u32) -> bool {
if let Some(node) = self.nodes.get_mut(id as usize) {
if node.deleted {
return false;
}
node.deleted = true;
true
} else {
false
}
}
pub fn is_deleted(&self, id: u32) -> bool {
self.nodes.get(id as usize).is_none_or(|n| n.deleted)
}
pub fn undelete(&mut self, id: u32) -> bool {
if let Some(node) = self.nodes.get_mut(id as usize)
&& node.deleted
{
node.deleted = false;
return true;
}
false
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn get_vector(&self, id: u32) -> Option<&[f32]> {
self.nodes.get(id as usize).map(|n| n.vector.as_slice())
}
pub fn params(&self) -> &HnswParams {
&self.params
}
pub fn entry_point(&self) -> Option<u32> {
self.entry_point
}
pub fn max_layer(&self) -> usize {
self.max_layer
}
pub fn rng_state(&self) -> u64 {
self.rng.0
}
pub fn export_vectors(&self) -> Vec<Vec<f32>> {
self.nodes.iter().map(|n| n.vector.clone()).collect()
}
pub fn export_neighbors(&self) -> Vec<Vec<Vec<u32>>> {
self.nodes.iter().map(|n| n.neighbors.clone()).collect()
}
pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
let snapshot = HnswSnapshot {
dim: self.dim,
params_m: self.params.m,
params_m0: self.params.m0,
params_ef_construction: self.params.ef_construction,
params_metric: self.params.metric as u8,
entry_point: self.entry_point,
max_layer: self.max_layer,
rng_state: self.rng.0,
nodes: self
.nodes
.iter()
.map(|n| NodeSnapshot {
vector: n.vector.clone(),
neighbors: n.neighbors.clone(),
deleted: n.deleted,
})
.collect(),
};
rmp_serde::to_vec_named(&snapshot).unwrap_or_default()
}
pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
let snapshot: HnswSnapshot = rmp_serde::from_slice(bytes).ok()?;
let metric = match snapshot.params_metric {
0 => super::super::distance::DistanceMetric::L2,
1 => super::super::distance::DistanceMetric::Cosine,
2 => super::super::distance::DistanceMetric::InnerProduct,
_ => super::super::distance::DistanceMetric::Cosine,
};
let params = HnswParams {
m: snapshot.params_m,
m0: snapshot.params_m0,
ef_construction: snapshot.params_ef_construction,
metric,
};
let nodes: Vec<Node> = snapshot
.nodes
.into_iter()
.map(|n| Node {
vector: n.vector,
neighbors: n.neighbors,
deleted: n.deleted,
})
.collect();
Some(Self {
dim: snapshot.dim,
params,
nodes,
entry_point: snapshot.entry_point,
max_layer: snapshot.max_layer,
rng: Xorshift64::new(snapshot.rng_state),
})
}
pub(super) fn random_layer(&mut self) -> usize {
let ml = 1.0 / (self.params.m as f64).ln();
let r = self.rng.next_f64().max(f64::MIN_POSITIVE);
(-r.ln() * ml).floor() as usize
}
pub(super) fn dist_to_node(&self, query: &[f32], node_id: u32) -> f32 {
distance(
query,
&self.nodes[node_id as usize].vector,
self.params.metric,
)
}
pub fn compact(&mut self) -> usize {
let tombstone_count = self.tombstone_count();
if tombstone_count == 0 {
return 0;
}
let mut id_map: Vec<u32> = Vec::with_capacity(self.nodes.len());
let mut new_id = 0u32;
for node in &self.nodes {
if node.deleted {
id_map.push(u32::MAX);
} else {
id_map.push(new_id);
new_id += 1;
}
}
let mut new_nodes: Vec<Node> = Vec::with_capacity(new_id as usize);
for node in self.nodes.drain(..) {
if node.deleted {
continue;
}
let remapped_neighbors: Vec<Vec<u32>> = node
.neighbors
.into_iter()
.map(|layer_neighbors| {
layer_neighbors
.into_iter()
.filter_map(|old_nid| {
let new_nid = id_map[old_nid as usize];
if new_nid == u32::MAX {
None } else {
Some(new_nid)
}
})
.collect()
})
.collect();
new_nodes.push(Node {
vector: node.vector,
neighbors: remapped_neighbors,
deleted: false,
});
}
self.entry_point = if let Some(old_ep) = self.entry_point {
let new_ep = id_map[old_ep as usize];
if new_ep == u32::MAX {
new_nodes
.iter()
.enumerate()
.max_by_key(|(_, n)| n.neighbors.len())
.map(|(i, _)| i as u32)
} else {
Some(new_ep)
}
} else {
None
};
self.max_layer = new_nodes
.iter()
.map(|n| n.neighbors.len().saturating_sub(1))
.max()
.unwrap_or(0);
self.nodes = new_nodes;
tombstone_count
}
pub(super) fn max_neighbors(&self, layer: usize) -> usize {
if layer == 0 {
self.params.m0
} else {
self.params.m
}
}
}