Skip to main content

gam_solve/
persistent_warm_start.rs

1use gam_runtime::warm_start::{EntryKind, Fingerprinter, StoreOptions, WarmStartStore};
2use serde::{Deserialize, Serialize};
3use std::sync::OnceLock;
4use std::time::Duration;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7const CACHE_VERSION: u32 = 1;
8const MAX_ENTRY_BYTES: u64 = 16 * 1024 * 1024;
9const MAX_TOTAL_BYTES: u64 = 256 * 1024 * 1024;
10const CACHE_TTL_SECS: u64 = 60 * 60 * 24 * 365 * 10;
11
12/// String tag identifying the on-disk cache schema, embedded directly in
13/// cache keys.
14///
15/// The leading `schema2-` prefix is bumped manually only when the
16/// serialized cache layout changes in a way that makes prior entries
17/// unsafe to consume (struct fields added/removed, optimization
18/// invariants altered, payload semantics shift). This is **deliberately
19/// separate** from `CARGO_PKG_VERSION` so a routine library version bump
20/// does NOT invalidate every user's warm-start cache.
21pub fn cache_schema_tag() -> String {
22    // Bumped from `schema2-` → `schema3-` when the three hand-written
23    // hashers (`Fingerprinter`, `StableHasher`, `CacheDigestBuilder`) were
24    // unified onto `Fingerprinter`. Prior on-disk warm-start entries are
25    // walled off into the `schema2-` keyspace and cold-start once; this
26    // is the intentional consequence of the unification, documented in
27    // the commit that performs it. See `src/warm_start/key.rs` for the new
28    // canonical hasher API.
29    // Bumped to `v2` when the persistent warm-start key stopped hashing the
30    // θ-dependent, lazily-refreshed isometry Jacobian cache slots
31    // (`jacobian_cache` / `jacobian_second_cache` / `third_decoder_derivative`).
32    // Those snapshots made the key non-reproducible across identical repeat
33    // fits, so the outer `skip-outer-validation` warm hit was lost (#1048). The
34    // bump walls off any entries written under the old, drifting keys; they are
35    // simply never matched (and TTL-evicted) rather than aliasing.
36    // Bumped to `v3` when the descriptor-indexed cross-fit `FitArtifact`
37    // keyspace ("fit-artifact-key") was introduced alongside the existing
38    // inner/outer warm-start records. The bump walls off any entries written
39    // under the old layouts so a mixed-schema store never aliases a legacy
40    // payload into the new artifact reader (and vice versa).
41    "schema3-unified-fingerprinter-v3".to_string()
42}
43
44#[derive(Clone, Debug, Serialize, Deserialize)]
45pub struct PersistentWarmStartRecord {
46    pub version: u32,
47    pub key: String,
48    pub package_version: String,
49    pub created_unix_secs: u64,
50    pub updated_unix_secs: u64,
51    pub n_rows: usize,
52    pub n_cols: usize,
53    pub rho: Vec<f64>,
54    pub beta: Vec<f64>,
55    pub prev_rho: Option<Vec<f64>>,
56    pub prev_beta: Option<Vec<f64>>,
57    pub last_inner_iters: usize,
58    pub last_inner_converged: bool,
59    pub last_pirls_lm_lambda: Option<f64>,
60    pub last_ift_prediction_residual: Option<f64>,
61    pub last_pirls_accept_rho: Option<f64>,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
65pub struct PersistentBlockInnerSummary {
66    pub log_likelihood: f64,
67    pub penalty_value: f64,
68    pub cycles: usize,
69    pub converged: bool,
70    pub block_logdet_h: f64,
71    pub block_logdet_s: f64,
72}
73
74impl PersistentBlockInnerSummary {
75    fn is_valid(&self) -> bool {
76        self.log_likelihood.is_finite()
77            && self.penalty_value.is_finite()
78            && self.block_logdet_h.is_finite()
79            && self.block_logdet_s.is_finite()
80    }
81}
82
83#[derive(Clone, Debug, Serialize, Deserialize)]
84pub struct PersistentBlockWarmStartRecord {
85    pub version: u32,
86    pub key: String,
87    pub package_version: String,
88    pub created_unix_secs: u64,
89    pub updated_unix_secs: u64,
90    pub n_rows: usize,
91    pub block_names: Vec<String>,
92    pub block_dims: Vec<usize>,
93    pub rho: Vec<f64>,
94    pub block_beta: Vec<Vec<f64>>,
95    pub active_sets: Vec<Option<Vec<usize>>>,
96    #[serde(default)]
97    pub inner: Option<PersistentBlockInnerSummary>,
98}
99
100impl PersistentBlockWarmStartRecord {
101    pub fn new(
102        key: String,
103        n_rows: usize,
104        block_names: Vec<String>,
105        block_dims: Vec<usize>,
106    ) -> Self {
107        let now = unix_secs_now();
108        Self {
109            version: CACHE_VERSION,
110            key,
111            package_version: env!("CARGO_PKG_VERSION").to_string(),
112            created_unix_secs: now,
113            updated_unix_secs: now,
114            n_rows,
115            block_names,
116            block_dims,
117            rho: Vec::new(),
118            block_beta: Vec::new(),
119            active_sets: Vec::new(),
120            inner: None,
121        }
122    }
123
124    pub fn is_compatible(
125        &self,
126        key: &str,
127        n_rows: usize,
128        block_names: &[String],
129        block_dims: &[usize],
130        rho_len: usize,
131    ) -> bool {
132        self.version == CACHE_VERSION
133            && self.key == key
134            // Note: `package_version` is no longer required to match. A
135            // library version bump that doesn't change the cache schema
136            // (the common case for patch / minor releases) should NOT
137            // invalidate users' on-disk warm-start caches. Schema-breaking
138            // changes bump the `schemaN-` prefix in `cache_schema_tag()`,
139            // which is encoded in the cache key itself.
140            && self.n_rows == n_rows
141            && self.block_names == block_names
142            && self.block_dims == block_dims
143            && self.rho.len() == rho_len
144            && self.rho.iter().all(|v| v.is_finite())
145            && self.block_beta.len() == block_dims.len()
146            && self
147                .block_beta
148                .iter()
149                .zip(block_dims.iter())
150                .all(|(beta, dim)| beta.len() == *dim && beta.iter().all(|v| v.is_finite()))
151            && self.active_sets.len() == block_dims.len()
152            && self.inner.as_ref().is_none_or(|inner| inner.is_valid())
153    }
154}
155
156impl PersistentWarmStartRecord {
157    pub fn new(key: String, n_rows: usize, n_cols: usize) -> Self {
158        let now = unix_secs_now();
159        Self {
160            version: CACHE_VERSION,
161            key,
162            package_version: env!("CARGO_PKG_VERSION").to_string(),
163            created_unix_secs: now,
164            updated_unix_secs: now,
165            n_rows,
166            n_cols,
167            rho: Vec::new(),
168            beta: Vec::new(),
169            prev_rho: None,
170            prev_beta: None,
171            last_inner_iters: 0,
172            last_inner_converged: false,
173            last_pirls_lm_lambda: None,
174            last_ift_prediction_residual: None,
175            last_pirls_accept_rho: None,
176        }
177    }
178
179    pub fn is_compatible(&self, key: &str, n_rows: usize, n_cols: usize) -> bool {
180        self.version == CACHE_VERSION
181            && self.key == key
182            // Note: `package_version` is no longer required to match. A
183            // library version bump that doesn't change the cache schema
184            // (the common case for patch / minor releases) should NOT
185            // invalidate users' on-disk warm-start caches. Schema-breaking
186            // changes bump the `schemaN-` prefix in `cache_schema_tag()`,
187            // which is encoded in the cache key itself.
188            && self.n_rows == n_rows
189            && self.n_cols == n_cols
190            && self.rho.iter().all(|v| v.is_finite())
191            && self.beta.len() == n_cols
192            && self.beta.iter().all(|v| v.is_finite())
193            && self
194                .prev_rho
195                .as_ref()
196                .is_none_or(|rho| rho.len() == self.rho.len() && rho.iter().all(|v| v.is_finite()))
197            && self
198                .prev_beta
199                .as_ref()
200                .is_none_or(|beta| beta.len() == n_cols && beta.iter().all(|v| v.is_finite()))
201    }
202}
203
204pub fn load_record(key: &str) -> Option<PersistentWarmStartRecord> {
205    load_json_record(key)
206}
207
208pub fn load_block_record(key: &str) -> Option<PersistentBlockWarmStartRecord> {
209    load_json_record(key)
210}
211
212pub fn store_record(record: &PersistentWarmStartRecord) -> Result<(), String> {
213    store_json_record(&record.key, record)
214}
215
216pub fn store_block_record(record: &PersistentBlockWarmStartRecord) -> Result<(), String> {
217    store_json_record(&record.key, record)
218}
219
220fn store_json_record<T: Serialize>(key: &str, record: &T) -> Result<(), String> {
221    let bytes = serde_json::to_vec(record)
222        .map_err(|e| format!("failed to encode warm-start cache record: {e}"))?;
223    if bytes.len() as u64 > MAX_ENTRY_BYTES {
224        return Ok(());
225    }
226    let Some(store) = persistent_store() else {
227        return Ok(());
228    };
229    let mut fp = Fingerprinter::new();
230    fp.absorb_str(b"warm-start-key", key);
231    store
232        .save(&fp.finalize(), &bytes, None, None, EntryKind::Checkpoint)
233        .map(|_| ())
234        .map_err(|e| format!("failed to persist warm-start cache record: {e}"))
235}
236
237fn load_json_record<T: for<'de> Deserialize<'de>>(key: &str) -> Option<T> {
238    let store = persistent_store()?;
239    let mut fp = Fingerprinter::new();
240    fp.absorb_str(b"warm-start-key", key);
241    let entry = store.lookup(&fp.finalize()).ok().flatten()?;
242    if entry.payload.len() as u64 > MAX_ENTRY_BYTES {
243        return None;
244    }
245    serde_json::from_slice(&entry.payload).ok()
246}
247
248/// Anchor the warm-start cache under the platform temp directory.
249///
250/// Reading `XDG_CACHE_HOME` / `HOME` / `LOCALAPPDATA` (the canonical
251/// `dirs::cache_dir()` fallbacks) requires `env::var_os`, which is banned
252/// in this crate (see the build-script tripwire scan and the
253/// `feedback_no_env_vars` policy memo). `std::env::temp_dir()` resolves
254/// platform-conventional locations through OS-level primitives without
255/// going through `env::var`, so we route the persistent warm-start
256/// checkpoint root there instead. The directory is durable across
257/// processes within a single boot, and `WarmStartStore::open` falls back
258/// to `None` if the path is unwritable.
259fn persistent_store() -> Option<WarmStartStore> {
260    // Memoize the store process-wide. The root (`temp_dir()/gam/warm/v1`) is
261    // constant within a process, so a single instance suffices — and reusing
262    // it is essential, not just an optimization: `WarmStartStore` carries the
263    // per-store directory-scan / metadata cache and the eviction-throttle
264    // counters that #1114 added. Reconstructing the store on every save/lookup
265    // (as this used to) handed each fit an empty cache and a zeroed throttle,
266    // so every operation re-walked the cache root and re-read every metadata
267    // JSON from disk — the syscall storm that made several quality tests look
268    // hung. Clones returned here share the cache and throttle via `Arc`.
269    static STORE: OnceLock<Option<WarmStartStore>> = OnceLock::new();
270    STORE
271        .get_or_init(|| {
272            let root = std::env::temp_dir().join("gam").join("warm").join("v1");
273            WarmStartStore::open(
274                root,
275                StoreOptions {
276                    size_budget_bytes: MAX_TOTAL_BYTES,
277                    ttl: Duration::from_secs(CACHE_TTL_SECS),
278                },
279            )
280            .ok()
281        })
282        .clone()
283}
284
285/// Open a [`gam_runtime::warm_start::Session`] for outer-iterate (rho-axis) checkpoints.
286///
287/// Uses a different fingerprint tag than the inner `warm-start-key`
288/// absorption (see [`load_json_record`]) so the outer-iterate keyspace
289/// is disjoint from the inner beta-record keyspace —
290/// the two layers persist different payload shapes and must not alias.
291pub(crate) fn open_outer_session(
292    key: &str,
293) -> Option<std::sync::Arc<gam_runtime::warm_start::Session>> {
294    let store = persistent_store()?;
295    let mut fp = Fingerprinter::new();
296    fp.absorb_str(b"outer-iterate-key", key);
297    let fp = fp.finalize();
298    Some(std::sync::Arc::new(gam_runtime::warm_start::Session::open(
299        store, fp,
300    )))
301}
302
303/// Persist a descriptor-indexed cross-fit [`FitArtifact`] under the
304/// `fit-artifact-key` keyspace, keyed by the descriptor's structural key (so
305/// an LOSO fold of the same model retrieves a prior full-data fit). The
306/// schema tag is folded into the key so legacy layouts are walled off.
307///
308/// Best-effort: encoding / store failures are swallowed (a warm-start
309/// artifact is never required), oversize payloads are dropped.
310pub fn store_fit_artifact(
311    artifact: &crate::warm_start_artifact::FitArtifact,
312) -> Result<(), String> {
313    if !artifact.is_usable() {
314        // Never persist a non-finite / wrong-schema artifact: it could only
315        // ever be rejected on load anyway.
316        return Ok(());
317    }
318    let bytes = serde_json::to_vec(artifact)
319        .map_err(|e| format!("failed to encode fit-artifact record: {e}"))?;
320    if bytes.len() as u64 > MAX_ENTRY_BYTES {
321        return Ok(());
322    }
323    let Some(store) = persistent_store() else {
324        return Ok(());
325    };
326    let key = artifact.descriptor.descriptor_key().to_hex();
327    let mut fp = Fingerprinter::new();
328    fp.absorb_str(b"fit-artifact-key", &cache_schema_tag());
329    fp.absorb_str(b"fit-artifact-descriptor", &key);
330    store
331        .save(&fp.finalize(), &bytes, None, None, EntryKind::Checkpoint)
332        .map(|_| ())
333        .map_err(|e| format!("failed to persist fit-artifact record: {e}"))
334}
335
336/// Load the newest valid cross-fit [`FitArtifact`] whose descriptor key
337/// matches `descriptor_key_hex` (the hex of [`crate::warm_start_artifact::FitDescriptor::descriptor_key`]).
338///
339/// Uses `lookup_latest` (newest-valid) rather than objective-ranked lookup:
340/// descriptor-key matches can be different folds / row sets whose objectives
341/// are not on a common scale, so "lowest objective" is the wrong rule.
342/// Returns `None` (cold fallback) on any miss or non-finite payload.
343pub fn load_fit_artifact_by_descriptor(
344    descriptor_key_hex: &str,
345) -> Option<crate::warm_start_artifact::FitArtifact> {
346    let store = persistent_store()?;
347    let mut fp = Fingerprinter::new();
348    fp.absorb_str(b"fit-artifact-key", &cache_schema_tag());
349    fp.absorb_str(b"fit-artifact-descriptor", descriptor_key_hex);
350    let entry = store.lookup_latest(&fp.finalize()).ok().flatten()?;
351    if entry.payload.len() as u64 > MAX_ENTRY_BYTES {
352        return None;
353    }
354    let artifact: crate::warm_start_artifact::FitArtifact =
355        serde_json::from_slice(&entry.payload).ok()?;
356    // Finite-guard on the way out: a corrupt payload must cold-fallback,
357    // never poison a fit.
358    artifact.is_usable().then_some(artifact)
359}
360
361fn unix_secs_now() -> u64 {
362    SystemTime::now()
363        .duration_since(UNIX_EPOCH)
364        .map(|d| d.as_secs())
365        .unwrap_or(0)
366}
367
368#[cfg(test)]
369mod warm_start_artifact_tests {
370    use super::*;
371    use crate::warm_start_artifact::{
372        FIT_ARTIFACT_SCHEMA, FitArtifact, FitDescriptor, GlobalFitSummary, ResponseSig,
373        SerializableBasisMeta, TermArtifact, TermRole, term_identity_from_block,
374    };
375
376    fn sample_artifact(family: &str, var: &str, rho: Vec<f64>) -> FitArtifact {
377        // Block-layer identity (the surviving, fold-invariant identity API):
378        // the block name carries the variable, with one unlabeled penalty.
379        let block_name = format!("s({var})");
380        let id = term_identity_from_block(TermRole::Mean, &block_name, &[None], &[1], 10);
381        FitArtifact {
382            schema: FIT_ARTIFACT_SCHEMA,
383            created_unix_secs: unix_secs_now(),
384            descriptor: FitDescriptor {
385                family_kind: family.to_string(),
386                term_identities: vec![id],
387                response_signature: ResponseSig {
388                    family_kind: family.to_string(),
389                    n_response_channels: 1,
390                },
391                row_population: None,
392            },
393            terms: vec![TermArtifact {
394                identity: id,
395                role: TermRole::Mean,
396                basis_meta: SerializableBasisMeta {
397                    kind: "block-spec".to_string(),
398                    degree: None,
399                    num_knots: None,
400                    n_centers: Some(8),
401                    nullspace_order: None,
402                    matern_nu: None,
403                    periodic: false,
404                },
405                joint_null_rotation: None,
406                raw_beta: vec![0.1, -0.2, 0.3, 0.4, -0.5, 0.6, -0.7, 0.8],
407                rho_for_term: rho,
408            }],
409            global: GlobalFitSummary {
410                outer_objective: -42.0,
411                converged: true,
412                n_rows: 500,
413            },
414        }
415    }
416
417    #[test]
418    fn artifact_round_trips_on_disk_by_descriptor() {
419        // Use a unique family-kind tag so this test's descriptor key is
420        // disjoint from any other run's keyspace (the store is process-shared
421        // under the temp dir).
422        let family = format!("test-roundtrip-{}", unix_secs_now());
423        let artifact = sample_artifact(&family, "x", vec![2.5]);
424        let key_hex = artifact.descriptor.descriptor_key().to_hex();
425
426        // If the platform temp dir is unwritable, the store is None and the
427        // round-trip is a no-op; only assert when persistence is available.
428        if persistent_store().is_none() {
429            return;
430        }
431        store_fit_artifact(&artifact).expect("store fit artifact");
432        let loaded = load_fit_artifact_by_descriptor(&key_hex)
433            .expect("artifact must be retrievable by descriptor key");
434        assert_eq!(loaded.schema, artifact.schema);
435        assert_eq!(loaded.terms.len(), 1);
436        assert_eq!(loaded.terms[0].identity, artifact.terms[0].identity);
437        assert_eq!(loaded.terms[0].rho_for_term, vec![2.5]);
438        assert_eq!(loaded.terms[0].raw_beta, artifact.terms[0].raw_beta);
439        assert_eq!(
440            loaded.descriptor.descriptor_key(),
441            artifact.descriptor.descriptor_key()
442        );
443    }
444
445    #[test]
446    fn loso_fold_descriptor_matches_full_data_artifact() {
447        let family = format!("test-loso-{}", unix_secs_now());
448        // Full-data fit on 1000 rows.
449        let mut full = sample_artifact(&family, "x", vec![1.7]);
450        full.descriptor.row_population =
451            Some(crate::warm_start_artifact::RowPopulationTag {
452                n_rows: 1000,
453                label: Some("full".to_string()),
454            });
455        full.global.n_rows = 1000;
456        let full_key = full.descriptor.descriptor_key().to_hex();
457
458        // LOSO fold: same term identities, fewer rows. Its descriptor key
459        // must equal the full-data key, so the load hits the stored artifact.
460        let fold = sample_artifact(&family, "x", vec![1.7]);
461        let fold_key = fold.descriptor.descriptor_key().to_hex();
462        assert_eq!(
463            full_key, fold_key,
464            "fold and full descriptor keys must match"
465        );
466
467        if persistent_store().is_none() {
468            return;
469        }
470        store_fit_artifact(&full).expect("store full-data artifact");
471        let loaded = load_fit_artifact_by_descriptor(&fold_key)
472            .expect("LOSO fold must retrieve the full-data artifact");
473        assert_eq!(loaded.terms[0].rho_for_term, vec![1.7]);
474    }
475}