gam_terms/penalty_spec.rs
1//! Penalty specification for the public estimate API.
2//!
3//! `PenaltySpec` is a *penalty spec* keyed entirely on `gam-terms` penalty
4//! types (`PenaltyStructureHint`, `PenaltyOp`, `BlockwisePenalty`) plus the
5//! neutral `gam_problem::CoefficientPriorMean`. It therefore lives in
6//! `gam-terms` (the layer that owns those penalty primitives); the solver
7//! consumes it from above via `gam_terms::PenaltySpec`.
8//!
9//! Moved byte-identically from the monolith `src/model_types.rs` during the
10//! `#1521` carve, with the only changes being the crate-local module paths
11//! (`crate::terms::smooth::*` -> `crate::smooth::*`,
12//! `crate::terms::analytic_penalties::*` -> `crate::analytic_penalties::*`)
13//! and `EstimationError` sourced from `gam_problem`.
14
15use std::ops::Range;
16
17use ndarray::{Array2, s};
18
19use crate::smooth::{BlockwisePenalty, PenaltyStructureHint};
20
21/// Programmatic prior mean for a coefficient penalty block.
22///
23/// This type lives in the neutral `gam-problem` crate (with its inherent
24/// `evaluate` returning `gam_problem::PriorMeanError`); re-exported here so all
25/// existing `PenaltySpec`-adjacent references keep resolving. Solver-side
26/// callers map `PriorMeanError` into `EstimationError::InvalidInput`.
27pub use gam_problem::CoefficientPriorMean;
28pub use gam_problem::EstimationError;
29
30/// A penalty specification for the public estimate API.
31///
32/// `Block` stores only the active sub-block and its column range, avoiding
33/// the O(p^2) cost of embedding into a full penalty matrix.
34/// `Dense` stores a full `p x p` penalty matrix for callers that already
35/// have one.
36#[derive(Clone)]
37pub enum PenaltySpec {
38 /// Block-local penalty: `local` is `block_dim x block_dim`,
39 /// applied to columns `col_range` of the coefficient vector.
40 Block {
41 local: Array2<f64>,
42 col_range: Range<usize>,
43 prior_mean: CoefficientPriorMean,
44 /// Optional structural hint for fast-path spectral decomposition.
45 structure_hint: Option<PenaltyStructureHint>,
46 /// Optional operator-form handle bit-equivalent to `local`.
47 op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
48 },
49 /// Full dense penalty matrix (`p x p`).
50 Dense(Array2<f64>),
51 /// Full dense penalty matrix with a programmatic prior mean in the same
52 /// global coefficient basis.
53 DenseWithMean {
54 matrix: Array2<f64>,
55 prior_mean: CoefficientPriorMean,
56 },
57}
58
59impl std::fmt::Debug for PenaltySpec {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 PenaltySpec::Block {
63 local,
64 col_range,
65 prior_mean,
66 structure_hint,
67 op,
68 } => f
69 .debug_struct("Block")
70 .field(
71 "local",
72 &format_args!("{}×{}", local.nrows(), local.ncols()),
73 )
74 .field("col_range", col_range)
75 .field("prior_mean", prior_mean)
76 .field("structure_hint", structure_hint)
77 .field("op", &op.as_ref().map(|o| o.dim()))
78 .finish(),
79 PenaltySpec::Dense(m) => f
80 .debug_tuple("Dense")
81 .field(&format_args!("{}×{}", m.nrows(), m.ncols()))
82 .finish(),
83 PenaltySpec::DenseWithMean { matrix, prior_mean } => f
84 .debug_struct("DenseWithMean")
85 .field(
86 "matrix",
87 &format_args!("{}×{}", matrix.nrows(), matrix.ncols()),
88 )
89 .field("prior_mean", prior_mean)
90 .finish(),
91 }
92 }
93}
94
95impl PenaltySpec {
96 /// The column range this penalty covers.
97 /// For `Dense`, this is `0..p` where `p = m.ncols()`.
98 pub fn col_range(&self, p: usize) -> Range<usize> {
99 match self {
100 PenaltySpec::Block { col_range, .. } => col_range.clone(),
101 PenaltySpec::Dense(m) => {
102 assert_eq!(m.ncols(), p);
103 0..p
104 }
105 PenaltySpec::DenseWithMean { matrix, .. } => {
106 assert_eq!(matrix.ncols(), p);
107 0..p
108 }
109 }
110 }
111
112 /// Op-form handle when present (only for `Block`; `Dense` always returns `None`).
113 pub fn op(&self) -> Option<&std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>> {
114 match self {
115 PenaltySpec::Block { op, .. } => op.as_ref(),
116 PenaltySpec::Dense(_) | PenaltySpec::DenseWithMean { .. } => None,
117 }
118 }
119
120 /// Convert from a `BlockwisePenalty`, preserving the structure hint and op.
121 pub fn from_blockwise(bp: BlockwisePenalty) -> Self {
122 PenaltySpec::Block {
123 local: bp.local,
124 col_range: bp.col_range,
125 prior_mean: bp.prior_mean,
126 structure_hint: bp.structure_hint,
127 op: bp.op,
128 }
129 }
130
131 pub fn from_blockwise_ref(bp: &BlockwisePenalty) -> Self {
132 PenaltySpec::Block {
133 local: bp.local.clone(),
134 col_range: bp.col_range.clone(),
135 prior_mean: bp.prior_mean.clone(),
136 structure_hint: bp.structure_hint.clone(),
137 op: bp.op.clone(),
138 }
139 }
140
141 /// Materialize the full `p x p` dense penalty matrix.
142 /// For `Dense`, this is a clone. For `Block`, this embeds `local` into a
143 /// zero matrix at the given `col_range`.
144 pub fn to_dense(&self) -> Array2<f64> {
145 match self {
146 PenaltySpec::Dense(m) => m.clone(),
147 PenaltySpec::DenseWithMean { matrix, .. } => matrix.clone(),
148 PenaltySpec::Block {
149 local, col_range, ..
150 } => {
151 let p = col_range.end.max(local.nrows());
152 // Caller should supply p externally when the total dim is larger;
153 // this is the best we can do without it.
154 let mut out = Array2::zeros((p, p));
155 out.slice_mut(s![col_range.clone(), col_range.clone()])
156 .assign(local);
157 out
158 }
159 }
160 }
161
162 /// Materialize the full `p_total x p_total` dense penalty matrix.
163 /// For `Dense`, this is a clone (asserts that it matches `p_total`).
164 /// For `Block`, this embeds `local` into a `p_total x p_total` zero matrix.
165 pub fn to_global(&self, p_total: usize) -> Array2<f64> {
166 match self {
167 PenaltySpec::Dense(m) => {
168 assert_eq!(m.nrows(), p_total);
169 m.clone()
170 }
171 PenaltySpec::DenseWithMean { matrix, .. } => {
172 assert_eq!(matrix.nrows(), p_total);
173 matrix.clone()
174 }
175 PenaltySpec::Block {
176 local, col_range, ..
177 } => {
178 let mut out = Array2::zeros((p_total, p_total));
179 out.slice_mut(s![col_range.clone(), col_range.clone()])
180 .assign(local);
181 out
182 }
183 }
184 }
185}
186
187/// Pure shape validation for a [`PenaltySpec`] against a coefficient dimension.
188///
189/// Neutral term-side check (no solver state): it only inspects block/dense
190/// dimensions and returns [`EstimationError`]. Moved byte-identically from the
191/// solver's `estimate::external_options` during the `#1521` carve so it can be
192/// single-sourced alongside the type it validates; the solver re-exports it via
193/// `gam_terms::validate_penalty_spec_shape`.
194pub fn validate_penalty_spec_shape(
195 idx: usize,
196 spec: &PenaltySpec,
197 p: usize,
198 context: &str,
199) -> Result<(), EstimationError> {
200 match spec {
201 PenaltySpec::Block {
202 local, col_range, ..
203 } => {
204 let bd = col_range.len();
205 if local.nrows() != bd || local.ncols() != bd {
206 crate::bail_invalid_estim!(
207 "{context}: block penalty {idx} local matrix must be {bd}x{bd}, got {}x{}",
208 local.nrows(),
209 local.ncols()
210 );
211 }
212 if col_range.end > p {
213 crate::bail_invalid_estim!(
214 "{context}: block penalty {idx} col_range {}..{} exceeds p={p}",
215 col_range.start,
216 col_range.end
217 );
218 }
219 }
220 PenaltySpec::Dense(m) => {
221 if m.nrows() != p || m.ncols() != p {
222 crate::bail_invalid_estim!(
223 "{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
224 m.nrows(),
225 m.ncols()
226 );
227 }
228 }
229 PenaltySpec::DenseWithMean { matrix, .. } => {
230 if matrix.nrows() != p || matrix.ncols() != p {
231 crate::bail_invalid_estim!(
232 "{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
233 matrix.nrows(),
234 matrix.ncols()
235 );
236 }
237 }
238 }
239 Ok(())
240}