1use gam_runtime::warm_start::{EntryKind, Fingerprinter, StoreOptions, WarmStartStore};
2use serde::{Deserialize, Serialize};
3use std::sync::OnceLock;
4use std::time::Duration;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7const CACHE_VERSION: u32 = 1;
8const MAX_ENTRY_BYTES: u64 = 16 * 1024 * 1024;
9const MAX_TOTAL_BYTES: u64 = 256 * 1024 * 1024;
10const CACHE_TTL_SECS: u64 = 60 * 60 * 24 * 365 * 10;
11
12pub fn cache_schema_tag() -> String {
22 "schema3-unified-fingerprinter-v3".to_string()
42}
43
44#[derive(Clone, Debug, Serialize, Deserialize)]
45pub struct PersistentWarmStartRecord {
46 pub version: u32,
47 pub key: String,
48 pub package_version: String,
49 pub created_unix_secs: u64,
50 pub updated_unix_secs: u64,
51 pub n_rows: usize,
52 pub n_cols: usize,
53 pub rho: Vec<f64>,
54 pub beta: Vec<f64>,
55 pub prev_rho: Option<Vec<f64>>,
56 pub prev_beta: Option<Vec<f64>>,
57 pub last_inner_iters: usize,
58 pub last_inner_converged: bool,
59 pub last_pirls_lm_lambda: Option<f64>,
60 pub last_ift_prediction_residual: Option<f64>,
61 pub last_pirls_accept_rho: Option<f64>,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
65pub struct PersistentBlockInnerSummary {
66 pub log_likelihood: f64,
67 pub penalty_value: f64,
68 pub cycles: usize,
69 pub converged: bool,
70 pub block_logdet_h: f64,
71 pub block_logdet_s: f64,
72}
73
74impl PersistentBlockInnerSummary {
75 fn is_valid(&self) -> bool {
76 self.log_likelihood.is_finite()
77 && self.penalty_value.is_finite()
78 && self.block_logdet_h.is_finite()
79 && self.block_logdet_s.is_finite()
80 }
81}
82
83#[derive(Clone, Debug, Serialize, Deserialize)]
84pub struct PersistentBlockWarmStartRecord {
85 pub version: u32,
86 pub key: String,
87 pub package_version: String,
88 pub created_unix_secs: u64,
89 pub updated_unix_secs: u64,
90 pub n_rows: usize,
91 pub block_names: Vec<String>,
92 pub block_dims: Vec<usize>,
93 pub rho: Vec<f64>,
94 pub block_beta: Vec<Vec<f64>>,
95 pub active_sets: Vec<Option<Vec<usize>>>,
96 #[serde(default)]
97 pub inner: Option<PersistentBlockInnerSummary>,
98}
99
100impl PersistentBlockWarmStartRecord {
101 pub fn new(
102 key: String,
103 n_rows: usize,
104 block_names: Vec<String>,
105 block_dims: Vec<usize>,
106 ) -> Self {
107 let now = unix_secs_now();
108 Self {
109 version: CACHE_VERSION,
110 key,
111 package_version: env!("CARGO_PKG_VERSION").to_string(),
112 created_unix_secs: now,
113 updated_unix_secs: now,
114 n_rows,
115 block_names,
116 block_dims,
117 rho: Vec::new(),
118 block_beta: Vec::new(),
119 active_sets: Vec::new(),
120 inner: None,
121 }
122 }
123
124 pub fn is_compatible(
125 &self,
126 key: &str,
127 n_rows: usize,
128 block_names: &[String],
129 block_dims: &[usize],
130 rho_len: usize,
131 ) -> bool {
132 self.version == CACHE_VERSION
133 && self.key == key
134 && self.n_rows == n_rows
141 && self.block_names == block_names
142 && self.block_dims == block_dims
143 && self.rho.len() == rho_len
144 && self.rho.iter().all(|v| v.is_finite())
145 && self.block_beta.len() == block_dims.len()
146 && self
147 .block_beta
148 .iter()
149 .zip(block_dims.iter())
150 .all(|(beta, dim)| beta.len() == *dim && beta.iter().all(|v| v.is_finite()))
151 && self.active_sets.len() == block_dims.len()
152 && self.inner.as_ref().is_none_or(|inner| inner.is_valid())
153 }
154}
155
156impl PersistentWarmStartRecord {
157 pub fn new(key: String, n_rows: usize, n_cols: usize) -> Self {
158 let now = unix_secs_now();
159 Self {
160 version: CACHE_VERSION,
161 key,
162 package_version: env!("CARGO_PKG_VERSION").to_string(),
163 created_unix_secs: now,
164 updated_unix_secs: now,
165 n_rows,
166 n_cols,
167 rho: Vec::new(),
168 beta: Vec::new(),
169 prev_rho: None,
170 prev_beta: None,
171 last_inner_iters: 0,
172 last_inner_converged: false,
173 last_pirls_lm_lambda: None,
174 last_ift_prediction_residual: None,
175 last_pirls_accept_rho: None,
176 }
177 }
178
179 pub fn is_compatible(&self, key: &str, n_rows: usize, n_cols: usize) -> bool {
180 self.version == CACHE_VERSION
181 && self.key == key
182 && self.n_rows == n_rows
189 && self.n_cols == n_cols
190 && self.rho.iter().all(|v| v.is_finite())
191 && self.beta.len() == n_cols
192 && self.beta.iter().all(|v| v.is_finite())
193 && self
194 .prev_rho
195 .as_ref()
196 .is_none_or(|rho| rho.len() == self.rho.len() && rho.iter().all(|v| v.is_finite()))
197 && self
198 .prev_beta
199 .as_ref()
200 .is_none_or(|beta| beta.len() == n_cols && beta.iter().all(|v| v.is_finite()))
201 }
202}
203
204pub fn load_record(key: &str) -> Option<PersistentWarmStartRecord> {
205 load_json_record(key)
206}
207
208pub fn load_block_record(key: &str) -> Option<PersistentBlockWarmStartRecord> {
209 load_json_record(key)
210}
211
212pub fn store_record(record: &PersistentWarmStartRecord) -> Result<(), String> {
213 store_json_record(&record.key, record)
214}
215
216pub fn store_block_record(record: &PersistentBlockWarmStartRecord) -> Result<(), String> {
217 store_json_record(&record.key, record)
218}
219
220fn store_json_record<T: Serialize>(key: &str, record: &T) -> Result<(), String> {
221 let bytes = serde_json::to_vec(record)
222 .map_err(|e| format!("failed to encode warm-start cache record: {e}"))?;
223 if bytes.len() as u64 > MAX_ENTRY_BYTES {
224 return Ok(());
225 }
226 let Some(store) = persistent_store() else {
227 return Ok(());
228 };
229 let mut fp = Fingerprinter::new();
230 fp.absorb_str(b"warm-start-key", key);
231 store
232 .save(&fp.finalize(), &bytes, None, None, EntryKind::Checkpoint)
233 .map(|_| ())
234 .map_err(|e| format!("failed to persist warm-start cache record: {e}"))
235}
236
237fn load_json_record<T: for<'de> Deserialize<'de>>(key: &str) -> Option<T> {
238 let store = persistent_store()?;
239 let mut fp = Fingerprinter::new();
240 fp.absorb_str(b"warm-start-key", key);
241 let entry = store.lookup(&fp.finalize()).ok().flatten()?;
242 if entry.payload.len() as u64 > MAX_ENTRY_BYTES {
243 return None;
244 }
245 serde_json::from_slice(&entry.payload).ok()
246}
247
248fn persistent_store() -> Option<WarmStartStore> {
260 static STORE: OnceLock<Option<WarmStartStore>> = OnceLock::new();
270 STORE
271 .get_or_init(|| {
272 let root = std::env::temp_dir().join("gam").join("warm").join("v1");
273 WarmStartStore::open(
274 root,
275 StoreOptions {
276 size_budget_bytes: MAX_TOTAL_BYTES,
277 ttl: Duration::from_secs(CACHE_TTL_SECS),
278 },
279 )
280 .ok()
281 })
282 .clone()
283}
284
285pub(crate) fn open_outer_session(
292 key: &str,
293) -> Option<std::sync::Arc<gam_runtime::warm_start::Session>> {
294 let store = persistent_store()?;
295 let mut fp = Fingerprinter::new();
296 fp.absorb_str(b"outer-iterate-key", key);
297 let fp = fp.finalize();
298 Some(std::sync::Arc::new(gam_runtime::warm_start::Session::open(
299 store, fp,
300 )))
301}
302
303pub fn store_fit_artifact(
311 artifact: &crate::warm_start_artifact::FitArtifact,
312) -> Result<(), String> {
313 if !artifact.is_usable() {
314 return Ok(());
317 }
318 let bytes = serde_json::to_vec(artifact)
319 .map_err(|e| format!("failed to encode fit-artifact record: {e}"))?;
320 if bytes.len() as u64 > MAX_ENTRY_BYTES {
321 return Ok(());
322 }
323 let Some(store) = persistent_store() else {
324 return Ok(());
325 };
326 let key = artifact.descriptor.descriptor_key().to_hex();
327 let mut fp = Fingerprinter::new();
328 fp.absorb_str(b"fit-artifact-key", &cache_schema_tag());
329 fp.absorb_str(b"fit-artifact-descriptor", &key);
330 store
331 .save(&fp.finalize(), &bytes, None, None, EntryKind::Checkpoint)
332 .map(|_| ())
333 .map_err(|e| format!("failed to persist fit-artifact record: {e}"))
334}
335
336pub fn load_fit_artifact_by_descriptor(
344 descriptor_key_hex: &str,
345) -> Option<crate::warm_start_artifact::FitArtifact> {
346 let store = persistent_store()?;
347 let mut fp = Fingerprinter::new();
348 fp.absorb_str(b"fit-artifact-key", &cache_schema_tag());
349 fp.absorb_str(b"fit-artifact-descriptor", descriptor_key_hex);
350 let entry = store.lookup_latest(&fp.finalize()).ok().flatten()?;
351 if entry.payload.len() as u64 > MAX_ENTRY_BYTES {
352 return None;
353 }
354 let artifact: crate::warm_start_artifact::FitArtifact =
355 serde_json::from_slice(&entry.payload).ok()?;
356 artifact.is_usable().then_some(artifact)
359}
360
361fn unix_secs_now() -> u64 {
362 SystemTime::now()
363 .duration_since(UNIX_EPOCH)
364 .map(|d| d.as_secs())
365 .unwrap_or(0)
366}
367
368#[cfg(test)]
369mod warm_start_artifact_tests {
370 use super::*;
371 use crate::warm_start_artifact::{
372 FIT_ARTIFACT_SCHEMA, FitArtifact, FitDescriptor, GlobalFitSummary, ResponseSig,
373 SerializableBasisMeta, TermArtifact, TermRole, term_identity_from_block,
374 };
375
376 fn sample_artifact(family: &str, var: &str, rho: Vec<f64>) -> FitArtifact {
377 let block_name = format!("s({var})");
380 let id = term_identity_from_block(TermRole::Mean, &block_name, &[None], &[1], 10);
381 FitArtifact {
382 schema: FIT_ARTIFACT_SCHEMA,
383 created_unix_secs: unix_secs_now(),
384 descriptor: FitDescriptor {
385 family_kind: family.to_string(),
386 term_identities: vec![id],
387 response_signature: ResponseSig {
388 family_kind: family.to_string(),
389 n_response_channels: 1,
390 },
391 row_population: None,
392 },
393 terms: vec![TermArtifact {
394 identity: id,
395 role: TermRole::Mean,
396 basis_meta: SerializableBasisMeta {
397 kind: "block-spec".to_string(),
398 degree: None,
399 num_knots: None,
400 n_centers: Some(8),
401 nullspace_order: None,
402 matern_nu: None,
403 periodic: false,
404 },
405 joint_null_rotation: None,
406 raw_beta: vec![0.1, -0.2, 0.3, 0.4, -0.5, 0.6, -0.7, 0.8],
407 rho_for_term: rho,
408 }],
409 global: GlobalFitSummary {
410 outer_objective: -42.0,
411 converged: true,
412 n_rows: 500,
413 },
414 }
415 }
416
417 #[test]
418 fn artifact_round_trips_on_disk_by_descriptor() {
419 let family = format!("test-roundtrip-{}", unix_secs_now());
423 let artifact = sample_artifact(&family, "x", vec![2.5]);
424 let key_hex = artifact.descriptor.descriptor_key().to_hex();
425
426 if persistent_store().is_none() {
429 return;
430 }
431 store_fit_artifact(&artifact).expect("store fit artifact");
432 let loaded = load_fit_artifact_by_descriptor(&key_hex)
433 .expect("artifact must be retrievable by descriptor key");
434 assert_eq!(loaded.schema, artifact.schema);
435 assert_eq!(loaded.terms.len(), 1);
436 assert_eq!(loaded.terms[0].identity, artifact.terms[0].identity);
437 assert_eq!(loaded.terms[0].rho_for_term, vec![2.5]);
438 assert_eq!(loaded.terms[0].raw_beta, artifact.terms[0].raw_beta);
439 assert_eq!(
440 loaded.descriptor.descriptor_key(),
441 artifact.descriptor.descriptor_key()
442 );
443 }
444
445 #[test]
446 fn loso_fold_descriptor_matches_full_data_artifact() {
447 let family = format!("test-loso-{}", unix_secs_now());
448 let mut full = sample_artifact(&family, "x", vec![1.7]);
450 full.descriptor.row_population =
451 Some(crate::warm_start_artifact::RowPopulationTag {
452 n_rows: 1000,
453 label: Some("full".to_string()),
454 });
455 full.global.n_rows = 1000;
456 let full_key = full.descriptor.descriptor_key().to_hex();
457
458 let fold = sample_artifact(&family, "x", vec![1.7]);
461 let fold_key = fold.descriptor.descriptor_key().to_hex();
462 assert_eq!(
463 full_key, fold_key,
464 "fold and full descriptor keys must match"
465 );
466
467 if persistent_store().is_none() {
468 return;
469 }
470 store_fit_artifact(&full).expect("store full-data artifact");
471 let loaded = load_fit_artifact_by_descriptor(&fold_key)
472 .expect("LOSO fold must retrieve the full-data artifact");
473 assert_eq!(loaded.terms[0].rho_for_term, vec![1.7]);
474 }
475}