1use crate::warm_start::key::Fingerprint;
7use crate::warm_start::store::{EntryKind, WarmStartEntry, WarmStartStore};
8use std::sync::Mutex;
9use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
10
11const MIN_CHECKPOINT_INTERVAL: Duration = Duration::from_secs(2);
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum LoadSource {
19 Exact,
20 Preloaded,
21}
22
23#[derive(Debug, Clone)]
24pub struct LoadedEntry {
25 pub entry: WarmStartEntry,
26 pub source: LoadSource,
27}
28
29#[derive(Debug)]
30pub struct Session {
31 store: WarmStartStore,
32 key: Fingerprint,
33 run_id: String,
34 inner: Mutex<Inner>,
35 preloaded: Mutex<Option<WarmStartEntry>>,
43}
44
45#[derive(Debug)]
46struct Inner {
47 last_write: Option<Instant>,
48 best_seen: Option<f64>,
49}
50
51impl Session {
52 pub fn open(store: WarmStartStore, key: Fingerprint) -> Self {
53 let nanos = SystemTime::now()
54 .duration_since(UNIX_EPOCH)
55 .map(|d| d.as_nanos())
56 .unwrap_or(0);
57 let pid = std::process::id();
58 let run_id = format!("ckpt-r{pid:x}-{nanos:x}");
59 Self {
60 store,
61 key,
62 run_id,
63 inner: Mutex::new(Inner {
64 last_write: None,
65 best_seen: None,
66 }),
67 preloaded: Mutex::new(None),
68 }
69 }
70
71 pub fn preload(&self, entry: WarmStartEntry) {
80 let mut slot = match self.preloaded.lock() {
81 Ok(g) => g,
82 Err(p) => p.into_inner(),
83 };
84 *slot = Some(entry);
85 }
86
87 pub fn key(&self) -> &Fingerprint {
88 &self.key
89 }
90
91 pub fn run_id(&self) -> &str {
92 &self.run_id
93 }
94
95 pub fn store(&self) -> &WarmStartStore {
96 &self.store
97 }
98
99 pub fn try_load(&self) -> Option<WarmStartEntry> {
109 self.try_load_with_source().map(|loaded| loaded.entry)
110 }
111
112 pub fn try_load_with_source(&self) -> Option<LoadedEntry> {
119 if let Ok(mut slot) = self.preloaded.lock()
120 && let Some(entry) = slot.take()
121 {
122 return Some(LoadedEntry {
123 entry,
124 source: LoadSource::Preloaded,
125 });
126 }
127 self.store
128 .lookup(&self.key)
129 .ok()
130 .flatten()
131 .map(|entry| LoadedEntry {
132 entry,
133 source: LoadSource::Exact,
134 })
135 }
136
137 pub fn peek_load(&self) -> Option<WarmStartEntry> {
145 self.peek_load_with_source().map(|loaded| loaded.entry)
146 }
147
148 pub fn peek_load_with_source(&self) -> Option<LoadedEntry> {
151 if let Ok(slot) = self.preloaded.lock()
152 && let Some(entry) = slot.as_ref()
153 {
154 return Some(LoadedEntry {
155 entry: entry.clone(),
156 source: LoadSource::Preloaded,
157 });
158 }
159 self.store
160 .lookup(&self.key)
161 .ok()
162 .flatten()
163 .map(|entry| LoadedEntry {
164 entry,
165 source: LoadSource::Exact,
166 })
167 }
168
169 pub fn checkpoint(
173 &self,
174 payload: &[u8],
175 objective: Option<f64>,
176 iteration: Option<u64>,
177 ) -> bool {
178 let now = Instant::now();
179 let mut guard = match self.inner.lock() {
180 Ok(g) => g,
181 Err(p) => p.into_inner(),
182 };
183 let improves = match (objective, guard.best_seen) {
184 (Some(o), Some(b)) => o < b - 1e-12,
185 (Some(_), None) => true,
186 _ => false,
187 };
188 if !improves
189 && let Some(last) = guard.last_write
190 && now.duration_since(last) < MIN_CHECKPOINT_INTERVAL
191 {
192 return false;
193 }
194 match self.store.save_overwrite(
195 &self.key,
196 &self.run_id,
197 payload,
198 objective,
199 iteration,
200 EntryKind::Checkpoint,
201 ) {
202 Ok(()) => {
203 guard.last_write = Some(now);
204 if let Some(o) = objective {
205 guard.best_seen = Some(match guard.best_seen {
206 Some(b) => b.min(o),
207 None => o,
208 });
209 }
210 true
211 }
212 Err(_) => false,
213 }
214 }
215
216 pub fn finalize(&self, payload: &[u8], objective: Option<f64>, iteration: Option<u64>) -> bool {
219 self.store
220 .save_overwrite(
221 &self.key,
222 &self.run_id,
223 payload,
224 objective,
225 iteration,
226 EntryKind::Final,
227 )
228 .is_ok()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::warm_start::key::Fingerprinter;
236 use crate::warm_start::store::StoreOptions;
237
238 fn temp_session(label: &str) -> (tempfile::TempDir, Session) {
239 let dir = tempfile::tempdir().unwrap();
240 let store = WarmStartStore::open(
241 dir.path().to_path_buf(),
242 StoreOptions {
243 size_budget_bytes: 1024 * 1024,
244 ttl: Duration::from_secs(60),
245 },
246 )
247 .unwrap();
248 let mut fp = Fingerprinter::new();
249 fp.absorb_str(b"label", label);
250 let key = fp.finalize();
251 let s = Session::open(store, key);
252 (dir, s)
253 }
254
255 #[test]
256 fn checkpoint_then_load() {
257 let (_d, s) = temp_session("ckpt");
258 assert!(s.checkpoint(b"iter-1", Some(2.0), Some(1)));
259 let got = s.try_load().unwrap();
260 assert_eq!(got.payload, b"iter-1");
261 assert_eq!(got.objective, Some(2.0));
262 assert_eq!(got.kind, EntryKind::Checkpoint);
263 }
264
265 #[test]
266 fn improving_objective_bypasses_rate_limit() {
267 let (_d, s) = temp_session("improve");
268 assert!(s.checkpoint(b"a", Some(5.0), Some(1)));
269 assert!(s.checkpoint(b"b", Some(3.0), Some(2)));
272 let got = s.try_load().unwrap();
273 assert_eq!(got.payload, b"b");
274 assert_eq!(got.objective, Some(3.0));
275 }
276
277 #[test]
278 fn non_improving_writes_are_throttled() {
279 let (_d, s) = temp_session("throttle");
280 assert!(s.checkpoint(b"a", Some(2.0), Some(1)));
281 assert!(!s.checkpoint(b"b", Some(5.0), Some(2)));
283 let got = s.try_load().unwrap();
285 assert_eq!(got.payload, b"a");
286 }
287
288 #[test]
289 fn finalize_promotes_to_final_kind() {
290 let (_d, s) = temp_session("final");
291 s.checkpoint(b"ckpt", Some(2.0), Some(1));
292 s.finalize(b"done", Some(1.0), Some(5));
293 let got = s.try_load().unwrap();
294 assert_eq!(got.payload, b"done");
295 assert_eq!(got.kind, EntryKind::Final);
296 }
297
298 #[test]
299 fn preload_takes_precedence_over_store_lookup() {
300 let (_d, s) = temp_session("preload-empty");
304 assert!(s.try_load().is_none(), "fresh key should have no entry");
305
306 let seeded = WarmStartEntry {
307 payload: b"from-prefix".to_vec(),
308 objective: Some(7.0),
309 iteration: Some(42),
310 kind: EntryKind::Final,
311 written_unix_secs: 0,
312 };
313 s.preload(seeded);
314
315 let got = s.try_load().expect("preloaded seed should be returned");
316 assert_eq!(got.payload, b"from-prefix");
317 assert_eq!(got.objective, Some(7.0));
318 }
319
320 #[test]
321 fn preload_consumed_on_first_try_load() {
322 let (_d, s) = temp_session("preload-consume");
327 s.checkpoint(b"exact", Some(2.0), Some(5));
328
329 let seeded = WarmStartEntry {
330 payload: b"seed".to_vec(),
331 objective: Some(99.0),
332 iteration: Some(1),
333 kind: EntryKind::Checkpoint,
334 written_unix_secs: 0,
335 };
336 s.preload(seeded);
337
338 let first = s.try_load().expect("first call should return seed");
340 assert_eq!(first.payload, b"seed");
341
342 let second = s.try_load().expect("second call should read from store");
344 assert_eq!(second.payload, b"exact");
345 }
346
347 #[test]
348 fn peek_load_does_not_consume_preloaded_seed() {
349 let (_d, s) = temp_session("preload-peek");
350 let seeded = WarmStartEntry {
351 payload: b"seed".to_vec(),
352 objective: Some(3.0),
353 iteration: Some(9),
354 kind: EntryKind::Final,
355 written_unix_secs: 0,
356 };
357 s.preload(seeded);
358
359 let peeked = s
360 .peek_load_with_source()
361 .expect("peek should see preloaded seed");
362 assert_eq!(peeked.entry.payload, b"seed");
363 assert_eq!(peeked.source, LoadSource::Preloaded);
364
365 let loaded = s
366 .try_load()
367 .expect("try_load should still receive the preloaded seed");
368 assert_eq!(loaded.payload, b"seed");
369 assert!(
370 s.try_load().is_none(),
371 "preloaded seed should be consumed only by try_load"
372 );
373 }
374
375 #[test]
376 fn second_session_reads_first_session_checkpoint() {
377 let dir = tempfile::tempdir().unwrap();
378 let mut fp = Fingerprinter::new();
379 fp.absorb_str(b"k", "shared");
380 let key = fp.finalize();
381
382 let store_a = WarmStartStore::open(
383 dir.path().to_path_buf(),
384 StoreOptions {
385 size_budget_bytes: 1024 * 1024,
386 ttl: Duration::from_secs(60),
387 },
388 )
389 .unwrap();
390 let s_a = Session::open(store_a, key);
391 s_a.checkpoint(b"from-a", Some(1.0), Some(3));
392
393 let store_b = WarmStartStore::open(
395 dir.path().to_path_buf(),
396 StoreOptions {
397 size_budget_bytes: 1024 * 1024,
398 ttl: Duration::from_secs(60),
399 },
400 )
401 .unwrap();
402 let s_b = Session::open(store_b, key);
403 let got = s_b.try_load().unwrap();
404 assert_eq!(got.payload, b"from-a");
405 assert_eq!(got.objective, Some(1.0));
406 }
407}