1use crate::distance::DistanceMetric;
6use crate::hnsw::flat_neighbors::FlatNeighborStore;
7use crate::hnsw::graph::{HnswIndex, Node, Xorshift64};
8
9const HNSW_RKYV_MAGIC: &[u8; 6] = b"RKHNS\0";
11
12#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
15pub(crate) struct HnswSnapshotRkyv {
16 pub dim: usize,
17 pub m: usize,
18 pub m0: usize,
19 pub ef_construction: usize,
20 pub metric: u8,
21 pub entry_point: Option<u32>,
22 pub max_layer: usize,
23 pub rng_state: u64,
24 pub node_vectors: Vec<Vec<f32>>,
26 pub node_neighbors: Vec<Vec<Vec<u32>>>,
28 pub node_deleted: Vec<bool>,
30}
31
32impl HnswIndex {
33 pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
37 let snapshot = HnswSnapshotRkyv {
38 dim: self.dim,
39 m: self.params.m,
40 m0: self.params.m0,
41 ef_construction: self.params.ef_construction,
42 metric: self.params.metric as u8,
43 entry_point: self.entry_point,
44 max_layer: self.max_layer,
45 rng_state: self.rng.0,
46 node_vectors: self.nodes.iter().map(|n| n.vector.clone()).collect(),
47 node_neighbors: if let Some(ref flat) = self.flat_neighbors {
48 flat.to_nested(self.nodes.len())
49 } else {
50 self.nodes.iter().map(|n| n.neighbors.clone()).collect()
51 },
52 node_deleted: self.nodes.iter().map(|n| n.deleted).collect(),
53 };
54 let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
55 .expect("HNSW rkyv serialization should not fail");
56 let mut buf = Vec::with_capacity(HNSW_RKYV_MAGIC.len() + rkyv_bytes.len());
57 buf.extend_from_slice(HNSW_RKYV_MAGIC);
58 buf.extend_from_slice(&rkyv_bytes);
59 buf
60 }
61
62 pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
66 if bytes.len() > HNSW_RKYV_MAGIC.len() && &bytes[..HNSW_RKYV_MAGIC.len()] == HNSW_RKYV_MAGIC
67 {
68 return Self::from_rkyv_checkpoint(&bytes[HNSW_RKYV_MAGIC.len()..]);
69 }
70 Self::from_msgpack_checkpoint(bytes)
71 }
72
73 fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
75 let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
76 aligned.extend_from_slice(bytes);
77 let snap: HnswSnapshotRkyv =
78 rkyv::from_bytes::<HnswSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
79 Self::from_hnsw_snapshot(snap)
80 }
81
82 fn from_msgpack_checkpoint(bytes: &[u8]) -> Option<Self> {
84 use zerompk::{FromMessagePack, ToMessagePack};
85
86 #[derive(ToMessagePack, FromMessagePack)]
87 struct Snapshot {
88 dim: usize,
89 m: usize,
90 m0: usize,
91 ef_construction: usize,
92 metric: u8,
93 entry_point: Option<u32>,
94 max_layer: usize,
95 rng_state: u64,
96 nodes: Vec<NodeSnap>,
97 }
98
99 #[derive(ToMessagePack, FromMessagePack)]
100 struct NodeSnap {
101 vector: Vec<f32>,
102 neighbors: Vec<Vec<u32>>,
103 deleted: bool,
104 }
105
106 let snap: Snapshot = zerompk::from_msgpack(bytes).ok()?;
107 Self::from_hnsw_snapshot(HnswSnapshotRkyv {
108 dim: snap.dim,
109 m: snap.m,
110 m0: snap.m0,
111 ef_construction: snap.ef_construction,
112 metric: snap.metric,
113 entry_point: snap.entry_point,
114 max_layer: snap.max_layer,
115 rng_state: snap.rng_state,
116 node_vectors: snap.nodes.iter().map(|n| n.vector.clone()).collect(),
117 node_neighbors: snap.nodes.iter().map(|n| n.neighbors.clone()).collect(),
118 node_deleted: snap.nodes.iter().map(|n| n.deleted).collect(),
119 })
120 }
121
122 fn from_hnsw_snapshot(snap: HnswSnapshotRkyv) -> Option<Self> {
124 use nodedb_types::hnsw::HnswParams;
125
126 let metric = match snap.metric {
127 0 => DistanceMetric::L2,
128 1 => DistanceMetric::Cosine,
129 2 => DistanceMetric::InnerProduct,
130 _ => DistanceMetric::Cosine,
131 };
132
133 let flat = FlatNeighborStore::from_nested(&snap.node_neighbors);
134
135 let nodes: Vec<Node> = snap
136 .node_vectors
137 .into_iter()
138 .zip(snap.node_deleted)
139 .map(|(vector, deleted)| Node {
140 vector,
141 neighbors: Vec::new(),
142 deleted,
143 })
144 .collect();
145
146 Some(Self {
147 dim: snap.dim,
148 params: HnswParams {
149 m: snap.m,
150 m0: snap.m0,
151 ef_construction: snap.ef_construction,
152 metric,
153 },
154 nodes,
155 entry_point: snap.entry_point,
156 max_layer: snap.max_layer,
157 rng: Xorshift64::new(snap.rng_state),
158 flat_neighbors: Some(flat),
159 })
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use crate::distance::DistanceMetric;
166 use crate::hnsw::{HnswIndex, HnswParams};
167
168 fn make_index() -> HnswIndex {
169 HnswIndex::with_seed(
170 3,
171 HnswParams {
172 m: 4,
173 m0: 8,
174 ef_construction: 32,
175 metric: DistanceMetric::L2,
176 },
177 12345,
178 )
179 }
180
181 #[test]
182 fn checkpoint_roundtrip() {
183 let mut idx = make_index();
184 for i in 0..50 {
185 idx.insert(vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3])
186 .unwrap();
187 }
188
189 let bytes = idx.checkpoint_to_bytes();
190 assert!(!bytes.is_empty());
191
192 let restored = HnswIndex::from_checkpoint(&bytes).unwrap();
193 assert_eq!(restored.len(), 50);
194 assert_eq!(restored.dim(), 3);
195 assert_eq!(restored.entry_point(), idx.entry_point());
196 assert_eq!(restored.max_layer(), idx.max_layer());
197
198 let query = vec![1.0, 2.0, 3.0];
199 let orig_results = idx.search(&query, 5, 32);
200 let rest_results = restored.search(&query, 5, 32);
201 assert_eq!(orig_results.len(), rest_results.len());
202 for (a, b) in orig_results.iter().zip(rest_results.iter()) {
203 assert_eq!(a.id, b.id);
204 assert!((a.distance - b.distance).abs() < 1e-6);
205 }
206 }
207}