Skip to main content

gam_problem/
types.rs

1use ndarray::{Array1, ArrayView1};
2use serde::{Deserialize, Serialize};
3use std::ops::{Deref, DerefMut};
4
5pub use gam_linalg::{RidgeDeterminantMode, RidgePolicy};
6
7/// Lower floor on positive working weights shared by likelihood families and
8/// PIRLS row assembly so weighted normal equations stay numerically well posed.
9pub const MIN_WEIGHT: f64 = 1e-12;
10
11pub use gam_spec::*;
12
13/// Storage form of the ridge penalty matrix.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum RidgeMatrixForm {
16    /// Ridge matrix is `delta * I`.
17    ScaledIdentity,
18}
19
20/// Concrete ridge metadata stamped into a fitted PIRLS result.
21#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
22pub struct RidgePassport {
23    /// Stabilization magnitude for matrix form `delta * I`.
24    pub delta: f64,
25    pub matrix_form: RidgeMatrixForm,
26    pub policy: RidgePolicy,
27}
28
29impl RidgePassport {
30    pub const fn scaled_identity(delta: f64, policy: RidgePolicy) -> Self {
31        Self {
32            delta,
33            matrix_form: RidgeMatrixForm::ScaledIdentity,
34            policy,
35        }
36    }
37
38    #[inline]
39    pub const fn penalty_logdet_ridge(self) -> f64 {
40        if self.policy.include_penalty_logdet {
41            self.delta
42        } else {
43            0.0
44        }
45    }
46
47    #[inline]
48    pub const fn laplacehessianridge(self) -> f64 {
49        if self.policy.include_laplacehessian {
50            self.delta
51        } else {
52            0.0
53        }
54    }
55}
56
57// ============================================================================
58// StabilizationLedger: canonical accounting for every fixed/heuristic ridge
59// added anywhere in the solver, linear-algebra, or family code paths.
60//
61// Three semantically distinct ridge uses must NEVER be conflated:
62//   1. SolverDampingOnly      — Levenberg/trust-region damping; never enters
63//                               objective, gradient, logdet, Hessian, or any
64//                               saved/serialized model artifact.
65//   2. NumericalPerturbation  — added strictly so a linear solve is well-
66//                               posed (e.g. Cholesky of a near-singular
67//                               matrix). Carries an optional backward-error
68//                               bound. Does NOT change the objective.
69//   3. ExplicitPrior          — model-level `delta * I` (or block-diagonal)
70//                               prior. Appears in quadratic, log normalizer,
71//                               Laplace Hessian, serialization, diagnostics.
72//
73// `RidgePassport` above already encodes the inclusion-flag matrix for the
74// PIRLS Laplace ridge specifically; this ledger is the broader sibling that
75// every other call site (RidgePlanner, matrix_inverse_with_regularization,
76// LAML rho-Hessian inversion, survival stabilization, custom-family
77// `ridge_floor`) routes through, so a downstream consumer can ask
78// `ledger.quadratic_delta()` rather than rediscovering the policy. The three
79// inclusion bits were lifted into the `StabilizationKind` discriminant so the
80// (kind, inclusion-flags) invariant is enforced statically — heterogeneous
81// combinations like "ExplicitPrior with quadratic excluded" no longer typecheck.
82// ============================================================================
83
84/// Inertia of a symmetric matrix (count of positive / zero / negative
85/// eigenvalues). Used by `bump_with_matrix` and other indefinite-aware
86/// stabilization rules to drive δ from spectral evidence rather than a
87/// condition-number heuristic.
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89pub struct Inertia {
90    pub positive: usize,
91    pub zero: usize,
92    pub negative: usize,
93}
94
95/// Why a stabilization δ was chosen at this site.
96#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
97pub enum StabilizationRule {
98    /// δ is a hard-coded constant in the source.
99    FixedConstant,
100    /// δ chosen so the SPD floor τ is met: δ = max(0, τ - λ_min(H)).
101    InertiaTarget { spd_floor: f64 },
102    /// δ chosen via a condition-number / sqrt-ratio heuristic.
103    Heuristic,
104    /// User- or family-specified prior precision.
105    UserSpecified,
106    /// δ derived from a back-off escalation after a factorization failure.
107    BackoffEscalation { attempts: usize },
108}
109
110/// Three semantically distinct flavours a ridge δ can have.
111#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
112pub enum StabilizationKind {
113    None,
114    /// LM/TR damping. NEVER enters the objective, gradient, logdet, Hessian,
115    /// or any saved model artifact. Lives only inside the trust-region step.
116    SolverDampingOnly,
117    /// Added strictly so a linear solve succeeds. The objective/Hessian the
118    /// caller sees is unchanged; the perturbation is a property of the
119    /// solver, not the model. `backward_error_bound` is the max change to
120    /// the solution norm imputable to the perturbation, when known.
121    NumericalPerturbation {
122        backward_error_bound: Option<f64>,
123    },
124    /// Part of the model. Enters quadratic, log normalizer, Hessian,
125    /// serialization, and user-visible summaries.
126    ExplicitPrior,
127}
128
129/// Canonical record of a single stabilization δ applied at a single site.
130///
131/// Construct via the helper constructors (`solver_damping`,
132/// `numerical_perturbation`, `explicit_prior`) so the `included_in_*`
133/// invariants are guaranteed to match `kind`. Direct field construction is
134/// public for serialization round-trips only.
135#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
136pub struct StabilizationLedger {
137    pub kind: StabilizationKind,
138    pub delta: f64,
139    pub matrix_form: RidgeMatrixForm,
140    pub chosen_by: StabilizationRule,
141    pub inertia_before: Option<Inertia>,
142    pub inertia_after: Option<Inertia>,
143}
144
145impl StabilizationLedger {
146    /// "No stabilization applied at this site" sentinel.
147    pub const fn none() -> Self {
148        Self {
149            kind: StabilizationKind::None,
150            delta: 0.0,
151            matrix_form: RidgeMatrixForm::ScaledIdentity,
152            chosen_by: StabilizationRule::FixedConstant,
153            inertia_before: None,
154            inertia_after: None,
155        }
156    }
157
158    /// LM/TR damping. δ is invisible to the objective, gradient, and any
159    /// saved artifact. Asserting this invariant at every read site is the
160    /// whole reason the ledger exists.
161    pub const fn solver_damping(delta: f64, chosen_by: StabilizationRule) -> Self {
162        Self {
163            kind: StabilizationKind::SolverDampingOnly,
164            delta,
165            matrix_form: RidgeMatrixForm::ScaledIdentity,
166            chosen_by,
167            inertia_before: None,
168            inertia_after: None,
169        }
170    }
171
172    /// Solver-only perturbation that leaves the objective unchanged. The
173    /// caller may attach a backward-error bound when one is available
174    /// (e.g. from iterative refinement / Wilkinson-style analysis).
175    pub const fn numerical_perturbation(
176        delta: f64,
177        chosen_by: StabilizationRule,
178        backward_error_bound: Option<f64>,
179    ) -> Self {
180        Self {
181            kind: StabilizationKind::NumericalPerturbation {
182                backward_error_bound,
183            },
184            delta,
185            matrix_form: RidgeMatrixForm::ScaledIdentity,
186            chosen_by,
187            inertia_before: None,
188            inertia_after: None,
189        }
190    }
191
192    /// Model-level explicit prior. δ enters every accounting pass: the
193    /// quadratic penalty, the Laplace Hessian, the penalty log-determinant,
194    /// and serialization.
195    pub const fn explicit_prior(delta: f64, matrix_form: RidgeMatrixForm) -> Self {
196        Self {
197            kind: StabilizationKind::ExplicitPrior,
198            delta,
199            matrix_form,
200            chosen_by: StabilizationRule::UserSpecified,
201            inertia_before: None,
202            inertia_after: None,
203        }
204    }
205
206    /// Bridge from the existing `RidgePassport` so PIRLS-side code (which
207    /// already passes a `RidgePassport` through every call) can hand a
208    /// ledger to anything that wants the new uniform view.
209    ///
210    /// `RidgePolicy` is homogeneous-by-construction: every constructor sets
211    /// the three inclusion flags identically. A passport whose policy
212    /// excludes every accounting term is morally a numerical perturbation
213    /// (the ridge is there to make the solve work but the objective ignores
214    /// it); a passport whose policy includes every accounting term is an
215    /// explicit prior. Heterogeneous flag combinations cannot be produced
216    /// by the public `RidgePolicy` API and have no inhabitants downstream.
217    pub const fn from_passport(passport: RidgePassport) -> Self {
218        let any_included = passport.policy.include_quadratic_penalty
219            || passport.policy.include_laplacehessian
220            || passport.policy.include_penalty_logdet;
221        let kind = if any_included {
222            StabilizationKind::ExplicitPrior
223        } else {
224            StabilizationKind::NumericalPerturbation {
225                backward_error_bound: None,
226            }
227        };
228        Self {
229            kind,
230            delta: passport.delta,
231            matrix_form: passport.matrix_form,
232            chosen_by: StabilizationRule::FixedConstant,
233            inertia_before: None,
234            inertia_after: None,
235        }
236    }
237
238    /// δ value to fold into the quadratic penalty term, or 0.0 if this
239    /// ledger entry is not part of the model. Derived from `kind`: only
240    /// [`StabilizationKind::ExplicitPrior`] contributes.
241    #[inline]
242    pub const fn quadratic_delta(&self) -> f64 {
243        match self.kind {
244            StabilizationKind::ExplicitPrior => self.delta,
245            StabilizationKind::None
246            | StabilizationKind::SolverDampingOnly
247            | StabilizationKind::NumericalPerturbation { .. } => 0.0,
248        }
249    }
250
251    /// δ value to add to the Laplace Hessian, or 0.0 if not included.
252    /// Derived from `kind`: only [`StabilizationKind::ExplicitPrior`]
253    /// contributes.
254    #[inline]
255    pub const fn laplace_hessian_delta(&self) -> f64 {
256        match self.kind {
257            StabilizationKind::ExplicitPrior => self.delta,
258            StabilizationKind::None
259            | StabilizationKind::SolverDampingOnly
260            | StabilizationKind::NumericalPerturbation { .. } => 0.0,
261        }
262    }
263
264    /// δ value to add inside log|S + δ I|, or 0.0 if not included.
265    /// Derived from `kind`: only [`StabilizationKind::ExplicitPrior`]
266    /// contributes.
267    #[inline]
268    pub const fn penalty_logdet_delta(&self) -> f64 {
269        match self.kind {
270            StabilizationKind::ExplicitPrior => self.delta,
271            StabilizationKind::None
272            | StabilizationKind::SolverDampingOnly
273            | StabilizationKind::NumericalPerturbation { .. } => 0.0,
274        }
275    }
276}
277/// Generate a `#[repr(transparent)]` `Array1<f64>` newtype with the
278/// `new`/`Deref`/`DerefMut`/`AsRef`/`From` boilerplate every wrapper in this
279/// module needs. Keeping the three semantic types behind one macro both
280/// removes ~100 lines of duplication and guarantees they cannot drift apart.
281macro_rules! array1_f64_newtype {
282    ($name:ident $(, $extra:ident)*) => {
283        #[repr(transparent)]
284        #[derive(Clone, Debug, PartialEq)]
285        pub struct $name(pub Array1<f64>);
286
287        impl $name {
288            #[inline]
289            pub fn new(values: Array1<f64>) -> Self {
290                Self(values)
291            }
292
293            #[inline]
294            pub fn zeros(len: usize) -> Self {
295                Self(Array1::zeros(len))
296            }
297        }
298
299        impl Deref for $name {
300            type Target = Array1<f64>;
301            #[inline]
302            fn deref(&self) -> &Self::Target { &self.0 }
303        }
304
305        impl DerefMut for $name {
306            #[inline]
307            fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
308        }
309
310        impl AsRef<Array1<f64>> for $name {
311            #[inline]
312            fn as_ref(&self) -> &Array1<f64> { &self.0 }
313        }
314
315        impl From<Array1<f64>> for $name {
316            #[inline]
317            fn from(values: Array1<f64>) -> Self { Self(values) }
318        }
319
320        impl From<$name> for Array1<f64> {
321            #[inline]
322            fn from(values: $name) -> Self { values.0 }
323        }
324
325        $( array1_f64_newtype!(@extra $name $extra); )*
326    };
327    (@extra $name:ident exp) => {
328        impl $name {
329            #[inline]
330            pub fn exp(&self) -> Array1<f64> { self.0.mapv(f64::exp) }
331        }
332    };
333}
334
335array1_f64_newtype!(Coefficients);
336array1_f64_newtype!(LinearPredictor);
337array1_f64_newtype!(LogSmoothingParams, exp);
338
339/// Index into `TermCollectionSpec::smooth_terms` (and the parallel
340/// `TermCollectionDesign::smooth.terms` slice produced from it).
341///
342/// This is **not** a penalty/ρ index, **not** a column index, and **not** a
343/// coefficient-offset index. Keeping it behind a `#[repr(transparent)]`
344/// newtype makes those confusables a compile error: a `SmoothTermIdx` cannot
345/// be silently used to index `rho`, `beta`, or a design column.
346#[repr(transparent)]
347#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
348pub struct SmoothTermIdx(usize);
349
350impl SmoothTermIdx {
351    #[inline]
352    pub const fn new(idx: usize) -> Self {
353        Self(idx)
354    }
355
356    /// Sentinel used by transient builders that must allocate a coord config
357    /// before the smooth term it references has been positioned in the spec.
358    /// Every code path that constructs a sentinel must overwrite it before
359    /// the value escapes the builder.
360    #[inline]
361    pub const fn placeholder() -> Self {
362        Self(usize::MAX)
363    }
364
365    #[inline]
366    pub const fn get(self) -> usize {
367        self.0
368    }
369
370    #[inline]
371    pub const fn is_placeholder(self) -> bool {
372        self.0 == usize::MAX
373    }
374}
375
376impl std::fmt::Display for SmoothTermIdx {
377    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378        write!(f, "{}", self.0)
379    }
380}
381
382/// Index into the canonical penalty list `&[CanonicalPenalty]` — equivalently,
383/// the position of a smoothing parameter in the ρ / λ vector.
384///
385/// Penalty/ρ indices are not interchangeable with `SmoothTermIdx` (a smooth
386/// term can carry multiple canonical penalties — e.g. tensor-product double
387/// penalties — and structural penalties don't correspond to any smooth term).
388/// Keeping them as separate newtypes makes the historical bug pattern
389/// "indexed `rho` with a smooth-term ordinal" impossible to express.
390#[repr(transparent)]
391#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
392pub struct PenaltyIdx(usize);
393
394impl PenaltyIdx {
395    #[inline]
396    pub const fn new(idx: usize) -> Self {
397        Self(idx)
398    }
399
400    #[inline]
401    pub const fn get(self) -> usize {
402        self.0
403    }
404}
405
406impl std::fmt::Display for PenaltyIdx {
407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        write!(f, "{}", self.0)
409    }
410}
411
412/// Index into a single smooth term's set of basis functions — i.e. the `k`
413/// in "the k-th basis function `B_k(x)` of this term".
414///
415/// Distinct from:
416///   * [`SmoothTermIdx`] — selects *which* smooth term in the spec.
417///   * [`PenaltyIdx`]    — selects *which* ρ/λ entry / canonical penalty.
418///   * A design-matrix column index — which lives in the *combined* layout
419///     after intercept/parametric blocks and per-term offsets are applied;
420///     a `BasisIdx` is term-local, a column index is model-global.
421///
422/// Keeping this as its own `#[repr(transparent)]` newtype makes the
423/// historically-easy confusion "indexed a global column slice with a
424/// term-local basis ordinal" (or vice versa) a compile error.
425#[repr(transparent)]
426#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
427pub struct BasisIdx(usize);
428
429impl BasisIdx {
430    #[inline]
431    pub const fn new(idx: usize) -> Self {
432        Self(idx)
433    }
434
435    #[inline]
436    pub const fn get(self) -> usize {
437        self.0
438    }
439}
440
441impl std::fmt::Display for BasisIdx {
442    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443        write!(f, "{}", self.0)
444    }
445}
446
447/// Index into the user-facing design matrix `data: Array2<f64>` — i.e. the
448/// position of a covariate column in the raw input frame, *before* any
449/// per-family basis expansion or intercept/parametric layout is applied.
450///
451/// Distinct from:
452///   * [`BasisIdx`] — term-local basis-function ordinal `k` of `B_k(x)`.
453///   * [`SmoothTermIdx`] — position in `TermCollectionSpec::smooth_terms`.
454///   * A coefficient-vector offset `β[i]` — spans the combined design after
455///     expansion, which is much wider than the user-facing data matrix.
456///
457/// Keeping this as its own `#[repr(transparent)]` newtype rules out the easy
458/// confusion of indexing the raw data frame with an expanded-column offset.
459#[repr(transparent)]
460#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
461pub struct ColIdx(usize);
462
463impl ColIdx {
464    #[inline]
465    pub const fn new(idx: usize) -> Self {
466        Self(idx)
467    }
468
469    #[inline]
470    pub const fn get(self) -> usize {
471        self.0
472    }
473}
474
475impl std::fmt::Display for ColIdx {
476    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477        write!(f, "{}", self.0)
478    }
479}
480
481/// Index of an observation (row) in the user-facing data frame / design
482/// matrix — i.e. the `i` in "the i-th observation".
483///
484/// Distinct from every column-type index in this module ([`ColIdx`],
485/// [`BasisIdx`], [`SmoothTermIdx`], [`PenaltyIdx`]) and from coefficient
486/// offsets. Keeping rows behind their own `#[repr(transparent)]` newtype
487/// makes the classic `data[[col, row]]` transposition a compile error.
488#[repr(transparent)]
489#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
490pub struct RowIdx(usize);
491
492impl RowIdx {
493    #[inline]
494    pub const fn new(idx: usize) -> Self {
495        Self(idx)
496    }
497
498    #[inline]
499    pub const fn get(self) -> usize {
500        self.0
501    }
502}
503
504impl std::fmt::Display for RowIdx {
505    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
506        write!(f, "{}", self.0)
507    }
508}
509
510#[repr(transparent)]
511#[derive(Clone, Copy, Debug)]
512pub struct LogSmoothingParamsView<'a>(pub ArrayView1<'a, f64>);
513
514impl<'a> LogSmoothingParamsView<'a> {
515    pub fn new(values: ArrayView1<'a, f64>) -> Self {
516        Self(values)
517    }
518
519    pub fn exp(&self) -> Array1<f64> {
520        self.0.mapv(f64::exp)
521    }
522}
523
524impl<'a> Deref for LogSmoothingParamsView<'a> {
525    type Target = ArrayView1<'a, f64>;
526
527    fn deref(&self) -> &Self::Target {
528        &self.0
529    }
530}
531
532#[cfg(test)]
533mod ridge_policy_tests {
534    use super::{RidgePassport, RidgePolicy, StabilizationKind, StabilizationLedger};
535
536    #[test]
537    fn solver_only_ridge_policy_stays_off_objective_accounting() {
538        let passport = RidgePassport::scaled_identity(1.0e-4, RidgePolicy::solver_only());
539
540        assert!(
541            !passport.policy.include_quadratic_penalty,
542            "solver-only ridge must not add a quadratic prior"
543        );
544        assert_eq!(
545            passport.penalty_logdet_ridge(),
546            0.0,
547            "solver-only ridge must not shift the penalty logdet"
548        );
549        assert_eq!(
550            passport.laplacehessianridge(),
551            0.0,
552            "solver-only ridge must not shift the Laplace Hessian"
553        );
554
555        let ledger = StabilizationLedger::from_passport(passport);
556        assert!(
557            matches!(
558                ledger.kind,
559                StabilizationKind::NumericalPerturbation {
560                    backward_error_bound: None
561                }
562            ),
563            "solver-only ridge is a numerical perturbation, not an explicit prior"
564        );
565        assert_eq!(
566            ledger.quadratic_delta(),
567            0.0,
568            "solver-only ridge must not contribute to the optimized objective"
569        );
570        assert_eq!(
571            ledger.laplace_hessian_delta(),
572            0.0,
573            "solver-only ridge must not contribute to REML curvature accounting"
574        );
575        assert_eq!(
576            ledger.penalty_logdet_delta(),
577            0.0,
578            "solver-only ridge must not contribute to determinant accounting"
579        );
580    }
581}