gam_problem/types.rs
1use ndarray::{Array1, ArrayView1};
2use serde::{Deserialize, Serialize};
3use std::ops::{Deref, DerefMut};
4
5pub use gam_linalg::{RidgeDeterminantMode, RidgePolicy};
6
7/// Lower floor on positive working weights shared by likelihood families and
8/// PIRLS row assembly so weighted normal equations stay numerically well posed.
9pub const MIN_WEIGHT: f64 = 1e-12;
10
11pub use gam_spec::*;
12
13/// Storage form of the ridge penalty matrix.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum RidgeMatrixForm {
16 /// Ridge matrix is `delta * I`.
17 ScaledIdentity,
18}
19
20/// Concrete ridge metadata stamped into a fitted PIRLS result.
21#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
22pub struct RidgePassport {
23 /// Stabilization magnitude for matrix form `delta * I`.
24 pub delta: f64,
25 pub matrix_form: RidgeMatrixForm,
26 pub policy: RidgePolicy,
27}
28
29impl RidgePassport {
30 pub const fn scaled_identity(delta: f64, policy: RidgePolicy) -> Self {
31 Self {
32 delta,
33 matrix_form: RidgeMatrixForm::ScaledIdentity,
34 policy,
35 }
36 }
37
38 #[inline]
39 pub const fn penalty_logdet_ridge(self) -> f64 {
40 if self.policy.include_penalty_logdet {
41 self.delta
42 } else {
43 0.0
44 }
45 }
46
47 #[inline]
48 pub const fn laplacehessianridge(self) -> f64 {
49 if self.policy.include_laplacehessian {
50 self.delta
51 } else {
52 0.0
53 }
54 }
55}
56
57// ============================================================================
58// StabilizationLedger: canonical accounting for every fixed/heuristic ridge
59// added anywhere in the solver, linear-algebra, or family code paths.
60//
61// Three semantically distinct ridge uses must NEVER be conflated:
62// 1. SolverDampingOnly — Levenberg/trust-region damping; never enters
63// objective, gradient, logdet, Hessian, or any
64// saved/serialized model artifact.
65// 2. NumericalPerturbation — added strictly so a linear solve is well-
66// posed (e.g. Cholesky of a near-singular
67// matrix). Carries an optional backward-error
68// bound. Does NOT change the objective.
69// 3. ExplicitPrior — model-level `delta * I` (or block-diagonal)
70// prior. Appears in quadratic, log normalizer,
71// Laplace Hessian, serialization, diagnostics.
72//
73// `RidgePassport` above already encodes the inclusion-flag matrix for the
74// PIRLS Laplace ridge specifically; this ledger is the broader sibling that
75// every other call site (RidgePlanner, matrix_inverse_with_regularization,
76// LAML rho-Hessian inversion, survival stabilization, custom-family
77// `ridge_floor`) routes through, so a downstream consumer can ask
78// `ledger.quadratic_delta()` rather than rediscovering the policy. The three
79// inclusion bits were lifted into the `StabilizationKind` discriminant so the
80// (kind, inclusion-flags) invariant is enforced statically — heterogeneous
81// combinations like "ExplicitPrior with quadratic excluded" no longer typecheck.
82// ============================================================================
83
84/// Inertia of a symmetric matrix (count of positive / zero / negative
85/// eigenvalues). Used by `bump_with_matrix` and other indefinite-aware
86/// stabilization rules to drive δ from spectral evidence rather than a
87/// condition-number heuristic.
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89pub struct Inertia {
90 pub positive: usize,
91 pub zero: usize,
92 pub negative: usize,
93}
94
95/// Why a stabilization δ was chosen at this site.
96#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
97pub enum StabilizationRule {
98 /// δ is a hard-coded constant in the source.
99 FixedConstant,
100 /// δ chosen so the SPD floor τ is met: δ = max(0, τ - λ_min(H)).
101 InertiaTarget { spd_floor: f64 },
102 /// δ chosen via a condition-number / sqrt-ratio heuristic.
103 Heuristic,
104 /// User- or family-specified prior precision.
105 UserSpecified,
106 /// δ derived from a back-off escalation after a factorization failure.
107 BackoffEscalation { attempts: usize },
108}
109
110/// Three semantically distinct flavours a ridge δ can have.
111#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
112pub enum StabilizationKind {
113 None,
114 /// LM/TR damping. NEVER enters the objective, gradient, logdet, Hessian,
115 /// or any saved model artifact. Lives only inside the trust-region step.
116 SolverDampingOnly,
117 /// Added strictly so a linear solve succeeds. The objective/Hessian the
118 /// caller sees is unchanged; the perturbation is a property of the
119 /// solver, not the model. `backward_error_bound` is the max change to
120 /// the solution norm imputable to the perturbation, when known.
121 NumericalPerturbation {
122 backward_error_bound: Option<f64>,
123 },
124 /// Part of the model. Enters quadratic, log normalizer, Hessian,
125 /// serialization, and user-visible summaries.
126 ExplicitPrior,
127}
128
129/// Canonical record of a single stabilization δ applied at a single site.
130///
131/// Construct via the helper constructors (`solver_damping`,
132/// `numerical_perturbation`, `explicit_prior`) so the `included_in_*`
133/// invariants are guaranteed to match `kind`. Direct field construction is
134/// public for serialization round-trips only.
135#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
136pub struct StabilizationLedger {
137 pub kind: StabilizationKind,
138 pub delta: f64,
139 pub matrix_form: RidgeMatrixForm,
140 pub chosen_by: StabilizationRule,
141 pub inertia_before: Option<Inertia>,
142 pub inertia_after: Option<Inertia>,
143}
144
145impl StabilizationLedger {
146 /// "No stabilization applied at this site" sentinel.
147 pub const fn none() -> Self {
148 Self {
149 kind: StabilizationKind::None,
150 delta: 0.0,
151 matrix_form: RidgeMatrixForm::ScaledIdentity,
152 chosen_by: StabilizationRule::FixedConstant,
153 inertia_before: None,
154 inertia_after: None,
155 }
156 }
157
158 /// LM/TR damping. δ is invisible to the objective, gradient, and any
159 /// saved artifact. Asserting this invariant at every read site is the
160 /// whole reason the ledger exists.
161 pub const fn solver_damping(delta: f64, chosen_by: StabilizationRule) -> Self {
162 Self {
163 kind: StabilizationKind::SolverDampingOnly,
164 delta,
165 matrix_form: RidgeMatrixForm::ScaledIdentity,
166 chosen_by,
167 inertia_before: None,
168 inertia_after: None,
169 }
170 }
171
172 /// Solver-only perturbation that leaves the objective unchanged. The
173 /// caller may attach a backward-error bound when one is available
174 /// (e.g. from iterative refinement / Wilkinson-style analysis).
175 pub const fn numerical_perturbation(
176 delta: f64,
177 chosen_by: StabilizationRule,
178 backward_error_bound: Option<f64>,
179 ) -> Self {
180 Self {
181 kind: StabilizationKind::NumericalPerturbation {
182 backward_error_bound,
183 },
184 delta,
185 matrix_form: RidgeMatrixForm::ScaledIdentity,
186 chosen_by,
187 inertia_before: None,
188 inertia_after: None,
189 }
190 }
191
192 /// Model-level explicit prior. δ enters every accounting pass: the
193 /// quadratic penalty, the Laplace Hessian, the penalty log-determinant,
194 /// and serialization.
195 pub const fn explicit_prior(delta: f64, matrix_form: RidgeMatrixForm) -> Self {
196 Self {
197 kind: StabilizationKind::ExplicitPrior,
198 delta,
199 matrix_form,
200 chosen_by: StabilizationRule::UserSpecified,
201 inertia_before: None,
202 inertia_after: None,
203 }
204 }
205
206 /// Bridge from the existing `RidgePassport` so PIRLS-side code (which
207 /// already passes a `RidgePassport` through every call) can hand a
208 /// ledger to anything that wants the new uniform view.
209 ///
210 /// `RidgePolicy` is homogeneous-by-construction: every constructor sets
211 /// the three inclusion flags identically. A passport whose policy
212 /// excludes every accounting term is morally a numerical perturbation
213 /// (the ridge is there to make the solve work but the objective ignores
214 /// it); a passport whose policy includes every accounting term is an
215 /// explicit prior. Heterogeneous flag combinations cannot be produced
216 /// by the public `RidgePolicy` API and have no inhabitants downstream.
217 pub const fn from_passport(passport: RidgePassport) -> Self {
218 let any_included = passport.policy.include_quadratic_penalty
219 || passport.policy.include_laplacehessian
220 || passport.policy.include_penalty_logdet;
221 let kind = if any_included {
222 StabilizationKind::ExplicitPrior
223 } else {
224 StabilizationKind::NumericalPerturbation {
225 backward_error_bound: None,
226 }
227 };
228 Self {
229 kind,
230 delta: passport.delta,
231 matrix_form: passport.matrix_form,
232 chosen_by: StabilizationRule::FixedConstant,
233 inertia_before: None,
234 inertia_after: None,
235 }
236 }
237
238 /// δ value to fold into the quadratic penalty term, or 0.0 if this
239 /// ledger entry is not part of the model. Derived from `kind`: only
240 /// [`StabilizationKind::ExplicitPrior`] contributes.
241 #[inline]
242 pub const fn quadratic_delta(&self) -> f64 {
243 match self.kind {
244 StabilizationKind::ExplicitPrior => self.delta,
245 StabilizationKind::None
246 | StabilizationKind::SolverDampingOnly
247 | StabilizationKind::NumericalPerturbation { .. } => 0.0,
248 }
249 }
250
251 /// δ value to add to the Laplace Hessian, or 0.0 if not included.
252 /// Derived from `kind`: only [`StabilizationKind::ExplicitPrior`]
253 /// contributes.
254 #[inline]
255 pub const fn laplace_hessian_delta(&self) -> f64 {
256 match self.kind {
257 StabilizationKind::ExplicitPrior => self.delta,
258 StabilizationKind::None
259 | StabilizationKind::SolverDampingOnly
260 | StabilizationKind::NumericalPerturbation { .. } => 0.0,
261 }
262 }
263
264 /// δ value to add inside log|S + δ I|, or 0.0 if not included.
265 /// Derived from `kind`: only [`StabilizationKind::ExplicitPrior`]
266 /// contributes.
267 #[inline]
268 pub const fn penalty_logdet_delta(&self) -> f64 {
269 match self.kind {
270 StabilizationKind::ExplicitPrior => self.delta,
271 StabilizationKind::None
272 | StabilizationKind::SolverDampingOnly
273 | StabilizationKind::NumericalPerturbation { .. } => 0.0,
274 }
275 }
276}
277/// Generate a `#[repr(transparent)]` `Array1<f64>` newtype with the
278/// `new`/`Deref`/`DerefMut`/`AsRef`/`From` boilerplate every wrapper in this
279/// module needs. Keeping the three semantic types behind one macro both
280/// removes ~100 lines of duplication and guarantees they cannot drift apart.
281macro_rules! array1_f64_newtype {
282 ($name:ident $(, $extra:ident)*) => {
283 #[repr(transparent)]
284 #[derive(Clone, Debug, PartialEq)]
285 pub struct $name(pub Array1<f64>);
286
287 impl $name {
288 #[inline]
289 pub fn new(values: Array1<f64>) -> Self {
290 Self(values)
291 }
292
293 #[inline]
294 pub fn zeros(len: usize) -> Self {
295 Self(Array1::zeros(len))
296 }
297 }
298
299 impl Deref for $name {
300 type Target = Array1<f64>;
301 #[inline]
302 fn deref(&self) -> &Self::Target { &self.0 }
303 }
304
305 impl DerefMut for $name {
306 #[inline]
307 fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
308 }
309
310 impl AsRef<Array1<f64>> for $name {
311 #[inline]
312 fn as_ref(&self) -> &Array1<f64> { &self.0 }
313 }
314
315 impl From<Array1<f64>> for $name {
316 #[inline]
317 fn from(values: Array1<f64>) -> Self { Self(values) }
318 }
319
320 impl From<$name> for Array1<f64> {
321 #[inline]
322 fn from(values: $name) -> Self { values.0 }
323 }
324
325 $( array1_f64_newtype!(@extra $name $extra); )*
326 };
327 (@extra $name:ident exp) => {
328 impl $name {
329 #[inline]
330 pub fn exp(&self) -> Array1<f64> { self.0.mapv(f64::exp) }
331 }
332 };
333}
334
335array1_f64_newtype!(Coefficients);
336array1_f64_newtype!(LinearPredictor);
337array1_f64_newtype!(LogSmoothingParams, exp);
338
339/// Index into `TermCollectionSpec::smooth_terms` (and the parallel
340/// `TermCollectionDesign::smooth.terms` slice produced from it).
341///
342/// This is **not** a penalty/ρ index, **not** a column index, and **not** a
343/// coefficient-offset index. Keeping it behind a `#[repr(transparent)]`
344/// newtype makes those confusables a compile error: a `SmoothTermIdx` cannot
345/// be silently used to index `rho`, `beta`, or a design column.
346#[repr(transparent)]
347#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
348pub struct SmoothTermIdx(usize);
349
350impl SmoothTermIdx {
351 #[inline]
352 pub const fn new(idx: usize) -> Self {
353 Self(idx)
354 }
355
356 /// Sentinel used by transient builders that must allocate a coord config
357 /// before the smooth term it references has been positioned in the spec.
358 /// Every code path that constructs a sentinel must overwrite it before
359 /// the value escapes the builder.
360 #[inline]
361 pub const fn placeholder() -> Self {
362 Self(usize::MAX)
363 }
364
365 #[inline]
366 pub const fn get(self) -> usize {
367 self.0
368 }
369
370 #[inline]
371 pub const fn is_placeholder(self) -> bool {
372 self.0 == usize::MAX
373 }
374}
375
376impl std::fmt::Display for SmoothTermIdx {
377 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378 write!(f, "{}", self.0)
379 }
380}
381
382/// Index into the canonical penalty list `&[CanonicalPenalty]` — equivalently,
383/// the position of a smoothing parameter in the ρ / λ vector.
384///
385/// Penalty/ρ indices are not interchangeable with `SmoothTermIdx` (a smooth
386/// term can carry multiple canonical penalties — e.g. tensor-product double
387/// penalties — and structural penalties don't correspond to any smooth term).
388/// Keeping them as separate newtypes makes the historical bug pattern
389/// "indexed `rho` with a smooth-term ordinal" impossible to express.
390#[repr(transparent)]
391#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
392pub struct PenaltyIdx(usize);
393
394impl PenaltyIdx {
395 #[inline]
396 pub const fn new(idx: usize) -> Self {
397 Self(idx)
398 }
399
400 #[inline]
401 pub const fn get(self) -> usize {
402 self.0
403 }
404}
405
406impl std::fmt::Display for PenaltyIdx {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 write!(f, "{}", self.0)
409 }
410}
411
412/// Index into a single smooth term's set of basis functions — i.e. the `k`
413/// in "the k-th basis function `B_k(x)` of this term".
414///
415/// Distinct from:
416/// * [`SmoothTermIdx`] — selects *which* smooth term in the spec.
417/// * [`PenaltyIdx`] — selects *which* ρ/λ entry / canonical penalty.
418/// * A design-matrix column index — which lives in the *combined* layout
419/// after intercept/parametric blocks and per-term offsets are applied;
420/// a `BasisIdx` is term-local, a column index is model-global.
421///
422/// Keeping this as its own `#[repr(transparent)]` newtype makes the
423/// historically-easy confusion "indexed a global column slice with a
424/// term-local basis ordinal" (or vice versa) a compile error.
425#[repr(transparent)]
426#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
427pub struct BasisIdx(usize);
428
429impl BasisIdx {
430 #[inline]
431 pub const fn new(idx: usize) -> Self {
432 Self(idx)
433 }
434
435 #[inline]
436 pub const fn get(self) -> usize {
437 self.0
438 }
439}
440
441impl std::fmt::Display for BasisIdx {
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 write!(f, "{}", self.0)
444 }
445}
446
447/// Index into the user-facing design matrix `data: Array2<f64>` — i.e. the
448/// position of a covariate column in the raw input frame, *before* any
449/// per-family basis expansion or intercept/parametric layout is applied.
450///
451/// Distinct from:
452/// * [`BasisIdx`] — term-local basis-function ordinal `k` of `B_k(x)`.
453/// * [`SmoothTermIdx`] — position in `TermCollectionSpec::smooth_terms`.
454/// * A coefficient-vector offset `β[i]` — spans the combined design after
455/// expansion, which is much wider than the user-facing data matrix.
456///
457/// Keeping this as its own `#[repr(transparent)]` newtype rules out the easy
458/// confusion of indexing the raw data frame with an expanded-column offset.
459#[repr(transparent)]
460#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
461pub struct ColIdx(usize);
462
463impl ColIdx {
464 #[inline]
465 pub const fn new(idx: usize) -> Self {
466 Self(idx)
467 }
468
469 #[inline]
470 pub const fn get(self) -> usize {
471 self.0
472 }
473}
474
475impl std::fmt::Display for ColIdx {
476 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477 write!(f, "{}", self.0)
478 }
479}
480
481/// Index of an observation (row) in the user-facing data frame / design
482/// matrix — i.e. the `i` in "the i-th observation".
483///
484/// Distinct from every column-type index in this module ([`ColIdx`],
485/// [`BasisIdx`], [`SmoothTermIdx`], [`PenaltyIdx`]) and from coefficient
486/// offsets. Keeping rows behind their own `#[repr(transparent)]` newtype
487/// makes the classic `data[[col, row]]` transposition a compile error.
488#[repr(transparent)]
489#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
490pub struct RowIdx(usize);
491
492impl RowIdx {
493 #[inline]
494 pub const fn new(idx: usize) -> Self {
495 Self(idx)
496 }
497
498 #[inline]
499 pub const fn get(self) -> usize {
500 self.0
501 }
502}
503
504impl std::fmt::Display for RowIdx {
505 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
506 write!(f, "{}", self.0)
507 }
508}
509
510#[repr(transparent)]
511#[derive(Clone, Copy, Debug)]
512pub struct LogSmoothingParamsView<'a>(pub ArrayView1<'a, f64>);
513
514impl<'a> LogSmoothingParamsView<'a> {
515 pub fn new(values: ArrayView1<'a, f64>) -> Self {
516 Self(values)
517 }
518
519 pub fn exp(&self) -> Array1<f64> {
520 self.0.mapv(f64::exp)
521 }
522}
523
524impl<'a> Deref for LogSmoothingParamsView<'a> {
525 type Target = ArrayView1<'a, f64>;
526
527 fn deref(&self) -> &Self::Target {
528 &self.0
529 }
530}
531
532#[cfg(test)]
533mod ridge_policy_tests {
534 use super::{RidgePassport, RidgePolicy, StabilizationKind, StabilizationLedger};
535
536 #[test]
537 fn solver_only_ridge_policy_stays_off_objective_accounting() {
538 let passport = RidgePassport::scaled_identity(1.0e-4, RidgePolicy::solver_only());
539
540 assert!(
541 !passport.policy.include_quadratic_penalty,
542 "solver-only ridge must not add a quadratic prior"
543 );
544 assert_eq!(
545 passport.penalty_logdet_ridge(),
546 0.0,
547 "solver-only ridge must not shift the penalty logdet"
548 );
549 assert_eq!(
550 passport.laplacehessianridge(),
551 0.0,
552 "solver-only ridge must not shift the Laplace Hessian"
553 );
554
555 let ledger = StabilizationLedger::from_passport(passport);
556 assert!(
557 matches!(
558 ledger.kind,
559 StabilizationKind::NumericalPerturbation {
560 backward_error_bound: None
561 }
562 ),
563 "solver-only ridge is a numerical perturbation, not an explicit prior"
564 );
565 assert_eq!(
566 ledger.quadratic_delta(),
567 0.0,
568 "solver-only ridge must not contribute to the optimized objective"
569 );
570 assert_eq!(
571 ledger.laplace_hessian_delta(),
572 0.0,
573 "solver-only ridge must not contribute to REML curvature accounting"
574 );
575 assert_eq!(
576 ledger.penalty_logdet_delta(),
577 0.0,
578 "solver-only ridge must not contribute to determinant accounting"
579 );
580 }
581}