1use std::cell::RefCell;
8
9use crate::distance::DistanceMetric;
10use crate::hnsw::arena::BeamSearchArena;
11use crate::hnsw::flat_neighbors::FlatNeighborStore;
12use crate::hnsw::graph::{ARENA_INITIAL_CAPACITY, HnswIndex, Node, NodeStorage, Xorshift64};
13
14const HNSW_RKYV_MAGIC: &[u8; 6] = b"RKHNS\0";
16pub const HNSW_FORMAT_VERSION: u8 = 1;
18
19#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
22pub(crate) struct HnswSnapshotRkyv {
23 pub dim: usize,
24 pub m: usize,
25 pub m0: usize,
26 pub ef_construction: usize,
27 pub metric: u8,
28 pub entry_point: Option<u32>,
29 pub max_layer: usize,
30 pub rng_state: u64,
31 pub node_vectors: Vec<Vec<f32>>,
33 pub node_neighbors: Vec<Vec<Vec<u32>>>,
35 pub node_deleted: Vec<bool>,
37}
38
39impl HnswIndex {
40 pub fn graph_checkpoint_to_bytes(&self) -> Vec<u8> {
51 let snapshot = HnswSnapshotRkyv {
52 dim: self.dim,
53 m: self.params.m,
54 m0: self.params.m0,
55 ef_construction: self.params.ef_construction,
56 metric: self.params.metric as u8,
57 entry_point: self.entry_point,
58 max_layer: self.max_layer,
59 rng_state: self.rng.0,
60 node_vectors: vec![Vec::new(); self.nodes.len()],
62 node_neighbors: if let Some(ref flat) = self.flat_neighbors {
63 flat.to_nested(self.nodes.len())
64 } else {
65 self.nodes.iter().map(|n| n.neighbors.clone()).collect()
66 },
67 node_deleted: self.nodes.iter().map(|n| n.deleted).collect(),
68 };
69 let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
70 .expect("HNSW graph-only rkyv serialization should not fail");
71 let mut buf = Vec::with_capacity(HNSW_RKYV_MAGIC.len() + 1 + rkyv_bytes.len());
72 buf.extend_from_slice(HNSW_RKYV_MAGIC);
73 buf.push(HNSW_FORMAT_VERSION);
74 buf.extend_from_slice(&rkyv_bytes);
75 buf
76 }
77
78 pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
82 let snapshot = HnswSnapshotRkyv {
83 dim: self.dim,
84 m: self.params.m,
85 m0: self.params.m0,
86 ef_construction: self.params.ef_construction,
87 metric: self.params.metric as u8,
88 entry_point: self.entry_point,
89 max_layer: self.max_layer,
90 rng_state: self.rng.0,
91 node_vectors: self.export_vectors(),
92 node_neighbors: if let Some(ref flat) = self.flat_neighbors {
93 flat.to_nested(self.nodes.len())
94 } else {
95 self.nodes.iter().map(|n| n.neighbors.clone()).collect()
96 },
97 node_deleted: self.nodes.iter().map(|n| n.deleted).collect(),
98 };
99 let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
100 .expect("HNSW rkyv serialization should not fail");
101 let mut buf = Vec::with_capacity(HNSW_RKYV_MAGIC.len() + 1 + rkyv_bytes.len());
102 buf.extend_from_slice(HNSW_RKYV_MAGIC);
103 buf.push(HNSW_FORMAT_VERSION);
104 buf.extend_from_slice(&rkyv_bytes);
105 buf
106 }
107
108 pub fn from_checkpoint(bytes: &[u8]) -> Result<Option<Self>, crate::error::VectorError> {
117 let header_len = HNSW_RKYV_MAGIC.len() + 1; if bytes.len() > header_len && &bytes[..HNSW_RKYV_MAGIC.len()] == HNSW_RKYV_MAGIC {
119 let version = bytes[HNSW_RKYV_MAGIC.len()];
120 if version != HNSW_FORMAT_VERSION {
121 return Err(crate::error::VectorError::UnsupportedVersion {
122 found: version,
123 expected: HNSW_FORMAT_VERSION,
124 });
125 }
126 return Ok(Self::from_rkyv_checkpoint(&bytes[header_len..]));
127 }
128 Ok(None)
130 }
131
132 fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
134 let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
135 aligned.extend_from_slice(bytes);
136 let snap: HnswSnapshotRkyv =
137 rkyv::from_bytes::<HnswSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
138 Self::from_hnsw_snapshot(snap)
139 }
140
141 fn from_hnsw_snapshot(snap: HnswSnapshotRkyv) -> Option<Self> {
143 use nodedb_types::hnsw::HnswParams;
144
145 let metric = match snap.metric {
146 0 => DistanceMetric::L2,
147 1 => DistanceMetric::Cosine,
148 2 => DistanceMetric::InnerProduct,
149 _ => DistanceMetric::Cosine,
150 };
151
152 let flat = FlatNeighborStore::from_nested(&snap.node_neighbors);
153
154 let nodes: Vec<Node> = snap
155 .node_vectors
156 .into_iter()
157 .zip(snap.node_deleted)
158 .map(|(vector, deleted)| Node {
159 storage: NodeStorage::F32(vector),
160 neighbors: Vec::new(),
161 deleted,
162 })
163 .collect();
164
165 let initial_capacity = snap.ef_construction.max(ARENA_INITIAL_CAPACITY);
166 Some(Self {
167 dim: snap.dim,
168 params: HnswParams {
169 m: snap.m,
170 m0: snap.m0,
171 ef_construction: snap.ef_construction,
172 metric,
173 dtype: nodedb_types::vector_dtype::VectorStorageDtype::F32,
174 },
175 nodes,
176 entry_point: snap.entry_point,
177 max_layer: snap.max_layer,
178 rng: Xorshift64::new(snap.rng_state),
179 flat_neighbors: Some(flat),
180 arena: RefCell::new(BeamSearchArena::new(initial_capacity)),
181 #[cfg(not(target_arch = "wasm32"))]
182 backing: None,
183 })
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use crate::distance::DistanceMetric;
190 use crate::hnsw::{HnswIndex, HnswParams};
191
192 fn make_index() -> HnswIndex {
193 HnswIndex::with_seed(
194 3,
195 HnswParams {
196 m: 4,
197 m0: 8,
198 ef_construction: 32,
199 metric: DistanceMetric::L2,
200 dtype: nodedb_types::vector_dtype::VectorStorageDtype::F32,
201 },
202 12345,
203 )
204 }
205
206 #[test]
207 fn checkpoint_roundtrip() {
208 let mut idx = make_index();
209 for i in 0..50 {
210 idx.insert(vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3])
211 .unwrap();
212 }
213
214 let bytes = idx.checkpoint_to_bytes();
215 assert!(!bytes.is_empty());
216
217 let restored = HnswIndex::from_checkpoint(&bytes).unwrap().unwrap();
218 assert_eq!(restored.len(), 50);
219 assert_eq!(restored.dim(), 3);
220 assert_eq!(restored.entry_point(), idx.entry_point());
221 assert_eq!(restored.max_layer(), idx.max_layer());
222
223 let query = vec![1.0, 2.0, 3.0];
224 let orig_results = idx.search(&query, 5, 32);
225 let rest_results = restored.search(&query, 5, 32);
226 assert_eq!(orig_results.len(), rest_results.len());
227 for (a, b) in orig_results.iter().zip(rest_results.iter()) {
228 assert_eq!(a.id, b.id);
229 assert!((a.distance - b.distance).abs() < 1e-6);
230 }
231 }
232
233 #[test]
234 fn golden_header_layout() {
235 let mut idx = make_index();
236 idx.insert(vec![1.0, 2.0, 3.0]).unwrap();
237 let bytes = idx.checkpoint_to_bytes();
238 assert_eq!(&bytes[0..6], b"RKHNS\0");
240 assert_eq!(bytes[6], super::HNSW_FORMAT_VERSION);
242 assert!(bytes.len() > 7);
244 }
245
246 #[test]
247 fn version_mismatch_returns_error() {
248 let mut idx = make_index();
249 idx.insert(vec![1.0, 2.0, 3.0]).unwrap();
250 let mut bytes = idx.checkpoint_to_bytes();
251 bytes[6] = 0;
253 match HnswIndex::from_checkpoint(&bytes) {
254 Err(crate::error::VectorError::UnsupportedVersion { found, expected }) => {
255 assert_eq!(found, 0);
256 assert_eq!(expected, super::HNSW_FORMAT_VERSION);
257 }
258 Err(other) => panic!("unexpected error: {other}"),
259 Ok(_) => panic!("expected UnsupportedVersion error, got Ok"),
260 }
261 }
262}