use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::graph::level_wave";
#[must_use]
pub fn level_wave_program(
step_body: Vec<Node>,
depth_buf: &str,
max_depth: u32,
lane_count: u32,
) -> Program {
let lane = Expr::InvocationId { axis: 0 };
let depth_for_lane = Expr::load(depth_buf, lane.clone());
let body = vec![Node::loop_for(
"__lw_depth__",
Expr::u32(0),
Expr::u32(max_depth),
vec![
Node::if_then(
Expr::and(
Expr::lt(lane.clone(), Expr::u32(lane_count)),
Expr::eq(depth_for_lane.clone(), Expr::var("__lw_depth__")),
),
step_body.clone(),
),
Node::Barrier {
ordering: vyre_foundation::MemoryOrdering::SeqCst,
},
],
)];
Program::wrapped(
vec![
BufferDecl::storage(depth_buf, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(lane_count),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
pub fn cpu_ref<F>(depths: &[u32], max_depth: u32, mut step_for_lane: F)
where
F: FnMut(u32, u32),
{
for current_depth in 0..max_depth {
for (lane_idx, lane_depth) in depths.iter().enumerate() {
if *lane_depth == current_depth {
step_for_lane(lane_idx as u32, current_depth);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_ref_visits_each_lane_at_its_depth() {
let depths = vec![0u32, 1, 2, 1, 0];
let mut visits: Vec<(u32, u32)> = Vec::new();
cpu_ref(&depths, 3, |lane, depth| visits.push((lane, depth)));
assert_eq!(visits.len(), depths.len());
for (idx, &(lane, depth)) in visits.iter().enumerate() {
assert_eq!(depth, depths[lane as usize]);
if idx > 0 {
assert!(depth >= visits[idx - 1].1);
}
}
}
#[test]
fn program_shape_matches_contract() {
let step = vec![Node::store("out", Expr::u32(0), Expr::u32(1))];
let program = level_wave_program(step, "depths", 8, 64);
assert!(
program.buffers.iter().any(|b| b.name() == "depths"),
"depth buffer must be declared"
);
}
}