Skip to main content

gam_solve/
warm_start_artifact.rs

1//! Cross-fit warm-start artifact: a descriptor-indexed, function-space
2//! snapshot of a converged fit, designed so a *related* later fit (a
3//! leave-one-subject-out fold, a re-fit on a different row population, a
4//! different reduced width) can warm-start from it even though the exact
5//! response-keyed inner cache (`persistent_warm_start.rs`) misses.
6//!
7//! The artifact is keyed by *structural identity*, not by data bytes. Two
8//! fits of the same term family (same role, same variables, same basis kind
9//! and the same STRUCTURAL basis parameters — degree, #centers, nullspace
10//! order, …) map to the same [`TermIdentityKey`] even when their realized
11//! `centers` / `input_scales` / `length_scale` differ across folds. That is
12//! precisely what lets the smoothing parameter ρ transfer survive a fold:
13//! "same term, different rows" matches; "3 PCs vs 10 PCs" or "different
14//! #centers" deliberately does NOT.
15//!
16//! Correctness is free. A warm start only sets the *starting iterate*; the
17//! outer REML/BFGS loop and the inner constrained Newton solve still run to
18//! their KKT certificate, so the converged answer is identical to a cold
19//! start within tolerance. Every field that flows back into the solver is
20//! finite-guarded at consume time; any anomaly falls back to cold.
21
22use gam_runtime::warm_start::key::{Fingerprint, Fingerprinter};
23use serde::{Deserialize, Serialize};
24
25/// On-disk schema version for [`FitArtifact`]. Bump when the serialized
26/// layout changes in a way that makes prior payloads unsafe to consume.
27pub const FIT_ARTIFACT_SCHEMA: u32 = 1;
28
29/// Saturation magnitude past which a copied ρ coordinate is considered
30/// pinned at the outer optimizer's box and is NOT transferred. Mirrors the
31/// persist-side gate in `families/custom_family/persistent_warm_start.rs` and the
32/// `[CACHE] hit-clamp` policy in `solver/outer_strategy.rs`.
33pub(crate) const RHO_SATURATION: f64 = 9.0;
34
35/// Structural role a term plays in the (possibly multi-channel) model.
36///
37/// Derived from the block name / channel at capture time. The role is part
38/// of the term identity so a "mean" smooth never transfers ρ to a
39/// "log-slope" smooth of the same variables.
40#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
41pub enum TermRole {
42    /// Location / mean channel (the default for a single-channel family).
43    Mean,
44    /// Log-scale / dispersion / log-slope channel.
45    LogSlope,
46    /// Any other channel (multinomial categories, frailty, …).
47    Generic,
48}
49
50impl TermRole {
51    /// Stable discriminant byte for hashing.
52    fn discriminant(self) -> u8 {
53        match self {
54            TermRole::Mean => 0,
55            TermRole::LogSlope => 1,
56            TermRole::Generic => 2,
57        }
58    }
59
60    /// Heuristic role from a block / channel name. Names are produced by the
61    /// family construction layer (e.g. `"<scale>"`, `"logslope"`, `"mean"`);
62    /// the classification is structural and deliberately coarse.
63    pub fn from_block_name(name: &str) -> TermRole {
64        let lower = name.to_ascii_lowercase();
65        if lower.contains("logslope")
66            || lower.contains("log_slope")
67            || lower.contains("scale")
68            || lower.contains("sigma")
69            || lower.contains("dispersion")
70            || lower.contains("disp")
71        {
72            TermRole::LogSlope
73        } else if lower.contains("mean") || lower.contains("loc") || lower.contains("marginal") {
74            TermRole::Mean
75        } else {
76            TermRole::Generic
77        }
78    }
79}
80
81/// Stable structural identity of one term, used to match a parent term to a
82/// new-fit term across folds / row populations.
83#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
84pub struct TermIdentityKey(pub Fingerprint);
85
86/// Build a term identity at the *block-spec* layer (`fit_custom_family` and
87/// friends), where the full `BasisMetadata` / variable names are no longer
88/// reachable — the design has already been assembled into a
89/// `gam_problem::ParameterBlockSpec`.
90///
91/// The block `name` (e.g. `"s(x)"`, `"<scale>"`) is produced by the formula /
92/// construction layer and is **fold-invariant**: it encodes the variables and
93/// basis kind and does not change when rows are dropped for an LOSO fold. The
94/// penalty *structure* (count, precision labels, nullspace dimensions) is also
95/// fold-invariant in SHAPE — only the matrix values change across folds, and
96/// we hash only the structure, never the values. So this identity matches
97/// "same model, different rows" while splitting on a genuine structural change
98/// (a different #penalties, a different label set, a different basis size).
99///
100/// `reduced_width` is the realized per-block coefficient dimension
101/// (`spec.design.ncols()`) — the basis column count *after* the
102/// identifiability reduction, which is the load-bearing dimension of the
103/// block's β. It is fold-invariant within one model (LOSO drops rows, never
104/// columns) but DIFFERS across models whose spatial basis collapses to a lower
105/// effective support (e.g. a duchon marginal that realizes p=21 on one disease
106/// and p=45 on another). Folding it into the identity is what makes a p=37 fit
107/// refuse to match a p=85 artifact: without it, two models with the same block
108/// name / penalty-label / nullspace SHAPE but different realized β-width hash to
109/// the SAME [`TermIdentityKey`] (and hence the same [`FitDescriptor`] key),
110/// producing the spurious "cached inner beta has length 85, but blocks require
111/// length 37" lookups. With it, only fits whose per-block β actually live in the
112/// same-dimension coordinate system match — so the gauge β-projection is always
113/// well-posed and same-width folds transfer ρ AND β, while different-width
114/// models never collide.
115///
116/// NOTE (architect-assumption mismatch): the original design routed identity
117/// through `SmoothTerm.metadata`, but at this layer that metadata has already
118/// been compiled away. The block name + penalty structure + realized reduced
119/// width is the honest, fold-invariant identity available here.
120pub fn term_identity_from_block(
121    role: TermRole,
122    block_name: &str,
123    precision_labels: &[Option<String>],
124    nullspace_dims: &[usize],
125    reduced_width: usize,
126) -> TermIdentityKey {
127    let mut fp = Fingerprinter::new();
128    fp.absorb_tag(b"fit-artifact-block-identity-v2");
129    fp.absorb_u64(b"role", u64::from(role.discriminant()));
130    fp.absorb_str(b"block_name", block_name);
131    fp.absorb_u64(b"n_penalties", precision_labels.len() as u64);
132    for label in precision_labels {
133        match label {
134            Some(l) => fp.absorb_str(b"label", l),
135            None => fp.absorb_tag(b"label-none"),
136        }
137    }
138    fp.absorb_u64(b"n_nullspace", nullspace_dims.len() as u64);
139    for d in nullspace_dims {
140        fp.absorb_u64(b"nullspace_dim", *d as u64);
141    }
142    fp.absorb_u64(b"reduced_width", reduced_width as u64);
143    TermIdentityKey(fp.finalize())
144}
145
146/// Signature of the response (family + dimensionality) a fit targeted.
147/// Carried for diagnostics; deliberately NOT part of the descriptor key so
148/// an LOSO fold matches a full-data parent (only the structural term set
149/// keys the descriptor).
150#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
151pub struct ResponseSig {
152    pub family_kind: String,
153    pub n_response_channels: usize,
154}
155
156/// Tag describing which rows a fit saw. Carried for diagnostics only; the
157/// descriptor key excludes it so different row populations (folds) match.
158#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
159pub struct RowPopulationTag {
160    pub n_rows: usize,
161    /// Optional caller-supplied label (fold id, disease, …).
162    pub label: Option<String>,
163}
164
165/// Identity descriptor of a whole fit: which family, which structural terms,
166/// what response, optionally which rows. The descriptor *key*
167/// ([`FitDescriptor::descriptor_key`]) hashes only the family kind and the
168/// SORTED term identities — it excludes row population and response bytes —
169/// so an LOSO fold of the same model matches a prior full-data artifact.
170#[derive(Clone, Debug, Serialize, Deserialize)]
171pub struct FitDescriptor {
172    pub family_kind: String,
173    pub term_identities: Vec<TermIdentityKey>,
174    pub response_signature: ResponseSig,
175    pub row_population: Option<RowPopulationTag>,
176}
177
178impl FitDescriptor {
179    /// Stable descriptor key = hash(family_kind ⊕ sorted term identities),
180    /// EXCLUDING row population and response bytes. This is the keyspace an
181    /// LOSO fold and its full-data parent share.
182    pub fn descriptor_key(&self) -> Fingerprint {
183        let mut fp = Fingerprinter::new();
184        fp.absorb_tag(b"fit-artifact-descriptor-v1");
185        fp.absorb_str(b"family_kind", &self.family_kind);
186        // Sort the term identities so block ORDER does not split the key:
187        // the same model assembled in a different block order is the same
188        // descriptor.
189        let mut keys: Vec<[u8; 32]> = self
190            .term_identities
191            .iter()
192            .map(|k| *k.0.as_bytes())
193            .collect();
194        keys.sort_unstable();
195        fp.absorb_u64(b"n_terms", keys.len() as u64);
196        for k in &keys {
197            fp.absorb_bytes(b"term", k);
198        }
199        fp.finalize()
200    }
201}
202
203/// Per-term captured state. Stores RAW per-term β (lifted from the converged
204/// reduced θ via the fit's [`crate::gauge::Gauge`] at capture time —
205/// the identifiability transform T is fit-specific and meaningless in another
206/// fit, so we persist the gauge-free raw coefficients) plus the term's ρ
207/// slice for transfer.
208#[derive(Clone, Debug, Serialize, Deserialize)]
209pub struct TermArtifact {
210    pub identity: TermIdentityKey,
211    pub role: TermRole,
212    /// Serializable structural subset of the term's basis metadata.
213    /// `BasisMetadata` itself is not `Serialize` (it carries large
214    /// data-derived arrays), so we persist only the fields needed to
215    /// re-derive identity and reason about the basis at consume time.
216    pub basis_meta: SerializableBasisMeta,
217    /// Joint-null absorption rotation captured at fit time, if any. Stored as
218    /// a flat row-major matrix so the function-space β projection (Phase 2)
219    /// can replay it; `None` when the term carried no rotation.
220    pub joint_null_rotation: Option<SerializableMatrix>,
221    /// RAW per-term coefficients (post-gauge-lift, pre-identifiability),
222    /// concatenated in the term's raw column order.
223    pub raw_beta: Vec<f64>,
224    /// Converged ρ (log smoothing parameters) for this term's penalties.
225    pub rho_for_term: Vec<f64>,
226}
227
228impl TermArtifact {
229    /// True iff every persisted numeric field is finite (the consume-side
230    /// finite-guard precondition).
231    pub fn is_finite(&self) -> bool {
232        self.raw_beta.iter().all(|v| v.is_finite())
233            && self.rho_for_term.iter().all(|v| v.is_finite())
234            && self
235                .joint_null_rotation
236                .as_ref()
237                .is_none_or(|m| m.data.iter().all(|v| v.is_finite()))
238    }
239}
240
241/// A serializable row-major dense matrix snapshot.
242#[derive(Clone, Debug, Serialize, Deserialize)]
243pub struct SerializableMatrix {
244    pub nrows: usize,
245    pub ncols: usize,
246    pub data: Vec<f64>,
247}
248
249/// Serializable structural subset of a term's basis metadata. Captures the
250/// basis-kind discriminant and the structural parameters used for identity
251/// and for diagnostics. Data-derived arrays (centers, basis matrices) are
252/// intentionally dropped.
253#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
254pub struct SerializableBasisMeta {
255    pub kind: String,
256    pub degree: Option<u64>,
257    pub num_knots: Option<u64>,
258    pub n_centers: Option<u64>,
259    pub nullspace_order: Option<u64>,
260    pub matern_nu: Option<u64>,
261    pub periodic: bool,
262}
263
264/// Whole-fit summary numbers carried for selection / logging.
265#[derive(Clone, Debug, Serialize, Deserialize)]
266pub struct GlobalFitSummary {
267    pub outer_objective: f64,
268    pub converged: bool,
269    pub n_rows: usize,
270}
271
272/// Provenance of a per-term transfer, for logging and tests.
273#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
274pub enum TransferProvenance {
275    /// β was function-projected from the parent (Phase 2).
276    Projected,
277    /// Only ρ was transferred; β stayed cold (Phase 1).
278    RhoOnly,
279    /// Nothing transferred; both β and ρ are at their cold defaults.
280    Cold,
281}
282
283/// The full descriptor-indexed warm-start artifact.
284#[derive(Clone, Debug, Serialize, Deserialize)]
285pub struct FitArtifact {
286    pub schema: u32,
287    pub created_unix_secs: u64,
288    pub descriptor: FitDescriptor,
289    pub terms: Vec<TermArtifact>,
290    pub global: GlobalFitSummary,
291}
292
293impl FitArtifact {
294    /// True iff the artifact is structurally usable as warm-start material:
295    /// the schema matches, the global summary is finite, and every term's
296    /// numeric payload is finite. A failing artifact must be ignored (cold
297    /// fallback), never error a fit.
298    pub fn is_usable(&self) -> bool {
299        self.schema == FIT_ARTIFACT_SCHEMA
300            && self.global.outer_objective.is_finite()
301            && self.terms.iter().all(TermArtifact::is_finite)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use ndarray::Array2;
309
310    /// Build a block-layer term identity (the surviving, fold-invariant
311    /// identity API). One unlabeled penalty with the given nullspace dim and a
312    /// fixed realized reduced width.
313    fn block_id(role: TermRole, block_name: &str) -> TermIdentityKey {
314        term_identity_from_block(role, block_name, &[None], &[1], 10)
315    }
316
317    /// A minimal serializable basis-meta stub, as produced at the block-spec
318    /// capture layer.
319    fn basis_meta_stub(n_centers: u64) -> SerializableBasisMeta {
320        SerializableBasisMeta {
321            kind: "block-spec".to_string(),
322            degree: None,
323            num_knots: None,
324            n_centers: Some(n_centers),
325            nullspace_order: None,
326            matern_nu: None,
327            periodic: false,
328        }
329    }
330
331    #[test]
332    fn block_identity_splits_on_block_name() {
333        let ka = block_id(TermRole::Mean, "s(x)");
334        let kb = block_id(TermRole::Mean, "s(z)");
335        assert_ne!(ka, kb, "different block name must split identity");
336    }
337
338    #[test]
339    fn block_identity_splits_on_role() {
340        let mean = block_id(TermRole::Mean, "s(x)");
341        let slope = block_id(TermRole::LogSlope, "s(x)");
342        assert_ne!(mean, slope, "different role must split identity");
343    }
344
345    #[test]
346    fn block_identity_splits_on_penalty_structure() {
347        let one = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 10);
348        let two = term_identity_from_block(TermRole::Mean, "s(x)", &[None, None], &[1], 10);
349        assert_ne!(one, two, "different #penalties must split identity");
350    }
351
352    #[test]
353    fn block_identity_splits_on_reduced_width() {
354        // The biobank LOSO collision: two models with identical block name /
355        // penalty / nullspace SHAPE but a different realized per-block β width
356        // (p=45 marginal vs the collapsed p=21) MUST hash to distinct
357        // identities, so a p=37 fit never matches a p=85 artifact.
358        let wide = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 45);
359        let narrow = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 21);
360        assert_ne!(
361            wide, narrow,
362            "different realized reduced width must split identity"
363        );
364    }
365
366    #[test]
367    fn block_identity_matches_across_folds_at_equal_width() {
368        // The marquee LOSO win: same model, same realized width, different rows
369        // -> identical identity, so ρ and the gauge β-projection both transfer.
370        let fold_a = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 45);
371        let fold_b = term_identity_from_block(TermRole::Mean, "s(x)", &[None], &[1], 45);
372        assert_eq!(
373            fold_a, fold_b,
374            "same model at equal width must share identity across folds"
375        );
376    }
377
378    #[test]
379    fn descriptor_key_excludes_rows_and_response() {
380        let id = block_id(TermRole::Mean, "s(x)");
381        let full = FitDescriptor {
382            family_kind: "gaussian".to_string(),
383            term_identities: vec![id],
384            response_signature: ResponseSig {
385                family_kind: "gaussian".to_string(),
386                n_response_channels: 1,
387            },
388            row_population: Some(RowPopulationTag {
389                n_rows: 1000,
390                label: Some("full".to_string()),
391            }),
392        };
393        let fold = FitDescriptor {
394            family_kind: "gaussian".to_string(),
395            term_identities: vec![id],
396            response_signature: ResponseSig {
397                family_kind: "gaussian".to_string(),
398                n_response_channels: 1,
399            },
400            row_population: Some(RowPopulationTag {
401                n_rows: 900, // an LOSO fold dropped 100 rows
402                label: Some("fold-3".to_string()),
403            }),
404        };
405        assert_eq!(
406            full.descriptor_key(),
407            fold.descriptor_key(),
408            "LOSO fold must share its full-data parent's descriptor key"
409        );
410    }
411
412    #[test]
413    fn descriptor_key_invariant_to_term_order() {
414        let a = block_id(TermRole::Mean, "s(x)");
415        let b = block_id(TermRole::Mean, "s(z)");
416        let sig = ResponseSig {
417            family_kind: "gaussian".to_string(),
418            n_response_channels: 1,
419        };
420        let d1 = FitDescriptor {
421            family_kind: "gaussian".to_string(),
422            term_identities: vec![a, b],
423            response_signature: sig.clone(),
424            row_population: None,
425        };
426        let d2 = FitDescriptor {
427            family_kind: "gaussian".to_string(),
428            term_identities: vec![b, a],
429            response_signature: sig,
430            row_population: None,
431        };
432        assert_eq!(d1.descriptor_key(), d2.descriptor_key());
433    }
434
435    #[test]
436    fn artifact_usable_guard_rejects_nonfinite() {
437        let id = block_id(TermRole::Mean, "s(x)");
438        let mut artifact = FitArtifact {
439            schema: FIT_ARTIFACT_SCHEMA,
440            created_unix_secs: 0,
441            descriptor: FitDescriptor {
442                family_kind: "gaussian".to_string(),
443                term_identities: vec![id],
444                response_signature: ResponseSig {
445                    family_kind: "gaussian".to_string(),
446                    n_response_channels: 1,
447                },
448                row_population: None,
449            },
450            terms: vec![TermArtifact {
451                identity: id,
452                role: TermRole::Mean,
453                basis_meta: basis_meta_stub(4),
454                joint_null_rotation: None,
455                raw_beta: vec![0.1, 0.2, 0.3, 0.4],
456                rho_for_term: vec![1.0],
457            }],
458            global: GlobalFitSummary {
459                outer_objective: -123.4,
460                converged: true,
461                n_rows: 100,
462            },
463        };
464        assert!(artifact.is_usable());
465        artifact.terms[0].raw_beta[2] = f64::NAN;
466        assert!(
467            !artifact.is_usable(),
468            "non-finite β must fail the usable guard"
469        );
470
471        artifact.terms[0].raw_beta[2] = 0.3;
472        artifact.global.outer_objective = f64::INFINITY;
473        assert!(
474            !artifact.is_usable(),
475            "non-finite objective must fail the usable guard"
476        );
477    }
478
479    #[test]
480    fn serializable_basis_meta_roundtrips() {
481        let meta = basis_meta_stub(7);
482        let bytes = serde_json::to_vec(&meta).expect("serialize");
483        let back: SerializableBasisMeta = serde_json::from_slice(&bytes).expect("deserialize");
484        assert_eq!(meta, back);
485        assert_eq!(back.n_centers, Some(7));
486        assert_eq!(back.kind, "block-spec");
487    }
488
489    #[test]
490    fn serializable_matrix_can_carry_rotation() {
491        let q = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
492        let m = SerializableMatrix {
493            nrows: q.nrows(),
494            ncols: q.ncols(),
495            data: q.iter().copied().collect(),
496        };
497        let bytes = serde_json::to_vec(&m).expect("serialize");
498        let back: SerializableMatrix = serde_json::from_slice(&bytes).expect("deserialize");
499        assert_eq!(back.nrows, 2);
500        assert_eq!(back.data, vec![1.0, 0.0, 0.0, 1.0]);
501    }
502}