Skip to main content

gam_terms/basis/
closed_form_operator.rs

1//! Operator form of the closed-form Duchon penalty.
2//!
3//! ## Status
4//!
5//! `matvec` is analytic and streaming: it applies the constraint transforms,
6//! evaluates the raw pair kernel row-by-row with Kahan summation, and never
7//! allocates the raw `K×K` Gram. `dense_form()` is still available for
8//! callers that explicitly need a materialized matrix, and caches only that
9//! opt-in dense build.
10
11use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1};
12use rayon::prelude::*;
13use smallvec::SmallVec;
14
15use crate::basis::{
16    closed_form_anisotropic_pair_block, closed_form_anisotropic_pair_value_with_powers,
17    closed_form_penalty, pure_duchon_diagonal_epsilon,
18};
19use gam_linalg::faer_ndarray::{fast_ab, fast_atb};
20
21/// Matrix-free closed-form anisotropic Duchon penalty operator.
22///
23/// Stores the parameters of the closed-form pair-block (`q, m, s, κ, η`, knot
24/// centers, and optional constraint factors). Gauge ownership is upstream: this
25/// operator only applies the already selected section in matrix-free matvecs.
26/// The hot `matvec` path stays matrix-free; `cached_dense` is populated only by
27/// `dense_form()`.
28pub struct ClosedFormPenaltyOperator {
29    /// Derivative order (0 = mass, 1 = tension, 2 = stiffness).
30    q: usize,
31    /// Outer kernel order parameter.
32    m: usize,
33    /// Inner Matérn order parameter.
34    s: usize,
35    /// Inverse length scale (κ ≥ 0).
36    kappa: f64,
37    /// Knot centers in the un-anisotropized coordinate, shape (K, d).
38    centers: Array2<f64>,
39    /// Per-axis raw anisotropy log-scales (length d). The pair-block builder
40    /// consumes these directly and applies the `J = exp(Σ η_k)` Jacobian
41    /// internally.
42    eta_raw: Vec<f64>,
43    /// Cached powers of `B = diag(exp(-2η))` for the analytic pair kernel.
44    eta_metric_powers: closed_form_penalty::AnisoMetricPowers,
45    /// Optional kernel-nullspace transform Z (K × kernel_cols).
46    kernel_nullspace: Option<Array2<f64>>,
47    /// Number of polynomial-block columns padded after the kernel block.
48    polynomial_block_cols: usize,
49    /// Optional outer spatial identifiability transform T (total_pre × total).
50    outer_identifiability: Option<Array2<f64>>,
51    /// Diagonal epsilon convention for regimes without an exact analytic
52    /// self-pair. In the convergent closed-form regimes this is zero and is
53    /// never read by pair evaluation.
54    diagonal_epsilon: f64,
55    /// Lazily-populated dense form. Populated only by `dense_form`; the
56    /// matvec/diag/trace/log-det paths stay matrix-free.
57    ///
58    /// `RayonSafeOnce` (not `OnceLock`) because `build_dense` runs faer
59    /// GEMMs (`fast_ab`, `fast_atb`) which dispatch nested rayon work — a
60    /// plain `OnceLock` here would deadlock if `dense_form` is first hit
61    /// concurrently from inside an outer par_iter. See
62    /// `feedback_oncelock_rayon_deadlock`.
63    cached_dense: gam_runtime::resource::RayonSafeOnce<Array2<f64>>,
64}
65
66// Cloning the operator resets its cache so the new instance rebuilds on first
67// use. This matches the legacy `derive(Clone)` behavior (which also produced a
68// fresh dense build per matvec — the cache is strictly an addition).
69impl Clone for ClosedFormPenaltyOperator {
70    fn clone(&self) -> Self {
71        Self {
72            q: self.q,
73            m: self.m,
74            s: self.s,
75            kappa: self.kappa,
76            centers: self.centers.clone(),
77            eta_raw: self.eta_raw.clone(),
78            eta_metric_powers: self.eta_metric_powers.clone(),
79            kernel_nullspace: self.kernel_nullspace.clone(),
80            polynomial_block_cols: self.polynomial_block_cols,
81            outer_identifiability: self.outer_identifiability.clone(),
82            diagonal_epsilon: self.diagonal_epsilon,
83            cached_dense: gam_runtime::resource::RayonSafeOnce::new(),
84        }
85    }
86}
87
88impl ClosedFormPenaltyOperator {
89    /// Build an operator with the same closed-form parameters that
90    /// `basis::closed_form_operator_penalty_in_total_basis` consumes.
91    pub fn new(
92        centers: ArrayView2<'_, f64>,
93        q: usize,
94        m: usize,
95        s: usize,
96        kappa: f64,
97        aniso_log_scales: Option<&[f64]>,
98        kernel_nullspace: Option<&Array2<f64>>,
99        polynomial_block_cols: usize,
100        outer_identifiability: Option<&Array2<f64>>,
101    ) -> Self {
102        let d = centers.ncols();
103        let eta_raw: Vec<f64> = if let Some(eta) = aniso_log_scales {
104            assert_eq!(
105                eta.len(),
106                d,
107                "ClosedFormPenaltyOperator::new: eta dimension mismatch"
108            );
109            eta.to_vec()
110        } else {
111            vec![0.0_f64; d]
112        };
113        let diagonal_epsilon =
114            if closed_form_penalty::analytic_self_pair_bundle(q, m, s, kappa, &eta_raw).is_some() {
115                0.0
116            } else {
117                pure_duchon_diagonal_epsilon(centers, &eta_raw)
118            };
119        Self {
120            q,
121            m,
122            s,
123            kappa,
124            centers: centers.to_owned(),
125            eta_metric_powers: closed_form_penalty::AnisoMetricPowers::new(&eta_raw),
126            eta_raw,
127            kernel_nullspace: kernel_nullspace.cloned(),
128            polynomial_block_cols,
129            outer_identifiability: outer_identifiability.cloned(),
130            diagonal_epsilon,
131            cached_dense: gam_runtime::resource::RayonSafeOnce::new(),
132        }
133    }
134
135    /// Return the cached dense form, building it on first call.
136    fn ensure_dense(&self) -> &Array2<f64> {
137        self.cached_dense.get_or_compute(|| self.build_dense())
138    }
139
140    /// Number of columns *after* applying constraint composition: the
141    /// dimension that callers see when invoking matvec/dense_form.
142    pub fn dim(&self) -> usize {
143        let kernel_cols = self
144            .kernel_nullspace
145            .as_ref()
146            .map(|z| z.ncols())
147            .unwrap_or_else(|| self.centers.nrows());
148        let total_pre = kernel_cols + self.polynomial_block_cols;
149        match &self.outer_identifiability {
150            Some(t) => t.ncols(),
151            None => total_pre,
152        }
153    }
154
155    #[inline]
156    fn is_raw_layout(&self) -> bool {
157        self.kernel_nullspace.is_none()
158            && self.polynomial_block_cols == 0
159            && self.outer_identifiability.is_none()
160    }
161
162    fn raw_diagonal_value(&self) -> f64 {
163        let mut r0: SmallVec<[f64; 16]> = SmallVec::with_capacity(self.centers.ncols());
164        r0.resize(self.centers.ncols(), 0.0);
165        closed_form_anisotropic_pair_value_with_powers(
166            self.q,
167            self.m,
168            self.s,
169            self.kappa,
170            &self.eta_raw,
171            &self.eta_metric_powers,
172            r0.as_slice(),
173            self.diagonal_epsilon,
174        )
175    }
176
177    /// Evaluate `(S w)` writing the result into `out`.
178    ///
179    /// With constraints composed, `S' = T^T diag(Z, I_poly)^T S_raw diag(Z, I_poly) T`.
180    /// We apply the chain right-to-left:
181    ///   1. `u = T w`            (dim → total_pre)
182    ///   2. `u_kernel = u[..kernel_cols]; u_poly = u[kernel_cols..]`
183    ///   3. `v = Z u_kernel`     (kernel_cols → K)
184    ///   4. `y = S_raw v`        (K → K), via on-the-fly pair-block evaluation
185    ///   5. `y_kernel = Z^T y`   (K → kernel_cols)
186    ///   6. compose with zero polynomial block, then `out = T^T [y_kernel; 0]`
187    pub fn matvec(&self, w: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
188        assert_eq!(
189            w.len(),
190            self.dim(),
191            "ClosedFormPenaltyOperator::matvec: input dim mismatch"
192        );
193        assert_eq!(
194            out.len(),
195            self.dim(),
196            "ClosedFormPenaltyOperator::matvec: output dim mismatch"
197        );
198
199        let pre = match &self.outer_identifiability {
200            Some(t) => t.dot(&w),
201            None => w.to_owned(),
202        };
203        let kernel_cols = self
204            .kernel_nullspace
205            .as_ref()
206            .map(|z| z.ncols())
207            .unwrap_or_else(|| self.centers.nrows());
208        let pre_kernel = pre.slice(ndarray::s![0..kernel_cols]);
209        let raw_input = match &self.kernel_nullspace {
210            Some(z) => z.dot(&pre_kernel),
211            None => pre_kernel.to_owned(),
212        };
213        let raw_output = self.raw_pair_matvec(raw_input.view());
214        let kernel_output = match &self.kernel_nullspace {
215            Some(z) => z.t().dot(&raw_output),
216            None => raw_output,
217        };
218        let total_pre = kernel_cols + self.polynomial_block_cols;
219        let mut projected = Array1::<f64>::zeros(total_pre);
220        projected
221            .slice_mut(ndarray::s![0..kernel_cols])
222            .assign(&kernel_output);
223        let final_output = match &self.outer_identifiability {
224            Some(t) => t.t().dot(&projected),
225            None => projected,
226        };
227        out.assign(&final_output);
228    }
229
230    /// Diagonal `S[i,i]` for i in 0..dim. In the raw layout this is the
231    /// analytic self-pair repeated K times. With constraint composition the
232    /// diagonal is *not* the K-space diagonal; we extract it via
233    /// `e_i^T S' e_i = matvec(e_i)[i]`.
234    pub fn diag(&self) -> Array1<f64> {
235        let n = self.dim();
236        if self.is_raw_layout() {
237            return Array1::from_elem(n, self.raw_diagonal_value());
238        }
239        // Read the diagonal directly off the cached dense form. This is the
240        // same matrix `matvec(e_i)` would reconstruct via the chain
241        // T^T diag(Z, I)^T S_raw diag(Z, I) T, but extracting the diagonal
242        // here is O(n) instead of O(n * K^2).
243        let dense = self.ensure_dense();
244        Array1::<f64>::from_iter((0..n).map(|i| dense[[i, i]]))
245    }
246
247    /// Trace `tr(S')`. In raw layout this is K times the analytic self-pair;
248    /// otherwise it uses the composed-basis diagonal.
249    pub fn trace(&self) -> f64 {
250        if self.is_raw_layout() {
251            return self.raw_diagonal_value() * self.dim() as f64;
252        }
253        self.diag().sum()
254    }
255
256    /// Exact `log det(S' + λI)`.
257    /// `S'` is rank-deficient under typical constraints (kernel/polynomial
258    /// nullspace), so the regularization `λ > 0` is mandatory.
259    pub fn log_det_plus_lambda_i(&self, lambda: f64) -> Result<f64, String> {
260        assert!(lambda > 0.0, "log_det_plus_lambda_i requires λ > 0");
261        let n = self.dim();
262        let mut dense = self.dense_form();
263        for i in 0..n {
264            dense[[i, i]] += lambda;
265        }
266        let (evals, _) = gam_linalg::faer_ndarray::FaerEigh::eigh(&dense, faer::Side::Lower)
267            .map_err(|e| {
268                format!("ClosedFormPenaltyOperator logdet eigendecomposition failed: {e}")
269            })?;
270        let mut logdet = 0.0;
271        for (idx, &ev) in evals.iter().enumerate() {
272            if !ev.is_finite() || ev <= 0.0 {
273                return Err(format!(
274                    "ClosedFormPenaltyOperator expected SPD S+λI, eigenvalue {idx} is {ev:.3e}"
275                ));
276            }
277            logdet += ev.ln();
278        }
279        Ok(logdet)
280    }
281
282    /// Materialize the full constrained operator as a dense `Array2` for
283    /// callers that explicitly request a matrix or for validation against
284    /// `closed_form_operator_penalty_in_total_basis`. Uses the internal
285    /// cache: the first call builds, subsequent calls clone from the cache.
286    pub fn dense_form(&self) -> Array2<f64> {
287        self.ensure_dense().clone()
288    }
289
290    /// Internal builder. Do not call directly — go through `ensure_dense` so
291    /// the result is cached.
292    fn build_dense(&self) -> Array2<f64> {
293        // Build raw K×K kernel block via the existing dense path so we share
294        // its cancellation-detector logic for small κ.
295        let g_raw = closed_form_anisotropic_pair_block(
296            self.centers.view(),
297            self.q,
298            self.m,
299            self.s,
300            self.kappa,
301            if self.eta_raw.iter().all(|&e| e == 0.0) {
302                None
303            } else {
304                Some(self.eta_raw.as_slice())
305            },
306        );
307        let kernel_cols = self
308            .kernel_nullspace
309            .as_ref()
310            .map(|z| z.ncols())
311            .unwrap_or_else(|| self.centers.nrows());
312        let g_kernel = match &self.kernel_nullspace {
313            Some(z) => {
314                let zt_g = fast_atb(z, &g_raw);
315                fast_ab(&zt_g, z)
316            }
317            None => g_raw,
318        };
319        let total_pre = kernel_cols + self.polynomial_block_cols;
320        let g_padded = if self.polynomial_block_cols == 0 {
321            g_kernel
322        } else {
323            let mut padded = Array2::<f64>::zeros((total_pre, total_pre));
324            padded
325                .slice_mut(ndarray::s![0..kernel_cols, 0..kernel_cols])
326                .assign(&g_kernel);
327            padded
328        };
329        match &self.outer_identifiability {
330            Some(t) => {
331                let tt_g = fast_atb(t, &g_padded);
332                fast_ab(&tt_g, t)
333            }
334            None => g_padded,
335        }
336    }
337
338    fn raw_pair_matvec(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
339        assert_eq!(
340            v.len(),
341            self.centers.nrows(),
342            "ClosedFormPenaltyOperator::raw_pair_matvec: input dim mismatch"
343        );
344        let k = self.centers.nrows();
345        let d = self.centers.ncols();
346        let rows: Vec<f64> = (0..k)
347            .into_par_iter()
348            .map(|i| {
349                let mut r: SmallVec<[f64; 16]> = SmallVec::with_capacity(d);
350                r.resize(d, 0.0);
351                let mut sum = 0.0_f64;
352                let mut correction = 0.0_f64;
353                for j in 0..k {
354                    for axis in 0..d {
355                        r[axis] = self.centers[[i, axis]] - self.centers[[j, axis]];
356                    }
357                    let gij = closed_form_anisotropic_pair_value_with_powers(
358                        self.q,
359                        self.m,
360                        self.s,
361                        self.kappa,
362                        &self.eta_raw,
363                        &self.eta_metric_powers,
364                        r.as_slice(),
365                        self.diagonal_epsilon,
366                    );
367                    let y = gij * v[j] - correction;
368                    let next = sum + y;
369                    correction = (next - sum) - y;
370                    sum = next;
371                }
372                sum
373            })
374            .collect();
375        Array1::from_vec(rows)
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use approx::assert_abs_diff_eq;
383    use ndarray::Array;
384
385    fn small_centers() -> Array2<f64> {
386        Array::from_shape_vec(
387            (5, 2),
388            vec![
389                0.10, 0.20, //
390                0.40, 0.15, //
391                0.55, 0.65, //
392                0.80, 0.30, //
393                0.25, 0.85, //
394            ],
395        )
396        .unwrap()
397    }
398
399    #[test]
400    fn test_operator_dense_agrees_unconstrained() {
401        let centers = small_centers();
402        let op = ClosedFormPenaltyOperator::new(
403            centers.view(),
404            1, // tension
405            2,
406            1,
407            1.0,
408            None,
409            None,
410            0,
411            None,
412        );
413        let dense = op.dense_form();
414        let n = op.dim();
415        let mut e = Array1::<f64>::zeros(n);
416        let mut col = Array1::<f64>::zeros(n);
417        for i in 0..n {
418            e.fill(0.0);
419            e[i] = 1.0;
420            op.matvec(e.view(), col.view_mut());
421            for j in 0..n {
422                let scale = dense[[j, i]].abs().max(1.0);
423                assert_abs_diff_eq!(col[j], dense[[j, i]], epsilon = 1e-9 * scale);
424            }
425        }
426    }
427
428    #[test]
429    fn test_operator_diag_agrees() {
430        let centers = small_centers();
431        let op = ClosedFormPenaltyOperator::new(
432            centers.view(),
433            2, // stiffness
434            2,
435            1,
436            0.5,
437            Some(&[0.10, -0.10]),
438            None,
439            0,
440            None,
441        );
442        let dense = op.dense_form();
443        let diag_op = op.diag();
444        for i in 0..op.dim() {
445            assert_abs_diff_eq!(diag_op[i], dense[[i, i]], epsilon = 1e-9);
446        }
447    }
448
449    #[test]
450    fn test_operator_matvec_random_vector() {
451        let centers = small_centers();
452        let op = ClosedFormPenaltyOperator::new(
453            centers.view(),
454            0, // mass
455            2,
456            1,
457            1.5,
458            None,
459            None,
460            0,
461            None,
462        );
463        let dense = op.dense_form();
464        let n = op.dim();
465        // Pseudo-random vector via deterministic LCG so the test is reproducible
466        // and free of an rng dependency.
467        let mut state: u64 = 0x9E37_79B9_7F4A_7C15;
468        let mut v = Array1::<f64>::zeros(n);
469        for vi in v.iter_mut() {
470            state = state
471                .wrapping_mul(6364136223846793005)
472                .wrapping_add(1442695040888963407);
473            *vi = ((state >> 11) as f64 / (1u64 << 53) as f64) - 0.5;
474        }
475        let mut got = Array1::<f64>::zeros(n);
476        op.matvec(v.view(), got.view_mut());
477        let want = dense.dot(&v);
478        for i in 0..n {
479            assert_abs_diff_eq!(got[i], want[i], epsilon = 1e-9);
480        }
481    }
482
483    #[test]
484    fn test_operator_matvec_stays_matrix_free_until_dense_requested() {
485        let centers = small_centers();
486        let op = ClosedFormPenaltyOperator::new(
487            centers.view(),
488            1,
489            2,
490            1,
491            1.0,
492            Some(&[0.35, 0.10]),
493            None,
494            0,
495            None,
496        );
497        let v = Array1::from_vec(vec![0.2, -0.1, 0.4, -0.3, 0.7]);
498        let mut out = Array1::<f64>::zeros(op.dim());
499        op.matvec(v.view(), out.view_mut());
500        assert!(
501            op.cached_dense.get().is_none(),
502            "matvec must not populate the dense KxK cache"
503        );
504        let dense = op.dense_form();
505        assert!(
506            op.cached_dense.get().is_some(),
507            "dense_form should be the only path that populates the dense cache"
508        );
509        let expected = dense.dot(&v);
510        for i in 0..op.dim() {
511            assert_abs_diff_eq!(out[i], expected[i], epsilon = 1e-8);
512        }
513    }
514
515    #[test]
516    fn test_operator_preserves_raw_anisotropy_coordinates() {
517        let centers = small_centers();
518        let eta = [0.35, 0.10];
519        let op =
520            ClosedFormPenaltyOperator::new(centers.view(), 1, 2, 1, 1.0, Some(&eta), None, 0, None);
521        let dense = op.dense_form();
522        let reference = crate::basis::closed_form_operator_penalty_in_total_basis(
523            centers.view(),
524            1,
525            2,
526            1,
527            1.0,
528            Some(&eta),
529            None,
530            0,
531            None,
532        );
533        for i in 0..op.dim() {
534            for j in 0..op.dim() {
535                let scale = reference[[i, j]].abs().max(1.0);
536                assert_abs_diff_eq!(dense[[i, j]], reference[[i, j]], epsilon = 1e-12 * scale);
537            }
538        }
539    }
540
541    #[test]
542    fn test_operator_with_kernel_nullspace_constraint() {
543        let centers = small_centers();
544        let k = centers.nrows();
545        // Synthetic Z: project out the constant direction in K-space.
546        let mut z = Array2::<f64>::zeros((k, k - 1));
547        let inv_sqrt_k = 1.0 / (k as f64).sqrt();
548        let constant: Vec<f64> = (0..k).map(|_| inv_sqrt_k).collect();
549        // Gram-Schmidt against the constant direction starting from canonical e_1..e_{k-1}.
550        for c in 0..(k - 1) {
551            let mut col = vec![0.0; k];
552            col[c + 1] = 1.0;
553            let inner: f64 = col.iter().zip(constant.iter()).map(|(a, b)| a * b).sum();
554            for i in 0..k {
555                col[i] -= inner * constant[i];
556            }
557            let norm = col.iter().map(|v| v * v).sum::<f64>().sqrt();
558            for i in 0..k {
559                z[[i, c]] = col[i] / norm;
560            }
561        }
562
563        let op = ClosedFormPenaltyOperator::new(
564            centers.view(),
565            1,
566            2,
567            1,
568            1.0,
569            Some(&[0.05, -0.05]),
570            Some(&z),
571            0,
572            None,
573        );
574        let dense = op.dense_form();
575        let n = op.dim();
576        assert_eq!(n, k - 1);
577        let mut e = Array1::<f64>::zeros(n);
578        let mut col = Array1::<f64>::zeros(n);
579        for i in 0..n {
580            e.fill(0.0);
581            e[i] = 1.0;
582            op.matvec(e.view(), col.view_mut());
583            for j in 0..n {
584                let scale = dense[[j, i]].abs().max(1.0);
585                assert_abs_diff_eq!(col[j], dense[[j, i]], epsilon = 1e-9 * scale);
586            }
587        }
588    }
589
590    #[test]
591    fn test_log_det_plus_lambda_matches_dense() {
592        let centers = small_centers();
593        let op = ClosedFormPenaltyOperator::new(centers.view(), 1, 2, 1, 1.0, None, None, 0, None);
594        let dense = op.dense_form();
595        let n = op.dim();
596        let lambda = 10.0_f64;
597        // Dense reference: exact log det(S + λI) via symmetric
598        // eigendecomposition.  λ is intentionally large enough that the
599        // regularized matrix is strictly SPD; if it is not, the operator method
600        // should error rather than flooring non-positive eigenvalues.
601        let mut reg = dense.clone();
602        for i in 0..n {
603            reg[[i, i]] += lambda;
604        }
605        for i in 0..n {
606            for j in (i + 1)..n {
607                let avg = 0.5 * (reg[[i, j]] + reg[[j, i]]);
608                reg[[i, j]] = avg;
609                reg[[j, i]] = avg;
610            }
611        }
612        let est = op.log_det_plus_lambda_i(lambda).expect("exact logdet");
613        use faer::Side;
614        use gam_linalg::faer_ndarray::FaerEigh;
615        let (evals, _) = FaerEigh::eigh(&reg, Side::Lower).expect("eigh");
616        let mut reference = 0.0_f64;
617        for (idx, &lam) in evals.iter().enumerate() {
618            assert!(lam > 0.0, "reference eigenvalue {idx} is {lam:.3e}");
619            reference += lam.ln();
620        }
621        assert_abs_diff_eq!(est, reference, epsilon = 1e-10);
622    }
623}