Skip to main content

gam_runtime/warm_start/
session.rs

1//! A `Session` ties a `WarmStartStore` to a specific `Fingerprint` so callers
2//! can resume + checkpoint a single fit without re-passing the key on every
3//! call. One session corresponds to one in-flight fit; periodic checkpoints
4//! overwrite a single run-id slot so we don't accumulate one entry per write.
5
6use crate::warm_start::key::Fingerprint;
7use crate::warm_start::store::{EntryKind, WarmStartEntry, WarmStartStore};
8use std::sync::Mutex;
9use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
10
11/// Minimum gap between checkpoint writes. Auto-derived; never less, so a
12/// tight loop can't thrash disk. Improvements over the best-so-far always
13/// bypass the rate limit — losing the best iterate to a hard crash is the
14/// failure mode this whole module exists to prevent.
15const MIN_CHECKPOINT_INTERVAL: Duration = Duration::from_secs(2);
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum LoadSource {
19    Exact,
20    Preloaded,
21}
22
23#[derive(Debug, Clone)]
24pub struct LoadedEntry {
25    pub entry: WarmStartEntry,
26    pub source: LoadSource,
27}
28
29#[derive(Debug)]
30pub struct Session {
31    store: WarmStartStore,
32    key: Fingerprint,
33    run_id: String,
34    inner: Mutex<Inner>,
35    /// Pre-loaded seed payload from a hierarchical near-match key.
36    ///
37    /// Populated by callers who looked up a related (but not exact-match)
38    /// entry from a different key in the same store. The first call to
39    /// [`Self::try_load`] returns and clears this slot — so the session
40    /// can be used as a unified "load best seed, save under exact key"
41    /// abstraction regardless of where the seed came from.
42    preloaded: Mutex<Option<WarmStartEntry>>,
43}
44
45#[derive(Debug)]
46struct Inner {
47    last_write: Option<Instant>,
48    best_seen: Option<f64>,
49}
50
51impl Session {
52    pub fn open(store: WarmStartStore, key: Fingerprint) -> Self {
53        let nanos = SystemTime::now()
54            .duration_since(UNIX_EPOCH)
55            .map(|d| d.as_nanos())
56            .unwrap_or(0);
57        let pid = std::process::id();
58        let run_id = format!("ckpt-r{pid:x}-{nanos:x}");
59        Self {
60            store,
61            key,
62            run_id,
63            inner: Mutex::new(Inner {
64                last_write: None,
65                best_seen: None,
66            }),
67            preloaded: Mutex::new(None),
68        }
69    }
70
71    /// Stash a near-match payload that the next [`Self::try_load`] call
72    /// should return in preference to looking up this session's key.
73    ///
74    /// Used by the workflow dispatcher to seed a fresh fit's outer loop
75    /// from a related but not-exact-fingerprint prior fit (e.g.,
76    /// cross-validation folds of the same model). The exact-key keyspace
77    /// remains untouched by this — checkpoint and finalize writes still
78    /// go to the session's own key.
79    pub fn preload(&self, entry: WarmStartEntry) {
80        let mut slot = match self.preloaded.lock() {
81            Ok(g) => g,
82            Err(p) => p.into_inner(),
83        };
84        *slot = Some(entry);
85    }
86
87    pub fn key(&self) -> &Fingerprint {
88        &self.key
89    }
90
91    pub fn run_id(&self) -> &str {
92        &self.run_id
93    }
94
95    pub fn store(&self) -> &WarmStartStore {
96        &self.store
97    }
98
99    /// Read the best entry currently on disk for this session's key.
100    /// Lookup is read-only against the store and may return entries from
101    /// other runs (the whole point of cross-run resume).
102    ///
103    /// If a near-match seed has been preloaded via [`Self::preload`],
104    /// the seed is returned in preference to the store lookup AND
105    /// consumed (so subsequent calls fall back to the store). This
106    /// makes the session a unified abstraction over "exact-key hit"
107    /// and "hierarchical-prefix seed."
108    pub fn try_load(&self) -> Option<WarmStartEntry> {
109        self.try_load_with_source().map(|loaded| loaded.entry)
110    }
111
112    /// Read the best available warm-start entry and report whether it came
113    /// from this session's exact key or from a preloaded near-match seed.
114    ///
115    /// Callers that only need a seed can use [`Self::try_load`]. Callers that
116    /// may skip expensive validation on a finalized exact hit need this source
117    /// bit so a near-match prefix seed is never mistaken for a completed fit.
118    pub fn try_load_with_source(&self) -> Option<LoadedEntry> {
119        if let Ok(mut slot) = self.preloaded.lock()
120            && let Some(entry) = slot.take()
121        {
122            return Some(LoadedEntry {
123                entry,
124                source: LoadSource::Preloaded,
125            });
126        }
127        self.store
128            .lookup(&self.key)
129            .ok()
130            .flatten()
131            .map(|entry| LoadedEntry {
132                entry,
133                source: LoadSource::Exact,
134            })
135    }
136
137    /// Read the currently available warm-start entry without consuming a
138    /// preloaded near-match seed.
139    ///
140    /// This is intentionally separate from [`Self::try_load`]: callers that
141    /// only need to make a scheduling decision (for example, whether to run an
142    /// expensive cold-start pilot) must not drain the preloaded seed that the
143    /// outer optimizer is about to consume.
144    pub fn peek_load(&self) -> Option<WarmStartEntry> {
145        self.peek_load_with_source().map(|loaded| loaded.entry)
146    }
147
148    /// Read the currently available warm-start entry with source metadata,
149    /// without consuming a preloaded near-match seed.
150    pub fn peek_load_with_source(&self) -> Option<LoadedEntry> {
151        if let Ok(slot) = self.preloaded.lock()
152            && let Some(entry) = slot.as_ref()
153        {
154            return Some(LoadedEntry {
155                entry: entry.clone(),
156                source: LoadSource::Preloaded,
157            });
158        }
159        self.store
160            .lookup(&self.key)
161            .ok()
162            .flatten()
163            .map(|entry| LoadedEntry {
164                entry,
165                source: LoadSource::Exact,
166            })
167    }
168
169    /// Persist a mid-fit checkpoint. Rate-limited; returns true if a write
170    /// actually happened. Always writes when the new objective strictly
171    /// improves on the best-so-far observed in this session.
172    pub fn checkpoint(
173        &self,
174        payload: &[u8],
175        objective: Option<f64>,
176        iteration: Option<u64>,
177    ) -> bool {
178        let now = Instant::now();
179        let mut guard = match self.inner.lock() {
180            Ok(g) => g,
181            Err(p) => p.into_inner(),
182        };
183        let improves = match (objective, guard.best_seen) {
184            (Some(o), Some(b)) => o < b - 1e-12,
185            (Some(_), None) => true,
186            _ => false,
187        };
188        if !improves
189            && let Some(last) = guard.last_write
190            && now.duration_since(last) < MIN_CHECKPOINT_INTERVAL
191        {
192            return false;
193        }
194        match self.store.save_overwrite(
195            &self.key,
196            &self.run_id,
197            payload,
198            objective,
199            iteration,
200            EntryKind::Checkpoint,
201        ) {
202            Ok(()) => {
203                guard.last_write = Some(now);
204                if let Some(o) = objective {
205                    guard.best_seen = Some(match guard.best_seen {
206                        Some(b) => b.min(o),
207                        None => o,
208                    });
209                }
210                true
211            }
212            Err(_) => false,
213        }
214    }
215
216    /// Persist the end-of-fit result, promoting this session's slot to
217    /// `EntryKind::Final`. Bypasses the rate limit.
218    pub fn finalize(&self, payload: &[u8], objective: Option<f64>, iteration: Option<u64>) -> bool {
219        self.store
220            .save_overwrite(
221                &self.key,
222                &self.run_id,
223                payload,
224                objective,
225                iteration,
226                EntryKind::Final,
227            )
228            .is_ok()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::warm_start::key::Fingerprinter;
236    use crate::warm_start::store::StoreOptions;
237
238    fn temp_session(label: &str) -> (tempfile::TempDir, Session) {
239        let dir = tempfile::tempdir().unwrap();
240        let store = WarmStartStore::open(
241            dir.path().to_path_buf(),
242            StoreOptions {
243                size_budget_bytes: 1024 * 1024,
244                ttl: Duration::from_secs(60),
245            },
246        )
247        .unwrap();
248        let mut fp = Fingerprinter::new();
249        fp.absorb_str(b"label", label);
250        let key = fp.finalize();
251        let s = Session::open(store, key);
252        (dir, s)
253    }
254
255    #[test]
256    fn checkpoint_then_load() {
257        let (_d, s) = temp_session("ckpt");
258        assert!(s.checkpoint(b"iter-1", Some(2.0), Some(1)));
259        let got = s.try_load().unwrap();
260        assert_eq!(got.payload, b"iter-1");
261        assert_eq!(got.objective, Some(2.0));
262        assert_eq!(got.kind, EntryKind::Checkpoint);
263    }
264
265    #[test]
266    fn improving_objective_bypasses_rate_limit() {
267        let (_d, s) = temp_session("improve");
268        assert!(s.checkpoint(b"a", Some(5.0), Some(1)));
269        // Immediately better objective — must write even though rate-limit
270        // window is open.
271        assert!(s.checkpoint(b"b", Some(3.0), Some(2)));
272        let got = s.try_load().unwrap();
273        assert_eq!(got.payload, b"b");
274        assert_eq!(got.objective, Some(3.0));
275    }
276
277    #[test]
278    fn non_improving_writes_are_throttled() {
279        let (_d, s) = temp_session("throttle");
280        assert!(s.checkpoint(b"a", Some(2.0), Some(1)));
281        // Worse objective inside the rate window — should be suppressed.
282        assert!(!s.checkpoint(b"b", Some(5.0), Some(2)));
283        // Disk still shows the better iterate.
284        let got = s.try_load().unwrap();
285        assert_eq!(got.payload, b"a");
286    }
287
288    #[test]
289    fn finalize_promotes_to_final_kind() {
290        let (_d, s) = temp_session("final");
291        s.checkpoint(b"ckpt", Some(2.0), Some(1));
292        s.finalize(b"done", Some(1.0), Some(5));
293        let got = s.try_load().unwrap();
294        assert_eq!(got.payload, b"done");
295        assert_eq!(got.kind, EntryKind::Final);
296    }
297
298    #[test]
299    fn preload_takes_precedence_over_store_lookup() {
300        // Hierarchical near-match semantics: when a session is opened on
301        // a fresh key (no entry) but preloaded with a near-match payload
302        // from a different key, try_load returns the preloaded entry.
303        let (_d, s) = temp_session("preload-empty");
304        assert!(s.try_load().is_none(), "fresh key should have no entry");
305
306        let seeded = WarmStartEntry {
307            payload: b"from-prefix".to_vec(),
308            objective: Some(7.0),
309            iteration: Some(42),
310            kind: EntryKind::Final,
311            written_unix_secs: 0,
312        };
313        s.preload(seeded);
314
315        let got = s.try_load().expect("preloaded seed should be returned");
316        assert_eq!(got.payload, b"from-prefix");
317        assert_eq!(got.objective, Some(7.0));
318    }
319
320    #[test]
321    fn preload_consumed_on_first_try_load() {
322        // The preload slot is consumed after one read so subsequent calls
323        // fall back to the store. This makes the session a unified
324        // "load best seed, save under exact key" abstraction without
325        // duplicating reads.
326        let (_d, s) = temp_session("preload-consume");
327        s.checkpoint(b"exact", Some(2.0), Some(5));
328
329        let seeded = WarmStartEntry {
330            payload: b"seed".to_vec(),
331            objective: Some(99.0),
332            iteration: Some(1),
333            kind: EntryKind::Checkpoint,
334            written_unix_secs: 0,
335        };
336        s.preload(seeded);
337
338        // First try_load: seed (preferred over store).
339        let first = s.try_load().expect("first call should return seed");
340        assert_eq!(first.payload, b"seed");
341
342        // Second try_load: store lookup after the seed is consumed.
343        let second = s.try_load().expect("second call should read from store");
344        assert_eq!(second.payload, b"exact");
345    }
346
347    #[test]
348    fn peek_load_does_not_consume_preloaded_seed() {
349        let (_d, s) = temp_session("preload-peek");
350        let seeded = WarmStartEntry {
351            payload: b"seed".to_vec(),
352            objective: Some(3.0),
353            iteration: Some(9),
354            kind: EntryKind::Final,
355            written_unix_secs: 0,
356        };
357        s.preload(seeded);
358
359        let peeked = s
360            .peek_load_with_source()
361            .expect("peek should see preloaded seed");
362        assert_eq!(peeked.entry.payload, b"seed");
363        assert_eq!(peeked.source, LoadSource::Preloaded);
364
365        let loaded = s
366            .try_load()
367            .expect("try_load should still receive the preloaded seed");
368        assert_eq!(loaded.payload, b"seed");
369        assert!(
370            s.try_load().is_none(),
371            "preloaded seed should be consumed only by try_load"
372        );
373    }
374
375    #[test]
376    fn second_session_reads_first_session_checkpoint() {
377        let dir = tempfile::tempdir().unwrap();
378        let mut fp = Fingerprinter::new();
379        fp.absorb_str(b"k", "shared");
380        let key = fp.finalize();
381
382        let store_a = WarmStartStore::open(
383            dir.path().to_path_buf(),
384            StoreOptions {
385                size_budget_bytes: 1024 * 1024,
386                ttl: Duration::from_secs(60),
387            },
388        )
389        .unwrap();
390        let s_a = Session::open(store_a, key);
391        s_a.checkpoint(b"from-a", Some(1.0), Some(3));
392
393        // Simulate a fresh process starting later.
394        let store_b = WarmStartStore::open(
395            dir.path().to_path_buf(),
396            StoreOptions {
397                size_budget_bytes: 1024 * 1024,
398                ttl: Duration::from_secs(60),
399            },
400        )
401        .unwrap();
402        let s_b = Session::open(store_b, key);
403        let got = s_b.try_load().unwrap();
404        assert_eq!(got.payload, b"from-a");
405        assert_eq!(got.objective, Some(1.0));
406    }
407}