use crate::catalog::Catalog;
use crate::checkpoint::writer::CheckpointData;
use crate::commit::tx::{IdempotencyKey, IdempotencyRecord};
use crate::error::AedbError;
use crate::storage::keyspace::Keyspace;
use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
pub fn load_checkpoint(
path: &Path,
) -> Result<
(
Keyspace,
Catalog,
u64,
HashMap<IdempotencyKey, IdempotencyRecord>,
),
AedbError,
> {
load_checkpoint_with_key(path, None)
}
pub fn load_checkpoint_with_key(
path: &Path,
encryption_key: Option<&[u8; 32]>,
) -> Result<
(
Keyspace,
Catalog,
u64,
HashMap<IdempotencyKey, IdempotencyRecord>,
),
AedbError,
> {
let bytes = fs::read(path)?;
let compressed = if bytes.starts_with(b"AEDBENC1") {
let key = encryption_key
.ok_or_else(|| AedbError::Validation("checkpoint requires key".into()))?;
decrypt_checkpoint_payload(&bytes, key)?
} else {
if bytes.len() < 32 {
return Err(AedbError::Decode("checkpoint too small".into()));
}
let (compressed, trailer_hash) = bytes.split_at(bytes.len() - 32);
let actual = Sha256::digest(compressed);
if actual.as_slice() != trailer_hash {
return Err(AedbError::Validation("checkpoint hash mismatch".into()));
}
compressed.to_vec()
};
let decompressed = zstd::stream::decode_all(compressed.as_slice())
.map_err(|e| AedbError::Io(std::io::Error::other(e.to_string())))?;
let data: CheckpointData =
rmp_serde::from_slice(&decompressed).map_err(|e| AedbError::Decode(e.to_string()))?;
let mut keyspace = Keyspace {
primary_index_backend: data.keyspace.primary_index_backend,
value_store: None,
kv_segment_store: None,
persistent_value_inline_threshold_bytes: usize::MAX,
namespaces: data.keyspace.namespaces,
async_indexes: data.keyspace.async_indexes,
mem_bytes: 0,
};
keyspace.refresh_mem_bytes();
Ok((keyspace, data.catalog, data.seq, data.idempotency))
}
fn decrypt_checkpoint_payload(bytes: &[u8], key: &[u8; 32]) -> Result<Vec<u8>, AedbError> {
const ENCRYPTED_MAGIC_SIZE_BYTES: usize = 8;
const NONCE_SIZE_BYTES: usize = 12;
let encrypted_header_size_bytes = ENCRYPTED_MAGIC_SIZE_BYTES + NONCE_SIZE_BYTES;
if bytes.len() < encrypted_header_size_bytes {
return Err(AedbError::Decode("encrypted checkpoint too small".into()));
}
let nonce_offset_bytes = ENCRYPTED_MAGIC_SIZE_BYTES;
let ciphertext_offset_bytes = encrypted_header_size_bytes;
debug_assert!(ciphertext_offset_bytes <= bytes.len());
let nonce = Nonce::from_slice(&bytes[nonce_offset_bytes..ciphertext_offset_bytes]);
let ciphertext = &bytes[ciphertext_offset_bytes..];
let cipher = Aes256Gcm::new_from_slice(key)
.map_err(|e| AedbError::Validation(format!("invalid encryption key: {e}")))?;
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| AedbError::Validation(format!("checkpoint decryption failed: {e}")))
}
#[cfg(test)]
mod tests {
use super::{load_checkpoint, load_checkpoint_with_key};
use crate::catalog::Catalog;
use crate::catalog::schema::ColumnDef;
use crate::catalog::types::{ColumnType, Row, Value};
use crate::checkpoint::writer::{write_checkpoint, write_checkpoint_with_key};
use crate::storage::keyspace::Keyspace;
use tempfile::tempdir;
#[test]
fn checkpoint_roundtrip_preserves_state() {
let dir = tempdir().expect("temp");
let mut keyspace = Keyspace::default();
let mut catalog = Catalog::default();
for p in ["p1", "p2", "p3"] {
catalog.create_project(p).expect("project");
for t in ["t1", "t2"] {
catalog
.create_table(
p,
"app",
t,
vec![
ColumnDef {
name: "id".into(),
col_type: ColumnType::Integer,
nullable: false,
},
ColumnDef {
name: "name".into(),
col_type: ColumnType::Text,
nullable: false,
},
],
vec!["id".into()],
)
.expect("table");
for i in 0..1000 {
keyspace.upsert_row(
p,
"app",
t,
vec![Value::Integer(i)],
Row {
values: vec![
Value::Integer(i),
Value::Text(format!("{p}-{t}-{i}").into()),
],
},
i as u64,
);
}
}
}
let snapshot = keyspace.snapshot();
let meta = write_checkpoint(&snapshot, &catalog, 42, dir.path()).expect("checkpoint");
let (loaded_ks, loaded_cat, seq, _) =
load_checkpoint(&dir.path().join(meta.filename)).expect("load");
assert_eq!(seq, 42);
assert_eq!(loaded_ks, keyspace);
assert_eq!(loaded_cat, catalog);
}
#[test]
fn checkpoint_loader_rejects_hash_mismatch() {
let dir = tempdir().expect("temp");
let keyspace = Keyspace::default();
let catalog = Catalog::default();
let meta = write_checkpoint(&keyspace.snapshot(), &catalog, 1, dir.path()).expect("write");
let path = dir.path().join(meta.filename);
let mut bytes = std::fs::read(&path).expect("read");
bytes[0] ^= 0xAA;
std::fs::write(&path, bytes).expect("write");
assert!(load_checkpoint(&path).is_err());
}
#[test]
fn checkpoint_encryption_roundtrip() {
let dir = tempdir().expect("temp");
let mut keyspace = Keyspace::default();
let mut catalog = Catalog::default();
catalog.create_project("p").expect("project");
keyspace.upsert_row(
"p",
"app",
"t",
vec![Value::Integer(1)],
Row {
values: vec![Value::Integer(1)],
},
1,
);
let key = [9u8; 32];
let meta = write_checkpoint_with_key(
&keyspace.snapshot(),
&catalog,
7,
dir.path(),
Some(&key),
Some("k1".into()),
std::collections::HashMap::new(),
3,
)
.expect("write");
assert_eq!(meta.key_id.as_deref(), Some("k1"));
let path = dir.path().join(meta.filename);
assert!(load_checkpoint(&path).is_err());
let (ks, _, seq, _) = load_checkpoint_with_key(&path, Some(&key)).expect("load");
assert_eq!(seq, 7);
assert_eq!(ks, keyspace);
}
#[test]
fn encrypted_checkpoints_use_distinct_random_nonces() {
let dir = tempdir().expect("temp");
let keyspace = Keyspace::default();
let catalog = Catalog::default();
let key = [7u8; 32];
let first = write_checkpoint_with_key(
&keyspace.snapshot(),
&catalog,
42,
dir.path(),
Some(&key),
Some("k1".into()),
std::collections::HashMap::new(),
3,
)
.expect("write first");
let first_bytes = std::fs::read(dir.path().join(&first.filename)).expect("read first");
let second = write_checkpoint_with_key(
&keyspace.snapshot(),
&catalog,
42,
dir.path(),
Some(&key),
Some("k1".into()),
std::collections::HashMap::new(),
3,
)
.expect("write second");
let second_bytes = std::fs::read(dir.path().join(second.filename)).expect("read second");
assert!(first_bytes.starts_with(b"AEDBENC1"));
assert!(second_bytes.starts_with(b"AEDBENC1"));
assert_ne!(
&first_bytes[8..20],
&second_bytes[8..20],
"encrypted checkpoints must not reuse nonces"
);
}
}