use vyre::ir::Program;
use vyre_primitives::math::prefix_scan::{prefix_scan_with_op_id, ScanKind};
use vyre_primitives::reduce::multi_block_prefix_scan::multi_block_prefix_scan_sum_u32;
const OP_ID: &str = "vyre-libs::math::scan_prefix_sum";
#[must_use]
pub fn scan_prefix_sum(input: &str, output: &str, n: u32) -> Program {
if n == 0 {
return crate::builder::invalid_output_program(
OP_ID,
output,
vyre::ir::DataType::U32,
"Fix: scan_prefix_sum requires n > 0.".to_string(),
);
}
if (1..=1024).contains(&n) {
prefix_scan_with_op_id(input, output, n, ScanKind::InclusiveSum, OP_ID)
} else {
wrap_large_scan_program(multi_block_prefix_scan_sum_u32(input, output, n))
}
}
fn wrap_large_scan_program(program: Program) -> Program {
Program::wrapped(
program.buffers().to_vec(),
program.workgroup_size(),
vec![crate::region::wrap_anonymous(
OP_ID,
program.entry().to_vec(),
)],
)
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || scan_prefix_sum("input", "output", 4),
test_inputs: Some(|| vec![vec![
vyre_primitives::wire::pack_u32_slice(&[1u32, 2, 3, 4]),
]]),
expected_output: Some(|| vec![vec![
vyre_primitives::wire::pack_u32_slice(&[1u32, 3, 6, 10]),
]]),
category: Some("math"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::{bytes_to_u32 as decode_u32_words, u32_bytes};
use vyre::ir::{Expr, Node};
use vyre_reference::value::Value;
fn run_scan(n: u32, input: &[u32]) -> Vec<u32> {
let program = scan_prefix_sum("input", "output", n);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(u32_bytes(input)),
Value::from(vec![0u8; (n as usize).saturating_mul(4)]),
],
)
.expect("Fix: prefix sum must execute");
decode_u32_words(&outputs[0].to_bytes())
}
#[test]
fn prefix_sum_single_element() {
let input = [42u32];
let actual = run_scan(1, &input);
assert_eq!(actual, vec![42u32]);
}
#[test]
fn prefix_sum_empty_n_zero_should_trap() {
let program = scan_prefix_sum("input", "output", 0);
let error = vyre_reference::reference_eval(
&program,
&[Value::from(vec![0u8; 0]), Value::from(vec![0u8; 0])],
)
.expect_err("n=0 prefix_sum must trap instead of returning empty");
let msg = error.to_string();
assert!(
msg.contains("trap") || msg.contains("Fix:"),
"n=0 prefix_sum error must be actionable: {msg}"
);
}
#[test]
fn prefix_sum_boundary_small_path() {
let input: Vec<u32> = (1..=1024).collect();
let actual = run_scan(1024, &input);
let expected: Vec<u32> = input
.iter()
.scan(0u32, |acc, &x| {
*acc = acc.wrapping_add(x);
Some(*acc)
})
.collect();
assert_eq!(actual, expected);
}
#[test]
fn prefix_sum_boundary_large_path_is_parallel_multi_block() {
let program = scan_prefix_sum("input", "output", 1025);
assert_top_region_generator(&program, OP_ID);
assert_eq!(program.workgroup_size(), [1024, 1, 1]);
assert!(
!contains_loop(&program),
"large scan_prefix_sum must not route through prefix_scan_large's serial loop"
);
assert!(
!contains_invocation_zero_gate(&program),
"large scan_prefix_sum must not gate useful work behind InvocationId.x == 0"
);
assert!(program
.buffers()
.iter()
.any(|buffer| buffer.name() == "output" && buffer.is_output()));
}
#[test]
fn prefix_sum_large_path_parallel_shape_sweep() {
for n in 1025..=4097 {
let program = scan_prefix_sum("input", "output", n);
assert_top_region_generator(&program, OP_ID);
assert_eq!(program.workgroup_size(), [1024, 1, 1], "n={n}");
assert!(
!contains_loop(&program),
"n={n}: large scan_prefix_sum must not emit a serial loop"
);
assert!(
!contains_invocation_zero_gate(&program),
"n={n}: large scan_prefix_sum must not gate useful work behind InvocationId.x == 0"
);
assert!(
program
.buffers()
.iter()
.any(|buffer| buffer.name() == "output"
&& buffer.is_output()
&& buffer.count() == n),
"n={n}: final output buffer must be declared with the requested element count"
);
}
}
#[test]
fn prefix_sum_overflow_wraps() {
let input = [u32::MAX, 1u32, 1u32];
let actual = run_scan(3, &input);
assert_eq!(actual[0], u32::MAX);
assert_eq!(actual[1], 0u32, "u32::MAX + 1 must wrap to 0");
assert_eq!(actual[2], 1u32, "0 + 1 must be 1");
}
fn assert_top_region_generator(program: &Program, expected: &str) {
match program.entry() {
[Node::Region { generator, .. }] => assert_eq!(generator.as_str(), expected),
other => panic!("expected single top-level Region, got {other:?}"),
}
}
fn contains_loop(program: &Program) -> bool {
program.entry().iter().any(node_contains_loop)
}
fn node_contains_loop(node: &Node) -> bool {
match node {
Node::Loop { .. } => true,
Node::Block(children) => children.iter().any(node_contains_loop),
Node::If {
then, otherwise, ..
} => then.iter().any(node_contains_loop) || otherwise.iter().any(node_contains_loop),
Node::Region { body, .. } => body.iter().any(node_contains_loop),
_ => false,
}
}
fn contains_invocation_zero_gate(program: &Program) -> bool {
program
.entry()
.iter()
.any(node_contains_invocation_zero_gate)
}
fn node_contains_invocation_zero_gate(node: &Node) -> bool {
match node {
Node::If {
cond,
then,
otherwise,
} => {
expr_is_invocation_zero(cond)
|| then.iter().any(node_contains_invocation_zero_gate)
|| otherwise.iter().any(node_contains_invocation_zero_gate)
}
Node::Block(children) => children.iter().any(node_contains_invocation_zero_gate),
Node::Loop { body, .. } => body.iter().any(node_contains_invocation_zero_gate),
Node::Region { body, .. } => body.iter().any(node_contains_invocation_zero_gate),
_ => false,
}
}
fn expr_is_invocation_zero(expr: &Expr) -> bool {
match expr {
Expr::BinOp { op, left, right } if *op == vyre::ir::BinOp::Eq => {
matches!(
(&**left, &**right),
(Expr::InvocationId { axis: 0 }, Expr::LitU32(0))
| (Expr::LitU32(0), Expr::InvocationId { axis: 0 })
)
}
Expr::BinOp { left, right, .. } => {
expr_is_invocation_zero(left) || expr_is_invocation_zero(right)
}
Expr::UnOp { operand, .. } => expr_is_invocation_zero(operand),
Expr::Load { index, .. } => expr_is_invocation_zero(index),
Expr::Select {
cond,
true_val,
false_val,
} => {
expr_is_invocation_zero(cond)
|| expr_is_invocation_zero(true_val)
|| expr_is_invocation_zero(false_val)
}
Expr::Atomic {
index,
expected,
value,
..
} => {
expr_is_invocation_zero(index)
|| expected
.as_ref()
.is_some_and(|expr| expr_is_invocation_zero(expr))
|| expr_is_invocation_zero(value)
}
Expr::Cast { value, .. } => expr_is_invocation_zero(value),
Expr::Call { args, .. } => args.iter().any(expr_is_invocation_zero),
Expr::Fma { a, b, c } => {
expr_is_invocation_zero(a)
|| expr_is_invocation_zero(b)
|| expr_is_invocation_zero(c)
}
_ => false,
}
}
}