use crate::vector_distance::DistanceMetric;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswParams {
pub m: usize,
pub m0: usize,
pub ef_construction: usize,
pub metric: DistanceMetric,
}
impl Default for HnswParams {
fn default() -> Self {
Self {
m: 16,
m0: 32,
ef_construction: 200,
metric: DistanceMetric::Cosine,
}
}
}
#[derive(
Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
)]
pub struct HnswNodeSnapshot {
pub vector: Vec<f32>,
pub neighbors: Vec<Vec<u32>>,
pub deleted: bool,
}
#[derive(
Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
)]
pub struct HnswCheckpoint {
pub dim: usize,
pub m: usize,
pub m0: usize,
pub ef_construction: usize,
pub metric: u8,
pub entry_point: Option<u32>,
pub max_layer: usize,
pub rng_state: u64,
pub nodes: Vec<HnswNodeSnapshot>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_params() {
let p = HnswParams::default();
assert_eq!(p.m, 16);
assert_eq!(p.m0, 32);
assert_eq!(p.ef_construction, 200);
}
#[test]
fn checkpoint_serde_roundtrip() {
let snap = HnswCheckpoint {
dim: 128,
m: 16,
m0: 32,
ef_construction: 200,
metric: 1,
entry_point: Some(0),
max_layer: 3,
rng_state: 42,
nodes: vec![HnswNodeSnapshot {
vector: vec![0.1, 0.2, 0.3],
neighbors: vec![vec![1, 2], vec![3]],
deleted: false,
}],
};
let bytes = zerompk::to_msgpack_vec(&snap).unwrap();
let restored: HnswCheckpoint = zerompk::from_msgpack(&bytes).unwrap();
assert_eq!(restored.dim, 128);
assert_eq!(restored.nodes.len(), 1);
assert_eq!(restored.nodes[0].vector.len(), 3);
}
}