Skip to main content

nodedb_vector/hnsw/
checkpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! HNSW checkpoint serialization and deserialization.
4//!
5//! Serialization format: rkyv with a `RKHNS\0` magic header.
6
7use std::cell::RefCell;
8
9use crate::distance::DistanceMetric;
10use crate::hnsw::arena::BeamSearchArena;
11use crate::hnsw::flat_neighbors::FlatNeighborStore;
12use crate::hnsw::graph::{ARENA_INITIAL_CAPACITY, HnswIndex, Node, NodeStorage, Xorshift64};
13
14/// Magic header for rkyv-serialized HNSW snapshots (6 bytes).
15const HNSW_RKYV_MAGIC: &[u8; 6] = b"RKHNS\0";
16/// Current format version for rkyv-serialized HNSW snapshots.
17pub const HNSW_FORMAT_VERSION: u8 = 1;
18
19/// rkyv-serialized HNSW snapshot. SoA layout for better rkyv compatibility
20/// (flat Vecs instead of Vec<struct>).
21#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
22pub(crate) struct HnswSnapshotRkyv {
23    pub dim: usize,
24    pub m: usize,
25    pub m0: usize,
26    pub ef_construction: usize,
27    pub metric: u8,
28    pub entry_point: Option<u32>,
29    pub max_layer: usize,
30    pub rng_state: u64,
31    /// Per-node vectors (SoA).
32    pub node_vectors: Vec<Vec<f32>>,
33    /// Per-node neighbor lists (SoA).
34    pub node_neighbors: Vec<Vec<Vec<u32>>>,
35    /// Per-node deleted flags (SoA).
36    pub node_deleted: Vec<bool>,
37}
38
39impl HnswIndex {
40    /// Serialize ONLY the graph topology (neighbors, entry point, params) to
41    /// rkyv bytes.  Vector data is intentionally omitted — `node_vectors` is
42    /// filled with empty `Vec<f32>` placeholders.
43    ///
44    /// Pair with [`Self::from_graph_checkpoint`] on restore.  The same
45    /// `RKHNS\0` magic and version byte are used so `from_checkpoint` can
46    /// decode the bytes, but the resulting index has empty node storage.
47    ///
48    /// This is the Lite pagedb-segment write path: vectors live in the segment,
49    /// graph topology lives in the B+ tree.  Origin never calls this method.
50    pub fn graph_checkpoint_to_bytes(&self) -> Vec<u8> {
51        let snapshot = HnswSnapshotRkyv {
52            dim: self.dim,
53            m: self.params.m,
54            m0: self.params.m0,
55            ef_construction: self.params.ef_construction,
56            metric: self.params.metric as u8,
57            entry_point: self.entry_point,
58            max_layer: self.max_layer,
59            rng_state: self.rng.0,
60            // Placeholder: one empty vec per node so the node count is preserved.
61            node_vectors: vec![Vec::new(); self.nodes.len()],
62            node_neighbors: if let Some(ref flat) = self.flat_neighbors {
63                flat.to_nested(self.nodes.len())
64            } else {
65                self.nodes.iter().map(|n| n.neighbors.clone()).collect()
66            },
67            node_deleted: self.nodes.iter().map(|n| n.deleted).collect(),
68        };
69        let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
70            .expect("HNSW graph-only rkyv serialization should not fail");
71        let mut buf = Vec::with_capacity(HNSW_RKYV_MAGIC.len() + 1 + rkyv_bytes.len());
72        buf.extend_from_slice(HNSW_RKYV_MAGIC);
73        buf.push(HNSW_FORMAT_VERSION);
74        buf.extend_from_slice(&rkyv_bytes);
75        buf
76    }
77
78    /// Serialize the index to rkyv bytes (with magic header) for storage.
79    ///
80    /// Magic header `RKHNS\0` allows `from_checkpoint` to detect format.
81    pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
82        let snapshot = HnswSnapshotRkyv {
83            dim: self.dim,
84            m: self.params.m,
85            m0: self.params.m0,
86            ef_construction: self.params.ef_construction,
87            metric: self.params.metric as u8,
88            entry_point: self.entry_point,
89            max_layer: self.max_layer,
90            rng_state: self.rng.0,
91            node_vectors: self.export_vectors(),
92            node_neighbors: if let Some(ref flat) = self.flat_neighbors {
93                flat.to_nested(self.nodes.len())
94            } else {
95                self.nodes.iter().map(|n| n.neighbors.clone()).collect()
96            },
97            node_deleted: self.nodes.iter().map(|n| n.deleted).collect(),
98        };
99        let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
100            .expect("HNSW rkyv serialization should not fail");
101        let mut buf = Vec::with_capacity(HNSW_RKYV_MAGIC.len() + 1 + rkyv_bytes.len());
102        buf.extend_from_slice(HNSW_RKYV_MAGIC);
103        buf.push(HNSW_FORMAT_VERSION);
104        buf.extend_from_slice(&rkyv_bytes);
105        buf
106    }
107
108    /// Restore an index from a checkpoint snapshot.
109    ///
110    /// Returns:
111    /// - `Ok(Some(index))` — successfully decoded (rkyv format).
112    /// - `Ok(None)` — bytes do not start with the `RKHNS\0` magic header.
113    /// - `Err(VectorError::UnsupportedVersion)` — magic matches `RKHNS\0` but
114    ///   the version byte is not `HNSW_FORMAT_VERSION`; the caller must reject
115    ///   the buffer.
116    pub fn from_checkpoint(bytes: &[u8]) -> Result<Option<Self>, crate::error::VectorError> {
117        let header_len = HNSW_RKYV_MAGIC.len() + 1; // magic + version byte
118        if bytes.len() > header_len && &bytes[..HNSW_RKYV_MAGIC.len()] == HNSW_RKYV_MAGIC {
119            let version = bytes[HNSW_RKYV_MAGIC.len()];
120            if version != HNSW_FORMAT_VERSION {
121                return Err(crate::error::VectorError::UnsupportedVersion {
122                    found: version,
123                    expected: HNSW_FORMAT_VERSION,
124                });
125            }
126            return Ok(Self::from_rkyv_checkpoint(&bytes[header_len..]));
127        }
128        // No recognized magic prefix — no index to restore.
129        Ok(None)
130    }
131
132    /// Restore from rkyv-serialized bytes.
133    fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
134        let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
135        aligned.extend_from_slice(bytes);
136        let snap: HnswSnapshotRkyv =
137            rkyv::from_bytes::<HnswSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
138        Self::from_hnsw_snapshot(snap)
139    }
140
141    /// Reconstruct HnswIndex from deserialized snapshot fields.
142    fn from_hnsw_snapshot(snap: HnswSnapshotRkyv) -> Option<Self> {
143        use nodedb_types::hnsw::HnswParams;
144
145        let metric = match snap.metric {
146            0 => DistanceMetric::L2,
147            1 => DistanceMetric::Cosine,
148            2 => DistanceMetric::InnerProduct,
149            _ => DistanceMetric::Cosine,
150        };
151
152        let flat = FlatNeighborStore::from_nested(&snap.node_neighbors);
153
154        let nodes: Vec<Node> = snap
155            .node_vectors
156            .into_iter()
157            .zip(snap.node_deleted)
158            .map(|(vector, deleted)| Node {
159                storage: NodeStorage::F32(vector),
160                neighbors: Vec::new(),
161                deleted,
162            })
163            .collect();
164
165        let initial_capacity = snap.ef_construction.max(ARENA_INITIAL_CAPACITY);
166        Some(Self {
167            dim: snap.dim,
168            params: HnswParams {
169                m: snap.m,
170                m0: snap.m0,
171                ef_construction: snap.ef_construction,
172                metric,
173                dtype: nodedb_types::vector_dtype::VectorStorageDtype::F32,
174            },
175            nodes,
176            entry_point: snap.entry_point,
177            max_layer: snap.max_layer,
178            rng: Xorshift64::new(snap.rng_state),
179            flat_neighbors: Some(flat),
180            arena: RefCell::new(BeamSearchArena::new(initial_capacity)),
181            #[cfg(not(target_arch = "wasm32"))]
182            backing: None,
183        })
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use crate::distance::DistanceMetric;
190    use crate::hnsw::{HnswIndex, HnswParams};
191
192    fn make_index() -> HnswIndex {
193        HnswIndex::with_seed(
194            3,
195            HnswParams {
196                m: 4,
197                m0: 8,
198                ef_construction: 32,
199                metric: DistanceMetric::L2,
200                dtype: nodedb_types::vector_dtype::VectorStorageDtype::F32,
201            },
202            12345,
203        )
204    }
205
206    #[test]
207    fn checkpoint_roundtrip() {
208        let mut idx = make_index();
209        for i in 0..50 {
210            idx.insert(vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3])
211                .unwrap();
212        }
213
214        let bytes = idx.checkpoint_to_bytes();
215        assert!(!bytes.is_empty());
216
217        let restored = HnswIndex::from_checkpoint(&bytes).unwrap().unwrap();
218        assert_eq!(restored.len(), 50);
219        assert_eq!(restored.dim(), 3);
220        assert_eq!(restored.entry_point(), idx.entry_point());
221        assert_eq!(restored.max_layer(), idx.max_layer());
222
223        let query = vec![1.0, 2.0, 3.0];
224        let orig_results = idx.search(&query, 5, 32);
225        let rest_results = restored.search(&query, 5, 32);
226        assert_eq!(orig_results.len(), rest_results.len());
227        for (a, b) in orig_results.iter().zip(rest_results.iter()) {
228            assert_eq!(a.id, b.id);
229            assert!((a.distance - b.distance).abs() < 1e-6);
230        }
231    }
232
233    #[test]
234    fn golden_header_layout() {
235        let mut idx = make_index();
236        idx.insert(vec![1.0, 2.0, 3.0]).unwrap();
237        let bytes = idx.checkpoint_to_bytes();
238        // Magic at bytes[0..6].
239        assert_eq!(&bytes[0..6], b"RKHNS\0");
240        // Version byte at bytes[6].
241        assert_eq!(bytes[6], super::HNSW_FORMAT_VERSION);
242        // rkyv payload follows immediately.
243        assert!(bytes.len() > 7);
244    }
245
246    #[test]
247    fn version_mismatch_returns_error() {
248        let mut idx = make_index();
249        idx.insert(vec![1.0, 2.0, 3.0]).unwrap();
250        let mut bytes = idx.checkpoint_to_bytes();
251        // Corrupt the version byte to an unsupported value.
252        bytes[6] = 0;
253        match HnswIndex::from_checkpoint(&bytes) {
254            Err(crate::error::VectorError::UnsupportedVersion { found, expected }) => {
255                assert_eq!(found, 0);
256                assert_eq!(expected, super::HNSW_FORMAT_VERSION);
257            }
258            Err(other) => panic!("unexpected error: {other}"),
259            Ok(_) => panic!("expected UnsupportedVersion error, got Ok"),
260        }
261    }
262}