Skip to main content

gam_problem/
joint_penalty.rs

1//! Joint (cross-block) penalty specifications.
2//!
3//! After the `T^T S_j T` pullback used by the V+M / SMGS-exact compile path,
4//! a single penalty `S_j` no longer has its nonzero region confined to one
5//! `ParameterBlockSpec`: the pullback by the inter-block coupling matrix `T`
6//! distributes weight across the *entire* compiled parameter vector. The
7//! existing `ParameterBlockSpec.penalties: Vec<PenaltyMatrix>` model encodes
8//! a per-block-local penalty (its dim equals the owning block's column count),
9//! so it cannot represent these full-width operators.
10//!
11//! [`JointPenaltySpec`] is the carrier for that case: one dense
12//! `total_compiled × total_compiled` matrix with its own initial smoothing
13//! parameter and structural nullspace dimension. It lives *alongside*, not
14//! *inside*, the per-block specs.
15//!
16//! ## Inner-solve integration
17//!
18//! `inner_blockwise_fit` and the joint-Newton kernels in `custom_family`
19//! consume ordinary block-local penalties as a `&[Array2<f64>]` paired with
20//! per-block `(start, end)` ranges:
21//!
22//! * `apply_joint_block_penalty_into(ranges, s_lambdas, …)` (≈ line 19960)
23//! * `joint_penalty_preconditioner_diag(…)` (≈ line 20067)
24//! * `add_joint_penalty_to_matrix(matrix, ranges, s_lambdas, …)` (≈ line 20132)
25//!
26//! A cross-block dense `S` has no single owning block range, so the solver also
27//! threads a `JointPenaltyBundle` through those helpers as a full-width path
28//! that:
29//!
30//! 1. computes `S · v` as a full `total × total` mat-vec (cf. `fast_av`),
31//! 2. accumulates `diag(S)` into the Jacobi preconditioner over the full
32//!    parameter vector, and
33//! 3. adds `λ · S` to the dense joint Hessian without slicing.
34//!
35//! The remaining construction-site work is to produce the correct
36//! `JointPenaltySpec` instances for each coupled-family compile path; once a
37//! bundle is supplied through `BlockwiseFitOptions::joint_penalties`, the inner
38//! solve consumes its objective, mat-vec, preconditioner, and dense-Hessian
39//! contributions.
40
41use ndarray::{Array2, ArrayView1};
42
43/// A penalty whose support spans the entire compiled parameter vector.
44///
45/// Unlike [`crate::families::custom_family::PenaltyMatrix`], this carries a
46/// single dense `total_compiled × total_compiled` quadratic form — the
47/// shape produced by `T^T S_j T` pullback after the V+M / SMGS-exact
48/// compile. The `nullspace_dim` is the structural dimension of `ker(S)`
49/// as reported by the construction site (rank-revealing on the *pulled-back*
50/// operator, not the pre-pullback `S_j`), so the REML pseudo-logdet can
51/// avoid numerical rank thresholds.
52#[derive(Debug, Clone)]
53pub struct JointPenaltySpec {
54    /// Optional user-visible precision label. Joint penalties that share a
55    /// label share one smoothing parameter (same convention as
56    /// [`crate::families::custom_family::PenaltyMatrix::Labeled`]).
57    pub label: Option<String>,
58    /// Dense symmetric PSD matrix of shape `(total_compiled, total_compiled)`.
59    pub matrix: Array2<f64>,
60    /// Initial value of `log λ` for this penalty.
61    pub initial_log_lambda: f64,
62    /// Structural nullspace dimension of `matrix` (i.e. `total_compiled - rank`).
63    pub nullspace_dim: usize,
64}
65
66/// Reason a [`JointPenaltySpec`] failed validation.
67#[derive(Debug, Clone, PartialEq)]
68pub enum JointPenaltyError {
69    NotSquare {
70        nrows: usize,
71        ncols: usize,
72    },
73    NonFiniteEntry {
74        row: usize,
75        col: usize,
76        value: f64,
77    },
78    NonFiniteInitialLogLambda {
79        value: f64,
80    },
81    NotSymmetric {
82        row: usize,
83        col: usize,
84        asymmetry: f64,
85    },
86    NullspaceTooLarge {
87        total: usize,
88        nullspace_dim: usize,
89    },
90}
91
92impl std::fmt::Display for JointPenaltyError {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        match self {
95            Self::NotSquare { nrows, ncols } => {
96                write!(f, "joint penalty matrix is not square: {nrows}x{ncols}")
97            }
98            Self::NonFiniteEntry { row, col, value } => write!(
99                f,
100                "joint penalty matrix has non-finite entry at ({row},{col}): {value}"
101            ),
102            Self::NonFiniteInitialLogLambda { value } => {
103                write!(f, "joint penalty initial_log_lambda is non-finite: {value}")
104            }
105            Self::NotSymmetric {
106                row,
107                col,
108                asymmetry,
109            } => write!(
110                f,
111                "joint penalty matrix is not symmetric at ({row},{col}): |S - Sᵀ|={asymmetry:.3e}"
112            ),
113            Self::NullspaceTooLarge {
114                total,
115                nullspace_dim,
116            } => write!(
117                f,
118                "joint penalty nullspace_dim={nullspace_dim} exceeds dim={total}"
119            ),
120        }
121    }
122}
123
124impl std::error::Error for JointPenaltyError {}
125
126impl JointPenaltySpec {
127    /// Symmetry tolerance for [`validate`]. Cross-block pullbacks via `T`
128    /// accumulate roundoff, so an exact symmetric requirement is too tight;
129    /// this matches the floor used by the surrounding penalty code paths.
130    const SYMMETRY_TOL: f64 = 1e-10;
131
132    /// Total compiled parameter count this penalty acts on.
133    #[inline]
134    pub fn dim(&self) -> usize {
135        self.matrix.nrows()
136    }
137
138    /// Trace of the penalty matrix (`Σ_i S[i,i]`).
139    pub fn trace(&self) -> f64 {
140        self.matrix.diag().iter().copied().sum()
141    }
142
143    /// Structural pseudo-rank, derived from the declared `nullspace_dim`.
144    /// This is the rank used by the REML pseudo-logdet under the
145    /// no-numerical-thresholds policy in the surrounding code.
146    #[inline]
147    pub fn pseudo_rank(&self) -> usize {
148        self.dim().saturating_sub(self.nullspace_dim)
149    }
150
151    /// Quadratic form `βᵀ S β`. Mirrors
152    /// [`crate::families::custom_family::PenaltyMatrix::quadratic_form`] for
153    /// the full-width case.
154    pub fn quadratic_form(&self, beta: ArrayView1<'_, f64>) -> f64 {
155        assert_eq!(
156            beta.len(),
157            self.dim(),
158            "joint penalty quadratic form: beta length {} != dim {}",
159            beta.len(),
160            self.dim()
161        );
162        beta.dot(&self.matrix.dot(&beta))
163    }
164
165    /// Validate shape, finiteness, symmetry, and nullspace bookkeeping.
166    pub fn validate(&self) -> Result<(), JointPenaltyError> {
167        let (nrows, ncols) = self.matrix.dim();
168        if nrows != ncols {
169            return Err(JointPenaltyError::NotSquare { nrows, ncols });
170        }
171        if !self.initial_log_lambda.is_finite() {
172            return Err(JointPenaltyError::NonFiniteInitialLogLambda {
173                value: self.initial_log_lambda,
174            });
175        }
176        if self.nullspace_dim > nrows {
177            return Err(JointPenaltyError::NullspaceTooLarge {
178                total: nrows,
179                nullspace_dim: self.nullspace_dim,
180            });
181        }
182        for ((row, col), &value) in self.matrix.indexed_iter() {
183            if !value.is_finite() {
184                return Err(JointPenaltyError::NonFiniteEntry { row, col, value });
185            }
186        }
187        for row in 0..nrows {
188            for col in (row + 1)..ncols {
189                let asymmetry = (self.matrix[[row, col]] - self.matrix[[col, row]]).abs();
190                if asymmetry > Self::SYMMETRY_TOL {
191                    return Err(JointPenaltyError::NotSymmetric {
192                        row,
193                        col,
194                        asymmetry,
195                    });
196                }
197            }
198        }
199        Ok(())
200    }
201}
202
203/// Per-evaluation bundle of cross-block penalties paired with their current
204/// log-smoothing parameters.
205///
206/// The outer optimizer concatenates joint penalty `log λ` values onto the
207/// per-block ρ vector; the inner solver receives this bundle via
208/// [`crate::families::custom_family::BlockwiseFitOptions::joint_penalties`]
209/// and adds the full-width quadratic / matvec / preconditioner / Hessian
210/// contributions to the joint-Newton primitives.
211#[derive(Clone, Debug)]
212pub struct JointPenaltyBundle {
213    pub specs: std::sync::Arc<Vec<JointPenaltySpec>>,
214    pub log_lambdas: Vec<f64>,
215}
216
217impl JointPenaltyBundle {
218    /// Build a bundle, validating the per-penalty `log λ` count and dimension
219    /// agreement against `total_compiled`.
220    pub fn new(
221        specs: std::sync::Arc<Vec<JointPenaltySpec>>,
222        log_lambdas: Vec<f64>,
223        total_compiled: usize,
224    ) -> Result<Self, String> {
225        if specs.len() != log_lambdas.len() {
226            return Err(format!(
227                "joint penalty bundle: {} specs vs {} log_lambdas",
228                specs.len(),
229                log_lambdas.len(),
230            ));
231        }
232        for (i, spec) in specs.iter().enumerate() {
233            if spec.dim() != total_compiled {
234                return Err(format!(
235                    "joint penalty {i}: dim {} != total_compiled {}",
236                    spec.dim(),
237                    total_compiled,
238                ));
239            }
240        }
241        Ok(Self { specs, log_lambdas })
242    }
243
244    #[inline]
245    pub fn len(&self) -> usize {
246        self.specs.len()
247    }
248
249    #[inline]
250    pub fn is_empty(&self) -> bool {
251        self.specs.is_empty()
252    }
253
254    /// Total joint-penalty contribution to the objective:
255    ///   `½ Σ_j exp(ρ_j) · βᵀ S_j β`.
256    pub fn quadratic(&self, beta: ArrayView1<'_, f64>) -> f64 {
257        let mut total = 0.0;
258        for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
259            let lam = log_lambda.exp();
260            total += 0.5 * lam * spec.quadratic_form(beta);
261        }
262        total
263    }
264
265    /// Accumulate `Σ_j exp(ρ_j) · S_j · v` into `out` (additive).
266    pub fn add_apply_into(&self, vector: ArrayView1<'_, f64>, out: &mut ndarray::Array1<f64>) {
267        assert_eq!(out.len(), vector.len());
268        for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
269            let lam = log_lambda.exp();
270            let sv = spec.matrix.dot(&vector);
271            out.scaled_add(lam, &sv);
272        }
273    }
274
275    /// Accumulate `Σ_j exp(ρ_j) · diag(S_j)` into `diag` (additive).
276    pub fn add_diag(&self, diag: &mut ndarray::Array1<f64>) {
277        for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
278            let lam = log_lambda.exp();
279            for (i, value) in spec.matrix.diag().iter().enumerate() {
280                diag[i] += lam * *value;
281            }
282        }
283    }
284
285    /// Accumulate `Σ_j exp(ρ_j) · S_j` into the full `matrix` (additive).
286    pub fn add_to_matrix(&self, matrix: &mut Array2<f64>) {
287        assert_eq!(matrix.nrows(), matrix.ncols());
288        for (spec, &log_lambda) in self.specs.iter().zip(self.log_lambdas.iter()) {
289            let lam = log_lambda.exp();
290            matrix.scaled_add(lam, &spec.matrix);
291        }
292    }
293
294    /// Per-penalty ρ-gradient contribution to the outer objective term:
295    ///   `∂/∂ρ_j [½ exp(ρ_j) βᵀ S_j β] = exp(ρ_j) · ½ βᵀ S_j β`.
296    pub fn rho_objective_gradient(&self, beta: ArrayView1<'_, f64>, out: &mut [f64]) {
297        assert_eq!(out.len(), self.specs.len());
298        for (i, (spec, &log_lambda)) in self.specs.iter().zip(self.log_lambdas.iter()).enumerate() {
299            let lam = log_lambda.exp();
300            out[i] = 0.5 * lam * spec.quadratic_form(beta);
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use ndarray::{Array1, Array2, array};
309
310    /// 4-dim cross-block dense penalty: a rank-2 operator that couples
311    /// indices {0,1} to {2,3} (i.e. nonzero off the 2×2 block diagonal),
312    /// which is exactly the shape that defeats a per-block `PenaltyMatrix`.
313    fn cross_block_spec() -> JointPenaltySpec {
314        // Build S = vᵀv + wᵀw where v and w span across both 2-blocks.
315        let v: Array1<f64> = array![1.0, 0.0, -1.0, 0.0];
316        let w: Array1<f64> = array![0.0, 1.0, 0.0, -1.0];
317        let mut matrix: Array2<f64> = Array2::zeros((4, 4));
318        for i in 0..4 {
319            for j in 0..4 {
320                matrix[[i, j]] = v[i] * v[j] + w[i] * w[j];
321            }
322        }
323        JointPenaltySpec {
324            label: Some("cross_block_pullback".to_string()),
325            matrix,
326            initial_log_lambda: -1.5,
327            nullspace_dim: 2,
328        }
329    }
330
331    #[test]
332    fn cross_block_dense_validates() {
333        let result = cross_block_spec().validate();
334        assert!(
335            result.is_ok(),
336            "valid cross-block spec rejected: {result:?}"
337        );
338    }
339
340    #[test]
341    fn trace_matches_diagonal_sum() {
342        let spec = cross_block_spec();
343        // diag(S) = [v0^2+w0^2, v1^2+w1^2, v2^2+w2^2, v3^2+w3^2] = [1,1,1,1]
344        assert!((spec.trace() - 4.0).abs() < 1e-12);
345    }
346
347    #[test]
348    fn pseudo_rank_uses_declared_nullspace() {
349        let spec = cross_block_spec();
350        assert_eq!(spec.dim(), 4);
351        assert_eq!(spec.pseudo_rank(), 2);
352    }
353
354    #[test]
355    fn quadratic_form_matches_explicit_mat_vec() {
356        let spec = cross_block_spec();
357        // Pick a beta that has support in both 2-blocks.
358        let beta: Array1<f64> = array![0.5, -0.25, 1.0, 0.75];
359        // v·β = 0.5 - 1.0 = -0.5; w·β = -0.25 - 0.75 = -1.0
360        // βᵀSβ = (v·β)^2 + (w·β)^2 = 0.25 + 1.0 = 1.25
361        let q = spec.quadratic_form(beta.view());
362        assert!((q - 1.25).abs() < 1e-12, "got {q}");
363    }
364
365    #[test]
366    fn determinant_zero_for_rank_deficient_matches_nullspace() {
367        use gam_linalg::faer_ndarray::FaerEigh;
368        let spec = cross_block_spec();
369        // Symmetric eigendecomposition; expect exactly nullspace_dim
370        // zeros (up to floating-point), matching the declared rank.
371        let (eigvals, _) =
372            FaerEigh::eigh(&spec.matrix, faer::Side::Lower).expect("symmetric eigh succeeds");
373        let mut sorted: Vec<f64> = eigvals.iter().copied().collect();
374        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
375        let zeros = sorted.iter().take_while(|&&v| v.abs() < 1e-10).count();
376        assert_eq!(
377            zeros, spec.nullspace_dim,
378            "spectrum {sorted:?} should have {} near-zeros",
379            spec.nullspace_dim
380        );
381        // Determinant = product of eigenvalues; with a real nullspace
382        // it is exactly zero modulo roundoff.
383        let det: f64 = sorted.iter().product();
384        assert!(det.abs() < 1e-10, "expected ~0 determinant, got {det}");
385    }
386
387    #[test]
388    fn validate_rejects_non_square() {
389        let spec = JointPenaltySpec {
390            label: None,
391            matrix: Array2::zeros((3, 4)),
392            initial_log_lambda: 0.0,
393            nullspace_dim: 0,
394        };
395        assert!(matches!(
396            spec.validate(),
397            Err(JointPenaltyError::NotSquare { nrows: 3, ncols: 4 })
398        ));
399    }
400
401    #[test]
402    fn validate_rejects_non_symmetric() {
403        let mut matrix = Array2::<f64>::zeros((3, 3));
404        matrix[[0, 1]] = 1.0;
405        matrix[[1, 0]] = -1.0;
406        let spec = JointPenaltySpec {
407            label: None,
408            matrix,
409            initial_log_lambda: 0.0,
410            nullspace_dim: 0,
411        };
412        assert!(matches!(
413            spec.validate(),
414            Err(JointPenaltyError::NotSymmetric { .. })
415        ));
416    }
417
418    #[test]
419    fn validate_rejects_oversized_nullspace() {
420        let spec = JointPenaltySpec {
421            label: None,
422            matrix: Array2::zeros((3, 3)),
423            initial_log_lambda: 0.0,
424            nullspace_dim: 4,
425        };
426        assert!(matches!(
427            spec.validate(),
428            Err(JointPenaltyError::NullspaceTooLarge {
429                total: 3,
430                nullspace_dim: 4
431            })
432        ));
433    }
434
435    #[test]
436    fn validate_rejects_non_finite_initial_log_lambda() {
437        let spec = JointPenaltySpec {
438            label: None,
439            matrix: Array2::zeros((2, 2)),
440            initial_log_lambda: f64::NAN,
441            nullspace_dim: 0,
442        };
443        assert!(matches!(
444            spec.validate(),
445            Err(JointPenaltyError::NonFiniteInitialLogLambda { .. })
446        ));
447    }
448
449    /// 2-block toy with one full-width SPD joint penalty:
450    ///
451    /// Two scalar blocks (`p = 1 + 1 = 2`). The unpenalised "log-likelihood"
452    /// is a quadratic with optimum at `b`:
453    ///     ℓ(β) = −½ (β − b)ᵀ I (β − b)
454    /// so `−∇ℓ = (β − b)` and `−∇²ℓ = I`. We add ONE joint penalty
455    /// `S = [[2, 1], [1, 2]]` (SPD, full rank, cross-block coupling
456    /// off-diagonal). With `λ = exp(ρ)` the penalised objective is
457    ///     F(β) = ½ (β − b)ᵀ (β − b) + ½ λ βᵀ S β
458    /// whose minimiser solves `(I + λ S) β̂ = b`. We verify the bundle's
459    /// `add_to_matrix` builds the right LHS and the `add_apply_into` /
460    /// `quadratic` helpers agree with the analytic gradient / objective at
461    /// `β̂`.
462    #[test]
463    fn bundle_two_block_minimiser_matches_analytic_solution() {
464        use gam_linalg::faer_ndarray::FaerCholesky;
465        use ndarray::Array2;
466
467        let spec = JointPenaltySpec {
468            label: Some("toy_cross_block".to_string()),
469            matrix: array![[2.0_f64, 1.0], [1.0, 2.0]],
470            initial_log_lambda: 0.0,
471            nullspace_dim: 0,
472        };
473        let log_lambda = -0.4_f64;
474        let lam = log_lambda.exp();
475        let bundle = JointPenaltyBundle::new(std::sync::Arc::new(vec![spec]), vec![log_lambda], 2)
476            .expect("valid bundle");
477
478        // Build LHS = I + λ S via add_to_matrix (the exact path the inner
479        // Newton uses to assemble the penalised joint Hessian).
480        let mut lhs = Array2::<f64>::eye(2);
481        bundle.add_to_matrix(&mut lhs);
482        // Verify add_to_matrix produced I + λ S.
483        let expected_lhs = array![[1.0 + lam * 2.0, lam], [lam, 1.0 + lam * 2.0]];
484        for r in 0..2 {
485            for c in 0..2 {
486                assert!(
487                    (lhs[[r, c]] - expected_lhs[[r, c]]).abs() < 1e-12,
488                    "lhs[{r}, {c}] = {} expected {}",
489                    lhs[[r, c]],
490                    expected_lhs[[r, c]]
491                );
492            }
493        }
494
495        // Solve (I + λ S) β̂ = b for b = [1.0, -0.5].
496        let b: Array1<f64> = array![1.0, -0.5];
497        let chol = lhs.cholesky(faer::Side::Lower).expect("SPD");
498        let mut rhs_mat = Array2::<f64>::zeros((2, 1));
499        rhs_mat[[0, 0]] = b[0];
500        rhs_mat[[1, 0]] = b[1];
501        let mut beta_mat = rhs_mat.clone();
502        chol.solve_mat_in_place(&mut beta_mat);
503        let beta_hat: Array1<f64> = array![beta_mat[[0, 0]], beta_mat[[1, 0]]];
504
505        // Gradient at β̂: (β̂ − b) + λ S β̂ should be ~0.
506        let mut grad = &beta_hat - &b;
507        bundle.add_apply_into(beta_hat.view(), &mut grad);
508        let grad_inf = grad.iter().map(|v: &f64| v.abs()).fold(0.0_f64, f64::max);
509        assert!(
510            grad_inf < 1e-12,
511            "penalised gradient at analytic minimiser must vanish: {grad_inf:.3e}"
512        );
513
514        // Objective ½(β̂−b)·(β̂−b) + bundle.quadratic(β̂) reproduces the
515        // closed-form minimum value F(β̂) = ½(β̂−b)ᵀ(β̂−b) + ½λ β̂ᵀ S β̂.
516        let resid = &beta_hat - &b;
517        let unpen = 0.5 * resid.dot(&resid);
518        let pen = bundle.quadratic(beta_hat.view());
519        let expected_obj = 0.5 * resid.dot(&resid)
520            + 0.5 * lam * beta_hat.dot(&array![[2.0, 1.0], [1.0, 2.0]].dot(&beta_hat));
521        assert!(
522            (unpen + pen - expected_obj).abs() < 1e-12,
523            "objective sum {} mismatched expected {}",
524            unpen + pen,
525            expected_obj
526        );
527
528        // Preconditioner diag accumulator: diag(I) + λ diag(S) = [1+2λ, 1+2λ].
529        let mut diag = ndarray::Array1::<f64>::from_elem(2, 1.0);
530        bundle.add_diag(&mut diag);
531        assert!((diag[0] - (1.0 + lam * 2.0)).abs() < 1e-12);
532        assert!((diag[1] - (1.0 + lam * 2.0)).abs() < 1e-12);
533
534        // rho-objective-gradient: ½ λ β̂ᵀ S β̂.
535        let mut rho_grad = vec![0.0_f64];
536        bundle.rho_objective_gradient(beta_hat.view(), &mut rho_grad);
537        let expected_rho_grad =
538            0.5 * lam * beta_hat.dot(&array![[2.0, 1.0], [1.0, 2.0]].dot(&beta_hat));
539        assert!(
540            (rho_grad[0] - expected_rho_grad).abs() < 1e-12,
541            "rho-grad {} expected {}",
542            rho_grad[0],
543            expected_rho_grad
544        );
545    }
546
547    #[test]
548    fn bundle_rejects_dim_mismatch() {
549        let spec = JointPenaltySpec {
550            label: None,
551            matrix: Array2::<f64>::eye(3),
552            initial_log_lambda: 0.0,
553            nullspace_dim: 0,
554        };
555        let err = JointPenaltyBundle::new(std::sync::Arc::new(vec![spec]), vec![0.0], 4)
556            .expect_err("dim mismatch must reject");
557        assert!(err.contains("total_compiled"));
558    }
559
560    #[test]
561    fn bundle_rejects_lambda_count_mismatch() {
562        let spec = JointPenaltySpec {
563            label: None,
564            matrix: Array2::<f64>::eye(2),
565            initial_log_lambda: 0.0,
566            nullspace_dim: 0,
567        };
568        let err = JointPenaltyBundle::new(std::sync::Arc::new(vec![spec]), vec![], 2)
569            .expect_err("count mismatch must reject");
570        assert!(err.contains("specs vs"));
571    }
572}