use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::graph::chebyshev_filter";
pub const MAX_K: u32 = 16;
#[must_use]
pub fn chebyshev_filter(
laplacian: &str,
signal: &str,
coeffs: &str,
output: &str,
scratch: &str,
n: u32,
k_steps: u32,
) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
output,
DataType::U32,
format!("Fix: chebyshev_filter requires n > 0, got {n}."),
);
}
if k_steps > MAX_K {
return crate::invalid_output_program(
OP_ID,
output,
DataType::U32,
format!("Fix: chebyshev_filter k_steps must be <= MAX_K={MAX_K}, got {k_steps}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let t_prev_at = |i: Expr| Expr::load(scratch, i);
let t_curr_at = |i: Expr| Expr::load(scratch, Expr::add(i, Expr::u32(n)));
let t_prev_store = |i: Expr, v: Expr| Node::store(scratch, i, v);
let t_curr_store = |i: Expr, v: Expr| Node::store(scratch, Expr::add(i, Expr::u32(n)), v);
let row_base = Expr::mul(t.clone(), Expr::u32(n));
let lhat_row_dot_signal = {
Node::loop_for(
"j",
Expr::u32(0),
Expr::u32(n),
vec![Node::assign(
"lapsig",
Expr::add(
Expr::var("lapsig"),
Expr::mul(
Expr::load(laplacian, Expr::add(row_base.clone(), Expr::var("j"))),
Expr::load(signal, Expr::var("j")),
),
),
)],
)
};
let mut body = vec![
Node::let_bind("acc_out", Expr::u32(0)),
Node::assign(
"acc_out",
Expr::add(
Expr::var("acc_out"),
Expr::mul(
Expr::load(coeffs, Expr::u32(0)),
Expr::load(signal, t.clone()),
),
),
),
];
if k_steps >= 1 {
body.push(Node::let_bind("lapsig", Expr::u32(0)));
body.push(lhat_row_dot_signal);
body.push(t_prev_store(t.clone(), Expr::load(signal, t.clone())));
body.push(t_curr_store(t.clone(), Expr::var("lapsig")));
body.push(Node::assign(
"acc_out",
Expr::add(
Expr::var("acc_out"),
Expr::mul(Expr::load(coeffs, Expr::u32(1)), Expr::var("lapsig")),
),
));
}
if k_steps >= 2 {
body.push(Node::loop_for(
"k",
Expr::u32(2),
Expr::add(Expr::u32(k_steps), Expr::u32(1)),
vec![
Node::let_bind("lap_curr", Expr::u32(0)),
Node::loop_for(
"j",
Expr::u32(0),
Expr::u32(n),
vec![Node::assign(
"lap_curr",
Expr::add(
Expr::var("lap_curr"),
Expr::mul(
Expr::load(laplacian, Expr::add(row_base.clone(), Expr::var("j"))),
t_curr_at(Expr::var("j")),
),
),
)],
),
Node::let_bind(
"t_next",
Expr::sub(
Expr::mul(Expr::u32(2), Expr::var("lap_curr")),
t_prev_at(t.clone()),
),
),
Node::assign(
"acc_out",
Expr::add(
Expr::var("acc_out"),
Expr::mul(Expr::load(coeffs, Expr::var("k")), Expr::var("t_next")),
),
),
Node::let_bind("old_curr", t_curr_at(t.clone())),
t_curr_store(t.clone(), Expr::var("t_next")),
t_prev_store(t.clone(), Expr::var("old_curr")),
],
));
}
let body_with_bounds = vec![Node::if_then(Expr::lt(t.clone(), Expr::u32(n)), {
let mut all = body;
all.push(Node::store(output, t, Expr::var("acc_out")));
all
})];
Program::wrapped(
vec![
BufferDecl::storage(laplacian, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(n * n),
BufferDecl::storage(signal, 1, BufferAccess::ReadOnly, DataType::U32).with_count(n),
BufferDecl::storage(coeffs, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(k_steps + 1),
BufferDecl::storage(output, 3, BufferAccess::ReadWrite, DataType::U32).with_count(n),
BufferDecl::storage(scratch, 4, BufferAccess::ReadWrite, DataType::U32)
.with_count(2 * n),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body_with_bounds),
}],
)
}
#[must_use]
pub fn chebyshev_filter_cpu(
laplacian: &[f32],
signal: &[f32],
coeffs: &[f32],
n: u32,
k_steps: u32,
) -> Vec<f32> {
let mut out = Vec::new();
let mut t_prev = Vec::new();
let mut t_curr = Vec::new();
let mut t_next = Vec::new();
chebyshev_filter_cpu_into(
laplacian,
signal,
coeffs,
n,
k_steps,
&mut out,
&mut t_prev,
&mut t_curr,
&mut t_next,
);
out
}
#[allow(clippy::too_many_arguments)]
pub fn chebyshev_filter_cpu_into(
laplacian: &[f32],
signal: &[f32],
coeffs: &[f32],
n: u32,
k_steps: u32,
out: &mut Vec<f32>,
t_prev: &mut Vec<f32>,
t_curr: &mut Vec<f32>,
t_next: &mut Vec<f32>,
) {
let n = n as usize;
let c0 = coeffs.first().copied().unwrap_or(0.0);
out.clear();
out.reserve(n);
out.extend((0..n).map(|idx| c0 * signal.get(idx).copied().unwrap_or(0.0)));
if k_steps == 0 {
return;
}
t_prev.clear();
t_prev.extend((0..n).map(|idx| signal.get(idx).copied().unwrap_or(0.0)));
t_curr.clear();
t_curr.resize(n, 0.0);
for i in 0..n {
for j in 0..n {
t_curr[i] += laplacian.get(i * n + j).copied().unwrap_or(0.0) * t_prev[j];
}
}
let c1 = coeffs.get(1).copied().unwrap_or(0.0);
for i in 0..n {
out[i] += c1 * t_curr[i];
}
for &c_k in coeffs.iter().take(k_steps as usize + 1).skip(2) {
t_next.clear();
t_next.resize(n, 0.0);
for i in 0..n {
for j in 0..n {
t_next[i] += laplacian.get(i * n + j).copied().unwrap_or(0.0) * t_curr[j];
}
}
for i in 0..n {
t_next[i] = 2.0 * t_next[i] - t_prev[i];
}
for i in 0..n {
out[i] += c_k * t_next[i];
}
std::mem::swap(t_prev, t_curr);
std::mem::swap(t_curr, t_next);
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f32 = 1e-4;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < EPS * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_k0_returns_scaled_signal() {
let l = vec![0.0; 4]; let x = vec![1.0, 2.0];
let c = vec![3.0];
let out = chebyshev_filter_cpu(&l, &x, &c, 2, 0);
assert_eq!(out, vec![3.0, 6.0]);
}
#[test]
fn cpu_k1_recovers_linear_filter() {
let l = vec![0.5, 0.0, 0.0, 0.5];
let x = vec![1.0, 1.0];
let c = vec![0.0, 1.0];
let out = chebyshev_filter_cpu(&l, &x, &c, 2, 1);
assert!(approx_eq(out[0], 0.5));
assert!(approx_eq(out[1], 0.5));
}
#[test]
fn cpu_recurrence_t2_matches_definition() {
let l = vec![0.5, 0.0, 0.0, 0.5];
let x = vec![1.0, 1.0];
let c = vec![0.0, 0.0, 1.0];
let out = chebyshev_filter_cpu(&l, &x, &c, 2, 2);
assert!(approx_eq(out[0], -0.5));
assert!(approx_eq(out[1], -0.5));
}
#[test]
fn cpu_recurrence_zero_signal_stays_zero() {
let l = vec![1.0; 16];
let x = vec![0.0; 4];
let c = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let out = chebyshev_filter_cpu(&l, &x, &c, 4, 4);
for v in out {
assert!(approx_eq(v, 0.0));
}
}
#[test]
fn cpu_ref_into_reuses_recurrence_buffers() {
let l = vec![0.5, 0.0, 0.0, 0.5];
let x = vec![1.0, 1.0];
let c = vec![0.0, 0.0, 1.0];
let mut out = Vec::with_capacity(8);
let mut t_prev = Vec::with_capacity(8);
let mut t_curr = Vec::with_capacity(8);
let mut t_next = Vec::with_capacity(8);
let pointers = [
out.as_ptr(),
t_prev.as_ptr(),
t_curr.as_ptr(),
t_next.as_ptr(),
];
chebyshev_filter_cpu_into(
&l,
&x,
&c,
2,
2,
&mut out,
&mut t_prev,
&mut t_curr,
&mut t_next,
);
assert!(approx_eq(out[0], -0.5));
assert!(approx_eq(out[1], -0.5));
let after = [
out.as_ptr(),
t_prev.as_ptr(),
t_curr.as_ptr(),
t_next.as_ptr(),
];
for ptr in after {
assert!(pointers.contains(&ptr));
}
}
#[test]
fn cpu_short_inputs_are_zero_padded() {
let out = chebyshev_filter_cpu(&[1.0], &[2.0], &[], 2, 1);
assert_eq!(out, vec![0.0, 0.0]);
}
#[test]
fn emitted_program_buffer_layout() {
let p = chebyshev_filter("L", "x", "c", "y", "s", 8, 3);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["L", "x", "c", "y", "s"]);
assert_eq!(p.buffers[0].count(), 8 * 8); assert_eq!(p.buffers[1].count(), 8); assert_eq!(p.buffers[2].count(), 4); assert_eq!(p.buffers[3].count(), 8); assert_eq!(p.buffers[4].count(), 16); }
#[test]
fn emitted_program_zero_k_works() {
let p = chebyshev_filter("L", "x", "c", "y", "s", 4, 0);
assert_eq!(p.buffers[2].count(), 1);
}
#[test]
fn zero_n_traps() {
let p = chebyshev_filter("L", "x", "c", "y", "s", 0, 1);
assert!(p.stats().trap());
}
#[test]
fn k_over_max_traps() {
let p = chebyshev_filter("L", "x", "c", "y", "s", 4, MAX_K + 1);
assert!(p.stats().trap());
}
}