use nodedb_types::BoundingBox;
use serde::{Deserialize, Serialize};
use zerompk::{FromMessagePack, ToMessagePack};
use crate::rtree::{RTree, RTreeEntry};
const SEGV_MAGIC: [u8; 4] = *b"SEGV";
const RTREE_RKYV_MAGIC: &[u8; 6] = b"RKSPT\0";
pub const RTREE_FORMAT_VERSION: u8 = 1;
fn encrypt_payload(
key: &nodedb_wal::crypto::WalEncryptionKey,
plaintext: &[u8],
) -> Result<Vec<u8>, RTreeCheckpointError> {
nodedb_wal::crypto::encrypt_segment_envelope(key, &SEGV_MAGIC, plaintext)
.map_err(|e| RTreeCheckpointError::EncryptionFailed(e.to_string()))
}
fn decrypt_payload(
key: &nodedb_wal::crypto::WalEncryptionKey,
blob: &[u8],
) -> Result<Vec<u8>, RTreeCheckpointError> {
nodedb_wal::crypto::decrypt_segment_envelope(key, &SEGV_MAGIC, blob)
.map_err(|e| RTreeCheckpointError::DecryptionFailed(e.to_string()))
}
pub(crate) fn encrypt_geohash_payload(
key: &nodedb_wal::crypto::WalEncryptionKey,
plaintext: &[u8],
) -> Result<Vec<u8>, RTreeCheckpointError> {
encrypt_payload(key, plaintext)
}
pub(crate) fn decrypt_geohash_payload(
key: &nodedb_wal::crypto::WalEncryptionKey,
blob: &[u8],
) -> Result<Vec<u8>, RTreeCheckpointError> {
decrypt_payload(key, blob)
}
#[derive(Debug, Clone, Serialize, Deserialize, ToMessagePack, FromMessagePack)]
pub struct SpatialIndexMeta {
pub collection: String,
pub field: String,
pub index_type: SpatialIndexType,
pub entry_count: u64,
pub extent: Option<BoundingBox>,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToMessagePack, FromMessagePack,
)]
#[msgpack(c_enum)]
#[repr(u8)]
#[non_exhaustive]
pub enum SpatialIndexType {
RTree = 0,
Geohash = 1,
}
impl SpatialIndexType {
pub fn as_str(&self) -> &'static str {
match self {
Self::RTree => "rtree",
Self::Geohash => "geohash",
}
}
}
impl std::fmt::Display for SpatialIndexType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
struct RTreeSnapshotRkyv {
entries: Vec<RTreeEntry>,
}
impl RTree {
pub fn checkpoint_to_bytes(
&self,
kek: Option<&nodedb_wal::crypto::WalEncryptionKey>,
) -> Result<Vec<u8>, RTreeCheckpointError> {
let snapshot = RTreeSnapshotRkyv {
entries: self.entries().into_iter().cloned().collect(),
};
let rkyv_bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&snapshot)
.map_err(|e| RTreeCheckpointError::RkyvSerialize(e.to_string()))?;
let inner_len = RTREE_RKYV_MAGIC.len() + 1 + rkyv_bytes.len();
let _guard = self
.governor()
.and_then(|gov| gov.reserve(nodedb_mem::EngineId::Spatial, inner_len).ok());
let mut inner = Vec::with_capacity(inner_len);
inner.extend_from_slice(RTREE_RKYV_MAGIC);
inner.push(RTREE_FORMAT_VERSION);
inner.extend_from_slice(&rkyv_bytes);
if let Some(key) = kek {
return encrypt_payload(key, &inner);
}
Ok(inner)
}
pub fn from_checkpoint(
bytes: &[u8],
kek: Option<&nodedb_wal::crypto::WalEncryptionKey>,
) -> Result<Self, RTreeCheckpointError> {
let is_encrypted = bytes.len() >= 4 && bytes[0..4] == SEGV_MAGIC;
let inner: Vec<u8>;
let inner_ref: &[u8];
if is_encrypted {
if let Some(key) = kek {
inner = decrypt_payload(key, bytes)?;
inner_ref = &inner;
} else {
return Err(RTreeCheckpointError::MissingKek);
}
} else if kek.is_some() {
return Err(RTreeCheckpointError::KekRequired);
} else {
inner_ref = bytes;
}
Self::decode_plaintext_inner(inner_ref)
}
fn decode_plaintext_inner(bytes: &[u8]) -> Result<Self, RTreeCheckpointError> {
let header_len = RTREE_RKYV_MAGIC.len() + 1; if bytes.len() <= header_len || &bytes[..RTREE_RKYV_MAGIC.len()] != RTREE_RKYV_MAGIC {
return Err(RTreeCheckpointError::UnrecognizedFormat);
}
let version = bytes[RTREE_RKYV_MAGIC.len()];
if version != RTREE_FORMAT_VERSION {
return Err(RTreeCheckpointError::UnsupportedVersion {
found: version,
expected: RTREE_FORMAT_VERSION,
});
}
let rkyv_bytes = &bytes[header_len..];
let mut aligned = rkyv::util::AlignedVec::<16>::with_capacity(rkyv_bytes.len());
aligned.extend_from_slice(rkyv_bytes);
let snapshot: RTreeSnapshotRkyv =
rkyv::from_bytes::<RTreeSnapshotRkyv, rkyv::rancor::Error>(&aligned)
.map_err(|e| RTreeCheckpointError::RkyvDeserialize(e.to_string()))?;
Ok(RTree::bulk_load(snapshot.entries))
}
}
pub fn rtree_storage_key(collection: &str, field: &str) -> Vec<u8> {
let mut key = Vec::with_capacity(collection.len() + field.len() + 8);
key.extend_from_slice(collection.as_bytes());
key.push(0);
key.extend_from_slice(field.as_bytes());
key.push(0);
key.extend_from_slice(b"rtree");
key
}
pub fn meta_storage_key(collection: &str, field: &str) -> Vec<u8> {
let mut key = Vec::with_capacity(collection.len() + field.len() + 7);
key.extend_from_slice(collection.as_bytes());
key.push(0);
key.extend_from_slice(field.as_bytes());
key.push(0);
key.extend_from_slice(b"meta");
key
}
pub fn serialize_meta(meta: &SpatialIndexMeta) -> Result<Vec<u8>, RTreeCheckpointError> {
zerompk::to_msgpack_vec(meta).map_err(RTreeCheckpointError::Serialize)
}
pub fn deserialize_meta(bytes: &[u8]) -> Result<SpatialIndexMeta, RTreeCheckpointError> {
zerompk::from_msgpack(bytes).map_err(RTreeCheckpointError::Deserialize)
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum RTreeCheckpointError {
#[error("R-tree checkpoint serialization failed: {0}")]
Serialize(zerompk::Error),
#[error("R-tree checkpoint deserialization failed: {0}")]
Deserialize(zerompk::Error),
#[error("R-tree rkyv serialization failed: {0}")]
RkyvSerialize(String),
#[error("R-tree rkyv deserialization failed: {0}")]
RkyvDeserialize(String),
#[error("unsupported R-tree checkpoint version {found}; expected {expected}")]
UnsupportedVersion { found: u8, expected: u8 },
#[error("unrecognized R-tree checkpoint format (missing RKSPT\\0 magic)")]
UnrecognizedFormat,
#[error(
"spatial checkpoint is encrypted but no encryption key was provided; \
cannot load an encrypted checkpoint without a key"
)]
MissingKek,
#[error(
"spatial checkpoint is plaintext but an encryption key is configured; \
refusing to load an unencrypted checkpoint when encryption is required"
)]
KekRequired,
#[error("spatial checkpoint encryption failed: {0}")]
EncryptionFailed(String),
#[error("spatial checkpoint decryption failed: {0}")]
DecryptionFailed(String),
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entry(id: u64, lng: f64, lat: f64) -> RTreeEntry {
RTreeEntry {
id,
bbox: BoundingBox::from_point(lng, lat),
}
}
#[test]
fn checkpoint_roundtrip_empty() {
let tree = RTree::new();
let bytes = tree.checkpoint_to_bytes(None).unwrap();
let restored = RTree::from_checkpoint(&bytes, None).unwrap();
assert_eq!(restored.len(), 0);
}
#[test]
fn checkpoint_roundtrip_entries() {
let mut tree = RTree::new();
for i in 0..100 {
tree.insert(make_entry(i, (i as f64) * 0.5, (i as f64) * 0.3));
}
assert_eq!(tree.len(), 100);
let bytes = tree.checkpoint_to_bytes(None).unwrap();
let restored = RTree::from_checkpoint(&bytes, None).unwrap();
assert_eq!(restored.len(), 100);
let all = restored.search(&BoundingBox::new(-180.0, -90.0, 180.0, 90.0));
assert_eq!(all.len(), 100);
}
#[test]
fn checkpoint_preserves_ids() {
let mut tree = RTree::new();
tree.insert(make_entry(42, 10.0, 20.0));
tree.insert(make_entry(99, 30.0, 40.0));
let bytes = tree.checkpoint_to_bytes(None).unwrap();
let restored = RTree::from_checkpoint(&bytes, None).unwrap();
let results = restored.search(&BoundingBox::new(5.0, 15.0, 15.0, 25.0));
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 42);
}
#[test]
fn corrupted_bytes_returns_error() {
assert!(matches!(
RTree::from_checkpoint(&[0xFF, 0xFF, 0xFF], None),
Err(RTreeCheckpointError::UnrecognizedFormat)
));
}
#[test]
fn meta_roundtrip() {
let meta = SpatialIndexMeta {
collection: "buildings".to_string(),
field: "geom".to_string(),
index_type: SpatialIndexType::RTree,
entry_count: 1000,
extent: Some(BoundingBox::new(-180.0, -90.0, 180.0, 90.0)),
};
let bytes = serialize_meta(&meta).unwrap();
let restored = deserialize_meta(&bytes).unwrap();
assert_eq!(restored.collection, "buildings");
assert_eq!(restored.entry_count, 1000);
assert_eq!(restored.index_type, SpatialIndexType::RTree);
}
#[test]
fn storage_key_format() {
let key = rtree_storage_key("buildings", "geom");
assert_eq!(key, b"buildings\0geom\0rtree");
let meta_key = meta_storage_key("buildings", "geom");
assert_eq!(meta_key, b"buildings\0geom\0meta");
}
#[test]
fn checkpoint_size_reasonable() {
let mut tree = RTree::new();
for i in 0..1000 {
tree.insert(make_entry(i, (i as f64) * 0.01, (i as f64) * 0.01));
}
let bytes = tree.checkpoint_to_bytes(None).unwrap();
assert!(
bytes.len() < 100_000,
"checkpoint too large: {} bytes",
bytes.len()
);
assert!(
bytes.len() > 10_000,
"checkpoint too small: {} bytes",
bytes.len()
);
}
#[test]
fn golden_header_layout() {
let mut tree = RTree::new();
tree.insert(make_entry(1, 10.0, 20.0));
let bytes = tree.checkpoint_to_bytes(None).unwrap();
assert_eq!(&bytes[0..6], b"RKSPT\0");
assert_eq!(bytes[6], super::RTREE_FORMAT_VERSION);
assert!(bytes.len() > 7);
}
#[test]
fn version_mismatch_returns_error() {
let mut tree = RTree::new();
tree.insert(make_entry(1, 10.0, 20.0));
let mut bytes = tree.checkpoint_to_bytes(None).unwrap();
bytes[6] = 0;
match RTree::from_checkpoint(&bytes, None) {
Err(RTreeCheckpointError::UnsupportedVersion { found, expected }) => {
assert_eq!(found, 0);
assert_eq!(expected, super::RTREE_FORMAT_VERSION);
}
Err(other) => panic!("unexpected error: {other}"),
Ok(_) => panic!("expected UnsupportedVersion error, got Ok"),
}
}
fn make_test_kek() -> nodedb_wal::crypto::WalEncryptionKey {
nodedb_wal::crypto::WalEncryptionKey::from_bytes(&[0x42u8; 32]).unwrap()
}
#[test]
fn spatial_rtree_checkpoint_encrypted_at_rest() {
let kek = make_test_kek();
let mut tree = RTree::new();
for i in 0..50 {
tree.insert(make_entry(i, i as f64, i as f64 * 0.5));
}
let enc_bytes = tree.checkpoint_to_bytes(Some(&kek)).unwrap();
assert_eq!(&enc_bytes[0..4], b"SEGV");
let restored = RTree::from_checkpoint(&enc_bytes, Some(&kek)).unwrap();
assert_eq!(restored.len(), 50);
let all = restored.search(&BoundingBox::new(-180.0, -90.0, 180.0, 90.0));
assert_eq!(all.len(), 50);
}
#[test]
fn spatial_rtree_refuses_plaintext_when_kek_required() {
let kek = make_test_kek();
let mut tree = RTree::new();
tree.insert(make_entry(1, 10.0, 20.0));
let plain_bytes = tree.checkpoint_to_bytes(None).unwrap();
assert!(matches!(
RTree::from_checkpoint(&plain_bytes, Some(&kek)),
Err(RTreeCheckpointError::KekRequired)
));
}
#[test]
fn spatial_rtree_refuses_encrypted_without_kek() {
let kek = make_test_kek();
let mut tree = RTree::new();
tree.insert(make_entry(1, 10.0, 20.0));
let enc_bytes = tree.checkpoint_to_bytes(Some(&kek)).unwrap();
assert!(matches!(
RTree::from_checkpoint(&enc_bytes, None),
Err(RTreeCheckpointError::MissingKek)
));
}
#[test]
fn spatial_rtree_tampered_ciphertext_rejected() {
let kek = make_test_kek();
let mut tree = RTree::new();
tree.insert(make_entry(1, 10.0, 20.0));
let mut enc_bytes = tree.checkpoint_to_bytes(Some(&kek)).unwrap();
enc_bytes[20] ^= 0xFF;
assert!(matches!(
RTree::from_checkpoint(&enc_bytes, Some(&kek)),
Err(RTreeCheckpointError::DecryptionFailed(_))
));
}
}