use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::graph::sheaf_diffusion_step";
#[must_use]
pub fn sheaf_diffusion_step(
stalks: &str,
restriction_diag: &str,
damping_scaled: &str,
stalks_next: &str,
n_nodes: u32,
d: u32,
) -> Program {
if n_nodes == 0 {
return crate::invalid_output_program(
OP_ID,
stalks_next,
DataType::U32,
"Fix: sheaf_diffusion_step requires n_nodes > 0, got 0.".to_string(),
);
}
if d == 0 {
return crate::invalid_output_program(
OP_ID,
stalks_next,
DataType::U32,
format!("Fix: sheaf_diffusion_step requires d > 0, got {d}."),
);
}
let cells = n_nodes * d;
let t = Expr::InvocationId { axis: 0 };
let s = Expr::load(stalks, t.clone());
let r = Expr::load(restriction_diag, t.clone());
let d_v = Expr::load(damping_scaled, Expr::u32(0));
let damped_r = Expr::shr(Expr::mul(d_v, r), Expr::u32(16));
let delta = Expr::shr(Expr::mul(damped_r, s.clone()), Expr::u32(16));
let value = Expr::sub(s, delta);
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(cells)),
vec![Node::store(stalks_next, t, value)],
)];
Program::wrapped(
vec![
BufferDecl::storage(stalks, 0, BufferAccess::ReadOnly, DataType::U32).with_count(cells),
BufferDecl::storage(restriction_diag, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(cells),
BufferDecl::storage(damping_scaled, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
BufferDecl::storage(stalks_next, 3, BufferAccess::ReadWrite, DataType::U32)
.with_count(cells),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn sheaf_diffusion_step_cpu(
stalks: &[f64],
restriction_diag: &[f64],
damping: f64,
) -> Vec<f64> {
let mut out = Vec::with_capacity(stalks.len());
sheaf_diffusion_step_cpu_into(stalks, restriction_diag, damping, &mut out);
out
}
pub fn sheaf_diffusion_step_cpu_into(
stalks: &[f64],
restriction_diag: &[f64],
damping: f64,
out: &mut Vec<f64>,
) {
let n = stalks.len().min(restriction_diag.len());
out.clear();
out.reserve(n);
out.extend(
stalks
.iter()
.zip(restriction_diag.iter())
.take(n)
.map(|(&s, &r)| s - damping * r * s),
);
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| sheaf_diffusion_step("s", "rd", "dmp", "sn", 2, 2),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[65536, 65536, 65536, 65536]), to_bytes(&[0, 0, 0, 0]), to_bytes(&[65536]), to_bytes(&[0, 0, 0, 0]), ]]
}),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[65536, 65536, 65536, 65536]), ]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_zero_damping_holds_stalks() {
let s = vec![1.0, 2.0, 3.0];
let r = vec![0.5, 0.5, 0.5];
let out = sheaf_diffusion_step_cpu(&s, &r, 0.0);
assert_eq!(out, s);
}
#[test]
fn cpu_unit_restriction_full_damp_zeros() {
let s = vec![10.0, 20.0];
let r = vec![1.0, 1.0];
let out = sheaf_diffusion_step_cpu(&s, &r, 1.0);
assert!(approx_eq(out[0], 0.0));
assert!(approx_eq(out[1], 0.0));
}
#[test]
fn cpu_partial_damping_decreases_magnitude() {
let s = vec![10.0];
let r = vec![0.5];
let out = sheaf_diffusion_step_cpu(&s, &r, 0.5);
assert!(approx_eq(out[0], 7.5));
}
#[test]
fn cpu_mismatched_inputs_truncate_to_complete_pairs() {
let out = sheaf_diffusion_step_cpu(&[10.0, 4.0], &[0.5], 1.0);
assert_eq!(out, vec![5.0]);
}
#[test]
fn cpu_iterations_decay_to_zero_under_full_restriction() {
let mut s = vec![1.0];
let r = vec![1.0];
for _ in 0..100 {
s = sheaf_diffusion_step_cpu(&s, &r, 0.1);
}
assert!(s[0].abs() < 1e-3);
}
#[test]
fn ir_program_buffer_layout() {
let p = sheaf_diffusion_step("s", "rd", "dmp", "sn", 4, 8);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["s", "rd", "dmp", "sn"]);
assert_eq!(p.buffers[0].count(), 32);
assert_eq!(p.buffers[1].count(), 32);
assert_eq!(p.buffers[2].count(), 1);
assert_eq!(p.buffers[3].count(), 32);
}
#[test]
fn zero_n_nodes_traps() {
let p = sheaf_diffusion_step("s", "rd", "dmp", "sn", 0, 1);
assert!(p.stats().trap());
}
#[test]
fn zero_d_traps() {
let p = sheaf_diffusion_step("s", "rd", "dmp", "sn", 1, 0);
assert!(p.stats().trap());
}
}