Skip to main content

gam_problem/
coefficient_prior_mean.rs

1//! Neutral programmatic prior-mean type for a coefficient penalty block.
2//!
3//! Lives in `gam-problem` so the penalty contract can carry a centering vector
4//! without depending on `solver`'s `EstimationError`. Evaluation failures are
5//! reported through the neutral [`PriorMeanError`]; callers map this into their
6//! own error flow (e.g. `EstimationError::InvalidInput`).
7
8use std::sync::Arc;
9
10use ndarray::Array1;
11
12/// Neutral error for prior-mean evaluation failures.
13///
14/// Carries the human-readable message; callers in the solver crate map this
15/// into `EstimationError::InvalidInput` to preserve end-to-end behavior.
16#[derive(Debug, Clone)]
17pub struct PriorMeanError(pub String);
18
19impl std::fmt::Display for PriorMeanError {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.write_str(&self.0)
22    }
23}
24
25impl std::error::Error for PriorMeanError {}
26
27/// Programmatic prior mean for a coefficient penalty block.
28///
29/// The mean is evaluated once during penalty canonicalization and then enters
30/// the solver as the centering vector in `(beta - mean)' S (beta - mean)`.
31#[derive(Clone, Default)]
32pub enum CoefficientPriorMean {
33    #[default]
34    Zero,
35    Scalar(f64),
36    Constant(Array1<f64>),
37    Functional {
38        metadata: Array1<f64>,
39        evaluator: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
40    },
41    /// Covariate-functional mean `mu(a) = amplitude * K(a)` for a coefficient block.
42    ///
43    /// Formula-level coefficient groups pass their row/covariate metadata as
44    /// `covariates`; the user-supplied kernel returns the block-sized basis
45    /// vector `K(a)` and the scalar amplitude supplies `alpha`.
46    KernelBasis {
47        covariates: Array1<f64>,
48        amplitude: f64,
49        kernel: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
50    },
51}
52
53impl std::fmt::Debug for CoefficientPriorMean {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            Self::Zero => f.write_str("Zero"),
57            Self::Scalar(value) => f.debug_tuple("Scalar").field(value).finish(),
58            Self::Constant(values) => f
59                .debug_tuple("Constant")
60                .field(&format_args!("len={}", values.len()))
61                .finish(),
62            Self::Functional { metadata, .. } => f
63                .debug_struct("Functional")
64                .field("metadata_len", &metadata.len())
65                .finish_non_exhaustive(),
66            Self::KernelBasis {
67                covariates,
68                amplitude,
69                ..
70            } => f
71                .debug_struct("KernelBasis")
72                .field("covariate_len", &covariates.len())
73                .field("amplitude", amplitude)
74                .finish_non_exhaustive(),
75        }
76    }
77}
78
79impl CoefficientPriorMean {
80    pub const fn scalar(value: f64) -> Self {
81        Self::Scalar(value)
82    }
83
84    pub fn constant(values: Array1<f64>) -> Self {
85        Self::Constant(values)
86    }
87
88    pub fn functional(
89        metadata: Array1<f64>,
90        evaluator: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
91    ) -> Self {
92        Self::Functional {
93            metadata,
94            evaluator,
95        }
96    }
97
98    pub fn kernel_basis(
99        covariates: Array1<f64>,
100        amplitude: f64,
101        kernel: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
102    ) -> Self {
103        Self::KernelBasis {
104            covariates,
105            amplitude,
106            kernel,
107        }
108    }
109
110    pub fn evaluate(&self, block_dim: usize, context: &str) -> Result<Array1<f64>, PriorMeanError> {
111        let values = match self {
112            Self::Zero => Array1::zeros(block_dim),
113            Self::Scalar(value) => {
114                if !value.is_finite() {
115                    return Err(PriorMeanError(format!(
116                        "{context}: coefficient prior mean scalar must be finite, got {value}"
117                    )));
118                }
119                Array1::from_elem(block_dim, *value)
120            }
121            Self::Constant(values) => values.clone(),
122            Self::Functional {
123                metadata,
124                evaluator,
125            } => evaluator(metadata),
126            Self::KernelBasis {
127                covariates,
128                amplitude,
129                kernel,
130            } => {
131                if !amplitude.is_finite() {
132                    return Err(PriorMeanError(format!(
133                        "{context}: coefficient prior mean amplitude must be finite, got {amplitude}"
134                    )));
135                }
136                let mut values = kernel(covariates);
137                values *= *amplitude;
138                values
139            }
140        };
141        if values.len() != block_dim {
142            return Err(PriorMeanError(format!(
143                "{context}: coefficient prior mean length must be {block_dim}, got {}",
144                values.len()
145            )));
146        }
147        if values.iter().any(|&value| !value.is_finite()) {
148            return Err(PriorMeanError(format!(
149                "{context}: coefficient prior mean contains non-finite values"
150            )));
151        }
152        Ok(values)
153    }
154}