use faer::Side;
use ndarray::{Array1, Array2, ArrayView1};
use crate::linalg::faer_ndarray::FaerEigh;
use crate::terms::analytic_penalties::{AnalyticPenalty, PenaltyTier};
const DENSE_EIGH_DIM_THRESHOLD: usize = 4096;
#[derive(Debug, Clone)]
pub struct EdgeRestriction {
pub r_uv: Array2<f64>,
pub r_vu: Option<Array2<f64>>,
}
impl EdgeRestriction {
#[must_use]
pub fn paired(r_uv: Array2<f64>, r_vu: Array2<f64>) -> Self {
Self {
r_uv,
r_vu: Some(r_vu),
}
}
#[must_use]
pub fn single(r_uv: Array2<f64>) -> Self {
Self { r_uv, r_vu: None }
}
pub fn edge_dim(&self) -> usize {
self.r_uv.nrows()
}
}
#[derive(Debug, Clone)]
pub struct SheafConsistencyPenalty {
edges: Vec<(usize, usize)>,
restrictions: Vec<EdgeRestriction>,
weight: f64,
stalk_offsets: Vec<usize>,
stalk_dims: Vec<usize>,
}
impl SheafConsistencyPenalty {
#[must_use = "build error must be handled"]
pub fn new(
edges: Vec<(usize, usize)>,
restrictions: Vec<EdgeRestriction>,
weight: f64,
stalk_dims: Vec<usize>,
) -> Result<Self, String> {
if !(weight.is_finite() && weight > 0.0) {
return Err(format!(
"SheafConsistencyPenalty::new requires finite weight > 0, got {weight}"
));
}
if edges.len() != restrictions.len() {
return Err(format!(
"SheafConsistencyPenalty::new edge count {} != restriction count {}",
edges.len(),
restrictions.len()
));
}
if stalk_dims.is_empty() {
return Err("SheafConsistencyPenalty::new requires at least one vertex".into());
}
for (v, &d) in stalk_dims.iter().enumerate() {
if d == 0 {
return Err(format!(
"SheafConsistencyPenalty::new stalk dim at vertex {v} is zero"
));
}
}
for (e, ((u, v), restriction)) in edges.iter().zip(restrictions.iter()).enumerate() {
if *u >= stalk_dims.len() || *v >= stalk_dims.len() {
return Err(format!(
"SheafConsistencyPenalty::new edge {e} = ({u}, {v}) references vertex \
out of range (K = {})",
stalk_dims.len()
));
}
let d_u = stalk_dims[*u];
let d_v = stalk_dims[*v];
let d_e = restriction.r_uv.nrows();
if restriction.r_uv.ncols() != d_u {
return Err(format!(
"SheafConsistencyPenalty::new edge {e}: r_uv has {} cols, expected d_u = {d_u}",
restriction.r_uv.ncols()
));
}
match &restriction.r_vu {
Some(r_vu) => {
if r_vu.ncols() != d_v {
return Err(format!(
"SheafConsistencyPenalty::new edge {e}: r_vu has {} cols, \
expected d_v = {d_v}",
r_vu.ncols()
));
}
if r_vu.nrows() != d_e {
return Err(format!(
"SheafConsistencyPenalty::new edge {e}: r_vu has {} rows, \
expected d_e = {d_e}",
r_vu.nrows()
));
}
}
None => {
if d_e != d_v {
return Err(format!(
"SheafConsistencyPenalty::new edge {e}: r_vu is identity but \
d_e ({d_e}) != d_v ({d_v})"
));
}
}
}
if !restriction.r_uv.iter().all(|x| x.is_finite()) {
return Err(format!(
"SheafConsistencyPenalty::new edge {e}: r_uv contains non-finite entries"
));
}
if let Some(r_vu) = &restriction.r_vu
&& !r_vu.iter().all(|x| x.is_finite())
{
return Err(format!(
"SheafConsistencyPenalty::new edge {e}: r_vu contains non-finite entries"
));
}
}
let mut stalk_offsets = Vec::with_capacity(stalk_dims.len() + 1);
let mut acc = 0usize;
for &d in &stalk_dims {
stalk_offsets.push(acc);
acc = acc.checked_add(d).ok_or_else(|| {
"SheafConsistencyPenalty::new stalk offsets overflow usize".to_string()
})?;
}
stalk_offsets.push(acc);
Ok(Self {
edges,
restrictions,
weight,
stalk_offsets,
stalk_dims,
})
}
pub fn total_dim(&self) -> usize {
*self.stalk_offsets.last().expect("offsets non-empty")
}
pub fn num_edges(&self) -> usize {
self.edges.len()
}
pub fn num_vertices(&self) -> usize {
self.stalk_dims.len()
}
pub fn stalk_dims(&self) -> &[usize] {
&self.stalk_dims
}
pub fn weight(&self) -> f64 {
self.weight
}
fn vertex_slice<'a>(&self, s: ArrayView1<'a, f64>, v: usize) -> ArrayView1<'a, f64> {
let start = self.stalk_offsets[v];
let end = self.stalk_offsets[v + 1];
s.slice_move(ndarray::s![start..end])
}
fn delta(&self, s: ArrayView1<'_, f64>) -> Vec<Array1<f64>> {
assert_eq!(
s.len(),
self.total_dim(),
"stacked stalk vector has wrong length",
);
let mut out = Vec::with_capacity(self.edges.len());
for (e, &(u, v)) in self.edges.iter().enumerate() {
let s_u = self.vertex_slice(s, u);
let s_v = self.vertex_slice(s, v);
let restriction = &self.restrictions[e];
let mut delta_e = restriction.r_uv.dot(&s_u);
match &restriction.r_vu {
Some(r_vu) => {
let r_vu_s_v = r_vu.dot(&s_v);
delta_e.scaled_add(-1.0, &r_vu_s_v);
}
None => {
delta_e.scaled_add(-1.0, &s_v);
}
}
out.push(delta_e);
}
out
}
fn delta_transpose(&self, y: &[Array1<f64>]) -> Array1<f64> {
assert_eq!(
y.len(),
self.edges.len(),
"delta_transpose edge count mismatch"
);
let mut out = Array1::<f64>::zeros(self.total_dim());
for (e, &(u, v)) in self.edges.iter().enumerate() {
let restriction = &self.restrictions[e];
let y_e = &y[e];
assert_eq!(y_e.len(), restriction.edge_dim(), "edge dim mismatch");
let contrib_u = restriction.r_uv.t().dot(y_e);
let u_start = self.stalk_offsets[u];
let u_end = self.stalk_offsets[u + 1];
{
let mut out_u = out.slice_mut(ndarray::s![u_start..u_end]);
out_u.scaled_add(1.0, &contrib_u);
}
let v_start = self.stalk_offsets[v];
let v_end = self.stalk_offsets[v + 1];
match &restriction.r_vu {
Some(r_vu) => {
let contrib_v = r_vu.t().dot(y_e);
let mut out_v = out.slice_mut(ndarray::s![v_start..v_end]);
out_v.scaled_add(-1.0, &contrib_v);
}
None => {
let mut out_v = out.slice_mut(ndarray::s![v_start..v_end]);
out_v.scaled_add(-1.0, y_e);
}
}
}
out
}
pub fn laplacian_apply(&self, s: ArrayView1<'_, f64>) -> Array1<f64> {
let ds = self.delta(s);
self.delta_transpose(&ds)
}
pub fn value(&self, s: ArrayView1<'_, f64>) -> f64 {
let ds = self.delta(s);
let mut sq = 0.0;
for de in &ds {
for &x in de.iter() {
sq += x * x;
}
}
0.5 * self.weight * sq
}
pub fn gradient(&self, s: ArrayView1<'_, f64>) -> Array1<f64> {
let mut g = self.laplacian_apply(s);
g *= self.weight;
g
}
pub fn hessian_diag(&self, s: ArrayView1<'_, f64>) -> Array1<f64> {
assert_eq!(
s.len(),
self.total_dim(),
"stacked stalk vector has wrong length",
);
let mut diag = Array1::<f64>::zeros(self.total_dim());
for (e, &(u, v)) in self.edges.iter().enumerate() {
let restriction = &self.restrictions[e];
let u_start = self.stalk_offsets[u];
let r_uv = &restriction.r_uv;
for col in 0..r_uv.ncols() {
let mut s2 = 0.0;
for row in 0..r_uv.nrows() {
let a = r_uv[[row, col]];
s2 += a * a;
}
diag[u_start + col] += s2;
}
let v_start = self.stalk_offsets[v];
match &restriction.r_vu {
Some(r_vu) => {
for col in 0..r_vu.ncols() {
let mut s2 = 0.0;
for row in 0..r_vu.nrows() {
let a = r_vu[[row, col]];
s2 += a * a;
}
diag[v_start + col] += s2;
}
}
None => {
let d_v = self.stalk_dims[v];
for col in 0..d_v {
diag[v_start + col] += 1.0;
}
}
}
}
diag *= self.weight;
diag
}
pub fn hvp(&self, s: ArrayView1<'_, f64>, v: ArrayView1<'_, f64>) -> Array1<f64> {
assert_eq!(
s.len(),
self.total_dim(),
"stacked stalk vector has wrong length",
);
assert_eq!(v.len(), self.total_dim(), "hvp direction has wrong length");
let mut hv = self.laplacian_apply(v);
hv *= self.weight;
hv
}
fn dense_laplacian(&self) -> Array2<f64> {
let n = self.total_dim();
let mut l = Array2::<f64>::zeros((n, n));
let mut e = Array1::<f64>::zeros(n);
for j in 0..n {
e[j] = 1.0;
let col = self.laplacian_apply(e.view());
for i in 0..n {
l[[i, j]] = col[i];
}
e[j] = 0.0;
}
l
}
pub fn harmonic_modes(&self, tol: f64) -> usize {
assert!(
tol.is_finite() && tol >= 0.0,
"harmonic_modes requires finite non-negative tol, got {tol}",
);
let n = self.total_dim();
if n == 0 {
return 0;
}
if n <= DENSE_EIGH_DIM_THRESHOLD {
let l = self.dense_laplacian();
match l.eigh(Side::Lower) {
Ok((evals, _)) => evals.iter().filter(|&&e| e < tol).count(),
Err(err) => {
panic!("SheafConsistencyPenalty::harmonic_modes faer eigh failed: {err:?}")
}
}
} else {
self.harmonic_modes_lanczos(tol)
}
}
fn harmonic_modes_lanczos(&self, tol: f64) -> usize {
let n = self.total_dim();
let k = n.min(64).max(1);
let mut q_prev = Array1::<f64>::zeros(n);
let mut q_curr = Array1::<f64>::zeros(n);
for i in 0..n {
let mut z = (i as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15);
z ^= z >> 30;
z = z.wrapping_mul(0xBF58_476D_1CE4_E5B9);
z ^= z >> 27;
z = z.wrapping_mul(0x94D0_49BB_1331_11EB);
z ^= z >> 31;
q_curr[i] = (z as f64 / u64::MAX as f64) - 0.5;
}
let nrm = (q_curr.iter().map(|x| x * x).sum::<f64>())
.sqrt()
.max(1e-300);
q_curr.mapv_inplace(|x| x / nrm);
let mut alphas = Vec::with_capacity(k);
let mut betas = Vec::with_capacity(k);
let mut beta_prev: f64 = 0.0;
for _ in 0..k {
let mut w = self.laplacian_apply(q_curr.view());
let alpha = q_curr.iter().zip(w.iter()).map(|(a, b)| a * b).sum::<f64>();
w.scaled_add(-alpha, &q_curr);
if beta_prev != 0.0 {
w.scaled_add(-beta_prev, &q_prev);
}
let proj_curr = q_curr.iter().zip(w.iter()).map(|(a, b)| a * b).sum::<f64>();
w.scaled_add(-proj_curr, &q_curr);
let proj_prev = q_prev.iter().zip(w.iter()).map(|(a, b)| a * b).sum::<f64>();
w.scaled_add(-proj_prev, &q_prev);
let beta = (w.iter().map(|x| x * x).sum::<f64>()).sqrt();
alphas.push(alpha);
betas.push(beta);
if beta < 1e-12 {
break;
}
q_prev = q_curr.clone();
q_curr = w.mapv(|x| x / beta);
beta_prev = beta;
}
let m = alphas.len();
let mut t = Array2::<f64>::zeros((m, m));
for i in 0..m {
t[[i, i]] = alphas[i];
if i + 1 < m {
t[[i, i + 1]] = betas[i];
t[[i + 1, i]] = betas[i];
}
}
match t.eigh(Side::Lower) {
Ok((evals, _)) => evals.iter().filter(|&&e| e < tol).count(),
Err(err) => {
panic!("SheafConsistencyPenalty::harmonic_modes Lanczos eigh failed: {err:?}")
}
}
}
}
impl AnalyticPenalty for SheafConsistencyPenalty {
fn tier(&self) -> PenaltyTier {
PenaltyTier::Psi
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
assert!(
rho.iter().all(|x| x.is_finite()),
"SheafConsistencyPenalty: rho must be finite (got {rho:?})",
);
SheafConsistencyPenalty::value(self, target)
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
assert!(
rho.iter().all(|x| x.is_finite()),
"SheafConsistencyPenalty: rho must be finite (got {rho:?})",
);
SheafConsistencyPenalty::gradient(self, target)
}
fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
assert!(
rho.iter().all(|x| x.is_finite()),
"SheafConsistencyPenalty: rho must be finite (got {rho:?})",
);
Some(SheafConsistencyPenalty::hessian_diag(self, target))
}
fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert!(
rho.iter().all(|x| x.is_finite()),
"SheafConsistencyPenalty: rho must be finite (got {rho:?})",
);
SheafConsistencyPenalty::hvp(self, target, v)
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
assert_eq!(
rho.len(),
0,
"SheafConsistencyPenalty: rho_count is 0 but rho has length {}",
rho.len(),
);
assert_eq!(
target.len(),
self.total_dim(),
"SheafConsistencyPenalty: target length {} != total stalk dim {}",
target.len(),
self.total_dim(),
);
Array1::<f64>::zeros(0)
}
fn rho_count(&self) -> usize {
0
}
fn name(&self) -> &str {
"SheafConsistencyPenalty"
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
fn identity(d: usize) -> Array2<f64> {
let mut m = Array2::<f64>::zeros((d, d));
for i in 0..d {
m[[i, i]] = 1.0;
}
m
}
#[test]
fn single_edge_identity_restriction_value() {
let edges = vec![(0usize, 1usize)];
let restrictions = vec![EdgeRestriction::paired(identity(3), identity(3))];
let pen =
SheafConsistencyPenalty::new(edges, restrictions, 1.0, vec![3, 3]).expect("build");
let s = array![1.0_f64, 0.0, 0.0, 0.0, 1.0, 0.0];
let v = pen.value(s.view());
assert_abs_diff_eq!(v, 1.0, epsilon = 1e-12);
}
#[test]
fn gradient_matches_finite_difference_k2_random() {
let r_uv = array![[0.7_f64, -0.1, 0.3], [0.2, 0.9, -0.4]];
let r_vu = array![[1.0_f64, 0.5], [-0.3, 0.8]];
let edges = vec![(0usize, 1usize)];
let restrictions = vec![EdgeRestriction::paired(r_uv, r_vu)];
let pen =
SheafConsistencyPenalty::new(edges, restrictions, 0.3, vec![3, 2]).expect("build");
let s = array![0.4_f64, -1.1, 0.2, 0.6, -0.7];
let g = pen.gradient(s.view());
let eps = 1e-7;
for i in 0..s.len() {
let mut sp = s.clone();
let mut sm = s.clone();
sp[i] += eps;
sm[i] -= eps;
let fd = (pen.value(sp.view()) - pen.value(sm.view())) / (2.0 * eps);
assert_abs_diff_eq!(g[i], fd, epsilon = 1e-6);
}
}
#[test]
fn hvp_matches_reconstructed_laplacian_chain_k3() {
let r01_uv = array![[0.9_f64, 0.1], [-0.2, 0.7]];
let r01_vu = array![[1.0_f64, 0.0], [0.0, 1.0]];
let r12_uv = array![[0.5_f64, -0.3], [0.4, 0.8]];
let r12_vu = array![[0.6_f64, 0.0], [0.1, 1.1]];
let edges = vec![(0usize, 1usize), (1usize, 2usize)];
let restrictions = vec![
EdgeRestriction::paired(r01_uv, r01_vu),
EdgeRestriction::paired(r12_uv, r12_vu),
];
let pen =
SheafConsistencyPenalty::new(edges, restrictions, 1.0, vec![2, 2, 2]).expect("build");
let l_dense = pen.dense_laplacian();
let n = pen.total_dim();
let s = array![0.1_f64, -0.2, 0.3, 0.4, -0.5, 0.6];
let v = array![0.7_f64, 0.2, -0.1, 0.5, 0.3, -0.4];
let hv = pen.hvp(s.view(), v.view());
let mut lv = Array1::<f64>::zeros(n);
for i in 0..n {
let mut acc = 0.0;
for j in 0..n {
acc += l_dense[[i, j]] * v[j];
}
lv[i] = acc;
}
for i in 0..n {
assert_abs_diff_eq!(hv[i], lv[i], epsilon = 1e-10);
}
}
#[test]
fn harmonic_modes_two_components_identity_restrictions() {
let pen = SheafConsistencyPenalty::new(vec![], vec![], 1.0, vec![3, 3]).expect("build");
let h = pen.harmonic_modes(1e-10);
assert_eq!(h, 6);
let edges = vec![(0usize, 1usize), (2usize, 3usize)];
let restrictions = vec![
EdgeRestriction::paired(identity(2), identity(2)),
EdgeRestriction::paired(identity(2), identity(2)),
];
let pen2 = SheafConsistencyPenalty::new(edges, restrictions, 1.0, vec![2, 2, 2, 2])
.expect("build");
let h2 = pen2.harmonic_modes(1e-10);
assert_eq!(h2, 4);
}
#[test]
fn value_psd_and_zero_iff_kernel() {
let r01_uv = array![[0.9_f64, 0.1], [-0.2, 0.7]];
let r01_vu = array![[1.0_f64, 0.0], [0.0, 1.0]];
let edges = vec![(0usize, 1usize)];
let restrictions = vec![EdgeRestriction::paired(r01_uv.clone(), r01_vu.clone())];
let pen =
SheafConsistencyPenalty::new(edges, restrictions, 0.5, vec![2, 2]).expect("build");
let samples = [
array![0.0_f64, 0.0, 0.0, 0.0],
array![1.0_f64, 2.0, -0.5, 0.3],
array![-1.3_f64, 0.7, 0.2, -0.9],
];
for s in &samples {
let v = pen.value(s.view());
assert!(v >= 0.0, "value must be non-negative, got {v}");
}
let z = Array1::<f64>::zeros(4);
assert_abs_diff_eq!(pen.value(z.view()), 0.0, epsilon = 1e-15);
let s0 = array![0.3_f64, -1.1];
let s1 = r01_uv.dot(&s0);
let mut s = Array1::<f64>::zeros(4);
s[0] = s0[0];
s[1] = s0[1];
s[2] = s1[0];
s[3] = s1[1];
assert_abs_diff_eq!(pen.value(s.view()), 0.0, epsilon = 1e-12);
}
#[test]
fn hessian_diag_matches_diag_of_dense_laplacian() {
let r_uv = array![[0.7_f64, -0.1, 0.3], [0.2, 0.9, -0.4]];
let r_vu = array![[1.0_f64, 0.5], [-0.3, 0.8]];
let edges = vec![(0usize, 1usize)];
let restrictions = vec![EdgeRestriction::paired(r_uv, r_vu)];
let pen =
SheafConsistencyPenalty::new(edges, restrictions, 0.3, vec![3, 2]).expect("build");
let n = pen.total_dim();
let s = Array1::<f64>::zeros(n);
let diag = pen.hessian_diag(s.view());
let l = pen.dense_laplacian();
for i in 0..n {
assert_abs_diff_eq!(diag[i], 0.3 * l[[i, i]], epsilon = 1e-12);
}
}
#[test]
fn single_restriction_edge_form() {
let r = array![[1.0_f64, 2.0], [3.0, 4.0]];
let edges = vec![(0usize, 1usize)];
let restrictions = vec![EdgeRestriction::single(r.clone())];
let pen =
SheafConsistencyPenalty::new(edges, restrictions, 2.0, vec![2, 2]).expect("build");
let s = array![1.0_f64, 0.0, 1.0, 3.0];
assert_abs_diff_eq!(pen.value(s.view()), 0.0, epsilon = 1e-12);
let s2 = array![1.0_f64, 0.0, 0.0, 0.0];
assert_abs_diff_eq!(pen.value(s2.view()), 10.0, epsilon = 1e-12);
}
}