Skip to main content

gam_terms/analytic_penalties/
op.rs

1//! Operator-form penalty interface.
2//!
3//! Defines the `PenaltyOp` trait that abstracts a square symmetric PSD penalty
4//! operator. Two concrete implementations live alongside:
5//!   * `Array2<f64>` (via blanket `impl PenaltyOp for Array2<f64>`) — adapts a
6//!     materialized dense penalty into the operator interface.
7//!   * `ClosedFormPenaltyOperator` — implements the trait with analytic,
8//!     streaming pair-kernel matvecs and only materializes when `as_dense()` is
9//!     explicitly requested.
10//!
11//! See `matrix_free_penalty_integration_assessment.md` for why the operator
12//! does not have a "true matrix-free" backing implementation in our K range
13//! and why this trait is still worth threading through PIRLS/REML: the
14//! wallclock win lives at the *Hessian* level (PCG-against-implicit-H). The
15//! closed-form Duchon operator is also matrix-free so large K paths avoid
16//! accidental dense Gram construction in matvec/log-det probes.
17
18use std::sync::Arc;
19
20use faer::Side;
21use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
22
23use crate::basis::closed_form_operator::ClosedFormPenaltyOperator;
24use gam_linalg::faer_ndarray::{FaerEigh, fast_av_view_into};
25
26/// Square symmetric PSD penalty operator.
27///
28/// Implementations may be backed by a materialized `Array2<f64>` or by a
29/// closed-form operator that builds (and caches) its dense form lazily. All
30/// methods must be consistent with the same underlying matrix `S`:
31/// `matvec(w) = S w`, `diag()[i] = S[i,i]`, etc.
32pub trait PenaltyOp: Send + Sync {
33    /// Side length of the (square) operator.
34    fn dim(&self) -> usize;
35
36    /// Apply the operator: `out = S w`.
37    fn matvec(&self, w: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>);
38
39    /// Diagonal entries `S[i,i]`.
40    fn diag(&self) -> Array1<f64>;
41
42    /// Trace `tr(S) = Σ_i S[i,i]`.
43    fn trace(&self) -> f64 {
44        self.diag().sum()
45    }
46
47    /// Exact `log det(S + λI)` for `λ > 0`.
48    /// `S` is allowed to be rank-deficient; the regularization makes the
49    /// regularized operator strictly positive definite.
50    fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String>;
51
52    /// Symmetric eigendecomposition `S = V diag(λ) V^T`. The default
53    /// implementation materializes via `as_dense` and runs the same
54    /// `FaerEigh` path the existing dense pipeline uses, which preserves
55    /// numerical agreement with `analyze_penalty_block`. Implementations
56    /// that have a faster path (Lanczos top-k for very large K) may
57    /// override.
58    fn eigendecompose(&self) -> Result<(Array1<f64>, Array2<f64>), String> {
59        let dense = self.as_dense();
60        dense
61            .eigh(Side::Lower)
62            .map_err(|e| format!("PenaltyOp::eigendecompose: {e}"))
63    }
64
65    /// Materialize the operator as a dense matrix. Required for the
66    /// existing `analyze_penalty_block` path and for callers that need a
67    /// `&Array2` view (Cholesky factorization, etc.). Implementations that
68    /// already hold a dense form should return it cheaply.
69    fn as_dense(&self) -> Array2<f64>;
70}
71
72impl PenaltyOp for Array2<f64> {
73    fn dim(&self) -> usize {
74        assert_eq!(
75            self.nrows(),
76            self.ncols(),
77            "PenaltyOp matrix must be square"
78        );
79        self.nrows()
80    }
81
82    fn matvec(&self, w: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
83        fast_av_view_into(self, &w, out);
84    }
85
86    fn diag(&self) -> Array1<f64> {
87        let n = self.nrows();
88        let mut d = Array1::<f64>::zeros(n);
89        for i in 0..n {
90            d[i] = self[[i, i]];
91        }
92        d
93    }
94
95    fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
96        assert!(lambda > 0.0, "log_det_plus_lambda_i requires λ > 0");
97        let n = <Self as PenaltyOp>::dim(self);
98        let mut regularized = self.clone();
99        for i in 0..n {
100            regularized[[i, i]] += lambda;
101        }
102        let (evals, _) = regularized.eigh(Side::Lower).map_err(|e| {
103            format!("PenaltyOp::log_det_plus_lambda_i eigendecomposition failed: {e}")
104        })?;
105        let mut logdet = 0.0;
106        for (idx, &ev) in evals.iter().enumerate() {
107            if !ev.is_finite() || ev <= 0.0 {
108                return Err(format!(
109                    "PenaltyOp::log_det_plus_lambda_i expected SPD S+λI, \
110                     eigenvalue {idx} is {ev:.3e}"
111                ));
112            }
113            logdet += ev.ln();
114        }
115        Ok(logdet)
116    }
117
118    fn as_dense(&self) -> Array2<f64> {
119        self.clone()
120    }
121}
122
123impl PenaltyOp for ClosedFormPenaltyOperator {
124    fn dim(&self) -> usize {
125        self.dim()
126    }
127
128    fn matvec(&self, w: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
129        self.matvec(w, out)
130    }
131
132    fn diag(&self) -> Array1<f64> {
133        self.diag()
134    }
135
136    fn trace(&self) -> f64 {
137        self.trace()
138    }
139
140    fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
141        self.log_det_plus_lambda_i(lambda)
142    }
143
144    fn as_dense(&self) -> Array2<f64> {
145        self.dense_form()
146    }
147}
148
149/// Wrap any `PenaltyOp` with a scalar multiplier. Useful when the dense
150/// `PenaltyCandidate.matrix` has been normalized by a Frobenius factor `norm`
151/// and we need an operator whose `as_dense()` matches it bit-for-bit. The
152/// adapter divides every matvec / diag / trace result by `norm` (equivalently:
153/// scales by `1/norm`).
154pub struct ScaledPenaltyOp {
155    inner: Arc<dyn PenaltyOp>,
156    scale: f64,
157}
158
159impl ScaledPenaltyOp {
160    pub fn new(inner: Arc<dyn PenaltyOp>, scale: f64) -> Self {
161        Self { inner, scale }
162    }
163}
164
165impl PenaltyOp for ScaledPenaltyOp {
166    fn dim(&self) -> usize {
167        self.inner.dim()
168    }
169
170    fn matvec(&self, w: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
171        self.inner.matvec(w, out.view_mut());
172        out.mapv_inplace(|v| v * self.scale);
173    }
174
175    fn diag(&self) -> Array1<f64> {
176        let mut d = self.inner.diag();
177        d.mapv_inplace(|v| v * self.scale);
178        d
179    }
180
181    fn trace(&self) -> f64 {
182        self.inner.trace() * self.scale
183    }
184
185    fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
186        // log det(scale * S + λ I) cannot be derived from log det(S + (λ/scale) I)
187        // by a uniform shift unless we materialize. Materialize via as_dense and
188        // call the exact Array2 implementation on the scaled dense matrix.
189        let dense = self.as_dense();
190        <Array2<f64> as PenaltyOp>::log_det_plus_lambda_i(&dense, lambda)
191    }
192
193    fn as_dense(&self) -> Array2<f64> {
194        let mut m = self.inner.as_dense();
195        m.mapv_inplace(|v| v * self.scale);
196        m
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use approx::assert_abs_diff_eq;
204    use ndarray::Array;
205
206    fn psd_fixture() -> Array2<f64> {
207        // Symmetric PSD: A = B^T B with random-ish B.
208        let b = Array::from_shape_vec(
209            (3, 4),
210            vec![
211                1.0, -0.3, 0.7, 0.1, 0.2, 1.1, -0.5, 0.4, -0.1, 0.6, 0.9, -0.2,
212            ],
213        )
214        .unwrap();
215        b.t().dot(&b)
216    }
217
218    #[test]
219    fn array2_impl_matvec_matches_dot() {
220        let s = psd_fixture();
221        let v = Array1::from_vec(vec![0.7, -0.4, 0.2, 1.1]);
222        let mut out = Array1::<f64>::zeros(4);
223        s.matvec(v.view(), out.view_mut());
224        let want = s.dot(&v);
225        for i in 0..4 {
226            assert_abs_diff_eq!(out[i], want[i], epsilon = 1e-12);
227        }
228    }
229
230    #[test]
231    fn array2_impl_diag_and_trace() {
232        let s = psd_fixture();
233        let d = <Array2<f64> as PenaltyOp>::diag(&s);
234        for i in 0..4 {
235            assert_abs_diff_eq!(d[i], s[[i, i]], epsilon = 0.0);
236        }
237        let tr = <Array2<f64> as PenaltyOp>::trace(&s);
238        assert_abs_diff_eq!(tr, s.diag().sum(), epsilon = 0.0);
239    }
240
241    #[test]
242    fn array2_impl_eigendecompose_matches_eigh() {
243        let s = psd_fixture();
244        let (evals_op, evecs_op) = <Array2<f64> as PenaltyOp>::eigendecompose(&s).expect("eigh");
245        let (evals_ref, evecs_ref) = s.eigh(Side::Lower).expect("eigh ref");
246        for i in 0..evals_op.len() {
247            assert_abs_diff_eq!(evals_op[i], evals_ref[i], epsilon = 1e-12);
248        }
249        // Sign of eigenvectors is gauge-free; compare V V^T for a stable check.
250        let p_op = evecs_op.dot(&evecs_op.t());
251        let p_ref = evecs_ref.dot(&evecs_ref.t());
252        for i in 0..p_op.nrows() {
253            for j in 0..p_op.ncols() {
254                assert_abs_diff_eq!(p_op[[i, j]], p_ref[[i, j]], epsilon = 1e-12);
255            }
256        }
257    }
258}