Skip to main content

nodedb_spatial/
persist.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! R-tree checkpoint/restore for durable persistence.
4//!
5//! ## On-disk framing
6//!
7//! **Plaintext** R-tree checkpoints start with the 6-byte magic `RKSPT\0`.
8//! The first 4 bytes are never `SEGV`, so detection is unambiguous.
9//!
10//! **Encrypted** checkpoints use the same SEGV framing as the vector engine:
11//!
12//! ```text
13//! [SEGV (4B)] [version_u16_le (2B)] [cipher_alg_u8 (1B)] [kid_u8 (1B)]
14//! [epoch (4B)] [reserved (4B)] [AES-256-GCM ciphertext of the inner payload]
15//! ```
16//!
17//! The inner payload is either the raw rkyv bytes for R-tree or the msgpack
18//! bytes for geohash — the existing plaintext format. The nonce is
19//! `(epoch, lsn=0)`, and the 16-byte preamble is used as AAD.
20//!
21//! Storage key scheme (in redb under Namespace::Spatial):
22//! - `{collection}\x00{field}\x00rtree` → serialized R-tree entries
23//! - `{collection}\x00{field}\x00meta`  → SpatialIndexMeta
24
25use nodedb_types::BoundingBox;
26use serde::{Deserialize, Serialize};
27use zerompk::{FromMessagePack, ToMessagePack};
28
29use crate::rtree::{RTree, RTreeEntry};
30
31// ── SEGV framing constants ─────────────────────────────────────────────────
32//
33// The encrypted envelope itself lives in `nodedb_wal::crypto`; only the
34// magic constant is local to this module.
35
36/// Magic bytes identifying an encrypted spatial checkpoint. Shared with
37/// `nodedb-vector`'s collection checkpoint format.
38const SEGV_MAGIC: [u8; 4] = *b"SEGV";
39
40// ── Plaintext inner-format constants ───────────────────────────────────────
41
42/// Magic header for rkyv-serialized R-tree snapshots (6 bytes).
43const RTREE_RKYV_MAGIC: &[u8; 6] = b"RKSPT\0";
44
45/// Current format version for rkyv-serialized R-tree snapshots.
46pub const RTREE_FORMAT_VERSION: u8 = 1;
47
48// ── SEGV framing helpers ───────────────────────────────────────────────────
49
50fn encrypt_payload(
51    key: &nodedb_wal::crypto::WalEncryptionKey,
52    plaintext: &[u8],
53) -> Result<Vec<u8>, RTreeCheckpointError> {
54    nodedb_wal::crypto::encrypt_segment_envelope(key, &SEGV_MAGIC, plaintext)
55        .map_err(|e| RTreeCheckpointError::EncryptionFailed(e.to_string()))
56}
57
58fn decrypt_payload(
59    key: &nodedb_wal::crypto::WalEncryptionKey,
60    blob: &[u8],
61) -> Result<Vec<u8>, RTreeCheckpointError> {
62    nodedb_wal::crypto::decrypt_segment_envelope(key, &SEGV_MAGIC, blob)
63        .map_err(|e| RTreeCheckpointError::DecryptionFailed(e.to_string()))
64}
65
66/// Encrypt a geohash msgpack payload (called from `geohash_index.rs`).
67pub(crate) fn encrypt_geohash_payload(
68    key: &nodedb_wal::crypto::WalEncryptionKey,
69    plaintext: &[u8],
70) -> Result<Vec<u8>, RTreeCheckpointError> {
71    encrypt_payload(key, plaintext)
72}
73
74/// Decrypt a geohash msgpack payload (called from `geohash_index.rs`).
75pub(crate) fn decrypt_geohash_payload(
76    key: &nodedb_wal::crypto::WalEncryptionKey,
77    blob: &[u8],
78) -> Result<Vec<u8>, RTreeCheckpointError> {
79    decrypt_payload(key, blob)
80}
81
82// ── Metadata types ─────────────────────────────────────────────────────────
83
84/// Metadata for a persisted spatial index.
85#[derive(Debug, Clone, Serialize, Deserialize, ToMessagePack, FromMessagePack)]
86pub struct SpatialIndexMeta {
87    /// Collection this index belongs to.
88    pub collection: String,
89    /// Geometry field name being indexed.
90    pub field: String,
91    /// Index type.
92    pub index_type: SpatialIndexType,
93    /// Number of entries at last checkpoint.
94    pub entry_count: u64,
95    /// Bounding box of all indexed geometries (spatial extent).
96    pub extent: Option<BoundingBox>,
97}
98
99/// Type of spatial index.
100#[derive(
101    Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToMessagePack, FromMessagePack,
102)]
103#[msgpack(c_enum)]
104#[repr(u8)]
105#[non_exhaustive]
106pub enum SpatialIndexType {
107    RTree = 0,
108    Geohash = 1,
109}
110
111impl SpatialIndexType {
112    pub fn as_str(&self) -> &'static str {
113        match self {
114            Self::RTree => "rtree",
115            Self::Geohash => "geohash",
116        }
117    }
118}
119
120impl std::fmt::Display for SpatialIndexType {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.write_str(self.as_str())
123    }
124}
125
126// ── rkyv snapshot type ─────────────────────────────────────────────────────
127
128/// rkyv-serialized R-tree snapshot.
129#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
130struct RTreeSnapshotRkyv {
131    entries: Vec<RTreeEntry>,
132}
133
134// ── RTree checkpoint impl ──────────────────────────────────────────────────
135
136impl RTree {
137    /// Serialize the R-tree to bytes for checkpointing.
138    ///
139    /// When `kek` is `Some`, the inner rkyv payload is wrapped in an AES-256-GCM
140    /// encrypted SEGV envelope. When `None`, the raw rkyv bytes (with `RKSPT\0`
141    /// inner magic) are returned (existing plaintext format).
142    pub fn checkpoint_to_bytes(
143        &self,
144        kek: Option<&nodedb_wal::crypto::WalEncryptionKey>,
145    ) -> Result<Vec<u8>, RTreeCheckpointError> {
146        let snapshot = RTreeSnapshotRkyv {
147            entries: self.entries().into_iter().cloned().collect(),
148        };
149        let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
150            .map_err(|e| RTreeCheckpointError::RkyvSerialize(e.to_string()))?;
151
152        // Build inner plaintext: magic + version + rkyv payload.
153        let inner_len = RTREE_RKYV_MAGIC.len() + 1 + rkyv_bytes.len();
154        let _guard = self
155            .governor()
156            .and_then(|gov| gov.reserve(nodedb_mem::EngineId::Spatial, inner_len).ok());
157        let mut inner = Vec::with_capacity(inner_len);
158        inner.extend_from_slice(RTREE_RKYV_MAGIC);
159        inner.push(RTREE_FORMAT_VERSION);
160        inner.extend_from_slice(&rkyv_bytes);
161
162        if let Some(key) = kek {
163            return encrypt_payload(key, &inner);
164        }
165
166        Ok(inner)
167    }
168
169    /// Restore an R-tree from checkpoint bytes.
170    ///
171    /// `kek` controls the expected framing:
172    /// - `None` → file must be plaintext (starting with `RKSPT\0`). If it
173    ///   starts with `SEGV`, returns `Err(MissingKek)`.
174    /// - `Some(key)` → encryption is **required**. If the file starts with
175    ///   `SEGV`, it is decrypted. If plaintext, returns `Err(KekRequired)`.
176    pub fn from_checkpoint(
177        bytes: &[u8],
178        kek: Option<&nodedb_wal::crypto::WalEncryptionKey>,
179    ) -> Result<Self, RTreeCheckpointError> {
180        let is_encrypted = bytes.len() >= 4 && bytes[0..4] == SEGV_MAGIC;
181
182        let inner: Vec<u8>;
183        let inner_ref: &[u8];
184
185        if is_encrypted {
186            if let Some(key) = kek {
187                inner = decrypt_payload(key, bytes)?;
188                inner_ref = &inner;
189            } else {
190                return Err(RTreeCheckpointError::MissingKek);
191            }
192        } else if kek.is_some() {
193            return Err(RTreeCheckpointError::KekRequired);
194        } else {
195            inner_ref = bytes;
196        }
197
198        Self::decode_plaintext_inner(inner_ref)
199    }
200
201    fn decode_plaintext_inner(bytes: &[u8]) -> Result<Self, RTreeCheckpointError> {
202        let header_len = RTREE_RKYV_MAGIC.len() + 1; // magic + version byte
203        if bytes.len() <= header_len || &bytes[..RTREE_RKYV_MAGIC.len()] != RTREE_RKYV_MAGIC {
204            return Err(RTreeCheckpointError::UnrecognizedFormat);
205        }
206        let version = bytes[RTREE_RKYV_MAGIC.len()];
207        if version != RTREE_FORMAT_VERSION {
208            return Err(RTreeCheckpointError::UnsupportedVersion {
209                found: version,
210                expected: RTREE_FORMAT_VERSION,
211            });
212        }
213        let rkyv_bytes = &bytes[header_len..];
214        let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(rkyv_bytes.len());
215        aligned.extend_from_slice(rkyv_bytes);
216        let snapshot: RTreeSnapshotRkyv =
217            rkyv::from_bytes::<RTreeSnapshotRkyv, rkyv::rancor::Error>(&aligned)
218                .map_err(|e| RTreeCheckpointError::RkyvDeserialize(e.to_string()))?;
219        Ok(RTree::bulk_load(snapshot.entries))
220    }
221}
222
223// ── Storage key helpers ────────────────────────────────────────────────────
224
225/// Build the storage key for an R-tree checkpoint.
226///
227/// Format: `{collection}\0{field}\0rtree`
228pub fn rtree_storage_key(collection: &str, field: &str) -> Vec<u8> {
229    let mut key = Vec::with_capacity(collection.len() + field.len() + 8);
230    key.extend_from_slice(collection.as_bytes());
231    key.push(0);
232    key.extend_from_slice(field.as_bytes());
233    key.push(0);
234    key.extend_from_slice(b"rtree");
235    key
236}
237
238/// Build the storage key for spatial index metadata.
239///
240/// Format: `{collection}\0{field}\0meta`
241pub fn meta_storage_key(collection: &str, field: &str) -> Vec<u8> {
242    let mut key = Vec::with_capacity(collection.len() + field.len() + 7);
243    key.extend_from_slice(collection.as_bytes());
244    key.push(0);
245    key.extend_from_slice(field.as_bytes());
246    key.push(0);
247    key.extend_from_slice(b"meta");
248    key
249}
250
251/// Serialize index metadata to bytes.
252pub fn serialize_meta(meta: &SpatialIndexMeta) -> Result<Vec<u8>, RTreeCheckpointError> {
253    zerompk::to_msgpack_vec(meta).map_err(RTreeCheckpointError::Serialize)
254}
255
256/// Deserialize index metadata from bytes.
257pub fn deserialize_meta(bytes: &[u8]) -> Result<SpatialIndexMeta, RTreeCheckpointError> {
258    zerompk::from_msgpack(bytes).map_err(RTreeCheckpointError::Deserialize)
259}
260
261// ── Error type ─────────────────────────────────────────────────────────────
262
263/// Errors during R-tree checkpoint operations.
264#[derive(Debug, thiserror::Error)]
265#[non_exhaustive]
266pub enum RTreeCheckpointError {
267    #[error("R-tree checkpoint serialization failed: {0}")]
268    Serialize(zerompk::Error),
269    #[error("R-tree checkpoint deserialization failed: {0}")]
270    Deserialize(zerompk::Error),
271    #[error("R-tree rkyv serialization failed: {0}")]
272    RkyvSerialize(String),
273    #[error("R-tree rkyv deserialization failed: {0}")]
274    RkyvDeserialize(String),
275    #[error("unsupported R-tree checkpoint version {found}; expected {expected}")]
276    UnsupportedVersion { found: u8, expected: u8 },
277    #[error("unrecognized R-tree checkpoint format (missing RKSPT\\0 magic)")]
278    UnrecognizedFormat,
279    /// Checkpoint is encrypted (starts with `SEGV`) but no KEK was supplied.
280    #[error(
281        "spatial checkpoint is encrypted but no encryption key was provided; \
282         cannot load an encrypted checkpoint without a key"
283    )]
284    MissingKek,
285    /// Checkpoint is plaintext but a KEK is configured (policy violation).
286    #[error(
287        "spatial checkpoint is plaintext but an encryption key is configured; \
288         refusing to load an unencrypted checkpoint when encryption is required"
289    )]
290    KekRequired,
291    /// AES-256-GCM encryption failed.
292    #[error("spatial checkpoint encryption failed: {0}")]
293    EncryptionFailed(String),
294    /// AES-256-GCM decryption failed.
295    #[error("spatial checkpoint decryption failed: {0}")]
296    DecryptionFailed(String),
297}
298
299// ── Tests ──────────────────────────────────────────────────────────────────
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    fn make_entry(id: u64, lng: f64, lat: f64) -> RTreeEntry {
306        RTreeEntry {
307            id,
308            bbox: BoundingBox::from_point(lng, lat),
309        }
310    }
311
312    #[test]
313    fn checkpoint_roundtrip_empty() {
314        let tree = RTree::new();
315        let bytes = tree.checkpoint_to_bytes(None).unwrap();
316        let restored = RTree::from_checkpoint(&bytes, None).unwrap();
317        assert_eq!(restored.len(), 0);
318    }
319
320    #[test]
321    fn checkpoint_roundtrip_entries() {
322        let mut tree = RTree::new();
323        for i in 0..100 {
324            tree.insert(make_entry(i, (i as f64) * 0.5, (i as f64) * 0.3));
325        }
326        assert_eq!(tree.len(), 100);
327
328        let bytes = tree.checkpoint_to_bytes(None).unwrap();
329        let restored = RTree::from_checkpoint(&bytes, None).unwrap();
330        assert_eq!(restored.len(), 100);
331
332        // All entries should be searchable.
333        let all = restored.search(&BoundingBox::new(-180.0, -90.0, 180.0, 90.0));
334        assert_eq!(all.len(), 100);
335    }
336
337    #[test]
338    fn checkpoint_preserves_ids() {
339        let mut tree = RTree::new();
340        tree.insert(make_entry(42, 10.0, 20.0));
341        tree.insert(make_entry(99, 30.0, 40.0));
342
343        let bytes = tree.checkpoint_to_bytes(None).unwrap();
344        let restored = RTree::from_checkpoint(&bytes, None).unwrap();
345
346        let results = restored.search(&BoundingBox::new(5.0, 15.0, 15.0, 25.0));
347        assert_eq!(results.len(), 1);
348        assert_eq!(results[0].id, 42);
349    }
350
351    #[test]
352    fn corrupted_bytes_returns_error() {
353        assert!(matches!(
354            RTree::from_checkpoint(&[0xFF, 0xFF, 0xFF], None),
355            Err(RTreeCheckpointError::UnrecognizedFormat)
356        ));
357    }
358
359    #[test]
360    fn meta_roundtrip() {
361        let meta = SpatialIndexMeta {
362            collection: "buildings".to_string(),
363            field: "geom".to_string(),
364            index_type: SpatialIndexType::RTree,
365            entry_count: 1000,
366            extent: Some(BoundingBox::new(-180.0, -90.0, 180.0, 90.0)),
367        };
368        let bytes = serialize_meta(&meta).unwrap();
369        let restored = deserialize_meta(&bytes).unwrap();
370        assert_eq!(restored.collection, "buildings");
371        assert_eq!(restored.entry_count, 1000);
372        assert_eq!(restored.index_type, SpatialIndexType::RTree);
373    }
374
375    #[test]
376    fn storage_key_format() {
377        let key = rtree_storage_key("buildings", "geom");
378        assert_eq!(key, b"buildings\0geom\0rtree");
379
380        let meta_key = meta_storage_key("buildings", "geom");
381        assert_eq!(meta_key, b"buildings\0geom\0meta");
382    }
383
384    #[test]
385    fn checkpoint_size_reasonable() {
386        let mut tree = RTree::new();
387        for i in 0..1000 {
388            tree.insert(make_entry(i, (i as f64) * 0.01, (i as f64) * 0.01));
389        }
390        let bytes = tree.checkpoint_to_bytes(None).unwrap();
391        // Each entry: id(8) + 4 f64(32) = ~40 bytes + rkyv overhead.
392        // 1000 entries ≈ 40-60KB is reasonable.
393        assert!(
394            bytes.len() < 100_000,
395            "checkpoint too large: {} bytes",
396            bytes.len()
397        );
398        assert!(
399            bytes.len() > 10_000,
400            "checkpoint too small: {} bytes",
401            bytes.len()
402        );
403    }
404
405    #[test]
406    fn golden_header_layout() {
407        let mut tree = RTree::new();
408        tree.insert(make_entry(1, 10.0, 20.0));
409        let bytes = tree.checkpoint_to_bytes(None).unwrap();
410        // Magic at bytes[0..6].
411        assert_eq!(&bytes[0..6], b"RKSPT\0");
412        // Version byte at bytes[6].
413        assert_eq!(bytes[6], super::RTREE_FORMAT_VERSION);
414        // rkyv payload follows immediately.
415        assert!(bytes.len() > 7);
416    }
417
418    #[test]
419    fn version_mismatch_returns_error() {
420        let mut tree = RTree::new();
421        tree.insert(make_entry(1, 10.0, 20.0));
422        let mut bytes = tree.checkpoint_to_bytes(None).unwrap();
423        // Corrupt the version byte to an unsupported value.
424        bytes[6] = 0;
425        match RTree::from_checkpoint(&bytes, None) {
426            Err(RTreeCheckpointError::UnsupportedVersion { found, expected }) => {
427                assert_eq!(found, 0);
428                assert_eq!(expected, super::RTREE_FORMAT_VERSION);
429            }
430            Err(other) => panic!("unexpected error: {other}"),
431            Ok(_) => panic!("expected UnsupportedVersion error, got Ok"),
432        }
433    }
434
435    fn make_test_kek() -> nodedb_wal::crypto::WalEncryptionKey {
436        nodedb_wal::crypto::WalEncryptionKey::from_bytes(&[0x42u8; 32]).unwrap()
437    }
438
439    #[test]
440    fn spatial_rtree_checkpoint_encrypted_at_rest() {
441        let kek = make_test_kek();
442        let mut tree = RTree::new();
443        for i in 0..50 {
444            tree.insert(make_entry(i, i as f64, i as f64 * 0.5));
445        }
446
447        let enc_bytes = tree.checkpoint_to_bytes(Some(&kek)).unwrap();
448
449        // Encrypted blob must start with SEGV, not RKSPT.
450        assert_eq!(&enc_bytes[0..4], b"SEGV");
451
452        // Round-trip: decrypt and verify all entries survive.
453        let restored = RTree::from_checkpoint(&enc_bytes, Some(&kek)).unwrap();
454        assert_eq!(restored.len(), 50);
455        let all = restored.search(&BoundingBox::new(-180.0, -90.0, 180.0, 90.0));
456        assert_eq!(all.len(), 50);
457    }
458
459    #[test]
460    fn spatial_rtree_refuses_plaintext_when_kek_required() {
461        let kek = make_test_kek();
462        let mut tree = RTree::new();
463        tree.insert(make_entry(1, 10.0, 20.0));
464
465        // Write plaintext checkpoint.
466        let plain_bytes = tree.checkpoint_to_bytes(None).unwrap();
467
468        // Attempting to load with a KEK must be refused.
469        assert!(matches!(
470            RTree::from_checkpoint(&plain_bytes, Some(&kek)),
471            Err(RTreeCheckpointError::KekRequired)
472        ));
473    }
474
475    #[test]
476    fn spatial_rtree_refuses_encrypted_without_kek() {
477        let kek = make_test_kek();
478        let mut tree = RTree::new();
479        tree.insert(make_entry(1, 10.0, 20.0));
480
481        let enc_bytes = tree.checkpoint_to_bytes(Some(&kek)).unwrap();
482
483        // Loading without a key must be refused.
484        assert!(matches!(
485            RTree::from_checkpoint(&enc_bytes, None),
486            Err(RTreeCheckpointError::MissingKek)
487        ));
488    }
489
490    #[test]
491    fn spatial_rtree_tampered_ciphertext_rejected() {
492        let kek = make_test_kek();
493        let mut tree = RTree::new();
494        tree.insert(make_entry(1, 10.0, 20.0));
495
496        let mut enc_bytes = tree.checkpoint_to_bytes(Some(&kek)).unwrap();
497        // Flip a byte in the ciphertext region (after the 16-byte preamble).
498        enc_bytes[20] ^= 0xFF;
499
500        assert!(matches!(
501            RTree::from_checkpoint(&enc_bytes, Some(&kek)),
502            Err(RTreeCheckpointError::DecryptionFailed(_))
503        ));
504    }
505}