use crate::error::{PersistenceError, PersistenceResult};
use crate::formats::{CHECKPOINT_MAGIC, FORMAT_VERSION};
use crate::storage::{self, Directory};
use std::io::{Read, Write};
use std::sync::Arc;
pub const MAX_CHECKPOINT_PAYLOAD_BYTES: usize = 256 * 1024 * 1024;
#[doc(hidden)]
#[derive(Debug, Clone, Copy)]
pub struct CheckpointHeader {
magic: [u8; 4],
version: u32,
pub last_applied_id: u64,
pub payload_len: u64,
pub checksum: u32,
}
impl CheckpointHeader {
pub const SIZE: usize = 4 + 4 + 8 + 8 + 4;
pub(crate) fn new(last_applied_id: u64, payload_len: u64, checksum: u32) -> Self {
Self {
magic: CHECKPOINT_MAGIC,
version: FORMAT_VERSION,
last_applied_id,
payload_len,
checksum,
}
}
pub fn write<W: Write>(&self, w: &mut W) -> PersistenceResult<()> {
w.write_all(&self.magic)?;
w.write_all(&self.version.to_le_bytes())?;
w.write_all(&self.last_applied_id.to_le_bytes())?;
w.write_all(&self.payload_len.to_le_bytes())?;
w.write_all(&self.checksum.to_le_bytes())?;
Ok(())
}
pub fn read<R: Read + ?Sized>(r: &mut R) -> PersistenceResult<Self> {
let mut magic = [0u8; 4];
r.read_exact(&mut magic)?;
if magic != CHECKPOINT_MAGIC {
return Err(PersistenceError::Format("invalid checkpoint magic".into()));
}
let mut buf4 = [0u8; 4];
let mut buf8 = [0u8; 8];
r.read_exact(&mut buf4)?;
let version = u32::from_le_bytes(buf4);
if version != FORMAT_VERSION {
return Err(PersistenceError::Format(
"checkpoint version mismatch".into(),
));
}
r.read_exact(&mut buf8)?;
let last_applied_id = u64::from_le_bytes(buf8);
r.read_exact(&mut buf8)?;
let payload_len = u64::from_le_bytes(buf8);
r.read_exact(&mut buf4)?;
let checksum = u32::from_le_bytes(buf4);
Ok(Self {
magic,
version,
last_applied_id,
payload_len,
checksum,
})
}
}
pub struct CheckpointFile {
dir: Arc<dyn Directory>,
}
impl CheckpointFile {
pub fn new(dir: impl Into<Arc<dyn Directory>>) -> Self {
Self { dir: dir.into() }
}
pub fn write_postcard<T: serde::Serialize>(
&self,
path: &str,
last_applied_id: u64,
value: &T,
) -> PersistenceResult<()> {
let payload =
postcard::to_allocvec(value).map_err(|e| PersistenceError::Encode(e.to_string()))?;
if payload.len() > MAX_CHECKPOINT_PAYLOAD_BYTES {
return Err(PersistenceError::Format(format!(
"checkpoint payload too large: {} bytes (max {})",
payload.len(),
MAX_CHECKPOINT_PAYLOAD_BYTES
)));
}
let checksum = crc32fast::hash(&payload);
let h = CheckpointHeader::new(last_applied_id, payload.len() as u64, checksum);
let mut buf = Vec::with_capacity(CheckpointHeader::SIZE + payload.len());
h.write(&mut buf)?;
buf.extend_from_slice(&payload);
self.dir.atomic_write(path, &buf)?;
Ok(())
}
pub fn write_postcard_durable<T: serde::Serialize>(
&self,
path: &str,
last_applied_id: u64,
value: &T,
) -> PersistenceResult<()> {
if self.dir.file_path(path).is_none() {
return Err(PersistenceError::NotSupported(
"write_postcard_durable requires Directory::file_path()".into(),
));
}
self.write_postcard(path, last_applied_id, value)?;
storage::sync_file(&*self.dir, path)?;
storage::sync_parent_dir(&*self.dir, path)?;
Ok(())
}
pub fn read_postcard<T: serde::de::DeserializeOwned>(
&self,
path: &str,
) -> PersistenceResult<(u64, T)> {
let mut f = self.dir.open_file(path)?;
let h = CheckpointHeader::read(&mut *f)?;
let len = usize::try_from(h.payload_len)
.map_err(|_| PersistenceError::Format("payload_len overflow".into()))?;
if len > MAX_CHECKPOINT_PAYLOAD_BYTES {
return Err(PersistenceError::Format(format!(
"checkpoint payload too large: {} bytes (max {})",
len, MAX_CHECKPOINT_PAYLOAD_BYTES
)));
}
let mut payload = vec![0u8; len];
f.read_exact(&mut payload)?;
let got = crc32fast::hash(&payload);
if got != h.checksum {
return Err(PersistenceError::CrcMismatch {
expected: h.checksum,
actual: got,
});
}
let val: T =
postcard::from_bytes(&payload).map_err(|e| PersistenceError::Decode(e.to_string()))?;
Ok((h.last_applied_id, val))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::{FsDirectory, MemoryDirectory};
#[test]
fn checkpoint_roundtrip_postcard() {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct S {
n: u64,
city: String,
}
let dir: Arc<dyn Directory> = Arc::new(MemoryDirectory::new());
let ckpt = CheckpointFile::new(dir.clone());
ckpt.write_postcard(
"c.bin",
42,
&S {
n: 7,
city: "東京".into(),
},
)
.unwrap();
let (last_id, out): (u64, S) = ckpt.read_postcard("c.bin").unwrap();
assert_eq!(last_id, 42);
assert_eq!(out.n, 7);
assert_eq!(out.city, "東京");
}
#[test]
fn durable_checkpoint_requires_fs_backend() {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct S {
n: u64,
}
let mem: Arc<dyn Directory> = Arc::new(MemoryDirectory::new());
let ckpt = CheckpointFile::new(mem.clone());
let err = ckpt
.write_postcard_durable("c.bin", 1, &S { n: 7 })
.unwrap_err();
assert!(matches!(err, PersistenceError::NotSupported(_)));
assert!(!mem.exists("c.bin"));
}
#[test]
fn durable_checkpoint_roundtrip_fs() {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct S {
city: String,
}
let tmp = tempfile::tempdir().unwrap();
let fs: Arc<dyn Directory> = Arc::new(FsDirectory::new(tmp.path()).unwrap());
let ckpt = CheckpointFile::new(fs.clone());
ckpt.write_postcard_durable(
"checkpoints/c1.chk",
7,
&S {
city: "東京".into(),
},
)
.unwrap();
let (last, out): (u64, S) = ckpt.read_postcard("checkpoints/c1.chk").unwrap();
assert_eq!(last, 7);
assert_eq!(out.city, "東京");
}
}