Skip to main content

nodedb_vector/hnsw/
checkpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! HNSW checkpoint serialization and deserialization.
4//!
5//! Serialization format: rkyv with a `RKHNS\0` magic header.
6
7use std::cell::RefCell;
8
9use crate::distance::DistanceMetric;
10use crate::hnsw::arena::BeamSearchArena;
11use crate::hnsw::flat_neighbors::FlatNeighborStore;
12use crate::hnsw::graph::{ARENA_INITIAL_CAPACITY, HnswIndex, Node, Xorshift64};
13
14/// Magic header for rkyv-serialized HNSW snapshots (6 bytes).
15const HNSW_RKYV_MAGIC: &[u8; 6] = b"RKHNS\0";
16/// Current format version for rkyv-serialized HNSW snapshots.
17pub const HNSW_FORMAT_VERSION: u8 = 1;
18
19/// rkyv-serialized HNSW snapshot. SoA layout for better rkyv compatibility
20/// (flat Vecs instead of Vec<struct>).
21#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
22pub(crate) struct HnswSnapshotRkyv {
23    pub dim: usize,
24    pub m: usize,
25    pub m0: usize,
26    pub ef_construction: usize,
27    pub metric: u8,
28    pub entry_point: Option<u32>,
29    pub max_layer: usize,
30    pub rng_state: u64,
31    /// Per-node vectors (SoA).
32    pub node_vectors: Vec<Vec<f32>>,
33    /// Per-node neighbor lists (SoA).
34    pub node_neighbors: Vec<Vec<Vec<u32>>>,
35    /// Per-node deleted flags (SoA).
36    pub node_deleted: Vec<bool>,
37}
38
39impl HnswIndex {
40    /// Serialize the index to rkyv bytes (with magic header) for storage.
41    ///
42    /// Magic header `RKHNS\0` allows `from_checkpoint` to detect format.
43    pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
44        let snapshot = HnswSnapshotRkyv {
45            dim: self.dim,
46            m: self.params.m,
47            m0: self.params.m0,
48            ef_construction: self.params.ef_construction,
49            metric: self.params.metric as u8,
50            entry_point: self.entry_point,
51            max_layer: self.max_layer,
52            rng_state: self.rng.0,
53            node_vectors: self.nodes.iter().map(|n| n.vector.clone()).collect(),
54            node_neighbors: if let Some(ref flat) = self.flat_neighbors {
55                flat.to_nested(self.nodes.len())
56            } else {
57                self.nodes.iter().map(|n| n.neighbors.clone()).collect()
58            },
59            node_deleted: self.nodes.iter().map(|n| n.deleted).collect(),
60        };
61        let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
62            .expect("HNSW rkyv serialization should not fail");
63        // no-governor: cold checkpoint serialize; fixed header + rkyv payload, governed at checkpoint call site
64        let mut buf = Vec::with_capacity(HNSW_RKYV_MAGIC.len() + 1 + rkyv_bytes.len());
65        buf.extend_from_slice(HNSW_RKYV_MAGIC);
66        buf.push(HNSW_FORMAT_VERSION);
67        buf.extend_from_slice(&rkyv_bytes);
68        buf
69    }
70
71    /// Restore an index from a checkpoint snapshot.
72    ///
73    /// Returns:
74    /// - `Ok(Some(index))` — successfully decoded (rkyv format).
75    /// - `Ok(None)` — bytes do not start with the `RKHNS\0` magic header.
76    /// - `Err(VectorError::UnsupportedVersion)` — magic matches `RKHNS\0` but
77    ///   the version byte is not `HNSW_FORMAT_VERSION`; the caller must reject
78    ///   the buffer.
79    pub fn from_checkpoint(bytes: &[u8]) -> Result<Option<Self>, crate::error::VectorError> {
80        let header_len = HNSW_RKYV_MAGIC.len() + 1; // magic + version byte
81        if bytes.len() > header_len && &bytes[..HNSW_RKYV_MAGIC.len()] == HNSW_RKYV_MAGIC {
82            let version = bytes[HNSW_RKYV_MAGIC.len()];
83            if version != HNSW_FORMAT_VERSION {
84                return Err(crate::error::VectorError::UnsupportedVersion {
85                    found: version,
86                    expected: HNSW_FORMAT_VERSION,
87                });
88            }
89            return Ok(Self::from_rkyv_checkpoint(&bytes[header_len..]));
90        }
91        // No recognized magic prefix — no index to restore.
92        Ok(None)
93    }
94
95    /// Restore from rkyv-serialized bytes.
96    fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
97        let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
98        aligned.extend_from_slice(bytes);
99        let snap: HnswSnapshotRkyv =
100            rkyv::from_bytes::<HnswSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
101        Self::from_hnsw_snapshot(snap)
102    }
103
104    /// Reconstruct HnswIndex from deserialized snapshot fields.
105    fn from_hnsw_snapshot(snap: HnswSnapshotRkyv) -> Option<Self> {
106        use nodedb_types::hnsw::HnswParams;
107
108        let metric = match snap.metric {
109            0 => DistanceMetric::L2,
110            1 => DistanceMetric::Cosine,
111            2 => DistanceMetric::InnerProduct,
112            _ => DistanceMetric::Cosine,
113        };
114
115        let flat = FlatNeighborStore::from_nested(&snap.node_neighbors);
116
117        let nodes: Vec<Node> = snap
118            .node_vectors
119            .into_iter()
120            .zip(snap.node_deleted)
121            .map(|(vector, deleted)| Node {
122                vector,
123                neighbors: Vec::new(),
124                deleted,
125            })
126            .collect();
127
128        let initial_capacity = snap.ef_construction.max(ARENA_INITIAL_CAPACITY);
129        Some(Self {
130            dim: snap.dim,
131            params: HnswParams {
132                m: snap.m,
133                m0: snap.m0,
134                ef_construction: snap.ef_construction,
135                metric,
136            },
137            nodes,
138            entry_point: snap.entry_point,
139            max_layer: snap.max_layer,
140            rng: Xorshift64::new(snap.rng_state),
141            flat_neighbors: Some(flat),
142            arena: RefCell::new(BeamSearchArena::new(initial_capacity)),
143        })
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use crate::distance::DistanceMetric;
150    use crate::hnsw::{HnswIndex, HnswParams};
151
152    fn make_index() -> HnswIndex {
153        HnswIndex::with_seed(
154            3,
155            HnswParams {
156                m: 4,
157                m0: 8,
158                ef_construction: 32,
159                metric: DistanceMetric::L2,
160            },
161            12345,
162        )
163    }
164
165    #[test]
166    fn checkpoint_roundtrip() {
167        let mut idx = make_index();
168        for i in 0..50 {
169            idx.insert(vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3])
170                .unwrap();
171        }
172
173        let bytes = idx.checkpoint_to_bytes();
174        assert!(!bytes.is_empty());
175
176        let restored = HnswIndex::from_checkpoint(&bytes).unwrap().unwrap();
177        assert_eq!(restored.len(), 50);
178        assert_eq!(restored.dim(), 3);
179        assert_eq!(restored.entry_point(), idx.entry_point());
180        assert_eq!(restored.max_layer(), idx.max_layer());
181
182        let query = vec![1.0, 2.0, 3.0];
183        let orig_results = idx.search(&query, 5, 32);
184        let rest_results = restored.search(&query, 5, 32);
185        assert_eq!(orig_results.len(), rest_results.len());
186        for (a, b) in orig_results.iter().zip(rest_results.iter()) {
187            assert_eq!(a.id, b.id);
188            assert!((a.distance - b.distance).abs() < 1e-6);
189        }
190    }
191
192    #[test]
193    fn golden_header_layout() {
194        let mut idx = make_index();
195        idx.insert(vec![1.0, 2.0, 3.0]).unwrap();
196        let bytes = idx.checkpoint_to_bytes();
197        // Magic at bytes[0..6].
198        assert_eq!(&bytes[0..6], b"RKHNS\0");
199        // Version byte at bytes[6].
200        assert_eq!(bytes[6], super::HNSW_FORMAT_VERSION);
201        // rkyv payload follows immediately.
202        assert!(bytes.len() > 7);
203    }
204
205    #[test]
206    fn version_mismatch_returns_error() {
207        let mut idx = make_index();
208        idx.insert(vec![1.0, 2.0, 3.0]).unwrap();
209        let mut bytes = idx.checkpoint_to_bytes();
210        // Corrupt the version byte to an unsupported value.
211        bytes[6] = 0;
212        match HnswIndex::from_checkpoint(&bytes) {
213            Err(crate::error::VectorError::UnsupportedVersion { found, expected }) => {
214                assert_eq!(found, 0);
215                assert_eq!(expected, super::HNSW_FORMAT_VERSION);
216            }
217            Err(other) => panic!("unexpected error: {other}"),
218            Ok(_) => panic!("expected UnsupportedVersion error, got Ok"),
219        }
220    }
221}