Skip to main content

gam_problem/
dispersion_cov.rs

1//! Newtype wrappers that disambiguate the two coefficient-space second-order
2//! quantities used throughout inference.
3//!
4//! Background ("dispersion ownership"):
5//!
6//! The fitter stores two related matrices for a fitted model.
7//!
8//! * `FitInference::beta_covariance` is the posterior coefficient covariance
9//!   `Vb = phi * H^{-1}`, with `H = X' W_H X + S(lambda)` and `phi` the
10//!   dispersion parameter. This matrix is *already* multiplied by `phi`
11//!   (see `solver/estimate.rs`'s `scaled_covariance` call).
12//! * `FitInference::penalized_hessian` is the raw penalised Hessian `H`,
13//!   with NO dispersion scaling.
14//!
15//! Several downstream consumers (HMC whitening, Laplace sampling, smooth
16//! tests, etc.) have to know which of these representations they hold so
17//! they apply `phi` exactly once. Passing both as bare `Array2<f64>` makes
18//! that easy to get wrong: the same matrix shape can mean either thing,
19//! and the compiler will not catch a missing — or duplicated —
20//! `phi` factor.
21//!
22//! The lightweight newtypes below give us a way to label the convention at
23//! API boundaries without changing the storage type of the existing
24//! `FitInference` fields. Storage stays `Array2<f64>` to avoid cascading
25//! changes into modules outside the dispersion-ownership refactor's scope
26//! (pirls, families, GPU paths, main, etc.); callers that want to be
27//! explicit can wrap with `PhiScaledCovariance::wrap` /
28//! `UnscaledPrecision::wrap` at the boundary.
29//!
30//! `Dispersion` lives in `gam-problem` as the neutral scale contract. The
31//! helper methods on the local `DispersionExt` trait give terse
32//! `phi()` / `inv_phi()` / `sqrt_phi()` call-sites for sampling code.
33
34use ndarray::{Array1, Array2};
35use serde::{Deserialize, Serialize};
36use std::ops::{Deref, DerefMut};
37
38pub use crate::Dispersion;
39
40/// Compute standard errors from a covariance matrix (sqrt of diagonal).
41pub fn se_from_covariance(cov: &Array2<f64>) -> Array1<f64> {
42    Array1::from_iter(cov.diag().iter().map(|&v| v.max(0.0).sqrt()))
43}
44
45/// Posterior coefficient covariance `Vb = phi * H^{-1}` — the matrix users
46/// see as `Cov(beta_hat)`. This newtype documents that `phi` has already
47/// been multiplied in.
48///
49/// `#[serde(transparent)]` keeps the on-disk wire format identical to the
50/// pre-newtype `Array2<f64>` storage so saved models round-trip cleanly.
51/// `Deref<Target = Array2<f64>>` lets out-of-scope read sites continue
52/// calling `Array2` methods (`.iter()`, `.nrows()`, `.dim()`, …) on the
53/// wrapper without modification.
54#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)]
55#[serde(transparent)]
56pub struct PhiScaledCovariance(pub Array2<f64>);
57
58impl PhiScaledCovariance {
59    /// Wrap an array that is known to already be on the `phi * H^{-1}`
60    /// scale.
61    #[inline]
62    pub fn wrap(cov: Array2<f64>) -> Self {
63        Self(cov)
64    }
65
66    /// Borrow the underlying `φ · H⁻¹` matrix without taking ownership.
67    #[inline]
68    pub fn as_array(&self) -> &Array2<f64> {
69        &self.0
70    }
71
72    /// Consume the wrapper and return the raw `φ · H⁻¹` matrix.
73    #[inline]
74    pub fn into_array(self) -> Array2<f64> {
75        self.0
76    }
77}
78
79impl From<Array2<f64>> for PhiScaledCovariance {
80    #[inline]
81    fn from(cov: Array2<f64>) -> Self {
82        Self(cov)
83    }
84}
85
86impl From<PhiScaledCovariance> for Array2<f64> {
87    #[inline]
88    fn from(cov: PhiScaledCovariance) -> Self {
89        cov.0
90    }
91}
92
93impl Deref for PhiScaledCovariance {
94    type Target = Array2<f64>;
95    #[inline]
96    fn deref(&self) -> &Array2<f64> {
97        &self.0
98    }
99}
100
101impl DerefMut for PhiScaledCovariance {
102    #[inline]
103    fn deref_mut(&mut self) -> &mut Array2<f64> {
104        &mut self.0
105    }
106}
107
108/// Raw penalised Hessian `H = X' W_H X + S(lambda)` with NO dispersion
109/// scaling. Equivalent to `phi * Vb^{-1}` only when `phi == 1`. Use this
110/// for whitening / precision-matrix paths, and pair it with a
111/// [`Dispersion`] at the boundary if the consumer cares about `phi`.
112#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)]
113#[serde(transparent)]
114pub struct UnscaledPrecision(pub Array2<f64>);
115
116impl UnscaledPrecision {
117    /// Wrap an `Array2` that is already on the unscaled
118    /// `H = XᵀW_H X + S(λ)` scale (no `φ` factor).  Caller is responsible
119    /// for ensuring the matrix actually represents the penalised Hessian.
120    #[inline]
121    pub fn wrap(hessian: Array2<f64>) -> Self {
122        Self(hessian)
123    }
124
125    /// Borrow the underlying penalised Hessian `H` without taking ownership.
126    #[inline]
127    pub fn as_array(&self) -> &Array2<f64> {
128        &self.0
129    }
130
131    /// Consume the wrapper and return the raw `H` matrix.
132    #[inline]
133    pub fn into_array(self) -> Array2<f64> {
134        self.0
135    }
136}
137
138impl From<Array2<f64>> for UnscaledPrecision {
139    #[inline]
140    fn from(h: Array2<f64>) -> Self {
141        Self(h)
142    }
143}
144
145impl From<UnscaledPrecision> for Array2<f64> {
146    #[inline]
147    fn from(h: UnscaledPrecision) -> Self {
148        h.0
149    }
150}
151
152impl Deref for UnscaledPrecision {
153    type Target = Array2<f64>;
154    #[inline]
155    fn deref(&self) -> &Array2<f64> {
156        &self.0
157    }
158}
159
160impl DerefMut for UnscaledPrecision {
161    #[inline]
162    fn deref_mut(&mut self) -> &mut Array2<f64> {
163        &mut self.0
164    }
165}
166
167/// Extension methods on [`Dispersion`] used by the sampling code, kept here
168/// so we do not need to touch the canonical definition in
169/// `solver::estimate`. The conversions are all `phi`-aware: `inv_phi()`
170/// and `sqrt_phi()` are floored away from zero so that downstream
171/// arithmetic never produces `NaN` / `Inf` on a pathological zero
172/// dispersion.
173pub trait DispersionExt {
174    fn inv_phi(self) -> f64;
175    fn sqrt_phi(self) -> f64;
176}
177
178impl DispersionExt for Dispersion {
179    #[inline]
180    fn inv_phi(self) -> f64 {
181        1.0 / self.phi().max(1e-300)
182    }
183
184    #[inline]
185    fn sqrt_phi(self) -> f64 {
186        self.phi().max(0.0).sqrt()
187    }
188}