Skip to main content

gam_problem/
coefficient_prior_mean.rs

1//! Neutral programmatic prior-mean type for a coefficient penalty block.
2//!
3//! Lives in `gam-problem` so the penalty contract can carry a centering vector
4//! without depending on `solver`'s `EstimationError`. Evaluation failures are
5//! reported through the neutral [`PriorMeanError`]; callers map this into their
6//! own error flow (e.g. `EstimationError::InvalidInput`).
7
8use std::sync::Arc;
9
10use ndarray::Array1;
11
12/// Neutral error for prior-mean evaluation failures.
13///
14/// Carries the human-readable message; callers in the solver crate map this
15/// into `EstimationError::InvalidInput` to preserve end-to-end behavior.
16#[derive(Debug, Clone)]
17pub struct PriorMeanError(pub String);
18
19impl std::fmt::Display for PriorMeanError {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.write_str(&self.0)
22    }
23}
24
25impl std::error::Error for PriorMeanError {}
26
27/// Programmatic prior mean for a coefficient penalty block.
28///
29/// The mean is evaluated once during penalty canonicalization and then enters
30/// the solver as the centering vector in `(beta - mean)' S (beta - mean)`.
31#[derive(Clone, Default)]
32pub enum CoefficientPriorMean {
33    #[default]
34    Zero,
35    Scalar(f64),
36    Constant(Array1<f64>),
37    Functional {
38        metadata: Array1<f64>,
39        evaluator: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
40    },
41    /// Covariate-functional mean `mu(a) = amplitude * K(a)` for a coefficient block.
42    ///
43    /// Formula-level coefficient groups pass their row/covariate metadata as
44    /// `covariates`; the user-supplied kernel returns the block-sized basis
45    /// vector `K(a)` and the scalar amplitude supplies `alpha`.
46    KernelBasis {
47        covariates: Array1<f64>,
48        amplitude: f64,
49        kernel: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
50    },
51}
52
53impl std::fmt::Debug for CoefficientPriorMean {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            Self::Zero => f.write_str("Zero"),
57            Self::Scalar(value) => f.debug_tuple("Scalar").field(value).finish(),
58            Self::Constant(values) => f
59                .debug_tuple("Constant")
60                .field(&format_args!("len={}", values.len()))
61                .finish(),
62            Self::Functional { metadata, .. } => f
63                .debug_struct("Functional")
64                .field("metadata_len", &metadata.len())
65                .finish_non_exhaustive(),
66            Self::KernelBasis {
67                covariates,
68                amplitude,
69                ..
70            } => f
71                .debug_struct("KernelBasis")
72                .field("covariate_len", &covariates.len())
73                .field("amplitude", amplitude)
74                .finish_non_exhaustive(),
75        }
76    }
77}
78
79impl CoefficientPriorMean {
80    pub const fn scalar(value: f64) -> Self {
81        Self::Scalar(value)
82    }
83
84    pub fn constant(values: Array1<f64>) -> Self {
85        Self::Constant(values)
86    }
87
88    pub fn functional(
89        metadata: Array1<f64>,
90        evaluator: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
91    ) -> Self {
92        Self::Functional {
93            metadata,
94            evaluator,
95        }
96    }
97
98    pub fn kernel_basis(
99        covariates: Array1<f64>,
100        amplitude: f64,
101        kernel: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
102    ) -> Self {
103        Self::KernelBasis {
104            covariates,
105            amplitude,
106            kernel,
107        }
108    }
109
110    pub fn evaluate(&self, block_dim: usize, context: &str) -> Result<Array1<f64>, PriorMeanError> {
111        let values = match self {
112            Self::Zero => Array1::zeros(block_dim),
113            Self::Scalar(value) => {
114                if !value.is_finite() {
115                    return Err(PriorMeanError(format!(
116                        "{context}: coefficient prior mean scalar must be finite, got {value}"
117                    )));
118                }
119                Array1::from_elem(block_dim, *value)
120            }
121            Self::Constant(values) => values.clone(),
122            Self::Functional {
123                metadata,
124                evaluator,
125            } => evaluator(metadata),
126            Self::KernelBasis {
127                covariates,
128                amplitude,
129                kernel,
130            } => {
131                if !amplitude.is_finite() {
132                    return Err(PriorMeanError(format!(
133                        "{context}: coefficient prior mean amplitude must be finite, got {amplitude}"
134                    )));
135                }
136                let mut values = kernel(covariates);
137                values *= *amplitude;
138                values
139            }
140        };
141        if values.len() != block_dim {
142            return Err(PriorMeanError(format!(
143                "{context}: coefficient prior mean length must be {block_dim}, got {}",
144                values.len()
145            )));
146        }
147        if values.iter().any(|&value| !value.is_finite()) {
148            return Err(PriorMeanError(format!(
149                "{context}: coefficient prior mean contains non-finite values"
150            )));
151        }
152        Ok(values)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use ndarray::array;
160
161    #[test]
162    fn zero_variant_returns_zeros_of_requested_len() {
163        let m = CoefficientPriorMean::Zero;
164        let v = m.evaluate(4, "ctx").unwrap();
165        assert_eq!(v.len(), 4);
166        assert!(v.iter().all(|&x| x == 0.0));
167    }
168
169    #[test]
170    fn scalar_fills_vector_with_constant() {
171        let m = CoefficientPriorMean::scalar(3.0);
172        let v = m.evaluate(3, "ctx").unwrap();
173        assert_eq!(v.len(), 3);
174        assert!(v.iter().all(|&x| x == 3.0));
175    }
176
177    #[test]
178    fn scalar_nan_returns_error() {
179        let m = CoefficientPriorMean::scalar(f64::NAN);
180        assert!(m.evaluate(2, "ctx").is_err());
181    }
182
183    #[test]
184    fn scalar_infinite_returns_error() {
185        let m = CoefficientPriorMean::scalar(f64::INFINITY);
186        assert!(m.evaluate(2, "ctx").is_err());
187    }
188
189    #[test]
190    fn constant_variant_clones_vector() {
191        let arr = array![1.0_f64, 2.0, 3.0];
192        let m = CoefficientPriorMean::constant(arr.clone());
193        let v = m.evaluate(3, "ctx").unwrap();
194        assert_eq!(v, arr);
195    }
196
197    #[test]
198    fn constant_dimension_mismatch_returns_error() {
199        let arr = array![1.0_f64, 2.0];
200        let m = CoefficientPriorMean::constant(arr);
201        assert!(m.evaluate(5, "ctx").is_err());
202    }
203
204    #[test]
205    fn functional_variant_calls_evaluator() {
206        let meta = array![0.0_f64];
207        let m = CoefficientPriorMean::functional(
208            meta,
209            Arc::new(|_| array![7.0_f64, 8.0]),
210        );
211        let v = m.evaluate(2, "ctx").unwrap();
212        assert_eq!(v[0], 7.0);
213        assert_eq!(v[1], 8.0);
214    }
215
216    #[test]
217    fn kernel_basis_scales_kernel_output() {
218        let covs = array![0.0_f64];
219        let m = CoefficientPriorMean::kernel_basis(
220            covs,
221            2.0,
222            Arc::new(|_| array![1.0_f64, 3.0]),
223        );
224        let v = m.evaluate(2, "ctx").unwrap();
225        assert!((v[0] - 2.0).abs() < 1e-14);
226        assert!((v[1] - 6.0).abs() < 1e-14);
227    }
228
229    #[test]
230    fn kernel_basis_nan_amplitude_returns_error() {
231        let covs = array![0.0_f64];
232        let m = CoefficientPriorMean::kernel_basis(
233            covs,
234            f64::NAN,
235            Arc::new(|_| array![1.0_f64]),
236        );
237        assert!(m.evaluate(1, "ctx").is_err());
238    }
239
240    #[test]
241    fn default_is_zero_variant() {
242        let m = CoefficientPriorMean::default();
243        let v = m.evaluate(5, "ctx").unwrap();
244        assert!(v.iter().all(|&x| x == 0.0));
245    }
246
247    #[test]
248    fn error_message_includes_context() {
249        let m = CoefficientPriorMean::scalar(f64::NAN);
250        let err = m.evaluate(1, "myctx").unwrap_err();
251        let msg = err.to_string();
252        assert!(msg.contains("myctx"), "error should mention context: {msg}");
253    }
254}