1use std::cell::RefCell;
8
9use crate::distance::DistanceMetric;
10use crate::hnsw::arena::BeamSearchArena;
11use crate::hnsw::flat_neighbors::FlatNeighborStore;
12use crate::hnsw::graph::{ARENA_INITIAL_CAPACITY, HnswIndex, Node, Xorshift64};
13
14const HNSW_RKYV_MAGIC: &[u8; 6] = b"RKHNS\0";
16pub const HNSW_FORMAT_VERSION: u8 = 1;
18
19#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
22pub(crate) struct HnswSnapshotRkyv {
23 pub dim: usize,
24 pub m: usize,
25 pub m0: usize,
26 pub ef_construction: usize,
27 pub metric: u8,
28 pub entry_point: Option<u32>,
29 pub max_layer: usize,
30 pub rng_state: u64,
31 pub node_vectors: Vec<Vec<f32>>,
33 pub node_neighbors: Vec<Vec<Vec<u32>>>,
35 pub node_deleted: Vec<bool>,
37}
38
39impl HnswIndex {
40 pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
44 let snapshot = HnswSnapshotRkyv {
45 dim: self.dim,
46 m: self.params.m,
47 m0: self.params.m0,
48 ef_construction: self.params.ef_construction,
49 metric: self.params.metric as u8,
50 entry_point: self.entry_point,
51 max_layer: self.max_layer,
52 rng_state: self.rng.0,
53 node_vectors: self.nodes.iter().map(|n| n.vector.clone()).collect(),
54 node_neighbors: if let Some(ref flat) = self.flat_neighbors {
55 flat.to_nested(self.nodes.len())
56 } else {
57 self.nodes.iter().map(|n| n.neighbors.clone()).collect()
58 },
59 node_deleted: self.nodes.iter().map(|n| n.deleted).collect(),
60 };
61 let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
62 .expect("HNSW rkyv serialization should not fail");
63 let mut buf = Vec::with_capacity(HNSW_RKYV_MAGIC.len() + 1 + rkyv_bytes.len());
65 buf.extend_from_slice(HNSW_RKYV_MAGIC);
66 buf.push(HNSW_FORMAT_VERSION);
67 buf.extend_from_slice(&rkyv_bytes);
68 buf
69 }
70
71 pub fn from_checkpoint(bytes: &[u8]) -> Result<Option<Self>, crate::error::VectorError> {
80 let header_len = HNSW_RKYV_MAGIC.len() + 1; if bytes.len() > header_len && &bytes[..HNSW_RKYV_MAGIC.len()] == HNSW_RKYV_MAGIC {
82 let version = bytes[HNSW_RKYV_MAGIC.len()];
83 if version != HNSW_FORMAT_VERSION {
84 return Err(crate::error::VectorError::UnsupportedVersion {
85 found: version,
86 expected: HNSW_FORMAT_VERSION,
87 });
88 }
89 return Ok(Self::from_rkyv_checkpoint(&bytes[header_len..]));
90 }
91 Ok(None)
93 }
94
95 fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
97 let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
98 aligned.extend_from_slice(bytes);
99 let snap: HnswSnapshotRkyv =
100 rkyv::from_bytes::<HnswSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
101 Self::from_hnsw_snapshot(snap)
102 }
103
104 fn from_hnsw_snapshot(snap: HnswSnapshotRkyv) -> Option<Self> {
106 use nodedb_types::hnsw::HnswParams;
107
108 let metric = match snap.metric {
109 0 => DistanceMetric::L2,
110 1 => DistanceMetric::Cosine,
111 2 => DistanceMetric::InnerProduct,
112 _ => DistanceMetric::Cosine,
113 };
114
115 let flat = FlatNeighborStore::from_nested(&snap.node_neighbors);
116
117 let nodes: Vec<Node> = snap
118 .node_vectors
119 .into_iter()
120 .zip(snap.node_deleted)
121 .map(|(vector, deleted)| Node {
122 vector,
123 neighbors: Vec::new(),
124 deleted,
125 })
126 .collect();
127
128 let initial_capacity = snap.ef_construction.max(ARENA_INITIAL_CAPACITY);
129 Some(Self {
130 dim: snap.dim,
131 params: HnswParams {
132 m: snap.m,
133 m0: snap.m0,
134 ef_construction: snap.ef_construction,
135 metric,
136 },
137 nodes,
138 entry_point: snap.entry_point,
139 max_layer: snap.max_layer,
140 rng: Xorshift64::new(snap.rng_state),
141 flat_neighbors: Some(flat),
142 arena: RefCell::new(BeamSearchArena::new(initial_capacity)),
143 })
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use crate::distance::DistanceMetric;
150 use crate::hnsw::{HnswIndex, HnswParams};
151
152 fn make_index() -> HnswIndex {
153 HnswIndex::with_seed(
154 3,
155 HnswParams {
156 m: 4,
157 m0: 8,
158 ef_construction: 32,
159 metric: DistanceMetric::L2,
160 },
161 12345,
162 )
163 }
164
165 #[test]
166 fn checkpoint_roundtrip() {
167 let mut idx = make_index();
168 for i in 0..50 {
169 idx.insert(vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3])
170 .unwrap();
171 }
172
173 let bytes = idx.checkpoint_to_bytes();
174 assert!(!bytes.is_empty());
175
176 let restored = HnswIndex::from_checkpoint(&bytes).unwrap().unwrap();
177 assert_eq!(restored.len(), 50);
178 assert_eq!(restored.dim(), 3);
179 assert_eq!(restored.entry_point(), idx.entry_point());
180 assert_eq!(restored.max_layer(), idx.max_layer());
181
182 let query = vec![1.0, 2.0, 3.0];
183 let orig_results = idx.search(&query, 5, 32);
184 let rest_results = restored.search(&query, 5, 32);
185 assert_eq!(orig_results.len(), rest_results.len());
186 for (a, b) in orig_results.iter().zip(rest_results.iter()) {
187 assert_eq!(a.id, b.id);
188 assert!((a.distance - b.distance).abs() < 1e-6);
189 }
190 }
191
192 #[test]
193 fn golden_header_layout() {
194 let mut idx = make_index();
195 idx.insert(vec![1.0, 2.0, 3.0]).unwrap();
196 let bytes = idx.checkpoint_to_bytes();
197 assert_eq!(&bytes[0..6], b"RKHNS\0");
199 assert_eq!(bytes[6], super::HNSW_FORMAT_VERSION);
201 assert!(bytes.len() > 7);
203 }
204
205 #[test]
206 fn version_mismatch_returns_error() {
207 let mut idx = make_index();
208 idx.insert(vec![1.0, 2.0, 3.0]).unwrap();
209 let mut bytes = idx.checkpoint_to_bytes();
210 bytes[6] = 0;
212 match HnswIndex::from_checkpoint(&bytes) {
213 Err(crate::error::VectorError::UnsupportedVersion { found, expected }) => {
214 assert_eq!(found, 0);
215 assert_eq!(expected, super::HNSW_FORMAT_VERSION);
216 }
217 Err(other) => panic!("unexpected error: {other}"),
218 Ok(_) => panic!("expected UnsupportedVersion error, got Ok"),
219 }
220 }
221}