use crate::cache::key::Fingerprint;
use crate::cache::store::{CachedEntry, EntryKind, WarmStartStore};
use std::sync::Mutex;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
const MIN_CHECKPOINT_INTERVAL: Duration = Duration::from_secs(2);
#[derive(Debug)]
pub struct Session {
store: WarmStartStore,
key: Fingerprint,
run_id: String,
inner: Mutex<Inner>,
preloaded: Mutex<Option<CachedEntry>>,
}
#[derive(Debug)]
struct Inner {
last_write: Option<Instant>,
best_seen: Option<f64>,
}
impl Session {
pub fn open(store: WarmStartStore, key: Fingerprint) -> Self {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let pid = std::process::id();
let run_id = format!("ckpt-r{pid:x}-{nanos:x}");
Self {
store,
key,
run_id,
inner: Mutex::new(Inner {
last_write: None,
best_seen: None,
}),
preloaded: Mutex::new(None),
}
}
pub fn preload(&self, entry: CachedEntry) {
let mut slot = match self.preloaded.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
*slot = Some(entry);
}
pub fn key(&self) -> &Fingerprint {
&self.key
}
pub fn run_id(&self) -> &str {
&self.run_id
}
pub fn store(&self) -> &WarmStartStore {
&self.store
}
pub fn try_load(&self) -> Option<CachedEntry> {
if let Ok(mut slot) = self.preloaded.lock()
&& let Some(entry) = slot.take()
{
return Some(entry);
}
self.store.lookup(&self.key).ok().flatten()
}
pub fn checkpoint(
&self,
payload: &[u8],
objective: Option<f64>,
iteration: Option<u64>,
) -> bool {
let now = Instant::now();
let mut guard = match self.inner.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
let improves = match (objective, guard.best_seen) {
(Some(o), Some(b)) => o < b - 1e-12,
(Some(_), None) => true,
_ => false,
};
if !improves {
if let Some(last) = guard.last_write {
if now.duration_since(last) < MIN_CHECKPOINT_INTERVAL {
return false;
}
}
}
match self.store.save_overwrite(
&self.key,
&self.run_id,
payload,
objective,
iteration,
EntryKind::Checkpoint,
) {
Ok(()) => {
guard.last_write = Some(now);
if let Some(o) = objective {
guard.best_seen = Some(match guard.best_seen {
Some(b) => b.min(o),
None => o,
});
}
true
}
Err(_) => false,
}
}
pub fn finalize(&self, payload: &[u8], objective: Option<f64>, iteration: Option<u64>) -> bool {
self.store
.save_overwrite(
&self.key,
&self.run_id,
payload,
objective,
iteration,
EntryKind::Final,
)
.is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::key::Fingerprinter;
use crate::cache::store::StoreOptions;
fn temp_session(label: &str) -> (tempfile::TempDir, Session) {
let dir = tempfile::tempdir().unwrap();
let store = WarmStartStore::open(
dir.path().to_path_buf(),
StoreOptions {
size_budget_bytes: 1024 * 1024,
ttl: Duration::from_secs(60),
},
)
.unwrap();
let mut fp = Fingerprinter::new();
fp.absorb_str(b"label", label);
let key = fp.finalize();
let s = Session::open(store, key);
(dir, s)
}
#[test]
fn checkpoint_then_load() {
let (_d, s) = temp_session("ckpt");
assert!(s.checkpoint(b"iter-1", Some(2.0), Some(1)));
let got = s.try_load().unwrap();
assert_eq!(got.payload, b"iter-1");
assert_eq!(got.objective, Some(2.0));
assert_eq!(got.kind, EntryKind::Checkpoint);
}
#[test]
fn improving_objective_bypasses_rate_limit() {
let (_d, s) = temp_session("improve");
assert!(s.checkpoint(b"a", Some(5.0), Some(1)));
assert!(s.checkpoint(b"b", Some(3.0), Some(2)));
let got = s.try_load().unwrap();
assert_eq!(got.payload, b"b");
assert_eq!(got.objective, Some(3.0));
}
#[test]
fn non_improving_writes_are_throttled() {
let (_d, s) = temp_session("throttle");
assert!(s.checkpoint(b"a", Some(2.0), Some(1)));
assert!(!s.checkpoint(b"b", Some(5.0), Some(2)));
let got = s.try_load().unwrap();
assert_eq!(got.payload, b"a");
}
#[test]
fn finalize_promotes_to_final_kind() {
let (_d, s) = temp_session("final");
s.checkpoint(b"ckpt", Some(2.0), Some(1));
s.finalize(b"done", Some(1.0), Some(5));
let got = s.try_load().unwrap();
assert_eq!(got.payload, b"done");
assert_eq!(got.kind, EntryKind::Final);
}
#[test]
fn preload_takes_precedence_over_store_lookup() {
let (_d, s) = temp_session("preload-empty");
assert!(s.try_load().is_none(), "fresh key should have no entry");
let seeded = CachedEntry {
payload: b"from-prefix".to_vec(),
objective: Some(7.0),
iteration: Some(42),
kind: EntryKind::Final,
written_unix_secs: 0,
};
s.preload(seeded);
let got = s.try_load().expect("preloaded seed should be returned");
assert_eq!(got.payload, b"from-prefix");
assert_eq!(got.objective, Some(7.0));
}
#[test]
fn preload_consumed_on_first_try_load() {
let (_d, s) = temp_session("preload-consume");
s.checkpoint(b"exact", Some(2.0), Some(5));
let seeded = CachedEntry {
payload: b"seed".to_vec(),
objective: Some(99.0),
iteration: Some(1),
kind: EntryKind::Checkpoint,
written_unix_secs: 0,
};
s.preload(seeded);
let first = s.try_load().expect("first call should return seed");
assert_eq!(first.payload, b"seed");
let second = s.try_load().expect("second call should fall back to store");
assert_eq!(second.payload, b"exact");
}
#[test]
fn second_session_reads_first_session_checkpoint() {
let dir = tempfile::tempdir().unwrap();
let mut fp = Fingerprinter::new();
fp.absorb_str(b"k", "shared");
let key = fp.finalize();
let store_a = WarmStartStore::open(
dir.path().to_path_buf(),
StoreOptions {
size_budget_bytes: 1024 * 1024,
ttl: Duration::from_secs(60),
},
)
.unwrap();
let s_a = Session::open(store_a, key);
s_a.checkpoint(b"from-a", Some(1.0), Some(3));
let store_b = WarmStartStore::open(
dir.path().to_path_buf(),
StoreOptions {
size_budget_bytes: 1024 * 1024,
ttl: Duration::from_secs(60),
},
)
.unwrap();
let s_b = Session::open(store_b, key);
let got = s_b.try_load().unwrap();
assert_eq!(got.payload, b"from-a");
assert_eq!(got.objective, Some(1.0));
}
}