#![allow(clippy::cast_possible_truncation)]
use std::fs::File;
use std::io::{self, Read, Write};
use std::path::Path;
pub const SNAPSHOT_MAGIC: &[u8; 4] = b"VAMM";
pub const SNAPSHOT_VERSION: u8 = 1;
#[inline]
fn crc32_hash(data: &[u8]) -> u32 {
const CRC32_TABLE: [u32; 256] = {
let mut table = [0u32; 256];
let mut i = 0;
while i < 256 {
let mut crc = i as u32;
let mut j = 0;
while j < 8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ 0xEDB8_8320;
} else {
crc >>= 1;
}
j += 1;
}
table[i] = crc;
i += 1;
}
table
};
let mut crc = 0xFFFF_FFFF_u32;
for &byte in data {
let idx = ((crc ^ u32::from(byte)) & 0xFF) as usize;
crc = (crc >> 8) ^ CRC32_TABLE[idx];
}
!crc
}
#[derive(Debug, Clone, Default)]
pub struct MemoryState {
pub semantic: Vec<u8>,
pub episodic: Vec<u8>,
pub procedural: Vec<u8>,
pub ttl: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct SnapshotMetadata {
pub version: u8,
pub total_size: usize,
pub checksum: u32,
}
#[derive(Debug)]
pub enum SnapshotError {
Io(io::Error),
InvalidMagic,
UnsupportedVersion(u8),
ChecksumMismatch {
expected: u32,
actual: u32,
},
CorruptedData(String),
}
impl std::fmt::Display for SnapshotError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "IO error: {e}"),
Self::InvalidMagic => write!(f, "Invalid snapshot magic bytes"),
Self::UnsupportedVersion(v) => write!(f, "Unsupported snapshot version: {v}"),
Self::ChecksumMismatch { expected, actual } => {
write!(
f,
"Checksum mismatch: expected {expected:08x}, got {actual:08x}"
)
}
Self::CorruptedData(msg) => write!(f, "Corrupted data: {msg}"),
}
}
}
impl std::error::Error for SnapshotError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
_ => None,
}
}
}
impl From<io::Error> for SnapshotError {
fn from(e: io::Error) -> Self {
Self::Io(e)
}
}
#[must_use]
pub fn create_snapshot(state: &MemoryState) -> Vec<u8> {
let total_size = 4
+ 1
+ 8
+ state.semantic.len()
+ 8
+ state.episodic.len()
+ 8
+ state.procedural.len()
+ 8
+ state.ttl.len()
+ 4;
let mut buf = Vec::with_capacity(total_size);
buf.extend_from_slice(SNAPSHOT_MAGIC);
buf.push(SNAPSHOT_VERSION);
buf.extend_from_slice(&(state.semantic.len() as u64).to_le_bytes());
buf.extend_from_slice(&state.semantic);
buf.extend_from_slice(&(state.episodic.len() as u64).to_le_bytes());
buf.extend_from_slice(&state.episodic);
buf.extend_from_slice(&(state.procedural.len() as u64).to_le_bytes());
buf.extend_from_slice(&state.procedural);
buf.extend_from_slice(&(state.ttl.len() as u64).to_le_bytes());
buf.extend_from_slice(&state.ttl);
let crc = crc32_hash(&buf);
buf.extend_from_slice(&crc.to_le_bytes());
buf
}
pub fn load_snapshot(data: &[u8]) -> Result<MemoryState, SnapshotError> {
validate_snapshot_header(data)?;
let mut offset = 5; let payload_end = data.len() - 4;
let semantic = read_section(data, &mut offset, payload_end, "Semantic")?;
let episodic = read_section(data, &mut offset, payload_end, "Episodic")?;
let procedural = read_section(data, &mut offset, payload_end, "Procedural")?;
let ttl = read_section(data, &mut offset, payload_end, "TTL")?;
Ok(MemoryState {
semantic,
episodic,
procedural,
ttl,
})
}
fn validate_snapshot_header(data: &[u8]) -> Result<(), SnapshotError> {
const MIN_SIZE: usize = 4 + 1 + 8 + 8 + 8 + 8 + 4;
if data.len() < MIN_SIZE {
return Err(SnapshotError::CorruptedData(
"Snapshot too small".to_string(),
));
}
if &data[0..4] != SNAPSHOT_MAGIC {
return Err(SnapshotError::InvalidMagic);
}
let version = data[4];
if version != SNAPSHOT_VERSION {
return Err(SnapshotError::UnsupportedVersion(version));
}
let stored_crc = u32::from_le_bytes(
data[data.len() - 4..]
.try_into()
.map_err(|_| SnapshotError::CorruptedData("Invalid CRC bytes".to_string()))?,
);
let computed_crc = crc32_hash(&data[..data.len() - 4]);
if stored_crc != computed_crc {
return Err(SnapshotError::ChecksumMismatch {
expected: stored_crc,
actual: computed_crc,
});
}
Ok(())
}
fn read_section(
data: &[u8],
offset: &mut usize,
payload_end: usize,
label: &str,
) -> Result<Vec<u8>, SnapshotError> {
let section_len = read_u64(&data[*offset..])? as usize;
*offset += 8;
if *offset + section_len > payload_end {
return Err(SnapshotError::CorruptedData(format!(
"{label} data truncated"
)));
}
let section = data[*offset..*offset + section_len].to_vec();
*offset += section_len;
Ok(section)
}
pub fn save_snapshot_to_file<P: AsRef<Path>>(
path: P,
state: &MemoryState,
) -> Result<(), SnapshotError> {
let path = path.as_ref();
let snapshot_data = create_snapshot(state);
let temp_path = path.with_extension("tmp");
let mut file = File::create(&temp_path)?;
file.write_all(&snapshot_data)?;
file.sync_all()?;
drop(file);
std::fs::rename(&temp_path, path)?;
Ok(())
}
pub fn load_snapshot_from_file<P: AsRef<Path>>(path: P) -> Result<MemoryState, SnapshotError> {
let mut file = File::open(path)?;
let mut data = Vec::new();
file.read_to_end(&mut data)?;
load_snapshot(&data)
}
fn read_u64(data: &[u8]) -> Result<u64, SnapshotError> {
if data.len() < 8 {
return Err(SnapshotError::CorruptedData(
"Not enough bytes for u64".to_string(),
));
}
Ok(u64::from_le_bytes(data[0..8].try_into().map_err(|_| {
SnapshotError::CorruptedData("Invalid u64 bytes".to_string())
})?))
}
pub struct SnapshotManager {
base_path: std::path::PathBuf,
max_snapshots: usize,
}
impl SnapshotManager {
pub fn new<P: AsRef<Path>>(base_path: P, max_snapshots: usize) -> Self {
Self {
base_path: base_path.as_ref().to_path_buf(),
max_snapshots,
}
}
pub fn create_versioned_snapshot(&self, state: &MemoryState) -> Result<u64, SnapshotError> {
std::fs::create_dir_all(&self.base_path)?;
let version = self.next_version()?;
let filename = format!("snapshot_{version:08}.vamm");
let path = self.base_path.join(filename);
save_snapshot_to_file(&path, state)?;
self.cleanup_old_snapshots()?;
Ok(version)
}
pub fn load_latest(&self) -> Result<(u64, MemoryState), SnapshotError> {
let version = self
.latest_version()?
.ok_or_else(|| SnapshotError::CorruptedData("No snapshots found".to_string()))?;
let state = self.load_version(version)?;
Ok((version, state))
}
pub fn load_version(&self, version: u64) -> Result<MemoryState, SnapshotError> {
let filename = format!("snapshot_{version:08}.vamm");
let path = self.base_path.join(filename);
load_snapshot_from_file(&path)
}
pub fn list_versions(&self) -> Result<Vec<u64>, SnapshotError> {
if !self.base_path.exists() {
return Ok(Vec::new());
}
let mut versions = Vec::new();
for entry in std::fs::read_dir(&self.base_path)? {
let entry = entry?;
let filename = entry.file_name();
let filename_str = filename.to_string_lossy();
if filename_str.starts_with("snapshot_") && filename_str.ends_with(".vamm") {
if let Some(version_str) = filename_str
.strip_prefix("snapshot_")
.and_then(|s| s.strip_suffix(".vamm"))
{
if let Ok(version) = version_str.parse::<u64>() {
versions.push(version);
}
}
}
}
versions.sort_unstable();
Ok(versions)
}
fn latest_version(&self) -> Result<Option<u64>, SnapshotError> {
Ok(self.list_versions()?.into_iter().max())
}
fn next_version(&self) -> Result<u64, SnapshotError> {
Ok(self.latest_version()?.map_or(1, |v| v + 1))
}
fn cleanup_old_snapshots(&self) -> Result<(), SnapshotError> {
let versions = self.list_versions()?;
if versions.len() <= self.max_snapshots {
return Ok(());
}
let to_remove = versions.len() - self.max_snapshots;
for version in versions.into_iter().take(to_remove) {
let filename = format!("snapshot_{version:08}.vamm");
let path = self.base_path.join(filename);
let _ = std::fs::remove_file(path);
}
Ok(())
}
}