Skip to main content

nodedb_vector/collection/
checkpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Checkpoint serialization and deserialization for `VectorCollection`.
4//!
5//! ## On-disk framing
6//!
7//! **Plaintext** checkpoints are raw MessagePack bytes (existing format).
8//! The first 4 bytes are never `SEGV`, so detection is unambiguous.
9//!
10//! **Encrypted** checkpoints use the following layout:
11//!
12//! ```text
13//! [SEGV (4B)] [version_u16_le (2B)] [cipher_alg_u8 (1B)] [kid_u8 (1B)]
14//! [epoch (4B)] [reserved (4B)] [AES-256-GCM ciphertext of msgpack payload]
15//! ```
16//!
17//! The first 16 bytes form a `SegmentPreamble` (reusing the existing preamble
18//! layout with a distinct `SEGV` magic). These 16 bytes are included as AAD,
19//! preventing preamble-swap attacks. The nonce is `(epoch, lsn=0)` — epoch
20//! provides per-write uniqueness even without an LSN.
21
22use std::collections::HashMap;
23
24use nodedb_types::{Surrogate, VectorQuantization};
25use serde::{Deserialize, Serialize};
26
27use crate::collection::payload_index::PayloadIndexSetSnapshot;
28use crate::collection::segment::{DEFAULT_SEAL_THRESHOLD, SealedSegment};
29use crate::collection::tier::StorageTier;
30use crate::distance::DistanceMetric;
31use crate::error::VectorError;
32use crate::flat::FlatIndex;
33use crate::hnsw::{HnswIndex, HnswParams};
34use crate::quantize::pq::PqCodec;
35use crate::quantize::sq8::Sq8Codec;
36
37use super::lifecycle::VectorCollection;
38
39/// Magic bytes identifying an encrypted vector checkpoint. Shared with
40/// `nodedb-spatial`'s SEGV checkpoint format.
41const SEGV_MAGIC: [u8; 4] = *b"SEGV";
42
43/// Encrypt `plaintext` into the SEGV envelope from [`nodedb_wal::crypto`].
44fn encrypt_checkpoint(
45    key: &nodedb_wal::crypto::WalEncryptionKey,
46    plaintext: &[u8],
47) -> Result<Vec<u8>, VectorError> {
48    nodedb_wal::crypto::encrypt_segment_envelope(key, &SEGV_MAGIC, plaintext).map_err(|e| {
49        VectorError::CheckpointEncryptionError {
50            detail: e.to_string(),
51        }
52    })
53}
54
55/// Decrypt an encrypted checkpoint blob (starting at byte 0, which is `SEGV`).
56fn decrypt_checkpoint(
57    key: &nodedb_wal::crypto::WalEncryptionKey,
58    blob: &[u8],
59) -> Result<Vec<u8>, VectorError> {
60    nodedb_wal::crypto::decrypt_segment_envelope(key, &SEGV_MAGIC, blob).map_err(|e| {
61        VectorError::CheckpointEncryptionError {
62            detail: e.to_string(),
63        }
64    })
65}
66
67#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
68pub(crate) struct CollectionSnapshot {
69    pub dim: usize,
70    pub params_m: usize,
71    pub params_m0: usize,
72    pub params_ef_construction: usize,
73    pub params_metric: u8,
74    pub next_id: u32,
75    pub growing_base_id: u32,
76    pub growing_vectors: Vec<Vec<f32>>,
77    pub growing_deleted: Vec<bool>,
78    pub sealed_segments: Vec<SealedSnapshot>,
79    pub building_segments: Vec<BuildingSnapshot>,
80    /// `(global_vector_id, surrogate_u32)` pairs.
81    #[serde(default)]
82    pub surrogate_map: Vec<(u32, u32)>,
83    /// `(document_surrogate_u32, [global_vector_ids])` pairs.
84    #[serde(default)]
85    pub multi_doc_map: Vec<(u32, Vec<u32>)>,
86    /// Quantization mode for the collection-level codec-dispatch index.
87    /// Serialised as a u8 matching `VectorQuantization` discriminants.
88    /// 0 = None (default, backward-compatible).
89    #[serde(default)]
90    pub quantization_tag: u8,
91    /// Serialised `PayloadIndexSetSnapshot` (msgpack bytes).
92    /// Empty vec = no payload indexes (default, backward-compatible).
93    #[serde(default)]
94    pub payload_index_bytes: Vec<u8>,
95}
96
97#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
98pub(crate) struct SealedSnapshot {
99    pub base_id: u32,
100    pub hnsw_bytes: Vec<u8>,
101    #[serde(default)]
102    pub pq_bytes: Option<Vec<u8>>,
103    #[serde(default)]
104    pub pq_codes: Option<Vec<u8>>,
105    /// Serialized [`Sq8Codec`] bytes (magic + version + msgpack).
106    /// Present when SQ8 quantization is active and PQ is absent.
107    /// `None` means no SQ8 quantization for this segment.
108    #[serde(default)]
109    pub sq8_bytes: Option<Vec<u8>>,
110    /// Pre-quantized SQ8 codes for all vectors in this segment.
111    /// Layout: `[v0_d0, v0_d1, ..., v1_d0, ...]` (dim bytes per vector).
112    /// `None` when SQ8 is not configured.
113    #[serde(default)]
114    pub sq8_codes: Option<Vec<u8>>,
115}
116
117#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
118pub(crate) struct BuildingSnapshot {
119    pub base_id: u32,
120    pub vectors: Vec<Vec<f32>>,
121    #[serde(default)]
122    pub deleted: Vec<bool>,
123}
124
125impl VectorCollection {
126    /// Serialize all segments for checkpointing.
127    ///
128    /// When `kek` is `Some`, the MessagePack payload is wrapped in an
129    /// AES-256-GCM encrypted envelope with a `SEGV` preamble. When `None`,
130    /// raw MessagePack bytes are returned (existing plaintext format).
131    ///
132    /// Returns an empty `Vec` on serialization failure (callers treat this as a
133    /// skip signal, consistent with the pre-existing error handling).
134    pub fn checkpoint_to_bytes(
135        &self,
136        kek: Option<&nodedb_wal::crypto::WalEncryptionKey>,
137    ) -> Vec<u8> {
138        let snapshot = CollectionSnapshot {
139            dim: self.dim,
140            params_m: self.params.m,
141            params_m0: self.params.m0,
142            params_ef_construction: self.params.ef_construction,
143            params_metric: self.params.metric as u8,
144            next_id: self.next_id,
145            growing_base_id: self.growing_base_id,
146            growing_vectors: (0..self.growing.len() as u32)
147                .filter_map(|i| self.growing.get_vector_raw(i).map(|v| v.to_vec()))
148                .collect(),
149            growing_deleted: (0..self.growing.len() as u32)
150                .map(|i| self.growing.is_deleted(i))
151                .collect(),
152            sealed_segments: self
153                .sealed
154                .iter()
155                .map(|s| {
156                    let (pq_bytes, pq_codes) = match &s.pq {
157                        Some((codec, codes)) => (codec.to_bytes().ok(), Some(codes.clone())),
158                        None => (None, None),
159                    };
160                    // Only serialize SQ8 when PQ is absent — a segment never carries both.
161                    let (sq8_bytes, sq8_codes) = if pq_bytes.is_none() {
162                        match &s.sq8 {
163                            Some((codec, codes)) => (Some(codec.to_bytes()), Some(codes.clone())),
164                            None => (None, None),
165                        }
166                    } else {
167                        (None, None)
168                    };
169                    SealedSnapshot {
170                        base_id: s.base_id,
171                        hnsw_bytes: s.index.checkpoint_to_bytes(),
172                        pq_bytes,
173                        pq_codes,
174                        sq8_bytes,
175                        sq8_codes,
176                    }
177                })
178                .collect(),
179            building_segments: self
180                .building
181                .iter()
182                .map(|b| BuildingSnapshot {
183                    base_id: b.base_id,
184                    vectors: (0..b.flat.len() as u32)
185                        .filter_map(|i| b.flat.get_vector_raw(i).map(|v| v.to_vec()))
186                        .collect(),
187                    deleted: (0..b.flat.len() as u32)
188                        .map(|i| b.flat.is_deleted(i))
189                        .collect(),
190                })
191                .collect(),
192            surrogate_map: self
193                .surrogate_map
194                .iter()
195                .map(|(&k, s)| (k, s.as_u32()))
196                .collect(),
197            multi_doc_map: self
198                .multi_doc_map
199                .iter()
200                .map(|(k, v)| (k.as_u32(), v.clone()))
201                .collect(),
202            quantization_tag: quantization_to_tag(self.quantization),
203            payload_index_bytes: {
204                let snap = self.payload.to_snapshot();
205                match zerompk::to_msgpack_vec(&snap) {
206                    Ok(bytes) => bytes,
207                    Err(e) => {
208                        tracing::warn!(
209                            error = %e,
210                            "vector payload index snapshot serialization failed"
211                        );
212                        return Vec::new();
213                    }
214                }
215            },
216        };
217        let msgpack = match zerompk::to_msgpack_vec(&snapshot) {
218            Ok(bytes) => bytes,
219            Err(e) => {
220                tracing::warn!(error = %e, "vector collection checkpoint serialization failed");
221                return Vec::new();
222            }
223        };
224
225        if let Some(key) = kek {
226            match encrypt_checkpoint(key, &msgpack) {
227                Ok(encrypted) => encrypted,
228                Err(e) => {
229                    tracing::warn!(error = %e, "vector collection checkpoint encryption failed");
230                    Vec::new()
231                }
232            }
233        } else {
234            msgpack
235        }
236    }
237
238    /// Restore a collection from checkpoint bytes.
239    ///
240    /// `kek` controls the expected framing:
241    /// - `None` → the file must be plaintext MessagePack (starting with bytes
242    ///   that are NOT `SEGV`). If the file starts with `SEGV` and no key is
243    ///   provided, returns `Err(CheckpointEncryptedNoKey)`.
244    /// - `Some(key)` → encryption is **required**. If the file starts with
245    ///   `SEGV`, it is decrypted with `key`. If the file is plaintext, returns
246    ///   `Err(CheckpointPlaintextKeyRequired)` — refuse to silently load
247    ///   unencrypted data when the operator has enabled at-rest encryption.
248    pub fn from_checkpoint(
249        bytes: &[u8],
250        kek: Option<&nodedb_wal::crypto::WalEncryptionKey>,
251    ) -> Result<Option<Self>, VectorError> {
252        let is_encrypted = bytes.len() >= 4 && bytes[0..4] == SEGV_MAGIC;
253
254        let msgpack: Vec<u8>;
255        let msgpack_ref: &[u8];
256
257        if is_encrypted {
258            if let Some(key) = kek {
259                msgpack = decrypt_checkpoint(key, bytes)?;
260                msgpack_ref = &msgpack;
261            } else {
262                return Err(VectorError::CheckpointEncryptedNoKey);
263            }
264        } else if kek.is_some() {
265            return Err(VectorError::CheckpointPlaintextKeyRequired);
266        } else {
267            msgpack_ref = bytes;
268        }
269
270        let snap: CollectionSnapshot = match zerompk::from_msgpack(msgpack_ref) {
271            Ok(s) => s,
272            Err(_) => return Ok(None),
273        };
274        let metric = match snap.params_metric {
275            0 => DistanceMetric::L2,
276            1 => DistanceMetric::Cosine,
277            2 => DistanceMetric::InnerProduct,
278            3 => DistanceMetric::Manhattan,
279            4 => DistanceMetric::Chebyshev,
280            5 => DistanceMetric::Hamming,
281            6 => DistanceMetric::Jaccard,
282            7 => DistanceMetric::Pearson,
283            _ => DistanceMetric::Cosine,
284        };
285        let params = HnswParams {
286            m: snap.params_m,
287            m0: snap.params_m0,
288            ef_construction: snap.params_ef_construction,
289            metric,
290        };
291
292        let mut growing = FlatIndex::new(snap.dim, metric);
293        for (i, v) in snap.growing_vectors.iter().enumerate() {
294            let deleted = snap.growing_deleted.get(i).copied().unwrap_or(false);
295            if deleted {
296                growing.insert_tombstoned(v.clone());
297            } else {
298                growing.insert(v.clone());
299            }
300        }
301
302        // no-governor: cold restore path; segment count bounded by collection config
303        let mut sealed = Vec::with_capacity(snap.sealed_segments.len());
304        for ss in &snap.sealed_segments {
305            if let Some(index) = HnswIndex::from_checkpoint(&ss.hnsw_bytes).ok().flatten() {
306                let pq = match (&ss.pq_bytes, &ss.pq_codes) {
307                    (Some(bytes), Some(codes)) => PqCodec::from_bytes(bytes)
308                        .ok()
309                        .map(|codec| (codec, codes.clone())),
310                    _ => None,
311                };
312                // Restore SQ8 from persisted bytes — never recompute on load.
313                // A segment never carries both PQ and SQ8.
314                let sq8 = if pq.is_some() {
315                    None
316                } else {
317                    match (&ss.sq8_bytes, &ss.sq8_codes) {
318                        (Some(codec_bytes), Some(codes)) => Sq8Codec::from_bytes(codec_bytes)
319                            .ok()
320                            .map(|codec| (codec, codes.clone())),
321                        _ => None,
322                    }
323                };
324                sealed.push(SealedSegment {
325                    index,
326                    base_id: ss.base_id,
327                    sq8,
328                    pq,
329                    tier: StorageTier::L0Ram,
330                    mmap_vectors: None,
331                });
332            }
333        }
334
335        for bs in &snap.building_segments {
336            let mut index = HnswIndex::new(snap.dim, params.clone());
337            for v in &bs.vectors {
338                index
339                    .insert(v.clone())
340                    .expect("dimension guaranteed by checkpoint");
341            }
342            // Replay building-segment tombstones onto the HNSW index.
343            for (i, &dead) in bs.deleted.iter().enumerate() {
344                if dead {
345                    index.delete(i as u32);
346                }
347            }
348            let sq8 = VectorCollection::build_sq8_for_index(&index);
349            sealed.push(SealedSegment {
350                index,
351                base_id: bs.base_id,
352                sq8,
353                pq: None,
354                tier: StorageTier::L0Ram,
355                mmap_vectors: None,
356            });
357        }
358
359        let next_segment_id = (sealed.len() + 1) as u32;
360
361        let index_config = crate::index_config::IndexConfig {
362            hnsw: params.clone(),
363            ..crate::index_config::IndexConfig::default()
364        };
365        Ok(Some(Self {
366            growing,
367            growing_base_id: snap.growing_base_id,
368            sealed,
369            building: Vec::new(),
370            params,
371            next_id: snap.next_id,
372            next_segment_id,
373            dim: snap.dim,
374            data_dir: None,
375            ram_budget_bytes: 0,
376            mmap_fallback_count: 0,
377            mmap_segment_count: 0,
378            surrogate_map: snap
379                .surrogate_map
380                .iter()
381                .map(|&(k, s)| (k, Surrogate::new(s)))
382                .collect(),
383            surrogate_to_local: snap
384                .surrogate_map
385                .iter()
386                .map(|&(k, s)| (Surrogate::new(s), k))
387                .collect(),
388            multi_doc_map: snap
389                .multi_doc_map
390                .into_iter()
391                .map(|(k, v)| (Surrogate::new(k), v))
392                .collect::<HashMap<_, _>>(),
393            seal_threshold: DEFAULT_SEAL_THRESHOLD,
394            index_config,
395            codec_dispatch: None,
396            quantization: quantization_from_tag(snap.quantization_tag),
397            payload: if snap.payload_index_bytes.is_empty() {
398                super::payload_index::PayloadIndexSet::default()
399            } else {
400                zerompk::from_msgpack::<PayloadIndexSetSnapshot>(&snap.payload_index_bytes)
401                    .map(super::payload_index::PayloadIndexSet::from_snapshot)
402                    .unwrap_or_default()
403            },
404            arena_index: None,
405        }))
406    }
407}
408
409/// Encode a `VectorQuantization` to a u8 tag for storage.
410fn quantization_to_tag(q: VectorQuantization) -> u8 {
411    match q {
412        VectorQuantization::None => 0,
413        VectorQuantization::Sq8 => 1,
414        VectorQuantization::Pq => 2,
415        VectorQuantization::RaBitQ => 3,
416        VectorQuantization::Bbq => 4,
417        VectorQuantization::Binary => 5,
418        VectorQuantization::Ternary => 6,
419        VectorQuantization::Opq => 7,
420        _ => 0,
421    }
422}
423
424/// Decode a u8 tag back to `VectorQuantization`.
425fn quantization_from_tag(tag: u8) -> VectorQuantization {
426    match tag {
427        0 => VectorQuantization::None,
428        1 => VectorQuantization::Sq8,
429        2 => VectorQuantization::Pq,
430        3 => VectorQuantization::RaBitQ,
431        4 => VectorQuantization::Bbq,
432        5 => VectorQuantization::Binary,
433        6 => VectorQuantization::Ternary,
434        7 => VectorQuantization::Opq,
435        _ => VectorQuantization::None,
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use crate::collection::lifecycle::VectorCollection;
442    use crate::distance::DistanceMetric;
443    use crate::hnsw::HnswParams;
444
445    /// SQ8 calibration data must survive a checkpoint round-trip without
446    /// being recomputed. Verifies that the O(N*dim) rebuild-on-restart bug
447    /// is eliminated: `sq8` on the restored sealed segment is `Some` and
448    /// its `inv_scales` match the original exactly.
449    #[test]
450    fn checkpoint_roundtrip_preserves_sq8() {
451        use crate::collection::lifecycle::VectorCollection;
452        use crate::hnsw::{HnswIndex, HnswParams};
453
454        let params = HnswParams {
455            metric: crate::distance::DistanceMetric::L2,
456            ..HnswParams::default()
457        };
458        // 1024 vectors of dim=8 — enough to pass the ≥1000 threshold in
459        // `build_sq8_for_index`, so sq8 will be Some after complete_build.
460        // Uses plain HNSW (default IndexType::Hnsw) so SQ8 is selected.
461        let mut coll = VectorCollection::with_seal_threshold(8, params, 1024);
462        for i in 0..1024u32 {
463            let mut v = vec![0.0f32; 8];
464            for (d, slot) in v.iter_mut().enumerate() {
465                *slot = ((i as f32) * 0.01 + (d as f32) * 0.1).sin();
466            }
467            coll.insert(v);
468        }
469        let req = coll.seal("sq8_test").expect("seal produced request");
470        let mut idx = HnswIndex::new(req.dim, req.params.clone());
471        for v in &req.vectors {
472            idx.insert(v.clone()).unwrap();
473        }
474        coll.complete_build(req.segment_id, idx);
475
476        let sealed = coll.sealed_segments();
477        assert!(!sealed.is_empty(), "expected at least one sealed segment");
478        let orig_sq8 = sealed[0]
479            .sq8
480            .as_ref()
481            .expect("sq8 must be Some after complete_build with ≥1000 vectors");
482        let orig_dim = orig_sq8.0.dim();
483        // Capture serialized form as ground truth.
484        let orig_bytes = orig_sq8.0.to_bytes();
485
486        let checkpoint = coll.checkpoint_to_bytes(None);
487        let restored = VectorCollection::from_checkpoint(&checkpoint, None)
488            .unwrap()
489            .unwrap();
490
491        let restored_sealed = restored.sealed_segments();
492        assert!(!restored_sealed.is_empty());
493        let restored_sq8 = restored_sealed[0]
494            .sq8
495            .as_ref()
496            .expect("sq8 must be Some after restoring checkpoint — never recomputed");
497
498        assert_eq!(restored_sq8.0.dim(), orig_dim, "dim mismatch after restore");
499        // Byte-level equality guarantees calibration data is persisted, not recomputed.
500        assert_eq!(
501            restored_sq8.0.to_bytes(),
502            orig_bytes,
503            "sq8 codec bytes differ — calibration data was recomputed rather than persisted"
504        );
505    }
506
507    #[test]
508    fn checkpoint_roundtrip() {
509        let mut coll = VectorCollection::new(
510            3,
511            HnswParams {
512                metric: DistanceMetric::L2,
513                ..HnswParams::default()
514            },
515        );
516        for i in 0..50u32 {
517            coll.insert(vec![i as f32, 0.0, 0.0]);
518        }
519        let bytes = coll.checkpoint_to_bytes(None);
520        let restored = VectorCollection::from_checkpoint(&bytes, None)
521            .unwrap()
522            .unwrap();
523        assert_eq!(restored.len(), 50);
524        assert_eq!(restored.dim(), 3);
525
526        let results = restored.search(&[25.0, 0.0, 0.0], 1, 64);
527        assert_eq!(results[0].id, 25);
528    }
529
530    /// Payload bitmap indexes registered on a vector-primary collection
531    /// must survive a checkpoint round-trip — otherwise `WHERE` filters
532    /// would silently return zero rows after a node restart.
533    #[test]
534    fn checkpoint_roundtrip_preserves_payload_bitmap() {
535        use crate::collection::PayloadIndexKind;
536        use crate::collection::payload_index::FilterPredicate;
537        use nodedb_types::Value;
538        use std::collections::HashMap;
539
540        let mut coll = VectorCollection::new(
541            3,
542            HnswParams {
543                metric: DistanceMetric::L2,
544                ..HnswParams::default()
545            },
546        );
547        coll.payload
548            .add_index("category".to_string(), PayloadIndexKind::Equality);
549        for i in 0u32..10 {
550            let node_id = coll.insert(vec![i as f32, 0.0, 0.0]);
551            let mut fields = HashMap::new();
552            let cat = if i % 2 == 0 { "A" } else { "B" };
553            fields.insert("category".to_string(), Value::String(cat.to_string()));
554            coll.payload.insert_row(node_id, &fields);
555        }
556
557        let bytes = coll.checkpoint_to_bytes(None);
558        let restored = VectorCollection::from_checkpoint(&bytes, None)
559            .unwrap()
560            .unwrap();
561
562        let pred = FilterPredicate::Eq {
563            field: "category".to_string(),
564            value: Value::String("A".to_string()),
565        };
566        let bm = restored
567            .payload
568            .pre_filter(&pred)
569            .expect("payload index 'category' must be present after restore");
570        assert_eq!(
571            bm.len(),
572            5,
573            "5 rows of category=A must survive checkpoint round-trip"
574        );
575    }
576}