use crate::cache::key::Fingerprint;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::fs;
use std::io::{self, Write as _};
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub(crate) const SCHEMA_VERSION: u32 = 1;
pub(crate) const DEFAULT_SIZE_BUDGET_BYTES: u64 = 1024 * 1024 * 1024;
pub(crate) const DEFAULT_TTL_SECS: u64 = 60 * 60 * 24 * 30;
#[derive(Debug, thiserror::Error)]
pub enum StoreError {
#[error("io: {0}")]
Io(#[from] io::Error),
#[error("json: {0}")]
Json(#[from] serde_json::Error),
}
#[derive(Debug, Clone)]
pub struct CachedEntry {
pub payload: Vec<u8>,
pub objective: Option<f64>,
pub iteration: Option<u64>,
pub written_unix_secs: u64,
pub kind: EntryKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum EntryKind {
Checkpoint,
Final,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OnDiskMeta {
schema_version: u32,
written_unix_secs: u64,
#[serde(default)]
written_nanos: u32,
objective: Option<f64>,
iteration: Option<u64>,
kind: EntryKind,
checksum_hex: String,
payload_bytes: u64,
}
#[derive(Debug, Clone)]
pub struct StoreOptions {
pub size_budget_bytes: u64,
pub ttl: Duration,
}
impl Default for StoreOptions {
fn default() -> Self {
Self {
size_budget_bytes: DEFAULT_SIZE_BUDGET_BYTES,
ttl: Duration::from_secs(DEFAULT_TTL_SECS),
}
}
}
#[derive(Debug, Clone)]
pub struct WarmStartStore {
root: PathBuf,
opts: StoreOptions,
}
impl WarmStartStore {
pub fn open(root: PathBuf, opts: StoreOptions) -> Result<Self, StoreError> {
fs::create_dir_all(&root)?;
Ok(Self { root, opts })
}
pub fn open_default() -> Option<Self> {
let base = dirs::cache_dir()?;
let root = base.join("gam").join("warm").join("v1");
Self::open(root, StoreOptions::default()).ok()
}
pub fn root(&self) -> &Path {
&self.root
}
pub fn options(&self) -> &StoreOptions {
&self.opts
}
fn key_dir(&self, key: &Fingerprint) -> PathBuf {
self.root.join(key.to_hex())
}
pub fn lookup(&self, key: &Fingerprint) -> Result<Option<CachedEntry>, StoreError> {
let dir = self.key_dir(key);
if !dir.exists() {
return Ok(None);
}
let mut best: Option<(OnDiskMeta, PathBuf)> = None;
for entry in fs::read_dir(&dir)? {
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("json") {
continue;
}
if path
.file_name()
.and_then(|s| s.to_str())
.map_or(false, |n| n.contains(".tmp."))
{
continue;
}
let meta = match read_meta(&path) {
Ok(m) => m,
Err(_) => {
let _ = fs::remove_file(&path);
continue;
}
};
if meta.schema_version != SCHEMA_VERSION {
continue;
}
let bin = path.with_extension("bin");
if !bin.exists() {
let _ = fs::remove_file(&path);
continue;
}
best = match best {
None => Some((meta, path)),
Some((cur, cur_path)) => {
if entry_better(&meta, &cur) {
Some((meta, path))
} else {
Some((cur, cur_path))
}
}
};
}
let (meta, meta_path) = match best {
Some(b) => b,
None => return Ok(None),
};
let bin_path = meta_path.with_extension("bin");
let payload = match fs::read(&bin_path) {
Ok(v) => v,
Err(_) => return Ok(None),
};
if checksum_hex(&payload) != meta.checksum_hex {
let _ = fs::remove_file(&meta_path);
let _ = fs::remove_file(&bin_path);
return Ok(None);
}
Ok(Some(CachedEntry {
payload,
objective: meta.objective,
iteration: meta.iteration,
written_unix_secs: meta.written_unix_secs,
kind: meta.kind,
}))
}
pub fn save(
&self,
key: &Fingerprint,
payload: &[u8],
objective: Option<f64>,
iteration: Option<u64>,
kind: EntryKind,
) -> Result<String, StoreError> {
let run_id = fresh_run_id();
self.save_overwrite(key, &run_id, payload, objective, iteration, kind)?;
Ok(run_id)
}
pub fn save_overwrite(
&self,
key: &Fingerprint,
run_id: &str,
payload: &[u8],
objective: Option<f64>,
iteration: Option<u64>,
kind: EntryKind,
) -> Result<(), StoreError> {
let dir = self.key_dir(key);
fs::create_dir_all(&dir)?;
let pid = std::process::id();
let nonce = nanos_now();
let bin_tmp = dir.join(format!("{run_id}.bin.tmp.{pid}.{nonce}"));
let meta_tmp = dir.join(format!("{run_id}.json.tmp.{pid}.{nonce}"));
let bin_final = dir.join(format!("{run_id}.bin"));
let meta_final = dir.join(format!("{run_id}.json"));
{
let mut f = fs::File::create(&bin_tmp)?;
f.write_all(payload)?;
let _ = f.sync_all();
}
let checksum = checksum_hex(payload);
let (secs, subsec_nanos) = unix_now_parts();
let meta = OnDiskMeta {
schema_version: SCHEMA_VERSION,
written_unix_secs: secs,
written_nanos: subsec_nanos,
objective,
iteration,
kind,
checksum_hex: checksum,
payload_bytes: payload.len() as u64,
};
{
let json = serde_json::to_vec_pretty(&meta)?;
let mut f = fs::File::create(&meta_tmp)?;
f.write_all(&json)?;
let _ = f.sync_all();
}
let bin_rename = fs::rename(&bin_tmp, &bin_final);
if let Err(e) = bin_rename {
let _ = fs::remove_file(&bin_tmp);
let _ = fs::remove_file(&meta_tmp);
return Err(StoreError::Io(e));
}
if let Err(e) = fs::rename(&meta_tmp, &meta_final) {
let _ = fs::remove_file(&bin_final);
let _ = fs::remove_file(&meta_tmp);
return Err(StoreError::Io(e));
}
let _ = self.evict_overflow();
Ok(())
}
pub fn evict_overflow(&self) -> Result<(), StoreError> {
let read_dir = match fs::read_dir(&self.root) {
Ok(rd) => rd,
Err(_) => return Ok(()),
};
let mut all: Vec<(PathBuf, PathBuf, u64, u128)> = Vec::new();
let now_nanos = nanos_now();
let ttl_nanos = self.opts.ttl.as_nanos();
for key_dir_entry in read_dir {
let key_dir = match key_dir_entry {
Ok(e) => e.path(),
Err(_) => continue,
};
if !key_dir.is_dir() {
continue;
}
let inner = match fs::read_dir(&key_dir) {
Ok(rd) => rd,
Err(_) => continue,
};
for f in inner {
let p = match f {
Ok(e) => e.path(),
Err(_) => continue,
};
let name = match p.file_name().and_then(|s| s.to_str()) {
Some(s) => s.to_string(),
None => continue,
};
if name.contains(".tmp.") {
if let Some(pid) = parse_tmp_pid(&name) {
if pid != std::process::id() {
let _ = fs::remove_file(&p);
}
}
continue;
}
if p.extension().and_then(|s| s.to_str()) != Some("json") {
continue;
}
let meta_md = match fs::metadata(&p) {
Ok(m) => m,
Err(_) => continue,
};
let bin = p.with_extension("bin");
let bin_md = match fs::metadata(&bin) {
Ok(m) => m,
Err(_) => {
let _ = fs::remove_file(&p);
continue;
}
};
let meta = match read_meta(&p) {
Ok(m) => m,
Err(_) => {
let _ = fs::remove_file(&p);
let _ = fs::remove_file(&bin);
continue;
}
};
let write_nanos = (meta.written_unix_secs as u128) * 1_000_000_000u128
+ meta.written_nanos as u128;
if ttl_nanos > 0 && now_nanos.saturating_sub(write_nanos) >= ttl_nanos {
let _ = fs::remove_file(&p);
let _ = fs::remove_file(&bin);
continue;
}
let total_bytes = meta_md.len() + bin_md.len();
all.push((p, bin, total_bytes, write_nanos));
}
if fs::read_dir(&key_dir)
.map(|mut it| it.next().is_none())
.unwrap_or(false)
{
let _ = fs::remove_dir(&key_dir);
}
}
let total: u64 = all.iter().map(|e| e.2).sum();
if total <= self.opts.size_budget_bytes {
return Ok(());
}
all.sort_by_key(|e| e.3);
let mut remaining = total;
for (meta, bin, bytes, _) in all.into_iter() {
if remaining <= self.opts.size_budget_bytes {
break;
}
let _ = fs::remove_file(&meta);
let _ = fs::remove_file(&bin);
remaining = remaining.saturating_sub(bytes);
}
Ok(())
}
}
fn parse_tmp_pid(name: &str) -> Option<u32> {
let tail = name.split(".tmp.").nth(1)?;
let pid_str = tail.split('.').next()?;
pid_str.parse::<u32>().ok()
}
fn read_meta(path: &Path) -> Result<OnDiskMeta, StoreError> {
let bytes = fs::read(path)?;
let parsed: OnDiskMeta = serde_json::from_slice(&bytes)?;
Ok(parsed)
}
fn entry_better(candidate: &OnDiskMeta, current: &OnDiskMeta) -> bool {
match (candidate.objective, current.objective) {
(Some(c), Some(d)) => {
if (c - d).abs() < 1e-12 {
match (candidate.kind, current.kind) {
(EntryKind::Final, EntryKind::Checkpoint) => true,
(EntryKind::Checkpoint, EntryKind::Final) => false,
_ => candidate.written_unix_secs > current.written_unix_secs,
}
} else {
c < d
}
}
(Some(_), None) => true,
(None, Some(_)) => false,
(None, None) => candidate.written_unix_secs > current.written_unix_secs,
}
}
fn checksum_hex(payload: &[u8]) -> String {
let mut h = Sha256::new();
h.update(payload);
let out = h.finalize();
let mut s = String::with_capacity(out.len() * 2);
for b in out.iter() {
use std::fmt::Write;
let _ = write!(&mut s, "{:02x}", b);
}
s
}
fn unix_now_parts() -> (u64, u32) {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| (d.as_secs(), d.subsec_nanos()))
.unwrap_or((0, 0))
}
fn nanos_now() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
}
fn fresh_run_id() -> String {
let pid = std::process::id();
let nanos = nanos_now();
format!("r{pid:x}-{nanos:x}")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::key::Fingerprinter;
use std::thread;
fn temp_store() -> (tempfile::TempDir, WarmStartStore) {
let dir = tempfile::tempdir().unwrap();
let store = WarmStartStore::open(
dir.path().to_path_buf(),
StoreOptions {
size_budget_bytes: 1024 * 1024,
ttl: Duration::from_secs(60),
},
)
.unwrap();
(dir, store)
}
fn key_for(s: &str) -> Fingerprint {
let mut fp = Fingerprinter::new();
fp.absorb_str(b"test", s);
fp.finalize()
}
#[test]
fn roundtrip_save_then_lookup() {
let (_d, store) = temp_store();
let key = key_for("roundtrip");
let _id = store
.save(
&key,
b"hello-warm",
Some(1.5),
Some(7),
EntryKind::Checkpoint,
)
.unwrap();
let got = store.lookup(&key).unwrap().unwrap();
assert_eq!(got.payload, b"hello-warm");
assert_eq!(got.objective, Some(1.5));
assert_eq!(got.iteration, Some(7));
assert_eq!(got.kind, EntryKind::Checkpoint);
}
#[test]
fn lookup_picks_lowest_objective() {
let (_d, store) = temp_store();
let key = key_for("multi");
store
.save(&key, b"worse", Some(3.0), Some(1), EntryKind::Checkpoint)
.unwrap();
store
.save(&key, b"better", Some(1.0), Some(2), EntryKind::Checkpoint)
.unwrap();
store
.save(&key, b"mid", Some(2.0), Some(3), EntryKind::Checkpoint)
.unwrap();
let got = store.lookup(&key).unwrap().unwrap();
assert_eq!(got.payload, b"better");
assert_eq!(got.objective, Some(1.0));
}
#[test]
fn tiebreak_final_beats_checkpoint() {
let (_d, store) = temp_store();
let key = key_for("tie");
store
.save(&key, b"ckpt", Some(1.0), None, EntryKind::Checkpoint)
.unwrap();
store
.save(&key, b"final", Some(1.0), None, EntryKind::Final)
.unwrap();
let got = store.lookup(&key).unwrap().unwrap();
assert_eq!(got.payload, b"final");
assert_eq!(got.kind, EntryKind::Final);
}
#[test]
fn tiebreak_latest_mtime_when_no_objective() {
let (_d, store) = temp_store();
let key = key_for("latest");
store
.save(&key, b"first", None, None, EntryKind::Checkpoint)
.unwrap();
thread::sleep(Duration::from_millis(1100));
store
.save(&key, b"second", None, None, EntryKind::Checkpoint)
.unwrap();
let got = store.lookup(&key).unwrap().unwrap();
assert_eq!(got.payload, b"second");
}
#[test]
fn corrupt_payload_is_cleaned_up() {
let (_d, store) = temp_store();
let key = key_for("corrupt");
store
.save(&key, b"original", Some(0.0), None, EntryKind::Checkpoint)
.unwrap();
let dir = store.key_dir(&key);
for entry in fs::read_dir(&dir).unwrap() {
let p = entry.unwrap().path();
if p.extension().and_then(|s| s.to_str()) == Some("bin") {
fs::write(&p, b"tampered!").unwrap();
}
}
let got = store.lookup(&key).unwrap();
assert!(got.is_none(), "tampered entry must be rejected");
let remaining: Vec<_> = fs::read_dir(&dir).unwrap().collect();
assert!(remaining.is_empty(), "corrupt entry should be removed");
}
#[test]
fn corrupt_meta_json_is_cleaned_up() {
let (_d, store) = temp_store();
let key = key_for("badjson");
store
.save(&key, b"x", None, None, EntryKind::Checkpoint)
.unwrap();
let dir = store.key_dir(&key);
for entry in fs::read_dir(&dir).unwrap() {
let p = entry.unwrap().path();
if p.extension().and_then(|s| s.to_str()) == Some("json") {
fs::write(&p, b"{not valid json").unwrap();
}
}
let got = store.lookup(&key).unwrap();
assert!(got.is_none());
}
#[test]
fn schema_mismatch_is_ignored() {
let (_d, store) = temp_store();
let key = key_for("schema");
store
.save(&key, b"x", None, None, EntryKind::Checkpoint)
.unwrap();
let dir = store.key_dir(&key);
for entry in fs::read_dir(&dir).unwrap() {
let p = entry.unwrap().path();
if p.extension().and_then(|s| s.to_str()) == Some("json") {
let raw = fs::read(&p).unwrap();
let mut parsed: serde_json::Value = serde_json::from_slice(&raw).unwrap();
parsed["schema_version"] = serde_json::json!(SCHEMA_VERSION + 99);
fs::write(&p, serde_json::to_vec_pretty(&parsed).unwrap()).unwrap();
}
}
assert!(store.lookup(&key).unwrap().is_none());
}
#[test]
fn missing_bin_treated_as_missing() {
let (_d, store) = temp_store();
let key = key_for("nobin");
store
.save(&key, b"x", None, None, EntryKind::Checkpoint)
.unwrap();
let dir = store.key_dir(&key);
for entry in fs::read_dir(&dir).unwrap() {
let p = entry.unwrap().path();
if p.extension().and_then(|s| s.to_str()) == Some("bin") {
fs::remove_file(&p).unwrap();
}
}
assert!(store.lookup(&key).unwrap().is_none());
}
#[test]
fn missing_key_returns_none() {
let (_d, store) = temp_store();
let key = key_for("absent");
assert!(store.lookup(&key).unwrap().is_none());
}
#[test]
fn lru_eviction_under_size_budget() {
let dir = tempfile::tempdir().unwrap();
let store = WarmStartStore::open(
dir.path().to_path_buf(),
StoreOptions {
size_budget_bytes: 4 * 1024,
ttl: Duration::from_secs(3600),
},
)
.unwrap();
let mut keys = Vec::new();
for i in 0..20 {
let mut fp = Fingerprinter::new();
fp.absorb_u64(b"i", i);
let key = fp.finalize();
keys.push(key);
let payload = vec![0u8; 256];
store
.save(&key, &payload, Some(i as f64), None, EntryKind::Checkpoint)
.unwrap();
}
let mut total = 0u64;
for kd in fs::read_dir(store.root()).unwrap() {
let kd = kd.unwrap().path();
if kd.is_dir() {
for f in fs::read_dir(&kd).unwrap() {
total += fs::metadata(f.unwrap().path()).unwrap().len();
}
}
}
assert!(
total <= 8 * 1024,
"eviction failed to bound size (got {total})"
);
assert!(store.lookup(&keys[0]).unwrap().is_none());
assert!(store.lookup(keys.last().unwrap()).unwrap().is_some());
}
#[test]
fn ttl_drops_old_entries() {
let dir = tempfile::tempdir().unwrap();
let store = WarmStartStore::open(
dir.path().to_path_buf(),
StoreOptions {
size_budget_bytes: 1024 * 1024,
ttl: Duration::from_secs(1),
},
)
.unwrap();
let key = key_for("ttl");
store
.save(&key, b"x", None, None, EntryKind::Checkpoint)
.unwrap();
assert!(store.lookup(&key).unwrap().is_some());
thread::sleep(Duration::from_millis(1500));
let other = key_for("ttl-other");
store
.save(&other, b"y", None, None, EntryKind::Checkpoint)
.unwrap();
assert!(store.lookup(&key).unwrap().is_none());
assert!(store.lookup(&other).unwrap().is_some());
}
#[test]
fn orphan_temp_files_from_dead_processes_are_swept() {
let (_d, store) = temp_store();
let key = key_for("tmp");
let dir = store.key_dir(&key);
fs::create_dir_all(&dir).unwrap();
let orphan_other = dir.join("r0-0.json.tmp.1.0");
let mine = dir.join(format!("r0-0.bin.tmp.{}.0", std::process::id()));
fs::write(&orphan_other, b"orphan").unwrap();
fs::write(&mine, b"mine").unwrap();
store.evict_overflow().unwrap();
assert!(!orphan_other.exists(), "other-PID tmp file should be swept");
assert!(mine.exists(), "same-PID tmp file must be left alone");
}
#[test]
fn tmp_filenames_without_pid_are_skipped() {
let (_d, store) = temp_store();
let key = key_for("malformed");
let dir = store.key_dir(&key);
fs::create_dir_all(&dir).unwrap();
let weird = dir.join("garbage.tmp.notapid.suffix");
fs::write(&weird, b"x").unwrap();
store.evict_overflow().unwrap();
assert!(weird.exists());
}
#[test]
fn save_overwrite_keeps_single_entry() {
let (_d, store) = temp_store();
let key = key_for("overwrite");
let id = store
.save(&key, b"v1", Some(2.0), Some(1), EntryKind::Checkpoint)
.unwrap();
store
.save_overwrite(&key, &id, b"v2", Some(1.0), Some(2), EntryKind::Checkpoint)
.unwrap();
let dir = store.key_dir(&key);
let files: Vec<_> = fs::read_dir(&dir).unwrap().collect();
assert_eq!(files.len(), 2, "overwrite should not create a new run-id");
let got = store.lookup(&key).unwrap().unwrap();
assert_eq!(got.payload, b"v2");
assert_eq!(got.objective, Some(1.0));
}
#[test]
fn keys_are_isolated() {
let (_d, store) = temp_store();
let a = key_for("a");
let b = key_for("b");
store
.save(&a, b"AAA", Some(1.0), None, EntryKind::Final)
.unwrap();
store
.save(&b, b"BBB", Some(1.0), None, EntryKind::Final)
.unwrap();
assert_eq!(store.lookup(&a).unwrap().unwrap().payload, b"AAA");
assert_eq!(store.lookup(&b).unwrap().unwrap().payload, b"BBB");
}
}