1use crate::vector_distance::DistanceMetric;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct HnswParams {
9 pub m: usize,
11 pub m0: usize,
13 pub ef_construction: usize,
15 pub metric: DistanceMetric,
17}
18
19impl Default for HnswParams {
20 fn default() -> Self {
21 Self {
22 m: 16,
23 m0: 32,
24 ef_construction: 200,
25 metric: DistanceMetric::Cosine,
26 }
27 }
28}
29
30#[derive(
36 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
37)]
38pub struct HnswNodeSnapshot {
39 pub vector: Vec<f32>,
40 pub neighbors: Vec<Vec<u32>>,
41 pub deleted: bool,
42}
43
44#[derive(
46 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
47)]
48pub struct HnswCheckpoint {
49 pub dim: usize,
50 pub m: usize,
51 pub m0: usize,
52 pub ef_construction: usize,
53 pub metric: u8,
54 pub entry_point: Option<u32>,
55 pub max_layer: usize,
56 pub rng_state: u64,
57 pub nodes: Vec<HnswNodeSnapshot>,
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[test]
65 fn default_params() {
66 let p = HnswParams::default();
67 assert_eq!(p.m, 16);
68 assert_eq!(p.m0, 32);
69 assert_eq!(p.ef_construction, 200);
70 }
71
72 #[test]
73 fn checkpoint_serde_roundtrip() {
74 let snap = HnswCheckpoint {
75 dim: 128,
76 m: 16,
77 m0: 32,
78 ef_construction: 200,
79 metric: 1,
80 entry_point: Some(0),
81 max_layer: 3,
82 rng_state: 42,
83 nodes: vec![HnswNodeSnapshot {
84 vector: vec![0.1, 0.2, 0.3],
85 neighbors: vec![vec![1, 2], vec![3]],
86 deleted: false,
87 }],
88 };
89 let bytes = zerompk::to_msgpack_vec(&snap).unwrap();
90 let restored: HnswCheckpoint = zerompk::from_msgpack(&bytes).unwrap();
91 assert_eq!(restored.dim, 128);
92 assert_eq!(restored.nodes.len(), 1);
93 assert_eq!(restored.nodes[0].vector.len(), 3);
94 }
95}