1use faer::Side;
39use ndarray::{Array1, Array2, ArrayView1};
40
41use crate::analytic_penalties::{AnalyticPenalty, PenaltyTier};
42use gam_linalg::faer_ndarray::FaerEigh;
43use gam_linalg::lanczos::{SymmetricLanczosOptions, symmetric_lanczos_eigenpairs};
44
45const DENSE_EIGH_DIM_THRESHOLD: usize = 4096;
50
51#[derive(Debug, Clone)]
57pub struct EdgeRestriction {
58 pub r_uv: Array2<f64>,
59 pub r_vu: Option<Array2<f64>>,
60}
61
62impl EdgeRestriction {
63 #[must_use]
65 pub fn paired(r_uv: Array2<f64>, r_vu: Array2<f64>) -> Self {
66 Self {
67 r_uv,
68 r_vu: Some(r_vu),
69 }
70 }
71
72 #[must_use]
74 pub fn single(r_uv: Array2<f64>) -> Self {
75 Self { r_uv, r_vu: None }
76 }
77
78 pub fn edge_dim(&self) -> usize {
80 self.r_uv.nrows()
81 }
82}
83
84#[derive(Debug, Clone)]
88pub struct SheafConsistencyPenalty {
89 edges: Vec<(usize, usize)>,
90 restrictions: Vec<EdgeRestriction>,
91 weight: f64,
92 stalk_offsets: Vec<usize>,
93 stalk_dims: Vec<usize>,
94}
95
96impl SheafConsistencyPenalty {
97 #[must_use = "build error must be handled"]
107 pub fn new(
108 edges: Vec<(usize, usize)>,
109 restrictions: Vec<EdgeRestriction>,
110 weight: f64,
111 stalk_dims: Vec<usize>,
112 ) -> Result<Self, String> {
113 if !(weight.is_finite() && weight > 0.0) {
114 return Err(format!(
115 "SheafConsistencyPenalty::new requires finite weight > 0, got {weight}"
116 ));
117 }
118 if edges.len() != restrictions.len() {
119 return Err(format!(
120 "SheafConsistencyPenalty::new edge count {} != restriction count {}",
121 edges.len(),
122 restrictions.len()
123 ));
124 }
125 if stalk_dims.is_empty() {
126 return Err("SheafConsistencyPenalty::new requires at least one vertex".into());
127 }
128 for (v, &d) in stalk_dims.iter().enumerate() {
129 if d == 0 {
130 return Err(format!(
131 "SheafConsistencyPenalty::new stalk dim at vertex {v} is zero"
132 ));
133 }
134 }
135 for (e, ((u, v), restriction)) in edges.iter().zip(restrictions.iter()).enumerate() {
136 if *u >= stalk_dims.len() || *v >= stalk_dims.len() {
137 return Err(format!(
138 "SheafConsistencyPenalty::new edge {e} = ({u}, {v}) references vertex \
139 out of range (K = {})",
140 stalk_dims.len()
141 ));
142 }
143 let d_u = stalk_dims[*u];
144 let d_v = stalk_dims[*v];
145 let d_e = restriction.r_uv.nrows();
146 if restriction.r_uv.ncols() != d_u {
147 return Err(format!(
148 "SheafConsistencyPenalty::new edge {e}: r_uv has {} cols, expected d_u = {d_u}",
149 restriction.r_uv.ncols()
150 ));
151 }
152 match &restriction.r_vu {
153 Some(r_vu) => {
154 if r_vu.ncols() != d_v {
155 return Err(format!(
156 "SheafConsistencyPenalty::new edge {e}: r_vu has {} cols, \
157 expected d_v = {d_v}",
158 r_vu.ncols()
159 ));
160 }
161 if r_vu.nrows() != d_e {
162 return Err(format!(
163 "SheafConsistencyPenalty::new edge {e}: r_vu has {} rows, \
164 expected d_e = {d_e}",
165 r_vu.nrows()
166 ));
167 }
168 }
169 None => {
170 if d_e != d_v {
171 return Err(format!(
172 "SheafConsistencyPenalty::new edge {e}: r_vu is identity but \
173 d_e ({d_e}) != d_v ({d_v})"
174 ));
175 }
176 }
177 }
178 if !restriction.r_uv.iter().all(|x| x.is_finite()) {
179 return Err(format!(
180 "SheafConsistencyPenalty::new edge {e}: r_uv contains non-finite entries"
181 ));
182 }
183 if let Some(r_vu) = &restriction.r_vu
184 && !r_vu.iter().all(|x| x.is_finite())
185 {
186 return Err(format!(
187 "SheafConsistencyPenalty::new edge {e}: r_vu contains non-finite entries"
188 ));
189 }
190 }
191 let mut stalk_offsets = Vec::with_capacity(stalk_dims.len() + 1);
192 let mut acc = 0usize;
193 for &d in &stalk_dims {
194 stalk_offsets.push(acc);
195 acc = acc.checked_add(d).ok_or_else(|| {
196 "SheafConsistencyPenalty::new stalk offsets overflow usize".to_string()
197 })?;
198 }
199 stalk_offsets.push(acc);
200 Ok(Self {
201 edges,
202 restrictions,
203 weight,
204 stalk_offsets,
205 stalk_dims,
206 })
207 }
208
209 pub fn total_dim(&self) -> usize {
211 *self.stalk_offsets.last().expect("offsets non-empty")
212 }
213
214 pub fn num_edges(&self) -> usize {
216 self.edges.len()
217 }
218
219 pub fn num_vertices(&self) -> usize {
221 self.stalk_dims.len()
222 }
223
224 pub fn stalk_dims(&self) -> &[usize] {
226 &self.stalk_dims
227 }
228
229 pub fn weight(&self) -> f64 {
231 self.weight
232 }
233
234 fn vertex_slice<'a>(&self, s: ArrayView1<'a, f64>, v: usize) -> ArrayView1<'a, f64> {
235 let start = self.stalk_offsets[v];
236 let end = self.stalk_offsets[v + 1];
237 s.slice_move(ndarray::s![start..end])
238 }
239
240 fn delta(&self, s: ArrayView1<'_, f64>) -> Vec<Array1<f64>> {
243 assert_eq!(
244 s.len(),
245 self.total_dim(),
246 "stacked stalk vector has wrong length",
247 );
248 let mut out = Vec::with_capacity(self.edges.len());
249 for (e, &(u, v)) in self.edges.iter().enumerate() {
250 let s_u = self.vertex_slice(s, u);
251 let s_v = self.vertex_slice(s, v);
252 let restriction = &self.restrictions[e];
253 let mut delta_e = restriction.r_uv.dot(&s_u);
255 match &restriction.r_vu {
257 Some(r_vu) => {
258 let r_vu_s_v = r_vu.dot(&s_v);
259 delta_e.scaled_add(-1.0, &r_vu_s_v);
260 }
261 None => {
262 delta_e.scaled_add(-1.0, &s_v);
263 }
264 }
265 out.push(delta_e);
266 }
267 out
268 }
269
270 fn delta_transpose(&self, y: &[Array1<f64>]) -> Array1<f64> {
273 assert_eq!(
274 y.len(),
275 self.edges.len(),
276 "delta_transpose edge count mismatch"
277 );
278 let mut out = Array1::<f64>::zeros(self.total_dim());
279 for (e, &(u, v)) in self.edges.iter().enumerate() {
280 let restriction = &self.restrictions[e];
281 let y_e = &y[e];
282 assert_eq!(y_e.len(), restriction.edge_dim(), "edge dim mismatch");
283 let contrib_u = restriction.r_uv.t().dot(y_e);
285 let u_start = self.stalk_offsets[u];
286 let u_end = self.stalk_offsets[u + 1];
287 {
288 let mut out_u = out.slice_mut(ndarray::s![u_start..u_end]);
289 out_u.scaled_add(1.0, &contrib_u);
290 }
291 let v_start = self.stalk_offsets[v];
293 let v_end = self.stalk_offsets[v + 1];
294 match &restriction.r_vu {
295 Some(r_vu) => {
296 let contrib_v = r_vu.t().dot(y_e);
297 let mut out_v = out.slice_mut(ndarray::s![v_start..v_end]);
298 out_v.scaled_add(-1.0, &contrib_v);
299 }
300 None => {
301 let mut out_v = out.slice_mut(ndarray::s![v_start..v_end]);
302 out_v.scaled_add(-1.0, y_e);
303 }
304 }
305 }
306 out
307 }
308
309 pub fn laplacian_apply(&self, s: ArrayView1<'_, f64>) -> Array1<f64> {
312 let ds = self.delta(s);
313 self.delta_transpose(&ds)
314 }
315
316 pub fn value(&self, s: ArrayView1<'_, f64>) -> f64 {
318 let ds = self.delta(s);
319 let mut sq = 0.0;
320 for de in &ds {
321 for &x in de.iter() {
322 sq += x * x;
323 }
324 }
325 0.5 * self.weight * sq
326 }
327
328 pub fn gradient(&self, s: ArrayView1<'_, f64>) -> Array1<f64> {
330 let mut g = self.laplacian_apply(s);
331 g *= self.weight;
332 g
333 }
334
335 pub fn hessian_diag(&self, s: ArrayView1<'_, f64>) -> Array1<f64> {
345 assert_eq!(
346 s.len(),
347 self.total_dim(),
348 "stacked stalk vector has wrong length",
349 );
350 let mut diag = Array1::<f64>::zeros(self.total_dim());
354 for (e, &(u, v)) in self.edges.iter().enumerate() {
355 let restriction = &self.restrictions[e];
356 let u_start = self.stalk_offsets[u];
357 let v_start = self.stalk_offsets[v];
358 let r_uv = &restriction.r_uv;
359
360 if u == v {
361 match &restriction.r_vu {
367 Some(r_vu) => {
368 for col in 0..r_uv.ncols() {
369 let mut s2 = 0.0;
370 for row in 0..r_uv.nrows() {
371 let diff = r_uv[[row, col]] - r_vu[[row, col]];
372 s2 += diff * diff;
373 }
374 diag[u_start + col] += s2;
375 }
376 }
377 None => {
378 let d = self.stalk_dims[u];
380 for col in 0..d {
381 let mut s2 = 0.0;
382 for row in 0..r_uv.nrows() {
383 let identity_entry = if row == col { 1.0 } else { 0.0 };
384 let diff = r_uv[[row, col]] - identity_entry;
385 s2 += diff * diff;
386 }
387 diag[u_start + col] += s2;
388 }
389 }
390 }
391 } else {
392 for col in 0..r_uv.ncols() {
396 let mut s2 = 0.0;
397 for row in 0..r_uv.nrows() {
398 let a = r_uv[[row, col]];
399 s2 += a * a;
400 }
401 diag[u_start + col] += s2;
402 }
403 match &restriction.r_vu {
404 Some(r_vu) => {
405 for col in 0..r_vu.ncols() {
406 let mut s2 = 0.0;
407 for row in 0..r_vu.nrows() {
408 let a = r_vu[[row, col]];
409 s2 += a * a;
410 }
411 diag[v_start + col] += s2;
412 }
413 }
414 None => {
415 let d_v = self.stalk_dims[v];
416 for col in 0..d_v {
417 diag[v_start + col] += 1.0;
418 }
419 }
420 }
421 }
422 }
423 diag *= self.weight;
424 diag
425 }
426
427 pub fn hvp(&self, s: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>) -> Array1<f64> {
431 assert_eq!(
432 s.len(),
433 self.total_dim(),
434 "stacked stalk vector has wrong length",
435 );
436 assert_eq!(v.len(), self.total_dim(), "hvp direction has wrong length");
437 let mut hv = self.laplacian_apply(v);
438 hv *= self.weight;
439 hv
440 }
441
442 fn dense_laplacian(&self) -> Array2<f64> {
449 let n = self.total_dim();
450 let mut l = Array2::<f64>::zeros((n, n));
451 let mut e = Array1::<f64>::zeros(n);
452 for j in 0..n {
453 e[j] = 1.0;
454 let col = self.laplacian_apply(e.view());
455 for i in 0..n {
456 l[[i, j]] = col[i];
457 }
458 e[j] = 0.0;
459 }
460 l
461 }
462
463 pub fn harmonic_modes(&self, tol: f64) -> usize {
471 assert!(
472 tol.is_finite() && tol >= 0.0,
473 "harmonic_modes requires finite non-negative tol, got {tol}",
474 );
475 let n = self.total_dim();
476 if n == 0 {
477 return 0;
478 }
479 if n <= DENSE_EIGH_DIM_THRESHOLD {
480 let l = self.dense_laplacian();
481 match l.eigh(Side::Lower) {
482 Ok((evals, _)) => evals.iter().filter(|&&e| e < tol).count(),
483 Err(err) => {
487 panic!("SheafConsistencyPenalty::harmonic_modes faer eigh failed: {err:?}")
488 }
489 }
490 } else {
491 self.harmonic_modes_lanczos(tol)
492 }
493 }
494
495 fn harmonic_modes_lanczos(&self, tol: f64) -> usize {
504 let n = self.total_dim();
505 let k = n.min(64).max(1);
506 let mut q0 = vec![0.0_f64; n];
508 for i in 0..n {
509 let mut state = (i as u64)
513 .wrapping_mul(0x9E37_79B9_7F4A_7C15)
514 .wrapping_sub(0x9E37_79B9_7F4A_7C15);
515 let z = gam_linalg::utils::splitmix64(&mut state);
516 q0[i] = (z as f64 / u64::MAX as f64) - 0.5;
517 }
518 match symmetric_lanczos_eigenpairs(
519 n,
520 &q0,
521 SymmetricLanczosOptions {
522 max_steps: k,
523 residual_tol: 1e-12,
524 local_reorthogonalize: true,
525 full_reorthogonalize: false,
526 },
527 |q, out| {
528 let w = self.laplacian_apply(ArrayView1::from(q));
529 out.copy_from_slice(w.as_slice().ok_or_else(|| {
530 "SheafConsistencyPenalty::harmonic_modes Lanczos matvec produced non-contiguous output"
531 .to_string()
532 })?);
533 Ok(())
534 },
535 ) {
536 Ok(eigen) => eigen.eigenvalues.iter().filter(|&&e| e < tol).count(),
537 Err(err) => {
538 panic!("SheafConsistencyPenalty::harmonic_modes Lanczos failed: {err}")
543 }
544 }
545 }
546}
547
548impl AnalyticPenalty for SheafConsistencyPenalty {
562 fn tier(&self) -> PenaltyTier {
563 PenaltyTier::Psi
564 }
565
566 fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
567 assert!(
568 rho.iter().all(|x| x.is_finite()),
569 "SheafConsistencyPenalty: rho must be finite (got {rho:?})",
570 );
571 SheafConsistencyPenalty::value(self, target)
572 }
573
574 fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
575 assert!(
576 rho.iter().all(|x| x.is_finite()),
577 "SheafConsistencyPenalty: rho must be finite (got {rho:?})",
578 );
579 SheafConsistencyPenalty::gradient(self, target)
580 }
581
582 fn hessian_diag(
583 &self,
584 target: ArrayView1<'_, f64>,
585 rho: ArrayView1<'_, f64>,
586 ) -> Option<Array1<f64>> {
587 assert!(
588 rho.iter().all(|x| x.is_finite()),
589 "SheafConsistencyPenalty: rho must be finite (got {rho:?})",
590 );
591 Some(SheafConsistencyPenalty::hessian_diag(self, target))
592 }
593
594 fn hvp(
595 &self,
596 target: ArrayView1<'_, f64>,
597 rho: ArrayView1<'_, f64>,
598 v: ArrayView1<'_, f64>,
599 ) -> Array1<f64> {
600 assert!(
601 rho.iter().all(|x| x.is_finite()),
602 "SheafConsistencyPenalty: rho must be finite (got {rho:?})",
603 );
604 SheafConsistencyPenalty::hvp(self, target, v)
605 }
606
607 fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
608 assert_eq!(
610 rho.len(),
611 0,
612 "SheafConsistencyPenalty: rho_count is 0 but rho has length {}",
613 rho.len(),
614 );
615 assert_eq!(
616 target.len(),
617 self.total_dim(),
618 "SheafConsistencyPenalty: target length {} != total stalk dim {}",
619 target.len(),
620 self.total_dim(),
621 );
622 Array1::<f64>::zeros(0)
623 }
624
625 fn rho_count(&self) -> usize {
626 0
627 }
628
629 fn name(&self) -> &str {
630 "SheafConsistencyPenalty"
631 }
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637 use approx::assert_abs_diff_eq;
638 use ndarray::array;
639
640 fn identity(d: usize) -> Array2<f64> {
641 let mut m = Array2::<f64>::zeros((d, d));
642 for i in 0..d {
643 m[[i, i]] = 1.0;
644 }
645 m
646 }
647
648 #[test]
649 fn single_edge_identity_restriction_value() {
650 let edges = vec![(0usize, 1usize)];
653 let restrictions = vec![EdgeRestriction::paired(identity(3), identity(3))];
654 let pen =
655 SheafConsistencyPenalty::new(edges, restrictions, 1.0, vec![3, 3]).expect("build");
656 let s = array![1.0_f64, 0.0, 0.0, 0.0, 1.0, 0.0];
657 let v = pen.value(s.view());
658 assert_abs_diff_eq!(v, 1.0, epsilon = 1e-12);
659 }
660
661 #[test]
662 fn gradient_matches_finite_difference_k2_random() {
663 let r_uv = array![[0.7_f64, -0.1, 0.3], [0.2, 0.9, -0.4]];
665 let r_vu = array![[1.0_f64, 0.5], [-0.3, 0.8]];
666 let edges = vec![(0usize, 1usize)];
667 let restrictions = vec![EdgeRestriction::paired(r_uv, r_vu)];
668 let pen =
669 SheafConsistencyPenalty::new(edges, restrictions, 0.3, vec![3, 2]).expect("build");
670 let s = array![0.4_f64, -1.1, 0.2, 0.6, -0.7];
671 let g = pen.gradient(s.view());
672 let eps = 1e-7;
673 for i in 0..s.len() {
674 let mut sp = s.clone();
675 let mut sm = s.clone();
676 sp[i] += eps;
677 sm[i] -= eps;
678 let fd = (pen.value(sp.view()) - pen.value(sm.view())) / (2.0 * eps);
679 assert_abs_diff_eq!(g[i], fd, epsilon = 1e-6);
680 }
681 }
682
683 #[test]
684 fn hvp_matches_reconstructed_laplacian_chain_k3() {
685 let r01_uv = array![[0.9_f64, 0.1], [-0.2, 0.7]];
688 let r01_vu = array![[1.0_f64, 0.0], [0.0, 1.0]];
689 let r12_uv = array![[0.5_f64, -0.3], [0.4, 0.8]];
690 let r12_vu = array![[0.6_f64, 0.0], [0.1, 1.1]];
691 let edges = vec![(0usize, 1usize), (1usize, 2usize)];
692 let restrictions = vec![
693 EdgeRestriction::paired(r01_uv, r01_vu),
694 EdgeRestriction::paired(r12_uv, r12_vu),
695 ];
696 let pen =
697 SheafConsistencyPenalty::new(edges, restrictions, 1.0, vec![2, 2, 2]).expect("build");
698 let l_dense = pen.dense_laplacian();
700 let n = pen.total_dim();
701 let s = array![0.1_f64, -0.2, 0.3, 0.4, -0.5, 0.6];
702 let v = array![0.7_f64, 0.2, -0.1, 0.5, 0.3, -0.4];
703 let hv = pen.hvp(s.view(), v.view());
704 let mut lv = Array1::<f64>::zeros(n);
706 for i in 0..n {
707 let mut acc = 0.0;
708 for j in 0..n {
709 acc += l_dense[[i, j]] * v[j];
710 }
711 lv[i] = acc;
712 }
713 for i in 0..n {
714 assert_abs_diff_eq!(hv[i], lv[i], epsilon = 1e-10);
715 }
716 }
717
718 #[test]
719 fn harmonic_modes_two_components_identity_restrictions() {
720 let pen = SheafConsistencyPenalty::new(vec![], vec![], 1.0, vec![3, 3]).expect("build");
722 let h = pen.harmonic_modes(1e-10);
723 assert_eq!(h, 6);
724
725 let edges = vec![(0usize, 1usize), (2usize, 3usize)];
729 let restrictions = vec![
730 EdgeRestriction::paired(identity(2), identity(2)),
731 EdgeRestriction::paired(identity(2), identity(2)),
732 ];
733 let pen2 = SheafConsistencyPenalty::new(edges, restrictions, 1.0, vec![2, 2, 2, 2])
734 .expect("build");
735 let h2 = pen2.harmonic_modes(1e-10);
736 assert_eq!(h2, 4);
737 }
738
739 #[test]
740 fn value_psd_and_zero_iff_kernel() {
741 let r01_uv = array![[0.9_f64, 0.1], [-0.2, 0.7]];
743 let r01_vu = array![[1.0_f64, 0.0], [0.0, 1.0]];
744 let edges = vec![(0usize, 1usize)];
745 let restrictions = vec![EdgeRestriction::paired(r01_uv.clone(), r01_vu.clone())];
746 let pen =
747 SheafConsistencyPenalty::new(edges, restrictions, 0.5, vec![2, 2]).expect("build");
748
749 let samples = [
751 array![0.0_f64, 0.0, 0.0, 0.0],
752 array![1.0_f64, 2.0, -0.5, 0.3],
753 array![-1.3_f64, 0.7, 0.2, -0.9],
754 ];
755 for s in &samples {
756 let v = pen.value(s.view());
757 assert!(v >= 0.0, "value must be non-negative, got {v}");
758 }
759 let z = Array1::<f64>::zeros(4);
761 assert_abs_diff_eq!(pen.value(z.view()), 0.0, epsilon = 1e-15);
762 let s0 = array![0.3_f64, -1.1];
765 let s1 = r01_uv.dot(&s0);
766 let mut s = Array1::<f64>::zeros(4);
767 s[0] = s0[0];
768 s[1] = s0[1];
769 s[2] = s1[0];
770 s[3] = s1[1];
771 assert_abs_diff_eq!(pen.value(s.view()), 0.0, epsilon = 1e-12);
772 }
773
774 #[test]
775 fn hessian_diag_matches_diag_of_dense_laplacian() {
776 let r_uv = array![[0.7_f64, -0.1, 0.3], [0.2, 0.9, -0.4]];
777 let r_vu = array![[1.0_f64, 0.5], [-0.3, 0.8]];
778 let edges = vec![(0usize, 1usize)];
779 let restrictions = vec![EdgeRestriction::paired(r_uv, r_vu)];
780 let pen =
781 SheafConsistencyPenalty::new(edges, restrictions, 0.3, vec![3, 2]).expect("build");
782 let n = pen.total_dim();
783 let s = Array1::<f64>::zeros(n);
784 let diag = pen.hessian_diag(s.view());
785 let l = pen.dense_laplacian();
786 for i in 0..n {
787 assert_abs_diff_eq!(diag[i], 0.3 * l[[i, i]], epsilon = 1e-12);
788 }
789 }
790
791 #[test]
792 fn hessian_diag_matches_dense_laplacian_on_self_loop_paired() {
793 let r_uv = array![[0.9_f64, 0.1], [-0.2, 0.7]];
797 let r_vu = array![[1.0_f64, 0.5], [-0.3, 0.8]];
798 let edges = vec![(0usize, 0usize)];
799 let restrictions = vec![EdgeRestriction::paired(r_uv, r_vu)];
800 let pen = SheafConsistencyPenalty::new(edges, restrictions, 0.7, vec![2]).expect("build");
801 let n = pen.total_dim();
802 let s = Array1::<f64>::zeros(n);
803 let diag = pen.hessian_diag(s.view());
804 let l = pen.dense_laplacian();
805 for i in 0..n {
806 assert_abs_diff_eq!(diag[i], 0.7 * l[[i, i]], epsilon = 1e-12);
807 }
808 }
809
810 #[test]
811 fn hessian_diag_matches_dense_laplacian_on_self_loop_single() {
812 let r_uv = array![[1.0_f64, 2.0], [3.0, 4.0]];
819 let edges = vec![(0usize, 0usize)];
820 let restrictions = vec![EdgeRestriction::single(r_uv)];
821 let pen = SheafConsistencyPenalty::new(edges, restrictions, 1.3, vec![2]).expect("build");
822 let n = pen.total_dim();
823 let s = Array1::<f64>::zeros(n);
824 let diag = pen.hessian_diag(s.view());
825 let l = pen.dense_laplacian();
826 for i in 0..n {
827 assert_abs_diff_eq!(diag[i], 1.3 * l[[i, i]], epsilon = 1e-12);
828 }
829 assert_abs_diff_eq!(diag[0], 1.3 * 9.0, epsilon = 1e-12);
832 assert_abs_diff_eq!(diag[1], 1.3 * 13.0, epsilon = 1e-12);
833 }
834
835 #[test]
836 fn hessian_diag_matches_dense_laplacian_mixed_self_loop_and_cross_edge() {
837 let r0_uv = array![[0.5_f64, -0.4], [0.3, 0.9]];
843 let r0_vu = array![[0.2_f64, 0.1], [-0.6, 0.7]];
844 let r1_uv = array![[1.1_f64, 0.2], [0.0, -0.5]];
845 let r1_vu = array![[0.8_f64, -0.1], [0.4, 1.0]];
846 let edges = vec![(0usize, 0usize), (0usize, 1usize)];
847 let restrictions = vec![
848 EdgeRestriction::paired(r0_uv, r0_vu),
849 EdgeRestriction::paired(r1_uv, r1_vu),
850 ];
851 let pen =
852 SheafConsistencyPenalty::new(edges, restrictions, 0.5, vec![2, 2]).expect("build");
853 let n = pen.total_dim();
854 let s = Array1::<f64>::zeros(n);
855 let diag = pen.hessian_diag(s.view());
856 let l = pen.dense_laplacian();
857 for i in 0..n {
858 assert_abs_diff_eq!(diag[i], 0.5 * l[[i, i]], epsilon = 1e-12);
859 }
860 }
861
862 #[test]
863 fn single_restriction_edge_form() {
864 let r = array![[1.0_f64, 2.0], [3.0, 4.0]];
866 let edges = vec![(0usize, 1usize)];
867 let restrictions = vec![EdgeRestriction::single(r.clone())];
868 let pen =
869 SheafConsistencyPenalty::new(edges, restrictions, 2.0, vec![2, 2]).expect("build");
870 let s = array![1.0_f64, 0.0, 1.0, 3.0];
872 assert_abs_diff_eq!(pen.value(s.view()), 0.0, epsilon = 1e-12);
873 let s2 = array![1.0_f64, 0.0, 0.0, 0.0];
875 assert_abs_diff_eq!(pen.value(s2.view()), 10.0, epsilon = 1e-12);
876 }
877}