Skip to main content

gam_terms/analytic_penalties/
sheaf.rs

1//! Cellular-sheaf consistency penalty.
2//!
3//! Given a directed graph `G = (V, E)` with per-vertex stalk vectors
4//! `s_v ∈ R^{d_v}` and per-edge linear restriction maps
5//! `R_e^{(u→e)}: R^{d_u} → R^{d_e}`, the **coboundary**
6//!
7//!     δs[e] = R_e^{(u→e)}(s_{u_e}) − R_e^{(v→e)}(s_{v_e})
8//!
9//! lifts each edge into a per-edge "discrepancy" vector. The
10//! **sheaf Laplacian** `L = δᵀ δ` is sparse PSD on the stacked stalk
11//! space `R^{Σ d_v}`. Globally consistent sections live in `ker L`;
12//! `dim ker L` ("number of harmonic modes") generalises the
13//! connected-component count of a graph Laplacian to sheaves.
14//!
15//! The penalty value is
16//!
17//!     P(s) = ½ · weight · sᵀ L s = ½ · weight · ∑_e ‖δs[e]‖².
18//!
19//! References:
20//!   * Hansen & Ghrist, "Toward a Spectral Theory of Cellular Sheaves",
21//!     J. Appl. Comput. Topol. 3 (2019).
22//!   * Bodnar, Di Giovanni, Chamberlain, Lió, Bronstein,
23//!     "Neural Sheaf Diffusion" (NeurIPS 2022).
24//!
25//! Design choices in this module:
26//!   * The Laplacian is **never materialised**. All operations route through
27//!     two matvecs (`δ` and `δᵀ`).
28//!   * Restriction maps are `(R_uv, Option<R_vu>)` pairs. If the second is
29//!     `None` it defaults to the identity (`δs[e] = R_uv·s_u − s_v`), which
30//!     is the "single-restriction edge" convention common in sheaf-diffusion
31//!     networks.
32//!   * `harmonic_modes(tol)` auto-routes through faer's self-adjoint
33//!     eigendecomposition (`gam_linalg::faer_ndarray::FaerEigh`). For
34//!     `Σ d_v > 4096`, the dense Gram of `δ` exceeds 128 MB; we use a Lanczos
35//!     trace-style probe (HKS-bounded null-space count) in that regime so we
36//!     stay matrix-free.
37
38use faer::Side;
39use ndarray::{Array1, Array2, ArrayView1};
40
41use crate::analytic_penalties::{AnalyticPenalty, PenaltyTier};
42use gam_linalg::faer_ndarray::FaerEigh;
43use gam_linalg::lanczos::{SymmetricLanczosOptions, symmetric_lanczos_eigenpairs};
44
45/// Threshold above which `harmonic_modes` switches from a dense faer eigen
46/// solve to a matrix-free Lanczos null-space count. The dense path
47/// materialises `L` (one `n×n` symmetric matrix); at `n = 4096` that's
48/// `n² · 8 B ≈ 128 MB`, our hard ceiling for the dense route.
49const DENSE_EIGH_DIM_THRESHOLD: usize = 4096;
50
51/// A single edge's pair of restriction operators.
52///
53/// * `r_uv` maps the tail stalk into the edge stalk.
54/// * `r_vu` maps the head stalk into the edge stalk; `None` means "identity"
55///   (which forces `d_e == d_v`).
56#[derive(Debug, Clone)]
57pub struct EdgeRestriction {
58    pub r_uv: Array2<f64>,
59    pub r_vu: Option<Array2<f64>>,
60}
61
62impl EdgeRestriction {
63    /// Both endpoints have an explicit restriction map.
64    #[must_use]
65    pub fn paired(r_uv: Array2<f64>, r_vu: Array2<f64>) -> Self {
66        Self {
67            r_uv,
68            r_vu: Some(r_vu),
69        }
70    }
71
72    /// Tail-only restriction; the head side is implicitly identity.
73    #[must_use]
74    pub fn single(r_uv: Array2<f64>) -> Self {
75        Self { r_uv, r_vu: None }
76    }
77
78    /// Output (edge-stalk) dimension `d_e` for this edge.
79    pub fn edge_dim(&self) -> usize {
80        self.r_uv.nrows()
81    }
82}
83
84/// Cellular-sheaf consistency penalty over a fixed directed graph + restriction
85/// maps. The stacked stalk space layout is row-major over vertices:
86/// vertex `v` occupies `stalk_offsets[v] .. stalk_offsets[v] + stalk_dims[v]`.
87#[derive(Debug, Clone)]
88pub struct SheafConsistencyPenalty {
89    edges: Vec<(usize, usize)>,
90    restrictions: Vec<EdgeRestriction>,
91    weight: f64,
92    stalk_offsets: Vec<usize>,
93    stalk_dims: Vec<usize>,
94}
95
96impl SheafConsistencyPenalty {
97    /// Construct a sheaf-consistency penalty.
98    ///
99    /// * `edges` — directed edges as `(u, v)` pairs, vertex indices `0..K`.
100    /// * `restrictions` — same length as `edges`; per-edge `EdgeRestriction`.
101    /// * `weight` — finite, positive scalar penalty weight.
102    /// * `stalk_dims` — per-vertex stalk dimensions `d_v`.
103    ///
104    /// Validates: dim agreement (`r_uv.ncols == d_u`, `r_vu.ncols == d_v`,
105    /// `r_uv.nrows == r_vu.nrows`), vertex indices in range, finite entries.
106    #[must_use = "build error must be handled"]
107    pub fn new(
108        edges: Vec<(usize, usize)>,
109        restrictions: Vec<EdgeRestriction>,
110        weight: f64,
111        stalk_dims: Vec<usize>,
112    ) -> Result<Self, String> {
113        if !(weight.is_finite() && weight > 0.0) {
114            return Err(format!(
115                "SheafConsistencyPenalty::new requires finite weight > 0, got {weight}"
116            ));
117        }
118        if edges.len() != restrictions.len() {
119            return Err(format!(
120                "SheafConsistencyPenalty::new edge count {} != restriction count {}",
121                edges.len(),
122                restrictions.len()
123            ));
124        }
125        if stalk_dims.is_empty() {
126            return Err("SheafConsistencyPenalty::new requires at least one vertex".into());
127        }
128        for (v, &d) in stalk_dims.iter().enumerate() {
129            if d == 0 {
130                return Err(format!(
131                    "SheafConsistencyPenalty::new stalk dim at vertex {v} is zero"
132                ));
133            }
134        }
135        for (e, ((u, v), restriction)) in edges.iter().zip(restrictions.iter()).enumerate() {
136            if *u >= stalk_dims.len() || *v >= stalk_dims.len() {
137                return Err(format!(
138                    "SheafConsistencyPenalty::new edge {e} = ({u}, {v}) references vertex \
139                     out of range (K = {})",
140                    stalk_dims.len()
141                ));
142            }
143            let d_u = stalk_dims[*u];
144            let d_v = stalk_dims[*v];
145            let d_e = restriction.r_uv.nrows();
146            if restriction.r_uv.ncols() != d_u {
147                return Err(format!(
148                    "SheafConsistencyPenalty::new edge {e}: r_uv has {} cols, expected d_u = {d_u}",
149                    restriction.r_uv.ncols()
150                ));
151            }
152            match &restriction.r_vu {
153                Some(r_vu) => {
154                    if r_vu.ncols() != d_v {
155                        return Err(format!(
156                            "SheafConsistencyPenalty::new edge {e}: r_vu has {} cols, \
157                             expected d_v = {d_v}",
158                            r_vu.ncols()
159                        ));
160                    }
161                    if r_vu.nrows() != d_e {
162                        return Err(format!(
163                            "SheafConsistencyPenalty::new edge {e}: r_vu has {} rows, \
164                             expected d_e = {d_e}",
165                            r_vu.nrows()
166                        ));
167                    }
168                }
169                None => {
170                    if d_e != d_v {
171                        return Err(format!(
172                            "SheafConsistencyPenalty::new edge {e}: r_vu is identity but \
173                             d_e ({d_e}) != d_v ({d_v})"
174                        ));
175                    }
176                }
177            }
178            if !restriction.r_uv.iter().all(|x| x.is_finite()) {
179                return Err(format!(
180                    "SheafConsistencyPenalty::new edge {e}: r_uv contains non-finite entries"
181                ));
182            }
183            if let Some(r_vu) = &restriction.r_vu
184                && !r_vu.iter().all(|x| x.is_finite())
185            {
186                return Err(format!(
187                    "SheafConsistencyPenalty::new edge {e}: r_vu contains non-finite entries"
188                ));
189            }
190        }
191        let mut stalk_offsets = Vec::with_capacity(stalk_dims.len() + 1);
192        let mut acc = 0usize;
193        for &d in &stalk_dims {
194            stalk_offsets.push(acc);
195            acc = acc.checked_add(d).ok_or_else(|| {
196                "SheafConsistencyPenalty::new stalk offsets overflow usize".to_string()
197            })?;
198        }
199        stalk_offsets.push(acc);
200        Ok(Self {
201            edges,
202            restrictions,
203            weight,
204            stalk_offsets,
205            stalk_dims,
206        })
207    }
208
209    /// Total dimension of the stacked stalk space `Σ d_v`.
210    pub fn total_dim(&self) -> usize {
211        *self.stalk_offsets.last().expect("offsets non-empty")
212    }
213
214    /// Number of edges.
215    pub fn num_edges(&self) -> usize {
216        self.edges.len()
217    }
218
219    /// Number of vertices `K`.
220    pub fn num_vertices(&self) -> usize {
221        self.stalk_dims.len()
222    }
223
224    /// Per-vertex stalk dimensions (clone of internal vector).
225    pub fn stalk_dims(&self) -> &[usize] {
226        &self.stalk_dims
227    }
228
229    /// Penalty weight.
230    pub fn weight(&self) -> f64 {
231        self.weight
232    }
233
234    fn vertex_slice<'a>(&self, s: ArrayView1<'a, f64>, v: usize) -> ArrayView1<'a, f64> {
235        let start = self.stalk_offsets[v];
236        let end = self.stalk_offsets[v + 1];
237        s.slice_move(ndarray::s![start..end])
238    }
239
240    /// Apply `δ` to a stacked-stalk vector `s`. Returns a `Vec<Array1<f64>>`
241    /// with one entry per edge containing `δs[e] ∈ R^{d_e}`.
242    fn delta(&self, s: ArrayView1<'_, f64>) -> Vec<Array1<f64>> {
243        assert_eq!(
244            s.len(),
245            self.total_dim(),
246            "stacked stalk vector has wrong length",
247        );
248        let mut out = Vec::with_capacity(self.edges.len());
249        for (e, &(u, v)) in self.edges.iter().enumerate() {
250            let s_u = self.vertex_slice(s, u);
251            let s_v = self.vertex_slice(s, v);
252            let restriction = &self.restrictions[e];
253            // R_uv · s_u
254            let mut delta_e = restriction.r_uv.dot(&s_u);
255            // − R_vu · s_v   (identity if r_vu is None)
256            match &restriction.r_vu {
257                Some(r_vu) => {
258                    let r_vu_s_v = r_vu.dot(&s_v);
259                    delta_e.scaled_add(-1.0, &r_vu_s_v);
260                }
261                None => {
262                    delta_e.scaled_add(-1.0, &s_v);
263                }
264            }
265            out.push(delta_e);
266        }
267        out
268    }
269
270    /// Apply `δᵀ` to per-edge discrepancies `y`. Returns the stacked-stalk
271    /// vector `δᵀ y ∈ R^{Σ d_v}`.
272    fn delta_transpose(&self, y: &[Array1<f64>]) -> Array1<f64> {
273        assert_eq!(
274            y.len(),
275            self.edges.len(),
276            "delta_transpose edge count mismatch"
277        );
278        let mut out = Array1::<f64>::zeros(self.total_dim());
279        for (e, &(u, v)) in self.edges.iter().enumerate() {
280            let restriction = &self.restrictions[e];
281            let y_e = &y[e];
282            assert_eq!(y_e.len(), restriction.edge_dim(), "edge dim mismatch");
283            // R_uvᵀ · y_e → vertex u
284            let contrib_u = restriction.r_uv.t().dot(y_e);
285            let u_start = self.stalk_offsets[u];
286            let u_end = self.stalk_offsets[u + 1];
287            {
288                let mut out_u = out.slice_mut(ndarray::s![u_start..u_end]);
289                out_u.scaled_add(1.0, &contrib_u);
290            }
291            // −R_vuᵀ · y_e → vertex v   (identity if r_vu is None)
292            let v_start = self.stalk_offsets[v];
293            let v_end = self.stalk_offsets[v + 1];
294            match &restriction.r_vu {
295                Some(r_vu) => {
296                    let contrib_v = r_vu.t().dot(y_e);
297                    let mut out_v = out.slice_mut(ndarray::s![v_start..v_end]);
298                    out_v.scaled_add(-1.0, &contrib_v);
299                }
300                None => {
301                    let mut out_v = out.slice_mut(ndarray::s![v_start..v_end]);
302                    out_v.scaled_add(-1.0, y_e);
303                }
304            }
305        }
306        out
307    }
308
309    /// Apply the sheaf Laplacian `L = δᵀ δ` to a stacked-stalk vector `s`.
310    /// Cost: two matvecs per edge; never materialises `L`.
311    pub fn laplacian_apply(&self, s: ArrayView1<'_, f64>) -> Array1<f64> {
312        let ds = self.delta(s);
313        self.delta_transpose(&ds)
314    }
315
316    /// Penalty value `½ · weight · ‖δs‖²`. Quadratic in `s`.
317    pub fn value(&self, s: ArrayView1<'_, f64>) -> f64 {
318        let ds = self.delta(s);
319        let mut sq = 0.0;
320        for de in &ds {
321            for &x in de.iter() {
322                sq += x * x;
323            }
324        }
325        0.5 * self.weight * sq
326    }
327
328    /// Gradient `∂P/∂s = weight · L s`. Length `Σ d_v`.
329    pub fn gradient(&self, s: ArrayView1<'_, f64>) -> Array1<f64> {
330        let mut g = self.laplacian_apply(s);
331        g *= self.weight;
332        g
333    }
334
335    /// Hessian diagonal `diag(weight · L)`. Independent of `s` because `L` is
336    /// constant. For a **distinct-vertex** edge `(u, v)` (`u ≠ v`) the coboundary
337    /// `C = [R_uv | −R_vu]` acts on disjoint stalk blocks, so
338    ///   * tail (u-side): `Σ_j R_uv[j, i_local]²`
339    ///   * head (v-side, single-restriction): `1.0` per incident edge
340    ///   * head (v-side, paired): `Σ_j R_vu[j, i_local]²`
341    /// For a **self-loop** edge `(u, u)` both sides share one block and the
342    /// coboundary collapses to `(R_uv − R_vu)·s_u`, so the correct contribution
343    /// is `colnorm²(R_uv − R_vu)` — NOT the sum of the two separate norms.
344    pub fn hessian_diag(&self, s: ArrayView1<'_, f64>) -> Array1<f64> {
345        assert_eq!(
346            s.len(),
347            self.total_dim(),
348            "stacked stalk vector has wrong length",
349        );
350        // L is constant in s; the argument is retained only for trait-style symmetry
351        // (other penalties take target as the first arg). The shape assertion above
352        // exercises that input.
353        let mut diag = Array1::<f64>::zeros(self.total_dim());
354        for (e, &(u, v)) in self.edges.iter().enumerate() {
355            let restriction = &self.restrictions[e];
356            let u_start = self.stalk_offsets[u];
357            let v_start = self.stalk_offsets[v];
358            let r_uv = &restriction.r_uv;
359
360            if u == v {
361                // Self-loop: δ(s)[e] = (R_uv − R_vu)·s_u, so the edge's
362                // contribution to diag(L) is colnorm²(R_uv − R_vu), NOT the
363                // sum of the two separate squared column norms (which would
364                // double-count on the shared stalk block). The distinct-vertex
365                // path below is not reached for self-loops.
366                match &restriction.r_vu {
367                    Some(r_vu) => {
368                        for col in 0..r_uv.ncols() {
369                            let mut s2 = 0.0;
370                            for row in 0..r_uv.nrows() {
371                                let diff = r_uv[[row, col]] - r_vu[[row, col]];
372                                s2 += diff * diff;
373                            }
374                            diag[u_start + col] += s2;
375                        }
376                    }
377                    None => {
378                        // r_vu = I; contribution is colnorm²(R_uv − I).
379                        let d = self.stalk_dims[u];
380                        for col in 0..d {
381                            let mut s2 = 0.0;
382                            for row in 0..r_uv.nrows() {
383                                let identity_entry = if row == col { 1.0 } else { 0.0 };
384                                let diff = r_uv[[row, col]] - identity_entry;
385                                s2 += diff * diff;
386                            }
387                            diag[u_start + col] += s2;
388                        }
389                    }
390                }
391            } else {
392                // Distinct-vertex path: u_start ≠ v_start, so u-side and v-side
393                // accumulations land on disjoint index ranges. Diagonal of Cᵀ C
394                // with C = [R_uv | −R_vu] decomposes cleanly into the two blocks.
395                for col in 0..r_uv.ncols() {
396                    let mut s2 = 0.0;
397                    for row in 0..r_uv.nrows() {
398                        let a = r_uv[[row, col]];
399                        s2 += a * a;
400                    }
401                    diag[u_start + col] += s2;
402                }
403                match &restriction.r_vu {
404                    Some(r_vu) => {
405                        for col in 0..r_vu.ncols() {
406                            let mut s2 = 0.0;
407                            for row in 0..r_vu.nrows() {
408                                let a = r_vu[[row, col]];
409                                s2 += a * a;
410                            }
411                            diag[v_start + col] += s2;
412                        }
413                    }
414                    None => {
415                        let d_v = self.stalk_dims[v];
416                        for col in 0..d_v {
417                            diag[v_start + col] += 1.0;
418                        }
419                    }
420                }
421            }
422        }
423        diag *= self.weight;
424        diag
425    }
426
427    /// Hessian-vector product `H v = weight · L v`. Two matvecs, no
428    /// materialisation. The `_s` argument is unused (L is constant); it
429    /// matches the trait-style `(target, v)` signature other penalties use.
430    pub fn hvp(&self, s: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>) -> Array1<f64> {
431        assert_eq!(
432            s.len(),
433            self.total_dim(),
434            "stacked stalk vector has wrong length",
435        );
436        assert_eq!(v.len(), self.total_dim(), "hvp direction has wrong length");
437        let mut hv = self.laplacian_apply(v);
438        hv *= self.weight;
439        hv
440    }
441
442    /// Materialise the dense Laplacian `L` (no weight applied).
443    ///
444    /// Used by [`Self::harmonic_modes`] when `total_dim() ≤
445    /// DENSE_EIGH_DIM_THRESHOLD`. Cost is `O(n²)` memory and `O(n · |E| · max d_e)`
446    /// flops via `n` independent matvecs against the standard basis.
447    /// **Not** called on the inner-loop hot path.
448    fn dense_laplacian(&self) -> Array2<f64> {
449        let n = self.total_dim();
450        let mut l = Array2::<f64>::zeros((n, n));
451        let mut e = Array1::<f64>::zeros(n);
452        for j in 0..n {
453            e[j] = 1.0;
454            let col = self.laplacian_apply(e.view());
455            for i in 0..n {
456                l[[i, j]] = col[i];
457            }
458            e[j] = 0.0;
459        }
460        l
461    }
462
463    /// Count eigenvalues of the unweighted Laplacian `L` strictly below
464    /// `tol`. Equals the number of harmonic modes (global sections, mod the
465    /// `tol`-tolerance). The penalty weight is **not** folded in: harmonic
466    /// modes are an intrinsic property of `δ`.
467    ///
468    /// Auto-routing: dense faer eigh when `total_dim ≤ DENSE_EIGH_DIM_THRESHOLD`;
469    /// matrix-free Lanczos null-space count otherwise.
470    pub fn harmonic_modes(&self, tol: f64) -> usize {
471        assert!(
472            tol.is_finite() && tol >= 0.0,
473            "harmonic_modes requires finite non-negative tol, got {tol}",
474        );
475        let n = self.total_dim();
476        if n == 0 {
477            return 0;
478        }
479        if n <= DENSE_EIGH_DIM_THRESHOLD {
480            let l = self.dense_laplacian();
481            match l.eigh(Side::Lower) {
482                Ok((evals, _)) => evals.iter().filter(|&&e| e < tol).count(),
483                // SAFETY: dense Laplacian above is symmetric positive semidefinite by construction
484                // (graph Laplacian of an undirected weighted graph), so eigh on the lower triangle
485                // must succeed; any err indicates a corrupted matrix and bailing here is correct.
486                Err(err) => {
487                    panic!("SheafConsistencyPenalty::harmonic_modes faer eigh failed: {err:?}")
488                }
489            }
490        } else {
491            self.harmonic_modes_lanczos(tol)
492        }
493    }
494
495    /// Matrix-free null-space-dim estimate via Lanczos tridiagonalisation +
496    /// Sturm-style sign count. We build a `k`-step Lanczos tridiagonal `T`
497    /// for `L` against a random start vector, eigendecompose `T` densely
498    /// (`k ≪ n`), and count Ritz values below `tol`. This **lower-bounds**
499    /// the harmonic-mode count for generic starts; for sheaf Laplacians the
500    /// kernel direction is reached within `k = min(n, 64)` iterations in
501    /// practice, but we expose the result as a tight bound rather than an
502    /// exact count.
503    fn harmonic_modes_lanczos(&self, tol: f64) -> usize {
504        let n = self.total_dim();
505        let k = n.min(64).max(1);
506        // Deterministic pseudo-random start to keep the bound reproducible.
507        let mut q0 = vec![0.0_f64; n];
508        for i in 0..n {
509            // Splitmix-style scrambling of i: deterministic, dependency-free.
510            // The canonical stateful step adds G internally, so seed it with
511            // `i·G − G` to finalize the same `i·G` input and stay bit-identical.
512            let mut state = (i as u64)
513                .wrapping_mul(0x9E37_79B9_7F4A_7C15)
514                .wrapping_sub(0x9E37_79B9_7F4A_7C15);
515            let z = gam_linalg::utils::splitmix64(&mut state);
516            q0[i] = (z as f64 / u64::MAX as f64) - 0.5;
517        }
518        match symmetric_lanczos_eigenpairs(
519            n,
520            &q0,
521            SymmetricLanczosOptions {
522                max_steps: k,
523                residual_tol: 1e-12,
524                local_reorthogonalize: true,
525                full_reorthogonalize: false,
526            },
527            |q, out| {
528                let w = self.laplacian_apply(ArrayView1::from(q));
529                out.copy_from_slice(w.as_slice().ok_or_else(|| {
530                    "SheafConsistencyPenalty::harmonic_modes Lanczos matvec produced non-contiguous output"
531                        .to_string()
532                })?);
533                Ok(())
534            },
535        ) {
536            Ok(eigen) => eigen.eigenvalues.iter().filter(|&&e| e < tol).count(),
537            Err(err) => {
538                // SAFETY: A Lanczos breakdown here is a non-recoverable numerical
539                // failure of the harmonic-mode decomposition (e.g. a malformed or
540                // non-symmetric operator); there is no meaningful count to return,
541                // so the error must surface rather than be silently swallowed.
542                panic!("SheafConsistencyPenalty::harmonic_modes Lanczos failed: {err}")
543            }
544        }
545    }
546}
547
548// ---------------------------------------------------------------------------
549// AnalyticPenalty trait bridge.
550// ---------------------------------------------------------------------------
551//
552// Wires `SheafConsistencyPenalty` into the analytic-penalty registry so it is
553// reachable from REML / PIRLS / CLI callers exactly like ARDPenalty,
554// BlockOrthogonalityPenalty, etc. `target` is the stacked-stalk vector
555// (treated as a ψ-tier flat block); `rho` is unused — this penalty is
556// quadratic with a fixed scalar weight set at construction. The
557// `harmonic_modes` query and the per-vertex layout helpers remain available
558// as inherent methods for callers that want the cellular-sheaf-specific
559// diagnostics.
560
561impl AnalyticPenalty for SheafConsistencyPenalty {
562    fn tier(&self) -> PenaltyTier {
563        PenaltyTier::Psi
564    }
565
566    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
567        assert!(
568            rho.iter().all(|x| x.is_finite()),
569            "SheafConsistencyPenalty: rho must be finite (got {rho:?})",
570        );
571        SheafConsistencyPenalty::value(self, target)
572    }
573
574    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
575        assert!(
576            rho.iter().all(|x| x.is_finite()),
577            "SheafConsistencyPenalty: rho must be finite (got {rho:?})",
578        );
579        SheafConsistencyPenalty::gradient(self, target)
580    }
581
582    fn hessian_diag(
583        &self,
584        target: ArrayView1<'_, f64>,
585        rho: ArrayView1<'_, f64>,
586    ) -> Option<Array1<f64>> {
587        assert!(
588            rho.iter().all(|x| x.is_finite()),
589            "SheafConsistencyPenalty: rho must be finite (got {rho:?})",
590        );
591        Some(SheafConsistencyPenalty::hessian_diag(self, target))
592    }
593
594    fn hvp(
595        &self,
596        target: ArrayView1<'_, f64>,
597        rho: ArrayView1<'_, f64>,
598        v: ArrayView1<'_, f64>,
599    ) -> Array1<f64> {
600        assert!(
601            rho.iter().all(|x| x.is_finite()),
602            "SheafConsistencyPenalty: rho must be finite (got {rho:?})",
603        );
604        SheafConsistencyPenalty::hvp(self, target, v)
605    }
606
607    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
608        // No learnable hyperparameter axes: rho_count == 0.
609        assert_eq!(
610            rho.len(),
611            0,
612            "SheafConsistencyPenalty: rho_count is 0 but rho has length {}",
613            rho.len(),
614        );
615        assert_eq!(
616            target.len(),
617            self.total_dim(),
618            "SheafConsistencyPenalty: target length {} != total stalk dim {}",
619            target.len(),
620            self.total_dim(),
621        );
622        Array1::<f64>::zeros(0)
623    }
624
625    fn rho_count(&self) -> usize {
626        0
627    }
628
629    fn name(&self) -> &str {
630        "SheafConsistencyPenalty"
631    }
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use approx::assert_abs_diff_eq;
638    use ndarray::array;
639
640    fn identity(d: usize) -> Array2<f64> {
641        let mut m = Array2::<f64>::zeros((d, d));
642        for i in 0..d {
643            m[[i, i]] = 1.0;
644        }
645        m
646    }
647
648    #[test]
649    fn single_edge_identity_restriction_value() {
650        // K=2, d_0 = d_1 = 3, R_uv = R_vu = I.
651        // s_0 = (1,0,0), s_1 = (0,1,0). δs = (1,-1,0). ‖·‖² = 2. Value = ½·1·2 = 1.
652        let edges = vec![(0usize, 1usize)];
653        let restrictions = vec![EdgeRestriction::paired(identity(3), identity(3))];
654        let pen =
655            SheafConsistencyPenalty::new(edges, restrictions, 1.0, vec![3, 3]).expect("build");
656        let s = array![1.0_f64, 0.0, 0.0, 0.0, 1.0, 0.0];
657        let v = pen.value(s.view());
658        assert_abs_diff_eq!(v, 1.0, epsilon = 1e-12);
659    }
660
661    #[test]
662    fn gradient_matches_finite_difference_k2_random() {
663        // K=2 with arbitrary restrictions; FD-check the gradient.
664        let r_uv = array![[0.7_f64, -0.1, 0.3], [0.2, 0.9, -0.4]];
665        let r_vu = array![[1.0_f64, 0.5], [-0.3, 0.8]];
666        let edges = vec![(0usize, 1usize)];
667        let restrictions = vec![EdgeRestriction::paired(r_uv, r_vu)];
668        let pen =
669            SheafConsistencyPenalty::new(edges, restrictions, 0.3, vec![3, 2]).expect("build");
670        let s = array![0.4_f64, -1.1, 0.2, 0.6, -0.7];
671        let g = pen.gradient(s.view());
672        let eps = 1e-7;
673        for i in 0..s.len() {
674            let mut sp = s.clone();
675            let mut sm = s.clone();
676            sp[i] += eps;
677            sm[i] -= eps;
678            let fd = (pen.value(sp.view()) - pen.value(sm.view())) / (2.0 * eps);
679            assert_abs_diff_eq!(g[i], fd, epsilon = 1e-6);
680        }
681    }
682
683    #[test]
684    fn hvp_matches_reconstructed_laplacian_chain_k3() {
685        // K=3 chain: edges (0,1) and (1,2), each with explicit 2x2 restrictions.
686        // d_0 = d_1 = d_2 = 2.
687        let r01_uv = array![[0.9_f64, 0.1], [-0.2, 0.7]];
688        let r01_vu = array![[1.0_f64, 0.0], [0.0, 1.0]];
689        let r12_uv = array![[0.5_f64, -0.3], [0.4, 0.8]];
690        let r12_vu = array![[0.6_f64, 0.0], [0.1, 1.1]];
691        let edges = vec![(0usize, 1usize), (1usize, 2usize)];
692        let restrictions = vec![
693            EdgeRestriction::paired(r01_uv, r01_vu),
694            EdgeRestriction::paired(r12_uv, r12_vu),
695        ];
696        let pen =
697            SheafConsistencyPenalty::new(edges, restrictions, 1.0, vec![2, 2, 2]).expect("build");
698        // Reconstruct L densely via 6 matvecs.
699        let l_dense = pen.dense_laplacian();
700        let n = pen.total_dim();
701        let s = array![0.1_f64, -0.2, 0.3, 0.4, -0.5, 0.6];
702        let v = array![0.7_f64, 0.2, -0.1, 0.5, 0.3, -0.4];
703        let hv = pen.hvp(s.view(), v.view());
704        // Reference: L · v (weight = 1)
705        let mut lv = Array1::<f64>::zeros(n);
706        for i in 0..n {
707            let mut acc = 0.0;
708            for j in 0..n {
709                acc += l_dense[[i, j]] * v[j];
710            }
711            lv[i] = acc;
712        }
713        for i in 0..n {
714            assert_abs_diff_eq!(hv[i], lv[i], epsilon = 1e-10);
715        }
716    }
717
718    #[test]
719    fn harmonic_modes_two_components_identity_restrictions() {
720        // Two disconnected vertices (no edges), d = 3 each → ker L = R^{6}, all 6 modes.
721        let pen = SheafConsistencyPenalty::new(vec![], vec![], 1.0, vec![3, 3]).expect("build");
722        let h = pen.harmonic_modes(1e-10);
723        assert_eq!(h, 6);
724
725        // K=4, two connected components: (0-1) and (2-3) with identity restrictions, d=2 each.
726        // Each component's sheaf-Laplacian kernel has dim d (the "constant sections").
727        // Total ker dim = 2·d = 4.
728        let edges = vec![(0usize, 1usize), (2usize, 3usize)];
729        let restrictions = vec![
730            EdgeRestriction::paired(identity(2), identity(2)),
731            EdgeRestriction::paired(identity(2), identity(2)),
732        ];
733        let pen2 = SheafConsistencyPenalty::new(edges, restrictions, 1.0, vec![2, 2, 2, 2])
734            .expect("build");
735        let h2 = pen2.harmonic_modes(1e-10);
736        assert_eq!(h2, 4);
737    }
738
739    #[test]
740    fn value_psd_and_zero_iff_kernel() {
741        // Random s on a non-trivial sheaf: value ≥ 0.
742        let r01_uv = array![[0.9_f64, 0.1], [-0.2, 0.7]];
743        let r01_vu = array![[1.0_f64, 0.0], [0.0, 1.0]];
744        let edges = vec![(0usize, 1usize)];
745        let restrictions = vec![EdgeRestriction::paired(r01_uv.clone(), r01_vu.clone())];
746        let pen =
747            SheafConsistencyPenalty::new(edges, restrictions, 0.5, vec![2, 2]).expect("build");
748
749        // Several random-ish stalks: non-negative value.
750        let samples = [
751            array![0.0_f64, 0.0, 0.0, 0.0],
752            array![1.0_f64, 2.0, -0.5, 0.3],
753            array![-1.3_f64, 0.7, 0.2, -0.9],
754        ];
755        for s in &samples {
756            let v = pen.value(s.view());
757            assert!(v >= 0.0, "value must be non-negative, got {v}");
758        }
759        // Zero stalk → zero value.
760        let z = Array1::<f64>::zeros(4);
761        assert_abs_diff_eq!(pen.value(z.view()), 0.0, epsilon = 1e-15);
762        // A kernel direction: pick s_0 arbitrary then set s_1 = r_vu⁻¹ · r_uv · s_0.
763        // r_vu = I, so s_1 = r_uv · s_0.
764        let s0 = array![0.3_f64, -1.1];
765        let s1 = r01_uv.dot(&s0);
766        let mut s = Array1::<f64>::zeros(4);
767        s[0] = s0[0];
768        s[1] = s0[1];
769        s[2] = s1[0];
770        s[3] = s1[1];
771        assert_abs_diff_eq!(pen.value(s.view()), 0.0, epsilon = 1e-12);
772    }
773
774    #[test]
775    fn hessian_diag_matches_diag_of_dense_laplacian() {
776        let r_uv = array![[0.7_f64, -0.1, 0.3], [0.2, 0.9, -0.4]];
777        let r_vu = array![[1.0_f64, 0.5], [-0.3, 0.8]];
778        let edges = vec![(0usize, 1usize)];
779        let restrictions = vec![EdgeRestriction::paired(r_uv, r_vu)];
780        let pen =
781            SheafConsistencyPenalty::new(edges, restrictions, 0.3, vec![3, 2]).expect("build");
782        let n = pen.total_dim();
783        let s = Array1::<f64>::zeros(n);
784        let diag = pen.hessian_diag(s.view());
785        let l = pen.dense_laplacian();
786        for i in 0..n {
787            assert_abs_diff_eq!(diag[i], 0.3 * l[[i, i]], epsilon = 1e-12);
788        }
789    }
790
791    #[test]
792    fn hessian_diag_matches_dense_laplacian_on_self_loop_paired() {
793        // Self-loop (0,0) with two distinct paired restrictions: the diagonal
794        // must equal diag(weight·L) built from the matrix-free operator, i.e.
795        // colnorm²(R_uv − R_vu), NOT colnorm²(R_uv) + colnorm²(R_vu).
796        let r_uv = array![[0.9_f64, 0.1], [-0.2, 0.7]];
797        let r_vu = array![[1.0_f64, 0.5], [-0.3, 0.8]];
798        let edges = vec![(0usize, 0usize)];
799        let restrictions = vec![EdgeRestriction::paired(r_uv, r_vu)];
800        let pen = SheafConsistencyPenalty::new(edges, restrictions, 0.7, vec![2]).expect("build");
801        let n = pen.total_dim();
802        let s = Array1::<f64>::zeros(n);
803        let diag = pen.hessian_diag(s.view());
804        let l = pen.dense_laplacian();
805        for i in 0..n {
806            assert_abs_diff_eq!(diag[i], 0.7 * l[[i, i]], epsilon = 1e-12);
807        }
808    }
809
810    #[test]
811    fn hessian_diag_matches_dense_laplacian_on_self_loop_single() {
812        // Self-loop (0,0) with a single-restriction edge: R_vu is the identity,
813        // so the coboundary is (R_uv − I)·s_0. The drop-cross-term bug would
814        // report colnorm²(R_uv) + 1 per column; the correct diagonal is
815        // colnorm²(R_uv − I). d_e == d_v == d_u = 2 (single-edge requirement).
816        // This single-restriction self-loop path is exercised by neither the
817        // committed repro (paired only) nor the landing fix's tests.
818        let r_uv = array![[1.0_f64, 2.0], [3.0, 4.0]];
819        let edges = vec![(0usize, 0usize)];
820        let restrictions = vec![EdgeRestriction::single(r_uv)];
821        let pen = SheafConsistencyPenalty::new(edges, restrictions, 1.3, vec![2]).expect("build");
822        let n = pen.total_dim();
823        let s = Array1::<f64>::zeros(n);
824        let diag = pen.hessian_diag(s.view());
825        let l = pen.dense_laplacian();
826        for i in 0..n {
827            assert_abs_diff_eq!(diag[i], 1.3 * l[[i, i]], epsilon = 1e-12);
828        }
829        // Spot the closed form: C = R_uv − I = [[0,2],[3,3]].
830        // col 0: 0² + 3² = 9; col 1: 2² + 3² = 13. ×weight 1.3 → [11.7, 16.9].
831        assert_abs_diff_eq!(diag[0], 1.3 * 9.0, epsilon = 1e-12);
832        assert_abs_diff_eq!(diag[1], 1.3 * 13.0, epsilon = 1e-12);
833    }
834
835    #[test]
836    fn hessian_diag_matches_dense_laplacian_mixed_self_loop_and_cross_edge() {
837        // A self-loop on vertex 0 AND a distinct edge (0,1) both touch vertex 0.
838        // The two edges' diagonal contributions must accumulate independently:
839        // the self-loop contributes colnorm²(R0 − R0b) on block 0, while the
840        // cross edge contributes colnorm²(R1_uv) on block 0 and colnorm²(R1_vu)
841        // on block 1. Checked against the operator-built dense Laplacian.
842        let r0_uv = array![[0.5_f64, -0.4], [0.3, 0.9]];
843        let r0_vu = array![[0.2_f64, 0.1], [-0.6, 0.7]];
844        let r1_uv = array![[1.1_f64, 0.2], [0.0, -0.5]];
845        let r1_vu = array![[0.8_f64, -0.1], [0.4, 1.0]];
846        let edges = vec![(0usize, 0usize), (0usize, 1usize)];
847        let restrictions = vec![
848            EdgeRestriction::paired(r0_uv, r0_vu),
849            EdgeRestriction::paired(r1_uv, r1_vu),
850        ];
851        let pen =
852            SheafConsistencyPenalty::new(edges, restrictions, 0.5, vec![2, 2]).expect("build");
853        let n = pen.total_dim();
854        let s = Array1::<f64>::zeros(n);
855        let diag = pen.hessian_diag(s.view());
856        let l = pen.dense_laplacian();
857        for i in 0..n {
858            assert_abs_diff_eq!(diag[i], 0.5 * l[[i, i]], epsilon = 1e-12);
859        }
860    }
861
862    #[test]
863    fn single_restriction_edge_form() {
864        // δs = R·s_0 − s_1 (single-restriction form). d_0 = 2, d_e = d_1 = 2.
865        let r = array![[1.0_f64, 2.0], [3.0, 4.0]];
866        let edges = vec![(0usize, 1usize)];
867        let restrictions = vec![EdgeRestriction::single(r.clone())];
868        let pen =
869            SheafConsistencyPenalty::new(edges, restrictions, 2.0, vec![2, 2]).expect("build");
870        // s_0 = (1, 0) → R·s_0 = (1, 3). s_1 = (1, 3) → δs = 0. Value = 0.
871        let s = array![1.0_f64, 0.0, 1.0, 3.0];
872        assert_abs_diff_eq!(pen.value(s.view()), 0.0, epsilon = 1e-12);
873        // Now break consistency: s_1 = (0, 0). δs = (1, 3). Value = ½·2·(1+9) = 10.
874        let s2 = array![1.0_f64, 0.0, 0.0, 0.0];
875        assert_abs_diff_eq!(pen.value(s2.view()), 10.0, epsilon = 1e-12);
876    }
877}