gam_problem/dispersion_cov.rs
1//! Newtype wrappers that disambiguate the two coefficient-space second-order
2//! quantities used throughout inference.
3//!
4//! Background ("dispersion ownership"):
5//!
6//! The fitter stores two related matrices for a fitted model.
7//!
8//! * `FitInference::beta_covariance` is the posterior coefficient covariance
9//! `Vb = phi * H^{-1}`, with `H = X' W_H X + S(lambda)` and `phi` the
10//! dispersion parameter. This matrix is *already* multiplied by `phi`
11//! (see `solver/estimate.rs`'s `scaled_covariance` call).
12//! * `FitInference::penalized_hessian` is the raw penalised Hessian `H`,
13//! with NO dispersion scaling.
14//!
15//! Several downstream consumers (HMC whitening, Laplace sampling, smooth
16//! tests, etc.) have to know which of these representations they hold so
17//! they apply `phi` exactly once. Passing both as bare `Array2<f64>` makes
18//! that easy to get wrong: the same matrix shape can mean either thing,
19//! and the compiler will not catch a missing — or duplicated —
20//! `phi` factor.
21//!
22//! The lightweight newtypes below give us a way to label the convention at
23//! API boundaries without changing the storage type of the existing
24//! `FitInference` fields. Storage stays `Array2<f64>` to avoid cascading
25//! changes into modules outside the dispersion-ownership refactor's scope
26//! (pirls, families, GPU paths, main, etc.); callers that want to be
27//! explicit can wrap with `PhiScaledCovariance::wrap` /
28//! `UnscaledPrecision::wrap` at the boundary.
29//!
30//! `Dispersion` lives in `gam-problem` as the neutral scale contract. The
31//! helper methods on the local `DispersionExt` trait give terse
32//! `phi()` / `inv_phi()` / `sqrt_phi()` call-sites for sampling code.
33
34use ndarray::{Array1, Array2};
35use serde::{Deserialize, Serialize};
36use std::ops::{Deref, DerefMut};
37
38pub use crate::Dispersion;
39
40/// Compute standard errors from a covariance matrix (sqrt of diagonal).
41pub fn se_from_covariance(cov: &Array2<f64>) -> Array1<f64> {
42 Array1::from_iter(cov.diag().iter().map(|&v| v.max(0.0).sqrt()))
43}
44
45/// Posterior coefficient covariance `Vb = phi * H^{-1}` — the matrix users
46/// see as `Cov(beta_hat)`. This newtype documents that `phi` has already
47/// been multiplied in.
48///
49/// `#[serde(transparent)]` keeps the on-disk wire format identical to the
50/// pre-newtype `Array2<f64>` storage so saved models round-trip cleanly.
51/// `Deref<Target = Array2<f64>>` lets out-of-scope read sites continue
52/// calling `Array2` methods (`.iter()`, `.nrows()`, `.dim()`, …) on the
53/// wrapper without modification.
54#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)]
55#[serde(transparent)]
56pub struct PhiScaledCovariance(pub Array2<f64>);
57
58impl PhiScaledCovariance {
59 /// Wrap an array that is known to already be on the `phi * H^{-1}`
60 /// scale.
61 #[inline]
62 pub fn wrap(cov: Array2<f64>) -> Self {
63 Self(cov)
64 }
65
66 /// Borrow the underlying `φ · H⁻¹` matrix without taking ownership.
67 #[inline]
68 pub fn as_array(&self) -> &Array2<f64> {
69 &self.0
70 }
71
72 /// Consume the wrapper and return the raw `φ · H⁻¹` matrix.
73 #[inline]
74 pub fn into_array(self) -> Array2<f64> {
75 self.0
76 }
77}
78
79impl From<Array2<f64>> for PhiScaledCovariance {
80 #[inline]
81 fn from(cov: Array2<f64>) -> Self {
82 Self(cov)
83 }
84}
85
86impl From<PhiScaledCovariance> for Array2<f64> {
87 #[inline]
88 fn from(cov: PhiScaledCovariance) -> Self {
89 cov.0
90 }
91}
92
93impl Deref for PhiScaledCovariance {
94 type Target = Array2<f64>;
95 #[inline]
96 fn deref(&self) -> &Array2<f64> {
97 &self.0
98 }
99}
100
101impl DerefMut for PhiScaledCovariance {
102 #[inline]
103 fn deref_mut(&mut self) -> &mut Array2<f64> {
104 &mut self.0
105 }
106}
107
108/// Raw penalised Hessian `H = X' W_H X + S(lambda)` with NO dispersion
109/// scaling. Equivalent to `phi * Vb^{-1}` only when `phi == 1`. Use this
110/// for whitening / precision-matrix paths, and pair it with a
111/// [`Dispersion`] at the boundary if the consumer cares about `phi`.
112#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)]
113#[serde(transparent)]
114pub struct UnscaledPrecision(pub Array2<f64>);
115
116impl UnscaledPrecision {
117 /// Wrap an `Array2` that is already on the unscaled
118 /// `H = XᵀW_H X + S(λ)` scale (no `φ` factor). Caller is responsible
119 /// for ensuring the matrix actually represents the penalised Hessian.
120 #[inline]
121 pub fn wrap(hessian: Array2<f64>) -> Self {
122 Self(hessian)
123 }
124
125 /// Borrow the underlying penalised Hessian `H` without taking ownership.
126 #[inline]
127 pub fn as_array(&self) -> &Array2<f64> {
128 &self.0
129 }
130
131 /// Consume the wrapper and return the raw `H` matrix.
132 #[inline]
133 pub fn into_array(self) -> Array2<f64> {
134 self.0
135 }
136}
137
138impl From<Array2<f64>> for UnscaledPrecision {
139 #[inline]
140 fn from(h: Array2<f64>) -> Self {
141 Self(h)
142 }
143}
144
145impl From<UnscaledPrecision> for Array2<f64> {
146 #[inline]
147 fn from(h: UnscaledPrecision) -> Self {
148 h.0
149 }
150}
151
152impl Deref for UnscaledPrecision {
153 type Target = Array2<f64>;
154 #[inline]
155 fn deref(&self) -> &Array2<f64> {
156 &self.0
157 }
158}
159
160impl DerefMut for UnscaledPrecision {
161 #[inline]
162 fn deref_mut(&mut self) -> &mut Array2<f64> {
163 &mut self.0
164 }
165}
166
167/// Extension methods on [`Dispersion`] used by the sampling code, kept here
168/// so we do not need to touch the canonical definition in
169/// `solver::estimate`. The conversions are all `phi`-aware: `inv_phi()`
170/// and `sqrt_phi()` are floored away from zero so that downstream
171/// arithmetic never produces `NaN` / `Inf` on a pathological zero
172/// dispersion.
173pub trait DispersionExt {
174 fn inv_phi(self) -> f64;
175 fn sqrt_phi(self) -> f64;
176}
177
178impl DispersionExt for Dispersion {
179 #[inline]
180 fn inv_phi(self) -> f64 {
181 1.0 / self.phi().max(1e-300)
182 }
183
184 #[inline]
185 fn sqrt_phi(self) -> f64 {
186 self.phi().max(0.0).sqrt()
187 }
188}