use std::fs::File;
use std::io::Read;
use crate::snapshot::mapping::{validate_snapshot, SnapshotMappingError};
use crate::snapshot::model::Snapshot;
pub trait SnapshotLoader {
fn load(&mut self) -> Result<Option<Snapshot>, SnapshotLoaderError>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SnapshotLoaderError {
IoError(String),
DecodeError(String),
MappingError(SnapshotMappingError),
IncompatibleVersion {
expected: u32,
found: u32,
},
CrcMismatch {
expected: u32,
actual: u32,
},
NotFound,
}
impl std::fmt::Display for SnapshotLoaderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SnapshotLoaderError::IoError(e) => write!(f, "I/O error: {e}"),
SnapshotLoaderError::DecodeError(e) => write!(f, "Decode error: {e}"),
SnapshotLoaderError::MappingError(e) => write!(f, "Mapping error: {e}"),
SnapshotLoaderError::IncompatibleVersion { expected, found } => {
write!(f, "Incompatible snapshot version: expected {expected}, found {found}")
}
SnapshotLoaderError::CrcMismatch { expected, actual } => {
write!(
f,
"Snapshot CRC-32 mismatch: expected {expected:#010x}, actual {actual:#010x}"
)
}
SnapshotLoaderError::NotFound => write!(f, "Snapshot file not found"),
}
}
}
impl std::error::Error for SnapshotLoaderError {}
impl std::convert::From<std::io::Error> for SnapshotLoaderError {
fn from(err: std::io::Error) -> Self {
match err.kind() {
std::io::ErrorKind::NotFound => SnapshotLoaderError::NotFound,
_ => SnapshotLoaderError::IoError(err.to_string()),
}
}
}
pub struct SnapshotFsLoader {
path: std::path::PathBuf,
version: u32,
}
impl SnapshotFsLoader {
pub fn new(path: std::path::PathBuf) -> Self {
SnapshotFsLoader { path, version: 4 }
}
pub fn with_version(path: std::path::PathBuf, version: u32) -> Self {
SnapshotFsLoader { path, version }
}
fn open_file(&self) -> Result<Option<File>, SnapshotLoaderError> {
match File::open(&self.path) {
Ok(file) => Ok(Some(file)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(SnapshotLoaderError::IoError(e.to_string())),
}
}
fn read_exact_or_eof(
file: &mut File,
buf: &mut [u8],
frame_name: &str,
) -> Result<bool, SnapshotLoaderError> {
let mut total_read = 0;
while total_read < buf.len() {
match file.read(&mut buf[total_read..]) {
Ok(0) => {
return if total_read == 0 {
Ok(false)
} else {
Err(SnapshotLoaderError::DecodeError(format!(
"truncated snapshot {frame_name}: read {total_read} of {} bytes",
buf.len()
)))
};
}
Ok(n) => total_read += n,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(SnapshotLoaderError::IoError(e.to_string())),
}
}
Ok(true)
}
fn read_version(file: &mut File) -> Result<Option<u32>, SnapshotLoaderError> {
let mut buf = [0u8; 4];
if Self::read_exact_or_eof(file, &mut buf, "version frame")? {
Ok(Some(u32::from_le_bytes(buf)))
} else {
Ok(None)
}
}
fn read_length(file: &mut File) -> Result<usize, SnapshotLoaderError> {
let mut buf = [0u8; 4];
if Self::read_exact_or_eof(file, &mut buf, "length frame")? {
Ok(u32::from_le_bytes(buf) as usize)
} else {
Err(SnapshotLoaderError::DecodeError(
"truncated snapshot: EOF after version frame, expected length frame".to_string(),
))
}
}
fn read_crc(file: &mut File) -> Result<u32, SnapshotLoaderError> {
let mut buf = [0u8; 4];
if Self::read_exact_or_eof(file, &mut buf, "CRC-32 frame")? {
Ok(u32::from_le_bytes(buf))
} else {
Err(SnapshotLoaderError::DecodeError(
"truncated snapshot: EOF after length frame, expected CRC-32 frame".to_string(),
))
}
}
fn read_payload(file: &mut File, length: usize) -> Result<Vec<u8>, SnapshotLoaderError> {
const MAX_REASONABLE_PAYLOAD: usize = 256 * 1024 * 1024; if length > MAX_REASONABLE_PAYLOAD {
return Err(SnapshotLoaderError::DecodeError(format!(
"snapshot payload length {length} exceeds maximum {MAX_REASONABLE_PAYLOAD}"
)));
}
let mut payload = vec![0u8; length];
file.read_exact(&mut payload).map_err(|e| SnapshotLoaderError::IoError(e.to_string()))?;
Ok(payload)
}
fn decode_snapshot(payload: &[u8]) -> Result<Snapshot, SnapshotLoaderError> {
serde_json::from_slice(payload).map_err(|e| {
SnapshotLoaderError::DecodeError(format!("Failed to decode snapshot: {e}"))
})
}
}
impl SnapshotLoader for SnapshotFsLoader {
fn load(&mut self) -> Result<Option<Snapshot>, SnapshotLoaderError> {
let mut file = match self.open_file()? {
Some(f) => f,
None => return Ok(None),
};
let version = match Self::read_version(&mut file)? {
Some(v) => v,
None => return Ok(None),
};
if version != self.version {
tracing::warn!(
expected = self.version,
found = version,
"snapshot version incompatible"
);
return Err(SnapshotLoaderError::IncompatibleVersion {
expected: self.version,
found: version,
});
}
let length = Self::read_length(&mut file)?;
let expected_crc = Self::read_crc(&mut file)?;
let payload = Self::read_payload(&mut file, length)?;
let actual_crc = crc32fast::hash(&payload);
if expected_crc != actual_crc {
tracing::warn!(
expected = format_args!("{expected_crc:#010x}"),
actual = format_args!("{actual_crc:#010x}"),
"snapshot CRC-32 mismatch"
);
return Err(SnapshotLoaderError::CrcMismatch {
expected: expected_crc,
actual: actual_crc,
});
}
let snapshot = match Self::decode_snapshot(&payload) {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "snapshot decode failed");
return Err(e);
}
};
if let Err(e) = validate_snapshot(&snapshot) {
tracing::warn!(error = %e, "snapshot validation failed");
return Err(SnapshotLoaderError::MappingError(e));
}
Ok(Some(snapshot))
}
}
#[cfg(test)]
mod tests {
use std::fs;
use std::io::Write;
use super::*;
use crate::snapshot::mapping::{SnapshotMappingError, SNAPSHOT_SCHEMA_VERSION};
use crate::snapshot::model::{SnapshotEngineControl, SnapshotMetadata};
use crate::snapshot::writer::{SnapshotFsWriter, SnapshotWriter};
static TEST_COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
fn temp_snapshot_path() -> std::path::PathBuf {
let dir = std::env::temp_dir();
let count = TEST_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let path = dir.join(format!("actionqueue_snapshot_loader_test_{count}.tmp"));
let _ = fs::remove_file(&path);
path
}
fn open_snapshot_writer(path: std::path::PathBuf) -> SnapshotFsWriter {
SnapshotFsWriter::new(path).expect("Failed to open snapshot writer for loader test")
}
fn create_test_snapshot() -> Snapshot {
Snapshot {
version: 4,
timestamp: 1234567890,
metadata: SnapshotMetadata {
schema_version: SNAPSHOT_SCHEMA_VERSION,
wal_sequence: 0,
task_count: 0,
run_count: 0,
},
tasks: Vec::new(),
runs: Vec::new(),
engine: SnapshotEngineControl::default(),
dependency_declarations: Vec::new(),
budgets: Vec::new(),
subscriptions: Vec::new(),
actors: Vec::new(),
tenants: Vec::new(),
role_assignments: Vec::new(),
capability_grants: Vec::new(),
ledger_entries: Vec::new(),
}
}
#[test]
fn test_new_loader_on_nonexistent_file() {
let path = temp_snapshot_path();
let _ = fs::remove_file(&path);
let mut loader = SnapshotFsLoader::new(path.clone());
let result = loader.load();
assert!(matches!(result, Ok(None)));
}
#[test]
fn test_load_missing_file_via_with_version() {
let path = temp_snapshot_path();
let _ = fs::remove_file(&path);
let mut loader = SnapshotFsLoader::with_version(path.clone(), 3);
let result = loader.load();
assert!(matches!(result, Ok(None)));
}
#[test]
fn test_load_returns_snapshot() {
let path = temp_snapshot_path();
let snapshot = create_test_snapshot();
let mut writer = open_snapshot_writer(path.clone());
writer.write(&snapshot).expect("Write should succeed");
writer.close().expect("Close should succeed");
let mut loader = SnapshotFsLoader::new(path.clone());
let loaded = loader.load().expect("Load should succeed");
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().version, 4);
let _ = fs::remove_file(path);
}
#[test]
fn test_load_incompatible_version() {
let path = temp_snapshot_path();
{
let mut file = File::create(&path).expect("Failed to create test file");
file.write_all(&1u32.to_le_bytes()).unwrap();
let payload = b"{}";
file.write_all(&(payload.len() as u32).to_le_bytes()).unwrap();
file.write_all(&crc32fast::hash(payload).to_le_bytes()).unwrap();
file.write_all(payload).unwrap();
file.flush().unwrap();
}
let mut loader = SnapshotFsLoader::new(path.clone());
let result = loader.load();
assert!(matches!(result, Err(SnapshotLoaderError::IncompatibleVersion { .. })));
let _ = fs::remove_file(path);
}
#[test]
fn test_load_with_custom_version() {
let path = temp_snapshot_path();
{
let mut writer = open_snapshot_writer(path.clone());
let mut snapshot = create_test_snapshot();
snapshot.version = 4;
writer.write(&snapshot).expect("Write should succeed");
writer.close().expect("Close should succeed");
}
let mut loader = SnapshotFsLoader::with_version(path.clone(), 4);
let loaded = loader.load().expect("Load should succeed");
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().version, 4);
let _ = fs::remove_file(path);
}
#[test]
fn test_load_invalid_json() {
let path = temp_snapshot_path();
{
let mut file = File::create(&path).expect("Failed to create test file");
file.write_all(&4u32.to_le_bytes()).unwrap();
let payload = b"{invalid json}";
file.write_all(&(payload.len() as u32).to_le_bytes()).unwrap();
file.write_all(&crc32fast::hash(payload).to_le_bytes()).unwrap();
file.write_all(payload).unwrap();
file.flush().unwrap();
}
let mut loader = SnapshotFsLoader::new(path.clone());
let result = loader.load();
assert!(matches!(result, Err(SnapshotLoaderError::DecodeError(_))));
let _ = fs::remove_file(path);
}
#[test]
fn test_load_rejects_snapshot_mapping_violation() {
let path = temp_snapshot_path();
let snapshot = Snapshot {
version: 4,
timestamp: 1234567890,
metadata: SnapshotMetadata {
schema_version: SNAPSHOT_SCHEMA_VERSION,
wal_sequence: 0,
task_count: 1,
run_count: 0,
},
tasks: Vec::new(),
runs: Vec::new(),
engine: SnapshotEngineControl::default(),
dependency_declarations: Vec::new(),
budgets: Vec::new(),
subscriptions: Vec::new(),
actors: Vec::new(),
tenants: Vec::new(),
role_assignments: Vec::new(),
capability_grants: Vec::new(),
ledger_entries: Vec::new(),
};
let payload = serde_json::to_vec(&snapshot).expect("snapshot should serialize");
{
let mut file = File::create(&path).expect("Failed to create test file");
file.write_all(&4u32.to_le_bytes()).expect("version frame write should succeed");
file.write_all(&(payload.len() as u32).to_le_bytes())
.expect("length frame write should succeed");
let crc = crc32fast::hash(&payload);
file.write_all(&crc.to_le_bytes()).expect("crc frame write should succeed");
file.write_all(&payload).expect("payload frame write should succeed");
file.flush().expect("flush should succeed");
}
let mut loader = SnapshotFsLoader::new(path.clone());
let result = loader.load();
assert!(matches!(
result,
Err(SnapshotLoaderError::MappingError(SnapshotMappingError::TaskCountMismatch {
declared: 1,
actual: 0
}))
));
let _ = fs::remove_file(path);
}
#[test]
fn test_error_display() {
assert_eq!(SnapshotLoaderError::NotFound.to_string(), "Snapshot file not found");
assert_eq!(
SnapshotLoaderError::IoError("test error".to_string()).to_string(),
"I/O error: test error"
);
assert_eq!(
SnapshotLoaderError::DecodeError("test error".to_string()).to_string(),
"Decode error: test error"
);
assert_eq!(
SnapshotLoaderError::MappingError(SnapshotMappingError::TaskCountMismatch {
declared: 1,
actual: 0
})
.to_string(),
"Mapping error: snapshot task_count mismatch: declared 1, actual 0"
);
assert_eq!(
SnapshotLoaderError::IncompatibleVersion { expected: 4, found: 1 }.to_string(),
"Incompatible snapshot version: expected 4, found 1"
);
assert_eq!(
SnapshotLoaderError::CrcMismatch { expected: 0xDEADBEEF, actual: 0x12345678 }
.to_string(),
"Snapshot CRC-32 mismatch: expected 0xdeadbeef, actual 0x12345678"
);
}
#[test]
fn test_load_rejects_old_version_3_format() {
let path = temp_snapshot_path();
{
let mut file = File::create(&path).expect("Failed to create test file");
file.write_all(&3u32.to_le_bytes()).unwrap();
let payload = b"{}";
file.write_all(&(payload.len() as u32).to_le_bytes()).unwrap();
file.write_all(payload).unwrap();
file.flush().unwrap();
}
let mut loader = SnapshotFsLoader::new(path.clone());
let result = loader.load();
assert!(matches!(
result,
Err(SnapshotLoaderError::IncompatibleVersion { expected: 4, found: 3 })
));
let _ = fs::remove_file(path);
}
#[test]
fn test_load_detects_corrupted_payload() {
let path = temp_snapshot_path();
let snapshot = create_test_snapshot();
let mut writer = open_snapshot_writer(path.clone());
writer.write(&snapshot).expect("Write should succeed");
writer.close().expect("Close should succeed");
{
let mut bytes = fs::read(&path).expect("snapshot file should be readable");
assert!(bytes.len() > 12, "snapshot file should have header + payload");
bytes[12] ^= 0xFF;
fs::write(&path, &bytes).expect("corrupted snapshot should be writable");
}
let mut loader = SnapshotFsLoader::new(path.clone());
let result = loader.load();
assert!(
matches!(result, Err(SnapshotLoaderError::CrcMismatch { .. })),
"corrupted payload should produce CrcMismatch, got: {result:?}"
);
let _ = fs::remove_file(path);
}
}