gam_terms/analytic_penalties/
op.rs1use std::sync::Arc;
19
20use faer::Side;
21use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
22
23use crate::basis::closed_form_operator::ClosedFormPenaltyOperator;
24use gam_linalg::faer_ndarray::{FaerEigh, fast_av_view_into};
25
26pub trait PenaltyOp: Send + Sync {
33 fn dim(&self) -> usize;
35
36 fn matvec(&self, w: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>);
38
39 fn diag(&self) -> Array1<f64>;
41
42 fn trace(&self) -> f64 {
44 self.diag().sum()
45 }
46
47 fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String>;
51
52 fn eigendecompose(&self) -> Result<(Array1<f64>, Array2<f64>), String> {
59 let dense = self.as_dense();
60 dense
61 .eigh(Side::Lower)
62 .map_err(|e| format!("PenaltyOp::eigendecompose: {e}"))
63 }
64
65 fn as_dense(&self) -> Array2<f64>;
70}
71
72impl PenaltyOp for Array2<f64> {
73 fn dim(&self) -> usize {
74 assert_eq!(
75 self.nrows(),
76 self.ncols(),
77 "PenaltyOp matrix must be square"
78 );
79 self.nrows()
80 }
81
82 fn matvec(&self, w: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
83 fast_av_view_into(self, &w, out);
84 }
85
86 fn diag(&self) -> Array1<f64> {
87 let n = self.nrows();
88 let mut d = Array1::<f64>::zeros(n);
89 for i in 0..n {
90 d[i] = self[[i, i]];
91 }
92 d
93 }
94
95 fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
96 assert!(lambda > 0.0, "log_det_plus_lambda_i requires λ > 0");
97 let n = <Self as PenaltyOp>::dim(self);
98 let mut regularized = self.clone();
99 for i in 0..n {
100 regularized[[i, i]] += lambda;
101 }
102 let (evals, _) = regularized.eigh(Side::Lower).map_err(|e| {
103 format!("PenaltyOp::log_det_plus_lambda_i eigendecomposition failed: {e}")
104 })?;
105 let mut logdet = 0.0;
106 for (idx, &ev) in evals.iter().enumerate() {
107 if !ev.is_finite() || ev <= 0.0 {
108 return Err(format!(
109 "PenaltyOp::log_det_plus_lambda_i expected SPD S+λI, \
110 eigenvalue {idx} is {ev:.3e}"
111 ));
112 }
113 logdet += ev.ln();
114 }
115 Ok(logdet)
116 }
117
118 fn as_dense(&self) -> Array2<f64> {
119 self.clone()
120 }
121}
122
123impl PenaltyOp for ClosedFormPenaltyOperator {
124 fn dim(&self) -> usize {
125 self.dim()
126 }
127
128 fn matvec(&self, w: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
129 self.matvec(w, out)
130 }
131
132 fn diag(&self) -> Array1<f64> {
133 self.diag()
134 }
135
136 fn trace(&self) -> f64 {
137 self.trace()
138 }
139
140 fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
141 self.log_det_plus_lambda_i(lambda)
142 }
143
144 fn as_dense(&self) -> Array2<f64> {
145 self.dense_form()
146 }
147}
148
149pub struct ScaledPenaltyOp {
155 inner: Arc<dyn PenaltyOp>,
156 scale: f64,
157}
158
159impl ScaledPenaltyOp {
160 pub fn new(inner: Arc<dyn PenaltyOp>, scale: f64) -> Self {
161 Self { inner, scale }
162 }
163}
164
165impl PenaltyOp for ScaledPenaltyOp {
166 fn dim(&self) -> usize {
167 self.inner.dim()
168 }
169
170 fn matvec(&self, w: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
171 self.inner.matvec(w, out.view_mut());
172 out.mapv_inplace(|v| v * self.scale);
173 }
174
175 fn diag(&self) -> Array1<f64> {
176 let mut d = self.inner.diag();
177 d.mapv_inplace(|v| v * self.scale);
178 d
179 }
180
181 fn trace(&self) -> f64 {
182 self.inner.trace() * self.scale
183 }
184
185 fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
186 let dense = self.as_dense();
190 <Array2<f64> as PenaltyOp>::log_det_plus_lambda_i(&dense, lambda)
191 }
192
193 fn as_dense(&self) -> Array2<f64> {
194 let mut m = self.inner.as_dense();
195 m.mapv_inplace(|v| v * self.scale);
196 m
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use approx::assert_abs_diff_eq;
204 use ndarray::Array;
205
206 fn psd_fixture() -> Array2<f64> {
207 let b = Array::from_shape_vec(
209 (3, 4),
210 vec![
211 1.0, -0.3, 0.7, 0.1, 0.2, 1.1, -0.5, 0.4, -0.1, 0.6, 0.9, -0.2,
212 ],
213 )
214 .unwrap();
215 b.t().dot(&b)
216 }
217
218 #[test]
219 fn array2_impl_matvec_matches_dot() {
220 let s = psd_fixture();
221 let v = Array1::from_vec(vec![0.7, -0.4, 0.2, 1.1]);
222 let mut out = Array1::<f64>::zeros(4);
223 s.matvec(v.view(), out.view_mut());
224 let want = s.dot(&v);
225 for i in 0..4 {
226 assert_abs_diff_eq!(out[i], want[i], epsilon = 1e-12);
227 }
228 }
229
230 #[test]
231 fn array2_impl_diag_and_trace() {
232 let s = psd_fixture();
233 let d = <Array2<f64> as PenaltyOp>::diag(&s);
234 for i in 0..4 {
235 assert_abs_diff_eq!(d[i], s[[i, i]], epsilon = 0.0);
236 }
237 let tr = <Array2<f64> as PenaltyOp>::trace(&s);
238 assert_abs_diff_eq!(tr, s.diag().sum(), epsilon = 0.0);
239 }
240
241 #[test]
242 fn array2_impl_eigendecompose_matches_eigh() {
243 let s = psd_fixture();
244 let (evals_op, evecs_op) = <Array2<f64> as PenaltyOp>::eigendecompose(&s).expect("eigh");
245 let (evals_ref, evecs_ref) = s.eigh(Side::Lower).expect("eigh ref");
246 for i in 0..evals_op.len() {
247 assert_abs_diff_eq!(evals_op[i], evals_ref[i], epsilon = 1e-12);
248 }
249 let p_op = evecs_op.dot(&evecs_op.t());
251 let p_ref = evecs_ref.dot(&evecs_ref.t());
252 for i in 0..p_op.nrows() {
253 for j in 0..p_op.ncols() {
254 assert_abs_diff_eq!(p_op[[i, j]], p_ref[[i, j]], epsilon = 1e-12);
255 }
256 }
257 }
258}