Skip to main content

gam_terms/
penalty_spec.rs

1//! Penalty specification for the public estimate API.
2//!
3//! `PenaltySpec` is a *penalty spec* keyed entirely on `gam-terms` penalty
4//! types (`PenaltyStructureHint`, `PenaltyOp`, `BlockwisePenalty`) plus the
5//! neutral `gam_problem::CoefficientPriorMean`. It therefore lives in
6//! `gam-terms` (the layer that owns those penalty primitives); the solver
7//! consumes it from above via `gam_terms::PenaltySpec`.
8//!
9//! Moved byte-identically from the monolith `src/model_types.rs` during the
10//! `#1521` carve, with the only changes being the crate-local module paths
11//! (`crate::terms::smooth::*` -> `crate::smooth::*`,
12//! `crate::terms::analytic_penalties::*` -> `crate::analytic_penalties::*`)
13//! and `EstimationError` sourced from `gam_problem`.
14
15use std::ops::Range;
16
17use ndarray::{Array2, s};
18
19use crate::smooth::{BlockwisePenalty, PenaltyStructureHint};
20
21/// Programmatic prior mean for a coefficient penalty block.
22///
23/// This type lives in the neutral `gam-problem` crate (with its inherent
24/// `evaluate` returning `gam_problem::PriorMeanError`); re-exported here so all
25/// existing `PenaltySpec`-adjacent references keep resolving. Solver-side
26/// callers map `PriorMeanError` into `EstimationError::InvalidInput`.
27pub use gam_problem::CoefficientPriorMean;
28pub use gam_problem::EstimationError;
29
30/// A penalty specification for the public estimate API.
31///
32/// `Block` stores only the active sub-block and its column range, avoiding
33/// the O(p^2) cost of embedding into a full penalty matrix.
34/// `Dense` stores a full `p x p` penalty matrix for callers that already
35/// have one.
36#[derive(Clone)]
37pub enum PenaltySpec {
38    /// Block-local penalty: `local` is `block_dim x block_dim`,
39    /// applied to columns `col_range` of the coefficient vector.
40    Block {
41        local: Array2<f64>,
42        col_range: Range<usize>,
43        prior_mean: CoefficientPriorMean,
44        /// Optional structural hint for fast-path spectral decomposition.
45        structure_hint: Option<PenaltyStructureHint>,
46        /// Optional operator-form handle bit-equivalent to `local`.
47        op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
48    },
49    /// Full dense penalty matrix (`p x p`).
50    Dense(Array2<f64>),
51    /// Full dense penalty matrix with a programmatic prior mean in the same
52    /// global coefficient basis.
53    DenseWithMean {
54        matrix: Array2<f64>,
55        prior_mean: CoefficientPriorMean,
56    },
57}
58
59impl std::fmt::Debug for PenaltySpec {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            PenaltySpec::Block {
63                local,
64                col_range,
65                prior_mean,
66                structure_hint,
67                op,
68            } => f
69                .debug_struct("Block")
70                .field(
71                    "local",
72                    &format_args!("{}×{}", local.nrows(), local.ncols()),
73                )
74                .field("col_range", col_range)
75                .field("prior_mean", prior_mean)
76                .field("structure_hint", structure_hint)
77                .field("op", &op.as_ref().map(|o| o.dim()))
78                .finish(),
79            PenaltySpec::Dense(m) => f
80                .debug_tuple("Dense")
81                .field(&format_args!("{}×{}", m.nrows(), m.ncols()))
82                .finish(),
83            PenaltySpec::DenseWithMean { matrix, prior_mean } => f
84                .debug_struct("DenseWithMean")
85                .field(
86                    "matrix",
87                    &format_args!("{}×{}", matrix.nrows(), matrix.ncols()),
88                )
89                .field("prior_mean", prior_mean)
90                .finish(),
91        }
92    }
93}
94
95impl PenaltySpec {
96    /// The column range this penalty covers.
97    /// For `Dense`, this is `0..p` where `p = m.ncols()`.
98    pub fn col_range(&self, p: usize) -> Range<usize> {
99        match self {
100            PenaltySpec::Block { col_range, .. } => col_range.clone(),
101            PenaltySpec::Dense(m) => {
102                assert_eq!(m.ncols(), p);
103                0..p
104            }
105            PenaltySpec::DenseWithMean { matrix, .. } => {
106                assert_eq!(matrix.ncols(), p);
107                0..p
108            }
109        }
110    }
111
112    /// Op-form handle when present (only for `Block`; `Dense` always returns `None`).
113    pub fn op(&self) -> Option<&std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>> {
114        match self {
115            PenaltySpec::Block { op, .. } => op.as_ref(),
116            PenaltySpec::Dense(_) | PenaltySpec::DenseWithMean { .. } => None,
117        }
118    }
119
120    /// Convert from a `BlockwisePenalty`, preserving the structure hint and op.
121    pub fn from_blockwise(bp: BlockwisePenalty) -> Self {
122        PenaltySpec::Block {
123            local: bp.local,
124            col_range: bp.col_range,
125            prior_mean: bp.prior_mean,
126            structure_hint: bp.structure_hint,
127            op: bp.op,
128        }
129    }
130
131    pub fn from_blockwise_ref(bp: &BlockwisePenalty) -> Self {
132        PenaltySpec::Block {
133            local: bp.local.clone(),
134            col_range: bp.col_range.clone(),
135            prior_mean: bp.prior_mean.clone(),
136            structure_hint: bp.structure_hint.clone(),
137            op: bp.op.clone(),
138        }
139    }
140
141    /// Materialize the full `p x p` dense penalty matrix.
142    /// For `Dense`, this is a clone.  For `Block`, this embeds `local` into a
143    /// zero matrix at the given `col_range`.
144    pub fn to_dense(&self) -> Array2<f64> {
145        match self {
146            PenaltySpec::Dense(m) => m.clone(),
147            PenaltySpec::DenseWithMean { matrix, .. } => matrix.clone(),
148            PenaltySpec::Block {
149                local, col_range, ..
150            } => {
151                let p = col_range.end.max(local.nrows());
152                // Caller should supply p externally when the total dim is larger;
153                // this is the best we can do without it.
154                let mut out = Array2::zeros((p, p));
155                out.slice_mut(s![col_range.clone(), col_range.clone()])
156                    .assign(local);
157                out
158            }
159        }
160    }
161
162    /// Materialize the full `p_total x p_total` dense penalty matrix.
163    /// For `Dense`, this is a clone (asserts that it matches `p_total`).
164    /// For `Block`, this embeds `local` into a `p_total x p_total` zero matrix.
165    pub fn to_global(&self, p_total: usize) -> Array2<f64> {
166        match self {
167            PenaltySpec::Dense(m) => {
168                assert_eq!(m.nrows(), p_total);
169                m.clone()
170            }
171            PenaltySpec::DenseWithMean { matrix, .. } => {
172                assert_eq!(matrix.nrows(), p_total);
173                matrix.clone()
174            }
175            PenaltySpec::Block {
176                local, col_range, ..
177            } => {
178                let mut out = Array2::zeros((p_total, p_total));
179                out.slice_mut(s![col_range.clone(), col_range.clone()])
180                    .assign(local);
181                out
182            }
183        }
184    }
185}
186
187/// Pure shape validation for a [`PenaltySpec`] against a coefficient dimension.
188///
189/// Neutral term-side check (no solver state): it only inspects block/dense
190/// dimensions and returns [`EstimationError`]. Moved byte-identically from the
191/// solver's `estimate::external_options` during the `#1521` carve so it can be
192/// single-sourced alongside the type it validates; the solver re-exports it via
193/// `gam_terms::validate_penalty_spec_shape`.
194pub fn validate_penalty_spec_shape(
195    idx: usize,
196    spec: &PenaltySpec,
197    p: usize,
198    context: &str,
199) -> Result<(), EstimationError> {
200    match spec {
201        PenaltySpec::Block {
202            local, col_range, ..
203        } => {
204            let bd = col_range.len();
205            if local.nrows() != bd || local.ncols() != bd {
206                crate::bail_invalid_estim!(
207                    "{context}: block penalty {idx} local matrix must be {bd}x{bd}, got {}x{}",
208                    local.nrows(),
209                    local.ncols()
210                );
211            }
212            if col_range.end > p {
213                crate::bail_invalid_estim!(
214                    "{context}: block penalty {idx} col_range {}..{} exceeds p={p}",
215                    col_range.start,
216                    col_range.end
217                );
218            }
219        }
220        PenaltySpec::Dense(m) => {
221            if m.nrows() != p || m.ncols() != p {
222                crate::bail_invalid_estim!(
223                    "{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
224                    m.nrows(),
225                    m.ncols()
226                );
227            }
228        }
229        PenaltySpec::DenseWithMean { matrix, .. } => {
230            if matrix.nrows() != p || matrix.ncols() != p {
231                crate::bail_invalid_estim!(
232                    "{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
233                    matrix.nrows(),
234                    matrix.ncols()
235                );
236            }
237        }
238    }
239    Ok(())
240}