use std::path::Path;
use talea_core::types::Seq;
use crate::state::BookState;
const CRC_LEN: usize = 4;
fn snap_name(seq: Seq) -> String {
format!("snapshot-{:020}.snap", seq)
}
fn parse_snap_seq(name: &str) -> Option<Seq> {
let s = name.strip_prefix("snapshot-")?.strip_suffix(".snap")?;
s.parse().ok()
}
fn fsync_dir(dir: &Path) -> std::io::Result<()> {
std::fs::File::open(dir)?.sync_all()
}
pub async fn write_snapshot(dir: &Path, state: &BookState, last_seq: Seq) -> std::io::Result<()> {
let payload = serde_json::to_vec(state)
.map_err(|e| std::io::Error::other(format!("snapshot serialize: {e}")))?;
let crc = crc32fast::hash(&payload);
let mut snap_bytes = Vec::with_capacity(CRC_LEN + payload.len());
snap_bytes.extend_from_slice(&crc.to_le_bytes());
snap_bytes.extend_from_slice(&payload);
let final_name = snap_name(last_seq);
let tmp_path = dir.join(format!("{final_name}.tmp"));
let final_path = dir.join(&final_name);
tokio::fs::write(&tmp_path, &snap_bytes).await?;
{
let f = std::fs::OpenOptions::new()
.write(true)
.open(&tmp_path)
.map_err(|e| std::io::Error::other(format!("open tmp for sync: {e}")))?;
f.sync_all()?;
}
tokio::fs::rename(&tmp_path, &final_path).await?;
let dir_owned = dir.to_path_buf();
tokio::task::spawn_blocking(move || fsync_dir(&dir_owned))
.await
.map_err(|e| std::io::Error::other(format!("spawn_blocking join: {e}")))??;
prune_old(dir).await
}
pub async fn load_latest(dir: &Path) -> std::io::Result<Option<(BookState, Seq)>> {
let mut entries: Vec<(Seq, std::path::PathBuf)> = Vec::new();
let mut rd = tokio::fs::read_dir(dir).await?;
while let Some(entry) = rd.next_entry().await? {
let name = entry
.file_name()
.into_string()
.map_err(|_| std::io::Error::other("non-UTF-8 snapshot filename"))?;
if let Some(seq) = parse_snap_seq(&name) {
entries.push((seq, entry.path()));
}
}
entries.sort_by_key(|b| std::cmp::Reverse(b.0));
for (seq, path) in &entries {
match try_load_snapshot(path).await {
Ok(state) => return Ok(Some((state, *seq))),
Err(reason) => {
tracing::warn!(
?path,
%reason,
"skipping invalid snapshot (CRC or parse failure); trying older"
);
}
}
}
Ok(None)
}
async fn try_load_snapshot(path: &Path) -> Result<BookState, String> {
let bytes = tokio::fs::read(path)
.await
.map_err(|e| format!("read: {e}"))?;
let Some((crc_bytes, payload)) = bytes.split_first_chunk::<CRC_LEN>() else {
return Err(format!("file too short: {} bytes", bytes.len()));
};
let stored_crc = u32::from_le_bytes(*crc_bytes);
let actual_crc = crc32fast::hash(payload);
if stored_crc != actual_crc {
return Err(format!(
"CRC mismatch: stored={stored_crc:#010x} actual={actual_crc:#010x}"
));
}
serde_json::from_slice(payload).map_err(|e| format!("JSON parse: {e}"))
}
pub async fn prune_old(dir: &Path) -> std::io::Result<()> {
let mut snaps: Vec<(Seq, std::path::PathBuf)> = Vec::new();
let mut rd = tokio::fs::read_dir(dir).await?;
while let Some(entry) = rd.next_entry().await? {
let name = entry
.file_name()
.into_string()
.map_err(|_| std::io::Error::other("non-UTF-8 snap filename in prune"))?;
if let Some(seq) = parse_snap_seq(&name) {
snaps.push((seq, entry.path()));
}
}
snaps.sort_by_key(|b| std::cmp::Reverse(b.0));
for (_, path) in snaps.into_iter().skip(2) {
if let Err(e) = tokio::fs::remove_file(&path).await {
tracing::warn!(?path, error = %e, "prune_old: failed to remove old snapshot");
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::{AccountState, BookState, CommittedRec, PostingIndex};
use talea_core::store::AccountCfg;
use talea_core::types::*;
fn make_state() -> BookState {
let mut st = BookState::default();
let id = AccountId {
book: Book("b".into()),
path: "cash".into(),
};
let asset = AssetId::new("USD");
st.accounts.insert(
id.to_key(),
AccountState {
def: AccountDef {
id: id.clone(),
asset: asset.clone(),
kind: AccountKind::Asset,
},
cfg: AccountCfg {
normal_side: Some(Direction::Debit),
min_balance: Some(0),
},
raw_balance: 42,
updated_seq: 1,
postings: PostingIndex::default(),
},
);
let txid = uuid::Uuid::now_v7();
let at = talea_core::store::ledger_now();
st.idem.insert(
"idem-key".into(),
CommittedRec {
txid: TxId(txid),
seq: 1,
at,
},
);
st.txids.insert(txid, (1, (1, 0)));
*st.sums.entry(asset).or_insert((0, 0)) = (42, 0);
st.next_seq = 2;
st.last_at = Some(at);
st
}
#[tokio::test]
async fn snapshot_round_trips_book_state() {
let dir = tempfile::tempdir().unwrap();
let st = make_state();
let seq: Seq = 42;
write_snapshot(dir.path(), &st, seq).await.unwrap();
let loaded = load_latest(dir.path()).await.unwrap();
assert!(
loaded.is_some(),
"load_latest must return Some after writing a snapshot"
);
let (got, got_seq) = loaded.unwrap();
assert_eq!(got_seq, seq, "returned seq must match written seq");
assert_eq!(
got.next_seq, st.next_seq,
"next_seq must survive round-trip"
);
assert_eq!(
got.accounts.len(),
st.accounts.len(),
"accounts must survive"
);
let id = AccountId {
book: Book("b".into()),
path: "cash".into(),
};
assert_eq!(
got.accounts[&id.to_key()].raw_balance,
st.accounts[&id.to_key()].raw_balance,
"raw_balance must survive round-trip"
);
assert!(
got.idem.hot.contains_key("idem-key"),
"idem must survive round-trip"
);
let orig_rec = &st.idem.hot["idem-key"];
let got_rec = &got.idem.hot["idem-key"];
assert_eq!(got_rec.seq, orig_rec.seq, "idem record seq must match");
assert_eq!(
got.txids.len(),
st.txids.len(),
"txids must survive round-trip"
);
assert_eq!(got.sums, st.sums, "sums must survive round-trip");
assert!(
!got.writer_attached
.load(std::sync::atomic::Ordering::Acquire),
"deserialized writer_attached must be false (fresh unattached flag)"
);
}
#[tokio::test]
async fn corrupt_snapshot_is_skipped_in_favor_of_older() {
let dir = tempfile::tempdir().unwrap();
let st = make_state();
write_snapshot(dir.path(), &st, 10).await.unwrap();
write_snapshot(dir.path(), &st, 20).await.unwrap();
let snap_20_path = dir.path().join(snap_name(20));
let mut bytes = std::fs::read(&snap_20_path).unwrap();
let mid = bytes.len() / 2;
bytes[mid] ^= 0xFF;
std::fs::write(&snap_20_path, &bytes).unwrap();
let loaded = load_latest(dir.path())
.await
.expect("load_latest must return Ok even when newest snapshot is corrupt");
assert!(
loaded.is_some(),
"load_latest must return Some (the seq-10 snapshot)"
);
let (_, got_seq) = loaded.unwrap();
assert_eq!(
got_seq, 10,
"must fall back to seq-10 snapshot, got seq={got_seq}"
);
}
#[tokio::test]
async fn only_latest_two_snapshots_are_retained() {
let dir = tempfile::tempdir().unwrap();
let st = make_state();
write_snapshot(dir.path(), &st, 10).await.unwrap();
write_snapshot(dir.path(), &st, 20).await.unwrap();
write_snapshot(dir.path(), &st, 30).await.unwrap();
let mut snaps: Vec<Seq> = std::fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| {
let e = e.ok()?;
let name = e.file_name().into_string().ok()?;
parse_snap_seq(&name)
})
.collect();
snaps.sort();
assert_eq!(
snaps,
vec![20, 30],
"only the two newest snapshots must be retained; found seqs: {snaps:?}"
);
}
#[tokio::test]
async fn stray_snap_tmp_is_ignored_by_load_latest() {
let dir = tempfile::tempdir().unwrap();
let st = make_state();
write_snapshot(dir.path(), &st, 5).await.unwrap();
let stray = dir.path().join("snapshot-00000000000000000099.snap.tmp");
std::fs::write(&stray, b"garbage").unwrap();
let loaded = load_latest(dir.path()).await.unwrap();
assert!(
loaded.is_some(),
"load_latest must return Some (seq-5 snapshot)"
);
let (_, got_seq) = loaded.unwrap();
assert_eq!(
got_seq, 5,
"must return the seq-5 snapshot, ignoring the .tmp file"
);
}
#[tokio::test]
async fn all_snapshots_corrupt_returns_ok_none() {
let dir = tempfile::tempdir().unwrap();
let st = make_state();
write_snapshot(dir.path(), &st, 10).await.unwrap();
write_snapshot(dir.path(), &st, 20).await.unwrap();
for seq in [10i64, 20] {
let path = dir.path().join(snap_name(seq));
let mut bytes = std::fs::read(&path).unwrap();
for b in bytes.iter_mut().skip(4) {
*b = 0;
}
std::fs::write(&path, &bytes).unwrap();
}
let result = load_latest(dir.path()).await;
assert!(
result.is_ok(),
"load_latest must return Ok even with all snapshots corrupt"
);
assert!(
result.unwrap().is_none(),
"load_latest must return None when all snapshots are corrupt"
);
}
}