1use crate::vector_distance::DistanceMetric;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct HnswParams {
11 pub m: usize,
13 pub m0: usize,
15 pub ef_construction: usize,
17 pub metric: DistanceMetric,
19}
20
21impl Default for HnswParams {
22 fn default() -> Self {
23 Self {
24 m: 16,
25 m0: 32,
26 ef_construction: 200,
27 metric: DistanceMetric::Cosine,
28 }
29 }
30}
31
32#[derive(
38 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
39)]
40pub struct HnswNodeSnapshot {
41 pub vector: Vec<f32>,
42 pub neighbors: Vec<Vec<u32>>,
43 pub deleted: bool,
44}
45
46#[derive(
48 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
49)]
50pub struct HnswCheckpoint {
51 pub dim: usize,
52 pub m: usize,
53 pub m0: usize,
54 pub ef_construction: usize,
55 pub metric: u8,
56 pub entry_point: Option<u32>,
57 pub max_layer: usize,
58 pub rng_state: u64,
59 pub nodes: Vec<HnswNodeSnapshot>,
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 #[test]
67 fn default_params() {
68 let p = HnswParams::default();
69 assert_eq!(p.m, 16);
70 assert_eq!(p.m0, 32);
71 assert_eq!(p.ef_construction, 200);
72 }
73
74 #[test]
75 fn checkpoint_serde_roundtrip() {
76 let snap = HnswCheckpoint {
77 dim: 128,
78 m: 16,
79 m0: 32,
80 ef_construction: 200,
81 metric: 1,
82 entry_point: Some(0),
83 max_layer: 3,
84 rng_state: 42,
85 nodes: vec![HnswNodeSnapshot {
86 vector: vec![0.1, 0.2, 0.3],
87 neighbors: vec![vec![1, 2], vec![3]],
88 deleted: false,
89 }],
90 };
91 let bytes = zerompk::to_msgpack_vec(&snap).unwrap();
92 let restored: HnswCheckpoint = zerompk::from_msgpack(&bytes).unwrap();
93 assert_eq!(restored.dim, 128);
94 assert_eq!(restored.nodes.len(), 1);
95 assert_eq!(restored.nodes[0].vector.len(), 3);
96 }
97}