Skip to main content

nodedb_vector/hnsw/
checkpoint.rs

1//! HNSW checkpoint serialization and deserialization.
2//!
3//! Supports rkyv (current format) and legacy MessagePack for backward compat.
4
5use crate::distance::DistanceMetric;
6use crate::hnsw::flat_neighbors::FlatNeighborStore;
7use crate::hnsw::graph::{HnswIndex, Node, Xorshift64};
8
9/// Magic header for rkyv-serialized HNSW snapshots (6 bytes).
10const HNSW_RKYV_MAGIC: &[u8; 6] = b"RKHNS\0";
11
12/// rkyv-serialized HNSW snapshot. SoA layout for better rkyv compatibility
13/// (flat Vecs instead of Vec<struct>).
14#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
15pub(crate) struct HnswSnapshotRkyv {
16    pub dim: usize,
17    pub m: usize,
18    pub m0: usize,
19    pub ef_construction: usize,
20    pub metric: u8,
21    pub entry_point: Option<u32>,
22    pub max_layer: usize,
23    pub rng_state: u64,
24    /// Per-node vectors (SoA).
25    pub node_vectors: Vec<Vec<f32>>,
26    /// Per-node neighbor lists (SoA).
27    pub node_neighbors: Vec<Vec<Vec<u32>>>,
28    /// Per-node deleted flags (SoA).
29    pub node_deleted: Vec<bool>,
30}
31
32impl HnswIndex {
33    /// Serialize the index to rkyv bytes (with magic header) for storage.
34    ///
35    /// Magic header `RKHNS\0` allows `from_checkpoint` to detect format.
36    pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
37        let snapshot = HnswSnapshotRkyv {
38            dim: self.dim,
39            m: self.params.m,
40            m0: self.params.m0,
41            ef_construction: self.params.ef_construction,
42            metric: self.params.metric as u8,
43            entry_point: self.entry_point,
44            max_layer: self.max_layer,
45            rng_state: self.rng.0,
46            node_vectors: self.nodes.iter().map(|n| n.vector.clone()).collect(),
47            node_neighbors: if let Some(ref flat) = self.flat_neighbors {
48                flat.to_nested(self.nodes.len())
49            } else {
50                self.nodes.iter().map(|n| n.neighbors.clone()).collect()
51            },
52            node_deleted: self.nodes.iter().map(|n| n.deleted).collect(),
53        };
54        let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
55            .expect("HNSW rkyv serialization should not fail");
56        let mut buf = Vec::with_capacity(HNSW_RKYV_MAGIC.len() + rkyv_bytes.len());
57        buf.extend_from_slice(HNSW_RKYV_MAGIC);
58        buf.extend_from_slice(&rkyv_bytes);
59        buf
60    }
61
62    /// Restore an index from a checkpoint snapshot.
63    ///
64    /// Auto-detects format: rkyv (magic `RKHNS\0`) or legacy MessagePack.
65    pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
66        if bytes.len() > HNSW_RKYV_MAGIC.len() && &bytes[..HNSW_RKYV_MAGIC.len()] == HNSW_RKYV_MAGIC
67        {
68            return Self::from_rkyv_checkpoint(&bytes[HNSW_RKYV_MAGIC.len()..]);
69        }
70        Self::from_msgpack_checkpoint(bytes)
71    }
72
73    /// Restore from rkyv-serialized bytes.
74    fn from_rkyv_checkpoint(bytes: &[u8]) -> Option<Self> {
75        let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(bytes.len());
76        aligned.extend_from_slice(bytes);
77        let snap: HnswSnapshotRkyv =
78            rkyv::from_bytes::<HnswSnapshotRkyv, rkyv::rancor::Error>(&aligned).ok()?;
79        Self::from_hnsw_snapshot(snap)
80    }
81
82    /// Restore from legacy MessagePack bytes.
83    fn from_msgpack_checkpoint(bytes: &[u8]) -> Option<Self> {
84        use zerompk::{FromMessagePack, ToMessagePack};
85
86        #[derive(ToMessagePack, FromMessagePack)]
87        struct Snapshot {
88            dim: usize,
89            m: usize,
90            m0: usize,
91            ef_construction: usize,
92            metric: u8,
93            entry_point: Option<u32>,
94            max_layer: usize,
95            rng_state: u64,
96            nodes: Vec<NodeSnap>,
97        }
98
99        #[derive(ToMessagePack, FromMessagePack)]
100        struct NodeSnap {
101            vector: Vec<f32>,
102            neighbors: Vec<Vec<u32>>,
103            deleted: bool,
104        }
105
106        let snap: Snapshot = zerompk::from_msgpack(bytes).ok()?;
107        Self::from_hnsw_snapshot(HnswSnapshotRkyv {
108            dim: snap.dim,
109            m: snap.m,
110            m0: snap.m0,
111            ef_construction: snap.ef_construction,
112            metric: snap.metric,
113            entry_point: snap.entry_point,
114            max_layer: snap.max_layer,
115            rng_state: snap.rng_state,
116            node_vectors: snap.nodes.iter().map(|n| n.vector.clone()).collect(),
117            node_neighbors: snap.nodes.iter().map(|n| n.neighbors.clone()).collect(),
118            node_deleted: snap.nodes.iter().map(|n| n.deleted).collect(),
119        })
120    }
121
122    /// Reconstruct HnswIndex from deserialized snapshot fields.
123    fn from_hnsw_snapshot(snap: HnswSnapshotRkyv) -> Option<Self> {
124        use nodedb_types::hnsw::HnswParams;
125
126        let metric = match snap.metric {
127            0 => DistanceMetric::L2,
128            1 => DistanceMetric::Cosine,
129            2 => DistanceMetric::InnerProduct,
130            _ => DistanceMetric::Cosine,
131        };
132
133        let flat = FlatNeighborStore::from_nested(&snap.node_neighbors);
134
135        let nodes: Vec<Node> = snap
136            .node_vectors
137            .into_iter()
138            .zip(snap.node_deleted)
139            .map(|(vector, deleted)| Node {
140                vector,
141                neighbors: Vec::new(),
142                deleted,
143            })
144            .collect();
145
146        Some(Self {
147            dim: snap.dim,
148            params: HnswParams {
149                m: snap.m,
150                m0: snap.m0,
151                ef_construction: snap.ef_construction,
152                metric,
153            },
154            nodes,
155            entry_point: snap.entry_point,
156            max_layer: snap.max_layer,
157            rng: Xorshift64::new(snap.rng_state),
158            flat_neighbors: Some(flat),
159        })
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use crate::distance::DistanceMetric;
166    use crate::hnsw::{HnswIndex, HnswParams};
167
168    fn make_index() -> HnswIndex {
169        HnswIndex::with_seed(
170            3,
171            HnswParams {
172                m: 4,
173                m0: 8,
174                ef_construction: 32,
175                metric: DistanceMetric::L2,
176            },
177            12345,
178        )
179    }
180
181    #[test]
182    fn checkpoint_roundtrip() {
183        let mut idx = make_index();
184        for i in 0..50 {
185            idx.insert(vec![(i as f32) * 0.1, (i as f32) * 0.2, (i as f32) * 0.3])
186                .unwrap();
187        }
188
189        let bytes = idx.checkpoint_to_bytes();
190        assert!(!bytes.is_empty());
191
192        let restored = HnswIndex::from_checkpoint(&bytes).unwrap();
193        assert_eq!(restored.len(), 50);
194        assert_eq!(restored.dim(), 3);
195        assert_eq!(restored.entry_point(), idx.entry_point());
196        assert_eq!(restored.max_layer(), idx.max_layer());
197
198        let query = vec![1.0, 2.0, 3.0];
199        let orig_results = idx.search(&query, 5, 32);
200        let rest_results = restored.search(&query, 5, 32);
201        assert_eq!(orig_results.len(), rest_results.len());
202        for (a, b) in orig_results.iter().zip(rest_results.iter()) {
203            assert_eq!(a.id, b.id);
204            assert!((a.distance - b.distance).abs() < 1e-6);
205        }
206    }
207}