1use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1};
12use rayon::prelude::*;
13use smallvec::SmallVec;
14
15use crate::basis::{
16 closed_form_anisotropic_pair_block, closed_form_anisotropic_pair_value_with_powers,
17 closed_form_penalty, pure_duchon_diagonal_epsilon,
18};
19use gam_linalg::faer_ndarray::{fast_ab, fast_atb};
20
21pub struct ClosedFormPenaltyOperator {
29 q: usize,
31 m: usize,
33 s: usize,
35 kappa: f64,
37 centers: Array2<f64>,
39 eta_raw: Vec<f64>,
43 eta_metric_powers: closed_form_penalty::AnisoMetricPowers,
45 kernel_nullspace: Option<Array2<f64>>,
47 polynomial_block_cols: usize,
49 outer_identifiability: Option<Array2<f64>>,
51 diagonal_epsilon: f64,
55 cached_dense: gam_runtime::resource::RayonSafeOnce<Array2<f64>>,
64}
65
66impl Clone for ClosedFormPenaltyOperator {
70 fn clone(&self) -> Self {
71 Self {
72 q: self.q,
73 m: self.m,
74 s: self.s,
75 kappa: self.kappa,
76 centers: self.centers.clone(),
77 eta_raw: self.eta_raw.clone(),
78 eta_metric_powers: self.eta_metric_powers.clone(),
79 kernel_nullspace: self.kernel_nullspace.clone(),
80 polynomial_block_cols: self.polynomial_block_cols,
81 outer_identifiability: self.outer_identifiability.clone(),
82 diagonal_epsilon: self.diagonal_epsilon,
83 cached_dense: gam_runtime::resource::RayonSafeOnce::new(),
84 }
85 }
86}
87
88impl ClosedFormPenaltyOperator {
89 pub fn new(
92 centers: ArrayView2<'_, f64>,
93 q: usize,
94 m: usize,
95 s: usize,
96 kappa: f64,
97 aniso_log_scales: Option<&[f64]>,
98 kernel_nullspace: Option<&Array2<f64>>,
99 polynomial_block_cols: usize,
100 outer_identifiability: Option<&Array2<f64>>,
101 ) -> Self {
102 let d = centers.ncols();
103 let eta_raw: Vec<f64> = if let Some(eta) = aniso_log_scales {
104 assert_eq!(
105 eta.len(),
106 d,
107 "ClosedFormPenaltyOperator::new: eta dimension mismatch"
108 );
109 eta.to_vec()
110 } else {
111 vec![0.0_f64; d]
112 };
113 let diagonal_epsilon =
114 if closed_form_penalty::analytic_self_pair_bundle(q, m, s, kappa, &eta_raw).is_some() {
115 0.0
116 } else {
117 pure_duchon_diagonal_epsilon(centers, &eta_raw)
118 };
119 Self {
120 q,
121 m,
122 s,
123 kappa,
124 centers: centers.to_owned(),
125 eta_metric_powers: closed_form_penalty::AnisoMetricPowers::new(&eta_raw),
126 eta_raw,
127 kernel_nullspace: kernel_nullspace.cloned(),
128 polynomial_block_cols,
129 outer_identifiability: outer_identifiability.cloned(),
130 diagonal_epsilon,
131 cached_dense: gam_runtime::resource::RayonSafeOnce::new(),
132 }
133 }
134
135 fn ensure_dense(&self) -> &Array2<f64> {
137 self.cached_dense.get_or_compute(|| self.build_dense())
138 }
139
140 pub fn dim(&self) -> usize {
143 let kernel_cols = self
144 .kernel_nullspace
145 .as_ref()
146 .map(|z| z.ncols())
147 .unwrap_or_else(|| self.centers.nrows());
148 let total_pre = kernel_cols + self.polynomial_block_cols;
149 match &self.outer_identifiability {
150 Some(t) => t.ncols(),
151 None => total_pre,
152 }
153 }
154
155 #[inline]
156 fn is_raw_layout(&self) -> bool {
157 self.kernel_nullspace.is_none()
158 && self.polynomial_block_cols == 0
159 && self.outer_identifiability.is_none()
160 }
161
162 fn raw_diagonal_value(&self) -> f64 {
163 let mut r0: SmallVec<[f64; 16]> = SmallVec::with_capacity(self.centers.ncols());
164 r0.resize(self.centers.ncols(), 0.0);
165 closed_form_anisotropic_pair_value_with_powers(
166 self.q,
167 self.m,
168 self.s,
169 self.kappa,
170 &self.eta_raw,
171 &self.eta_metric_powers,
172 r0.as_slice(),
173 self.diagonal_epsilon,
174 )
175 }
176
177 pub fn matvec(&self, w: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
188 assert_eq!(
189 w.len(),
190 self.dim(),
191 "ClosedFormPenaltyOperator::matvec: input dim mismatch"
192 );
193 assert_eq!(
194 out.len(),
195 self.dim(),
196 "ClosedFormPenaltyOperator::matvec: output dim mismatch"
197 );
198
199 let pre = match &self.outer_identifiability {
200 Some(t) => t.dot(&w),
201 None => w.to_owned(),
202 };
203 let kernel_cols = self
204 .kernel_nullspace
205 .as_ref()
206 .map(|z| z.ncols())
207 .unwrap_or_else(|| self.centers.nrows());
208 let pre_kernel = pre.slice(ndarray::s![0..kernel_cols]);
209 let raw_input = match &self.kernel_nullspace {
210 Some(z) => z.dot(&pre_kernel),
211 None => pre_kernel.to_owned(),
212 };
213 let raw_output = self.raw_pair_matvec(raw_input.view());
214 let kernel_output = match &self.kernel_nullspace {
215 Some(z) => z.t().dot(&raw_output),
216 None => raw_output,
217 };
218 let total_pre = kernel_cols + self.polynomial_block_cols;
219 let mut projected = Array1::<f64>::zeros(total_pre);
220 projected
221 .slice_mut(ndarray::s![0..kernel_cols])
222 .assign(&kernel_output);
223 let final_output = match &self.outer_identifiability {
224 Some(t) => t.t().dot(&projected),
225 None => projected,
226 };
227 out.assign(&final_output);
228 }
229
230 pub fn diag(&self) -> Array1<f64> {
235 let n = self.dim();
236 if self.is_raw_layout() {
237 return Array1::from_elem(n, self.raw_diagonal_value());
238 }
239 let dense = self.ensure_dense();
244 Array1::<f64>::from_iter((0..n).map(|i| dense[[i, i]]))
245 }
246
247 pub fn trace(&self) -> f64 {
250 if self.is_raw_layout() {
251 return self.raw_diagonal_value() * self.dim() as f64;
252 }
253 self.diag().sum()
254 }
255
256 pub fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
260 assert!(lambda > 0.0, "log_det_plus_lambda_i requires λ > 0");
261 let n = self.dim();
262 let mut dense = self.dense_form();
263 for i in 0..n {
264 dense[[i, i]] += lambda;
265 }
266 let (evals, _) = gam_linalg::faer_ndarray::FaerEigh::eigh(&dense, faer::Side::Lower)
267 .map_err(|e| {
268 format!("ClosedFormPenaltyOperator logdet eigendecomposition failed: {e}")
269 })?;
270 let mut logdet = 0.0;
271 for (idx, &ev) in evals.iter().enumerate() {
272 if !ev.is_finite() || ev <= 0.0 {
273 return Err(format!(
274 "ClosedFormPenaltyOperator expected SPD S+λI, eigenvalue {idx} is {ev:.3e}"
275 ));
276 }
277 logdet += ev.ln();
278 }
279 Ok(logdet)
280 }
281
282 pub fn dense_form(&self) -> Array2<f64> {
287 self.ensure_dense().clone()
288 }
289
290 fn build_dense(&self) -> Array2<f64> {
293 let g_raw = closed_form_anisotropic_pair_block(
296 self.centers.view(),
297 self.q,
298 self.m,
299 self.s,
300 self.kappa,
301 if self.eta_raw.iter().all(|&e| e == 0.0) {
302 None
303 } else {
304 Some(self.eta_raw.as_slice())
305 },
306 );
307 let kernel_cols = self
308 .kernel_nullspace
309 .as_ref()
310 .map(|z| z.ncols())
311 .unwrap_or_else(|| self.centers.nrows());
312 let g_kernel = match &self.kernel_nullspace {
313 Some(z) => {
314 let zt_g = fast_atb(z, &g_raw);
315 fast_ab(&zt_g, z)
316 }
317 None => g_raw,
318 };
319 let total_pre = kernel_cols + self.polynomial_block_cols;
320 let g_padded = if self.polynomial_block_cols == 0 {
321 g_kernel
322 } else {
323 let mut padded = Array2::<f64>::zeros((total_pre, total_pre));
324 padded
325 .slice_mut(ndarray::s![0..kernel_cols, 0..kernel_cols])
326 .assign(&g_kernel);
327 padded
328 };
329 match &self.outer_identifiability {
330 Some(t) => {
331 let tt_g = fast_atb(t, &g_padded);
332 fast_ab(&tt_g, t)
333 }
334 None => g_padded,
335 }
336 }
337
338 fn raw_pair_matvec(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
339 assert_eq!(
340 v.len(),
341 self.centers.nrows(),
342 "ClosedFormPenaltyOperator::raw_pair_matvec: input dim mismatch"
343 );
344 let k = self.centers.nrows();
345 let d = self.centers.ncols();
346 let rows: Vec<f64> = (0..k)
347 .into_par_iter()
348 .map(|i| {
349 let mut r: SmallVec<[f64; 16]> = SmallVec::with_capacity(d);
350 r.resize(d, 0.0);
351 let mut sum = 0.0_f64;
352 let mut correction = 0.0_f64;
353 for j in 0..k {
354 for axis in 0..d {
355 r[axis] = self.centers[[i, axis]] - self.centers[[j, axis]];
356 }
357 let gij = closed_form_anisotropic_pair_value_with_powers(
358 self.q,
359 self.m,
360 self.s,
361 self.kappa,
362 &self.eta_raw,
363 &self.eta_metric_powers,
364 r.as_slice(),
365 self.diagonal_epsilon,
366 );
367 let y = gij * v[j] - correction;
368 let next = sum + y;
369 correction = (next - sum) - y;
370 sum = next;
371 }
372 sum
373 })
374 .collect();
375 Array1::from_vec(rows)
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use approx::assert_abs_diff_eq;
383 use ndarray::Array;
384
385 fn small_centers() -> Array2<f64> {
386 Array::from_shape_vec(
387 (5, 2),
388 vec![
389 0.10, 0.20, 0.40, 0.15, 0.55, 0.65, 0.80, 0.30, 0.25, 0.85, ],
395 )
396 .unwrap()
397 }
398
399 #[test]
400 fn test_operator_dense_agrees_unconstrained() {
401 let centers = small_centers();
402 let op = ClosedFormPenaltyOperator::new(
403 centers.view(),
404 1, 2,
406 1,
407 1.0,
408 None,
409 None,
410 0,
411 None,
412 );
413 let dense = op.dense_form();
414 let n = op.dim();
415 let mut e = Array1::<f64>::zeros(n);
416 let mut col = Array1::<f64>::zeros(n);
417 for i in 0..n {
418 e.fill(0.0);
419 e[i] = 1.0;
420 op.matvec(e.view(), col.view_mut());
421 for j in 0..n {
422 let scale = dense[[j, i]].abs().max(1.0);
423 assert_abs_diff_eq!(col[j], dense[[j, i]], epsilon = 1e-9 * scale);
424 }
425 }
426 }
427
428 #[test]
429 fn test_operator_diag_agrees() {
430 let centers = small_centers();
431 let op = ClosedFormPenaltyOperator::new(
432 centers.view(),
433 2, 2,
435 1,
436 0.5,
437 Some(&[0.10, -0.10]),
438 None,
439 0,
440 None,
441 );
442 let dense = op.dense_form();
443 let diag_op = op.diag();
444 for i in 0..op.dim() {
445 assert_abs_diff_eq!(diag_op[i], dense[[i, i]], epsilon = 1e-9);
446 }
447 }
448
449 #[test]
450 fn test_operator_matvec_random_vector() {
451 let centers = small_centers();
452 let op = ClosedFormPenaltyOperator::new(
453 centers.view(),
454 0, 2,
456 1,
457 1.5,
458 None,
459 None,
460 0,
461 None,
462 );
463 let dense = op.dense_form();
464 let n = op.dim();
465 let mut state: u64 = 0x9E37_79B9_7F4A_7C15;
468 let mut v = Array1::<f64>::zeros(n);
469 for vi in v.iter_mut() {
470 state = state
471 .wrapping_mul(6364136223846793005)
472 .wrapping_add(1442695040888963407);
473 *vi = ((state >> 11) as f64 / (1u64 << 53) as f64) - 0.5;
474 }
475 let mut got = Array1::<f64>::zeros(n);
476 op.matvec(v.view(), got.view_mut());
477 let want = dense.dot(&v);
478 for i in 0..n {
479 assert_abs_diff_eq!(got[i], want[i], epsilon = 1e-9);
480 }
481 }
482
483 #[test]
484 fn test_operator_matvec_stays_matrix_free_until_dense_requested() {
485 let centers = small_centers();
486 let op = ClosedFormPenaltyOperator::new(
487 centers.view(),
488 1,
489 2,
490 1,
491 1.0,
492 Some(&[0.35, 0.10]),
493 None,
494 0,
495 None,
496 );
497 let v = Array1::from_vec(vec![0.2, -0.1, 0.4, -0.3, 0.7]);
498 let mut out = Array1::<f64>::zeros(op.dim());
499 op.matvec(v.view(), out.view_mut());
500 assert!(
501 op.cached_dense.get().is_none(),
502 "matvec must not populate the dense KxK cache"
503 );
504 let dense = op.dense_form();
505 assert!(
506 op.cached_dense.get().is_some(),
507 "dense_form should be the only path that populates the dense cache"
508 );
509 let expected = dense.dot(&v);
510 for i in 0..op.dim() {
511 assert_abs_diff_eq!(out[i], expected[i], epsilon = 1e-8);
512 }
513 }
514
515 #[test]
516 fn test_operator_preserves_raw_anisotropy_coordinates() {
517 let centers = small_centers();
518 let eta = [0.35, 0.10];
519 let op =
520 ClosedFormPenaltyOperator::new(centers.view(), 1, 2, 1, 1.0, Some(&eta), None, 0, None);
521 let dense = op.dense_form();
522 let reference = crate::basis::closed_form_operator_penalty_in_total_basis(
523 centers.view(),
524 1,
525 2,
526 1,
527 1.0,
528 Some(&eta),
529 None,
530 0,
531 None,
532 );
533 for i in 0..op.dim() {
534 for j in 0..op.dim() {
535 let scale = reference[[i, j]].abs().max(1.0);
536 assert_abs_diff_eq!(dense[[i, j]], reference[[i, j]], epsilon = 1e-12 * scale);
537 }
538 }
539 }
540
541 #[test]
542 fn test_operator_with_kernel_nullspace_constraint() {
543 let centers = small_centers();
544 let k = centers.nrows();
545 let mut z = Array2::<f64>::zeros((k, k - 1));
547 let inv_sqrt_k = 1.0 / (k as f64).sqrt();
548 let constant: Vec<f64> = (0..k).map(|_| inv_sqrt_k).collect();
549 for c in 0..(k - 1) {
551 let mut col = vec![0.0; k];
552 col[c + 1] = 1.0;
553 let inner: f64 = col.iter().zip(constant.iter()).map(|(a, b)| a * b).sum();
554 for i in 0..k {
555 col[i] -= inner * constant[i];
556 }
557 let norm = col.iter().map(|v| v * v).sum::<f64>().sqrt();
558 for i in 0..k {
559 z[[i, c]] = col[i] / norm;
560 }
561 }
562
563 let op = ClosedFormPenaltyOperator::new(
564 centers.view(),
565 1,
566 2,
567 1,
568 1.0,
569 Some(&[0.05, -0.05]),
570 Some(&z),
571 0,
572 None,
573 );
574 let dense = op.dense_form();
575 let n = op.dim();
576 assert_eq!(n, k - 1);
577 let mut e = Array1::<f64>::zeros(n);
578 let mut col = Array1::<f64>::zeros(n);
579 for i in 0..n {
580 e.fill(0.0);
581 e[i] = 1.0;
582 op.matvec(e.view(), col.view_mut());
583 for j in 0..n {
584 let scale = dense[[j, i]].abs().max(1.0);
585 assert_abs_diff_eq!(col[j], dense[[j, i]], epsilon = 1e-9 * scale);
586 }
587 }
588 }
589
590 #[test]
591 fn test_log_det_plus_lambda_matches_dense() {
592 let centers = small_centers();
593 let op = ClosedFormPenaltyOperator::new(centers.view(), 1, 2, 1, 1.0, None, None, 0, None);
594 let dense = op.dense_form();
595 let n = op.dim();
596 let lambda = 10.0_f64;
597 let mut reg = dense.clone();
602 for i in 0..n {
603 reg[[i, i]] += lambda;
604 }
605 for i in 0..n {
606 for j in (i + 1)..n {
607 let avg = 0.5 * (reg[[i, j]] + reg[[j, i]]);
608 reg[[i, j]] = avg;
609 reg[[j, i]] = avg;
610 }
611 }
612 let est = op.log_det_plus_lambda_i(lambda).expect("exact logdet");
613 use faer::Side;
614 use gam_linalg::faer_ndarray::FaerEigh;
615 let (evals, _) = FaerEigh::eigh(®, Side::Lower).expect("eigh");
616 let mut reference = 0.0_f64;
617 for (idx, &lam) in evals.iter().enumerate() {
618 assert!(lam > 0.0, "reference eigenvalue {idx} is {lam:.3e}");
619 reference += lam.ln();
620 }
621 assert_abs_diff_eq!(est, reference, epsilon = 1e-10);
622 }
623}