1use crate::vector_distance::DistanceMetric;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct HnswParams {
9 pub m: usize,
11 pub m0: usize,
13 pub ef_construction: usize,
15 pub metric: DistanceMetric,
17}
18
19impl Default for HnswParams {
20 fn default() -> Self {
21 Self {
22 m: 16,
23 m0: 32,
24 ef_construction: 200,
25 metric: DistanceMetric::Cosine,
26 }
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct HnswNodeSnapshot {
37 pub vector: Vec<f32>,
38 pub neighbors: Vec<Vec<u32>>,
39 pub deleted: bool,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct HnswCheckpoint {
45 pub dim: usize,
46 pub m: usize,
47 pub m0: usize,
48 pub ef_construction: usize,
49 pub metric: u8,
50 pub entry_point: Option<u32>,
51 pub max_layer: usize,
52 pub rng_state: u64,
53 pub nodes: Vec<HnswNodeSnapshot>,
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59
60 #[test]
61 fn default_params() {
62 let p = HnswParams::default();
63 assert_eq!(p.m, 16);
64 assert_eq!(p.m0, 32);
65 assert_eq!(p.ef_construction, 200);
66 }
67
68 #[test]
69 fn checkpoint_serde_roundtrip() {
70 let snap = HnswCheckpoint {
71 dim: 128,
72 m: 16,
73 m0: 32,
74 ef_construction: 200,
75 metric: 1,
76 entry_point: Some(0),
77 max_layer: 3,
78 rng_state: 42,
79 nodes: vec![HnswNodeSnapshot {
80 vector: vec![0.1, 0.2, 0.3],
81 neighbors: vec![vec![1, 2], vec![3]],
82 deleted: false,
83 }],
84 };
85 let bytes = rmp_serde::to_vec_named(&snap).unwrap();
86 let restored: HnswCheckpoint = rmp_serde::from_slice(&bytes).unwrap();
87 assert_eq!(restored.dim, 128);
88 assert_eq!(restored.nodes.len(), 1);
89 assert_eq!(restored.nodes[0].vector.len(), 3);
90 }
91}