use vyre_primitives::graph::sheaf::{sheaf_diffusion_step_cpu, sheaf_diffusion_step_cpu_into};
#[must_use]
pub fn diffuse_dispatch_stalks(stalks: &[f64], restriction_diag: &[f64], damping: f64) -> Vec<f64> {
use crate::observability::{bump, sheaf_heterophilic_dispatch_calls};
bump(&sheaf_heterophilic_dispatch_calls);
sheaf_diffusion_step_cpu(stalks, restriction_diag, damping)
}
pub fn diffuse_dispatch_stalks_into(
stalks: &[f64],
restriction_diag: &[f64],
damping: f64,
out: &mut Vec<f64>,
) {
use crate::observability::{bump, sheaf_heterophilic_dispatch_calls};
bump(&sheaf_heterophilic_dispatch_calls);
sheaf_diffusion_step_cpu_into(stalks, restriction_diag, damping, out);
}
#[must_use]
pub fn diffuse_to_equilibrium(
initial_stalks: &[f64],
restriction_diag: &[f64],
damping: f64,
tol: f64,
max_iters: u32,
) -> (Vec<f64>, u32) {
let mut current = Vec::with_capacity(initial_stalks.len());
let mut next = Vec::with_capacity(initial_stalks.len());
let iters = diffuse_to_equilibrium_into(
initial_stalks,
restriction_diag,
damping,
tol,
max_iters,
&mut current,
&mut next,
);
(current, iters)
}
pub fn diffuse_to_equilibrium_into(
initial_stalks: &[f64],
restriction_diag: &[f64],
damping: f64,
tol: f64,
max_iters: u32,
out: &mut Vec<f64>,
scratch: &mut Vec<f64>,
) -> u32 {
out.clear();
out.extend_from_slice(initial_stalks);
for iter in 0..max_iters {
diffuse_dispatch_stalks_into(out, restriction_diag, damping, scratch);
let max_change = scratch
.iter()
.zip(out.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
std::mem::swap(out, scratch);
if max_change < tol {
return iter + 1;
}
}
max_iters
}
#[must_use]
pub fn flag_fusion_incompatible(
initial_stalks: &[f64],
diffused_stalks: &[f64],
divergence_threshold: f64,
) -> Vec<u32> {
let mut out = Vec::new();
flag_fusion_incompatible_into(
initial_stalks,
diffused_stalks,
divergence_threshold,
&mut out,
);
out
}
pub fn flag_fusion_incompatible_into(
initial_stalks: &[f64],
diffused_stalks: &[f64],
divergence_threshold: f64,
out: &mut Vec<u32>,
) {
out.clear();
out.reserve(initial_stalks.len());
initial_stalks
.iter()
.zip(diffused_stalks.iter())
.map(|(&i, &d)| {
if (i - d).abs() > divergence_threshold {
1u32
} else {
0u32
}
})
.for_each(|flag| out.push(flag));
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-9 * (1.0 + a.abs() + b.abs())
}
#[test]
fn zero_damping_holds_initial() {
let s = vec![1.0, 2.0, 3.0];
let r = vec![0.5, 0.5, 0.5];
let out = diffuse_dispatch_stalks(&s, &r, 0.0);
for (a, b) in s.iter().zip(out.iter()) {
assert!(approx_eq(*a, *b));
}
}
#[test]
fn high_damping_drives_to_equilibrium() {
let s = vec![1.0, 1.0, 1.0];
let r = vec![1.0, 1.0, 1.0];
let (final_stalks, iters) = diffuse_to_equilibrium(&s, &r, 0.9, 1e-6, 100);
assert!(final_stalks.iter().all(|&v| v.abs() < 1.0));
assert!(iters < 100);
}
#[test]
fn flag_fusion_incompatible_threshold_zero_flags_all_changes() {
let initial = vec![1.0, 2.0, 3.0];
let diffused = vec![0.5, 2.0, 2.5];
let flags = flag_fusion_incompatible(&initial, &diffused, 0.0);
assert_eq!(flags, vec![1, 0, 1]);
}
#[test]
fn high_threshold_flags_nothing() {
let initial = vec![1.0, 2.0];
let diffused = vec![1.5, 2.5];
let flags = flag_fusion_incompatible(&initial, &diffused, 100.0);
assert_eq!(flags, vec![0, 0]);
}
#[test]
fn flag_fusion_incompatible_into_reuses_buffer() {
let initial = vec![1.0, 2.0, 3.0];
let diffused = vec![0.5, 2.0, 2.5];
let mut flags = Vec::with_capacity(8);
let ptr = flags.as_ptr();
flag_fusion_incompatible_into(&initial, &diffused, 0.0, &mut flags);
assert_eq!(flags, vec![1, 0, 1]);
assert_eq!(flags.as_ptr(), ptr);
}
#[test]
fn equilibrium_with_zero_max_iters_returns_initial() {
let s = vec![5.0, 10.0];
let r = vec![1.0, 1.0];
let (out, iters) = diffuse_to_equilibrium(&s, &r, 0.5, 1e-6, 0);
assert_eq!(out, s);
assert_eq!(iters, 0);
}
}