gam_problem/
coefficient_prior_mean.rs1use std::sync::Arc;
9
10use ndarray::Array1;
11
12#[derive(Debug, Clone)]
17pub struct PriorMeanError(pub String);
18
19impl std::fmt::Display for PriorMeanError {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.write_str(&self.0)
22 }
23}
24
25impl std::error::Error for PriorMeanError {}
26
27#[derive(Clone, Default)]
32pub enum CoefficientPriorMean {
33 #[default]
34 Zero,
35 Scalar(f64),
36 Constant(Array1<f64>),
37 Functional {
38 metadata: Array1<f64>,
39 evaluator: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
40 },
41 KernelBasis {
47 covariates: Array1<f64>,
48 amplitude: f64,
49 kernel: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
50 },
51}
52
53impl std::fmt::Debug for CoefficientPriorMean {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 Self::Zero => f.write_str("Zero"),
57 Self::Scalar(value) => f.debug_tuple("Scalar").field(value).finish(),
58 Self::Constant(values) => f
59 .debug_tuple("Constant")
60 .field(&format_args!("len={}", values.len()))
61 .finish(),
62 Self::Functional { metadata, .. } => f
63 .debug_struct("Functional")
64 .field("metadata_len", &metadata.len())
65 .finish_non_exhaustive(),
66 Self::KernelBasis {
67 covariates,
68 amplitude,
69 ..
70 } => f
71 .debug_struct("KernelBasis")
72 .field("covariate_len", &covariates.len())
73 .field("amplitude", amplitude)
74 .finish_non_exhaustive(),
75 }
76 }
77}
78
79impl CoefficientPriorMean {
80 pub const fn scalar(value: f64) -> Self {
81 Self::Scalar(value)
82 }
83
84 pub fn constant(values: Array1<f64>) -> Self {
85 Self::Constant(values)
86 }
87
88 pub fn functional(
89 metadata: Array1<f64>,
90 evaluator: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
91 ) -> Self {
92 Self::Functional {
93 metadata,
94 evaluator,
95 }
96 }
97
98 pub fn kernel_basis(
99 covariates: Array1<f64>,
100 amplitude: f64,
101 kernel: Arc<dyn Fn(&Array1<f64>) -> Array1<f64> + Send + Sync>,
102 ) -> Self {
103 Self::KernelBasis {
104 covariates,
105 amplitude,
106 kernel,
107 }
108 }
109
110 pub fn evaluate(&self, block_dim: usize, context: &str) -> Result<Array1<f64>, PriorMeanError> {
111 let values = match self {
112 Self::Zero => Array1::zeros(block_dim),
113 Self::Scalar(value) => {
114 if !value.is_finite() {
115 return Err(PriorMeanError(format!(
116 "{context}: coefficient prior mean scalar must be finite, got {value}"
117 )));
118 }
119 Array1::from_elem(block_dim, *value)
120 }
121 Self::Constant(values) => values.clone(),
122 Self::Functional {
123 metadata,
124 evaluator,
125 } => evaluator(metadata),
126 Self::KernelBasis {
127 covariates,
128 amplitude,
129 kernel,
130 } => {
131 if !amplitude.is_finite() {
132 return Err(PriorMeanError(format!(
133 "{context}: coefficient prior mean amplitude must be finite, got {amplitude}"
134 )));
135 }
136 let mut values = kernel(covariates);
137 values *= *amplitude;
138 values
139 }
140 };
141 if values.len() != block_dim {
142 return Err(PriorMeanError(format!(
143 "{context}: coefficient prior mean length must be {block_dim}, got {}",
144 values.len()
145 )));
146 }
147 if values.iter().any(|&value| !value.is_finite()) {
148 return Err(PriorMeanError(format!(
149 "{context}: coefficient prior mean contains non-finite values"
150 )));
151 }
152 Ok(values)
153 }
154}