gam_solve/warm_start_artifact.rs
1//! Cross-fit warm-start artifact: a descriptor-indexed, function-space
2//! snapshot of a converged fit, designed so a *related* later fit (a
3//! leave-one-subject-out fold, a re-fit on a different row population, a
4//! different reduced width) can warm-start from it even though the exact
5//! response-keyed inner cache (`persistent_warm_start.rs`) misses.
6//!
7//! The artifact is keyed by *structural identity*, not by data bytes. Two
8//! fits of the same term family (same role, same variables, same basis kind
9//! and the same STRUCTURAL basis parameters — degree, #centers, nullspace
10//! order, …) map to the same [`TermIdentityKey`] even when their realized
11//! `centers` / `input_scales` / `length_scale` differ across folds. That is
12//! precisely what lets the smoothing parameter ρ transfer survive a fold:
13//! "same term, different rows" matches; "3 PCs vs 10 PCs" or "different
14//! #centers" deliberately does NOT.
15//!
16//! Correctness is free. A warm start only sets the *starting iterate*; the
17//! outer REML/BFGS loop and the inner constrained Newton solve still run to
18//! their KKT certificate, so the converged answer is identical to a cold
19//! start within tolerance. Every field that flows back into the solver is
20//! finite-guarded at consume time; any anomaly falls back to cold.
21
22use gam_runtime::warm_start::key::{Fingerprint, Fingerprinter};
23use serde::{Deserialize, Serialize};
24
25/// On-disk schema version for [`FitArtifact`]. Bump when the serialized
26/// layout changes in a way that makes prior payloads unsafe to consume.
27pub const FIT_ARTIFACT_SCHEMA: u32 = 1;
28
29/// Saturation magnitude past which a copied ρ coordinate is considered
30/// pinned at the outer optimizer's box and is NOT transferred. Mirrors the
31/// persist-side gate in `families/custom_family/persistent_warm_start.rs` and the
32/// `[CACHE] hit-clamp` policy in `solver/outer_strategy.rs`.
33pub(crate) const RHO_SATURATION: f64 = 9.0;
34
35/// Structural role a term plays in the (possibly multi-channel) model.
36///
37/// Derived from the block name / channel at capture time. The role is part
38/// of the term identity so a "mean" smooth never transfers ρ to a
39/// "log-slope" smooth of the same variables.
40#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
41pub enum TermRole {
42 /// Location / mean channel (the default for a single-channel family).
43 Mean,
44 /// Log-scale / dispersion / log-slope channel.
45 LogSlope,
46 /// Any other channel (multinomial categories, frailty, …).
47 Generic,
48}
49
50impl TermRole {
51 /// Stable discriminant byte for hashing.
52 fn discriminant(self) -> u8 {
53 match self {
54 TermRole::Mean => 0,
55 TermRole::LogSlope => 1,
56 TermRole::Generic => 2,
57 }
58 }
59
60 /// Heuristic role from a block / channel name. Names are produced by the
61 /// family construction layer (e.g. `"<scale>"`, `"logslope"`, `"mean"`);
62 /// the classification is structural and deliberately coarse.
63 pub fn from_block_name(name: &str) -> TermRole {
64 let lower = name.to_ascii_lowercase();
65 if lower.contains("logslope")
66 || lower.contains("log_slope")
67 || lower.contains("scale")
68 || lower.contains("sigma")
69 || lower.contains("dispersion")
70 || lower.contains("disp")
71 {
72 TermRole::LogSlope
73 } else if lower.contains("mean") || lower.contains("loc") || lower.contains("marginal") {
74 TermRole::Mean
75 } else {
76 TermRole::Generic
77 }
78 }
79}
80
81/// Stable structural identity of one term, used to match a parent term to a
82/// new-fit term across folds / row populations.
83#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
84pub struct TermIdentityKey(pub Fingerprint);
85
86/// Build a term identity at the *block-spec* layer (`fit_custom_family` and
87/// friends), where the full `BasisMetadata` / variable names are no longer
88/// reachable — the design has already been assembled into a
89/// `gam_problem::ParameterBlockSpec`.
90///
91/// The block `name` (e.g. `"s(x)"`, `"<scale>"`) is produced by the formula /
92/// construction layer and is **fold-invariant**: it encodes the variables and
93/// basis kind and does not change when rows are dropped for an LOSO fold. The
94/// penalty *structure* (count, precision labels, nullspace dimensions) is also
95/// fold-invariant in SHAPE — only the matrix values change across folds, and
96/// we hash only the structure, never the values. So this identity matches
97/// "same model, different rows" while splitting on a genuine structural change
98/// (a different #penalties, a different label set, a different basis size).
99///
100/// `reduced_width` is the realized per-block coefficient dimension
101/// (`spec.design.ncols()`) — the basis column count *after* the
102/// identifiability reduction, which is the load-bearing dimension of the
103/// block's β. It is fold-invariant within one model (LOSO drops rows, never
104/// columns) but DIFFERS across models whose spatial basis collapses to a lower
105/// effective support (e.g. a duchon marginal that realizes p=21 on one disease
106/// and p=45 on another). Folding it into the identity is what makes a p=37 fit
107/// refuse to match a p=85 artifact: without it, two models with the same block
108/// name / penalty-label / nullspace SHAPE but different realized β-width hash to
109/// the SAME [`TermIdentityKey`] (and hence the same [`FitDescriptor`] key),
110/// producing the spurious "cached inner beta has length 85, but blocks require
111/// length 37" lookups. With it, only fits whose per-block β actually live in the
112/// same-dimension coordinate system match — so the gauge β-projection is always
113/// well-posed and same-width folds transfer ρ AND β, while different-width
114/// models never collide.
115///
116/// NOTE (architect-assumption mismatch): the original design routed identity
117/// through `SmoothTerm.metadata`, but at this layer that metadata has already
118/// been compiled away. The block name + penalty structure + realized reduced
119/// width is the honest, fold-invariant identity available here.
120pub fn term_identity_from_block(
121 role: TermRole,
122 block_name: &str,
123 precision_labels: &[Option<String>],
124 nullspace_dims: &[usize],
125 reduced_width: usize,
126) -> TermIdentityKey {
127 let mut fp = Fingerprinter::new();
128 fp.absorb_tag(b"fit-artifact-block-identity-v2");
129 fp.absorb_u64(b"role", u64::from(role.discriminant()));
130 fp.absorb_str(b"block_name", block_name);
131 fp.absorb_u64(b"n_penalties", precision_labels.len() as u64);
132 for label in precision_labels {
133 match label {
134 Some(l) => fp.absorb_str(b"label", l),
135 None => fp.absorb_tag(b"label-none"),
136 }
137 }
138 fp.absorb_u64(b"n_nullspace", nullspace_dims.len() as u64);
139 for d in nullspace_dims {
140 fp.absorb_u64(b"nullspace_dim", *d as u64);
141 }
142 fp.absorb_u64(b"reduced_width", reduced_width as u64);
143 TermIdentityKey(fp.finalize())
144}
145
146/// Signature of the response (family + dimensionality) a fit targeted.
147/// Carried for diagnostics; deliberately NOT part of the descriptor key so
148/// an LOSO fold matches a full-data parent (only the structural term set
149/// keys the descriptor).
150#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
151pub struct ResponseSig {
152 pub family_kind: String,
153 pub n_response_channels: usize,
154}
155
156/// Tag describing which rows a fit saw. Carried for diagnostics only; the
157/// descriptor key excludes it so different row populations (folds) match.
158#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
159pub struct RowPopulationTag {
160 pub n_rows: usize,
161 /// Optional caller-supplied label (fold id, disease, …).
162 pub label: Option<String>,
163}
164
165/// Identity descriptor of a whole fit: which family, which structural terms,
166/// what response, optionally which rows. The descriptor *key*
167/// ([`FitDescriptor::descriptor_key`]) hashes only the family kind and the
168/// SORTED term identities — it excludes row population and response bytes —
169/// so an LOSO fold of the same model matches a prior full-data artifact.
170#[derive(Clone, Debug, Serialize, Deserialize)]
171pub struct FitDescriptor {
172 pub family_kind: String,
173 pub term_identities: Vec<TermIdentityKey>,
174 pub response_signature: ResponseSig,
175 pub row_population: Option<RowPopulationTag>,
176}
177
178impl FitDescriptor {
179 /// Stable descriptor key = hash(family_kind ⊕ sorted term identities),
180 /// EXCLUDING row population and response bytes. This is the keyspace an
181 /// LOSO fold and its full-data parent share.
182 pub fn descriptor_key(&self) -> Fingerprint {
183 let mut fp = Fingerprinter::new();
184 fp.absorb_tag(b"fit-artifact-descriptor-v1");
185 fp.absorb_str(b"family_kind", &self.family_kind);
186 // Sort the term identities so block ORDER does not split the key:
187 // the same model assembled in a different block order is the same
188 // descriptor.
189 let mut keys: Vec<[u8; 32]> = self
190 .term_identities
191 .iter()
192 .map(|k| *k.0.as_bytes())
193 .collect();
194 keys.sort_unstable();
195 fp.absorb_u64(b"n_terms", keys.len() as u64);
196 for k in &keys {
197 fp.absorb_bytes(b"term", k);
198 }
199 fp.finalize()
200 }
201}
202
203/// Per-term captured state. Stores RAW per-term β (lifted from the converged
204/// reduced θ via the fit's [`crate::gauge::Gauge`] at capture time —
205/// the identifiability transform T is fit-specific and meaningless in another
206/// fit, so we persist the gauge-free raw coefficients) plus the term's ρ
207/// slice for transfer.
208#[derive(Clone, Debug, Serialize, Deserialize)]
209pub struct TermArtifact {
210 pub identity: TermIdentityKey,
211 pub role: TermRole,
212 /// Serializable structural subset of the term's basis metadata.
213 /// `BasisMetadata` itself is not `Serialize` (it carries large
214 /// data-derived arrays), so we persist only the fields needed to
215 /// re-derive identity and reason about the basis at consume time.
216 pub basis_meta: SerializableBasisMeta,
217 /// Joint-null absorption rotation captured at fit time, if any. Stored as
218 /// a flat row-major matrix so the function-space β projection (Phase 2)
219 /// can replay it; `None` when the term carried no rotation.
220 pub joint_null_rotation: Option<SerializableMatrix>,
221 /// RAW per-term coefficients (post-gauge-lift, pre-identifiability),
222 /// concatenated in the term's raw column order.
223 pub raw_beta: Vec<f64>,
224 /// Converged ρ (log smoothing parameters) for this term's penalties.
225 pub rho_for_term: Vec<f64>,
226}
227
228impl TermArtifact {
229 /// True iff every persisted numeric field is finite (the consume-side
230 /// finite-guard precondition).
231 pub fn is_finite(&self) -> bool {
232 self.raw_beta.iter().all(|v| v.is_finite())
233 && self.rho_for_term.iter().all(|v| v.is_finite())
234 && self
235 .joint_null_rotation
236 .as_ref()
237 .is_none_or(|m| m.data.iter().all(|v| v.is_finite()))
238 }
239}
240
241/// A serializable row-major dense matrix snapshot.
242#[derive(Clone, Debug, Serialize, Deserialize)]
243pub struct SerializableMatrix {
244 pub nrows: usize,
245 pub ncols: usize,
246 pub data: Vec<f64>,
247}
248
249/// Serializable structural subset of a term's basis metadata. Captures the
250/// basis-kind discriminant and the structural parameters used for identity
251/// and for diagnostics. Data-derived arrays (centers, basis matrices) are
252/// intentionally dropped.
253#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
254pub struct SerializableBasisMeta {
255 pub kind: String,
256 pub degree: Option<u64>,
257 pub num_knots: Option<u64>,
258 pub n_centers: Option<u64>,
259 pub nullspace_order: Option<u64>,
260 pub matern_nu: Option<u64>,
261 pub periodic: bool,
262}
263
264/// Whole-fit summary numbers carried for selection / logging.
265#[derive(Clone, Debug, Serialize, Deserialize)]
266pub struct GlobalFitSummary {
267 pub outer_objective: f64,
268 pub converged: bool,
269 pub n_rows: usize,
270}
271
272/// Provenance of a per-term transfer, for logging and tests.
273#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
274pub enum TransferProvenance {
275 /// β was function-projected from the parent (Phase 2).
276 Projected,
277 /// Only ρ was transferred; β stayed cold (Phase 1).
278 RhoOnly,
279 /// Nothing transferred; both β and ρ are at their cold defaults.
280 Cold,
281}
282
283/// The full descriptor-indexed warm-start artifact.
284#[derive(Clone, Debug, Serialize, Deserialize)]
285pub struct FitArtifact {
286 pub schema: u32,
287 pub created_unix_secs: u64,
288 pub descriptor: FitDescriptor,
289 pub terms: Vec<TermArtifact>,
290 pub global: GlobalFitSummary,
291}
292
293impl FitArtifact {
294 /// True iff the artifact is structurally usable as warm-start material:
295 /// the schema matches, the global summary is finite, and every term's
296 /// numeric payload is finite. A failing artifact must be ignored (cold
297 /// fallback), never error a fit.
298 pub fn is_usable(&self) -> bool {
299 self.schema == FIT_ARTIFACT_SCHEMA
300 && self.global.outer_objective.is_finite()
301 && self.terms.iter().all(TermArtifact::is_finite)
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use ndarray::Array2;
309
310 /// Build a block-layer term identity (the surviving, fold-invariant
311 /// identity API). One unlabeled penalty with the given nullspace dim and a
312 /// fixed realized reduced width.
313 fn block_id(role: TermRole, block_name: &str) -> TermIdentityKey {
314 term_identity_from_block(role, block_name, &[None], &[1], 10)
315 }
316
317 /// A minimal serializable basis-meta stub, as produced at the block-spec
318 /// capture layer.
319 fn basis_meta_stub(n_centers: u64) -> SerializableBasisMeta {
320 SerializableBasisMeta {
321 kind: "block-spec".to_string(),
322 degree: None,
323 num_knots: None,
324 n_centers: Some(n_centers),
325 nullspace_order: None,
326 matern_nu: None,
327 periodic: false,
328 }
329 }
330
331 #[test]
332 fn block_identity_splits_on_block_name() {
333 let ka = block_id(TermRole::Mean, "s(x)");
334 let kb = block_id(TermRole::Mean, "s(z)");
335 assert_ne!(ka, kb, "different block name must split identity");
336 }
337
338 #[test]
339 fn block_identity_splits_on_role() {
340 let mean = block_id(TermRole::Mean, "s(x)");
341 let slope = block_id(TermRole::LogSlope, "s(x)");
342 assert_ne!(mean, slope, "different role must split identity");
343 }
344
345 #[test]
346 fn block_identity_splits_on_penalty_structure() {
347 let one = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 10);
348 let two = term_identity_from_block(TermRole::Mean, "s(x)", &[None, None], &[1], 10);
349 assert_ne!(one, two, "different #penalties must split identity");
350 }
351
352 #[test]
353 fn block_identity_splits_on_reduced_width() {
354 // The biobank LOSO collision: two models with identical block name /
355 // penalty / nullspace SHAPE but a different realized per-block β width
356 // (p=45 marginal vs the collapsed p=21) MUST hash to distinct
357 // identities, so a p=37 fit never matches a p=85 artifact.
358 let wide = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 45);
359 let narrow = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 21);
360 assert_ne!(
361 wide, narrow,
362 "different realized reduced width must split identity"
363 );
364 }
365
366 #[test]
367 fn block_identity_matches_across_folds_at_equal_width() {
368 // The marquee LOSO win: same model, same realized width, different rows
369 // -> identical identity, so ρ and the gauge β-projection both transfer.
370 let fold_a = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 45);
371 let fold_b = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 45);
372 assert_eq!(
373 fold_a, fold_b,
374 "same model at equal width must share identity across folds"
375 );
376 }
377
378 #[test]
379 fn descriptor_key_excludes_rows_and_response() {
380 let id = block_id(TermRole::Mean, "s(x)");
381 let full = FitDescriptor {
382 family_kind: "gaussian".to_string(),
383 term_identities: vec![id],
384 response_signature: ResponseSig {
385 family_kind: "gaussian".to_string(),
386 n_response_channels: 1,
387 },
388 row_population: Some(RowPopulationTag {
389 n_rows: 1000,
390 label: Some("full".to_string()),
391 }),
392 };
393 let fold = FitDescriptor {
394 family_kind: "gaussian".to_string(),
395 term_identities: vec![id],
396 response_signature: ResponseSig {
397 family_kind: "gaussian".to_string(),
398 n_response_channels: 1,
399 },
400 row_population: Some(RowPopulationTag {
401 n_rows: 900, // an LOSO fold dropped 100 rows
402 label: Some("fold-3".to_string()),
403 }),
404 };
405 assert_eq!(
406 full.descriptor_key(),
407 fold.descriptor_key(),
408 "LOSO fold must share its full-data parent's descriptor key"
409 );
410 }
411
412 #[test]
413 fn descriptor_key_invariant_to_term_order() {
414 let a = block_id(TermRole::Mean, "s(x)");
415 let b = block_id(TermRole::Mean, "s(z)");
416 let sig = ResponseSig {
417 family_kind: "gaussian".to_string(),
418 n_response_channels: 1,
419 };
420 let d1 = FitDescriptor {
421 family_kind: "gaussian".to_string(),
422 term_identities: vec![a, b],
423 response_signature: sig.clone(),
424 row_population: None,
425 };
426 let d2 = FitDescriptor {
427 family_kind: "gaussian".to_string(),
428 term_identities: vec![b, a],
429 response_signature: sig,
430 row_population: None,
431 };
432 assert_eq!(d1.descriptor_key(), d2.descriptor_key());
433 }
434
435 #[test]
436 fn artifact_usable_guard_rejects_nonfinite() {
437 let id = block_id(TermRole::Mean, "s(x)");
438 let mut artifact = FitArtifact {
439 schema: FIT_ARTIFACT_SCHEMA,
440 created_unix_secs: 0,
441 descriptor: FitDescriptor {
442 family_kind: "gaussian".to_string(),
443 term_identities: vec![id],
444 response_signature: ResponseSig {
445 family_kind: "gaussian".to_string(),
446 n_response_channels: 1,
447 },
448 row_population: None,
449 },
450 terms: vec![TermArtifact {
451 identity: id,
452 role: TermRole::Mean,
453 basis_meta: basis_meta_stub(4),
454 joint_null_rotation: None,
455 raw_beta: vec![0.1, 0.2, 0.3, 0.4],
456 rho_for_term: vec![1.0],
457 }],
458 global: GlobalFitSummary {
459 outer_objective: -123.4,
460 converged: true,
461 n_rows: 100,
462 },
463 };
464 assert!(artifact.is_usable());
465 artifact.terms[0].raw_beta[2] = f64::NAN;
466 assert!(
467 !artifact.is_usable(),
468 "non-finite β must fail the usable guard"
469 );
470
471 artifact.terms[0].raw_beta[2] = 0.3;
472 artifact.global.outer_objective = f64::INFINITY;
473 assert!(
474 !artifact.is_usable(),
475 "non-finite objective must fail the usable guard"
476 );
477 }
478
479 #[test]
480 fn serializable_basis_meta_roundtrips() {
481 let meta = basis_meta_stub(7);
482 let bytes = serde_json::to_vec(&meta).expect("serialize");
483 let back: SerializableBasisMeta = serde_json::from_slice(&bytes).expect("deserialize");
484 assert_eq!(meta, back);
485 assert_eq!(back.n_centers, Some(7));
486 assert_eq!(back.kind, "block-spec");
487 }
488
489 #[test]
490 fn serializable_matrix_can_carry_rotation() {
491 let q = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
492 let m = SerializableMatrix {
493 nrows: q.nrows(),
494 ncols: q.ncols(),
495 data: q.iter().copied().collect(),
496 };
497 let bytes = serde_json::to_vec(&m).expect("serialize");
498 let back: SerializableMatrix = serde_json::from_slice(&bytes).expect("deserialize");
499 assert_eq!(back.nrows, 2);
500 assert_eq!(back.data, vec![1.0, 0.0, 0.0, 1.0]);
501 }
502}