use vyre_foundation::ir::{Node, Program};
use vyre_primitives::reduce::{
multi_block_prefix_scan::{
multi_block_prefix_scan_sum_u32, pass_a_local_scan, pass_c_broadcast_offsets, BLOCK_LANES,
},
radix_sort::radix_sort,
range_counts::{range_counts_u32, range_counts_u32_body, range_counts_u32_child},
workgroup_any::{
workgroup_any_u32, workgroup_any_u32_body, workgroup_any_u32_body_prefixed,
workgroup_any_u32_child, workgroup_any_u32_child_prefixed,
},
workgroup_tree::{
max_f32_child, sum_f32_child, sum_u32_child, workgroup_max_f32, workgroup_sum_f32,
workgroup_sum_u32, WorkgroupReductionScope,
},
};
#[cfg(any(test, feature = "cpu-parity"))]
use vyre_primitives::reduce::{
multi_block_prefix_scan::cpu_ref as primitive_prefix_sum,
radix_sort::cpu_ref as primitive_radix_sort, range_counts::cpu_ref as primitive_range_count,
workgroup_any::cpu_ref as primitive_workgroup_any,
};
#[must_use]
pub fn dispatch_workgroup_sum_f32(values: &str, out: &str, count: u32, tile: u32) -> Program {
workgroup_sum_f32(values, out, count, tile)
}
#[must_use]
pub fn dispatch_workgroup_sum_u32(values: &str, out: &str, count: u32, tile: u32) -> Program {
workgroup_sum_u32(values, out, count, tile)
}
#[must_use]
pub fn dispatch_workgroup_max_f32(values: &str, out: &str, count: u32, tile: u32) -> Program {
workgroup_max_f32(values, out, count, tile)
}
#[must_use]
pub fn child_sum_f32_stage(parent_op_id: &str, tile: u32, scratch: &'static str) -> Node {
sum_f32_child(
parent_op_id,
tile,
scratch,
WorkgroupReductionScope::EveryWorkgroup,
)
}
#[must_use]
pub fn child_sum_u32_stage(parent_op_id: &str, tile: u32, scratch: &'static str) -> Node {
sum_u32_child(
parent_op_id,
tile,
scratch,
WorkgroupReductionScope::EveryWorkgroup,
)
}
#[must_use]
pub fn child_max_f32_stage(parent_op_id: &str, tile: u32, scratch: &'static str) -> Node {
max_f32_child(
parent_op_id,
tile,
scratch,
WorkgroupReductionScope::EveryWorkgroup,
)
}
#[must_use]
pub fn range_count_accumulator_body(
histogram: &str,
out_var: &str,
start: u32,
end: u32,
) -> Vec<Node> {
range_counts_u32_body(histogram, out_var, start, end)
}
#[must_use]
pub fn child_range_count_stage(
parent_op_id: &str,
histogram: &str,
out_var: &str,
start: u32,
end: u32,
) -> Node {
range_counts_u32_child(parent_op_id, histogram, out_var, start, end)
}
#[must_use]
pub fn dispatch_range_count_u32(histogram: &str, out: &str, start: u32, end: u32) -> Program {
range_counts_u32(histogram, out, start, end)
}
#[must_use]
pub fn any_accumulator_body(values: &str, out_var: &str, count: u32) -> Vec<Node> {
workgroup_any_u32_body(values, out_var, count)
}
#[must_use]
pub fn prefixed_any_accumulator_body(
values: &str,
out_var: &str,
count: u32,
iter_var: &str,
) -> Vec<Node> {
workgroup_any_u32_body_prefixed(values, out_var, count, iter_var)
}
#[must_use]
pub fn child_any_stage(parent_op_id: &str, values: &str, out_var: &str, count: u32) -> Node {
workgroup_any_u32_child(parent_op_id, values, out_var, count)
}
#[must_use]
pub fn prefixed_child_any_stage(
parent_op_id: &str,
values: &str,
out_var: &str,
count: u32,
iter_var: &str,
) -> Node {
workgroup_any_u32_child_prefixed(parent_op_id, values, out_var, count, iter_var)
}
#[must_use]
pub fn dispatch_workgroup_any_u32(values: &str, out: &str, count: u32) -> Program {
workgroup_any_u32(values, out, count)
}
#[must_use]
pub fn dispatch_multi_block_prefix_sum(input: &str, output: &str, n: u32) -> Program {
multi_block_prefix_scan_sum_u32(input, output, n)
}
#[must_use]
pub fn prefix_pass_a(input: &str, partials: &str, block_totals: &str, n: u32) -> Program {
let num_blocks = n.div_ceil(BLOCK_LANES).max(1);
pass_a_local_scan(input, partials, block_totals, n, num_blocks)
}
#[must_use]
pub fn prefix_pass_c(partials: &str, block_totals_scanned: &str, output: &str, n: u32) -> Program {
let num_blocks = n.div_ceil(BLOCK_LANES).max(1);
pass_c_broadcast_offsets(partials, block_totals_scanned, output, n, num_blocks)
}
#[must_use]
pub fn dispatch_radix_sort(input: &str, output: &str, count: u32, bits: u32) -> Program {
radix_sort(input, output, count, bits)
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn reference_range_count(histogram: &[u32], start: u32, end: u32) -> u32 {
primitive_range_count(histogram, start, end)
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn reference_workgroup_any(values: &[u32]) -> u32 {
primitive_workgroup_any(values)
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn reference_prefix_sum(values: &[u32]) -> Vec<u32> {
primitive_prefix_sum(values)
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn reference_radix_sort(values: &[u32], bits: u32) -> Vec<u32> {
primitive_radix_sort(values, bits)
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn reference_sum_f32(values: &[f32]) -> f32 {
values.iter().copied().sum()
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn reference_sum_u32(values: &[u32]) -> u32 {
values.iter().copied().sum()
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn reference_max_f32(values: &[f32]) -> f32 {
values.iter().copied().fold(f32::MIN, f32::max)
}
#[cfg(test)]
mod tests {
use super::*;
fn region_generator(node: &Node) -> &str {
let Node::Region { generator, .. } = node else {
panic!("Fix: reduction helper must emit a Region child.");
};
generator.as_str()
}
fn program_generator(program: &Program) -> &str {
let Some(Node::Region { generator, .. }) = program.entry.first() else {
panic!("Fix: reduction Program must start with a Region.");
};
generator.as_str()
}
#[test]
fn program_builders_emit_expected_reduce_primitives() {
assert_eq!(
program_generator(&dispatch_workgroup_sum_f32("values", "out", 8, 4)),
"vyre-primitives::reduce::workgroup_sum_f32"
);
assert_eq!(
program_generator(&dispatch_workgroup_sum_u32("values", "out", 8, 4)),
"vyre-primitives::reduce::workgroup_sum_u32"
);
assert_eq!(
program_generator(&dispatch_workgroup_max_f32("values", "out", 8, 4)),
"vyre-primitives::reduce::workgroup_max_f32"
);
assert_eq!(
program_generator(&dispatch_range_count_u32("histogram", "out", 2, 5)),
"vyre-primitives::reduce::range_counts_u32"
);
assert_eq!(
program_generator(&dispatch_workgroup_any_u32("values", "out", 4)),
"vyre-primitives::reduce::workgroup_any_u32"
);
assert_eq!(
program_generator(&dispatch_radix_sort("keys", "sorted", 8, 16)),
"vyre-primitives::reduce::radix_sort"
);
}
#[test]
fn child_builders_keep_parent_region_context() {
let parent = "vyre-self-substrate::data::reduce_dispatch_pipeline";
assert_eq!(
region_generator(&child_sum_f32_stage(parent, 8, "scratch")),
"vyre-primitives::reduce::workgroup_sum_f32"
);
assert_eq!(
region_generator(&child_sum_u32_stage(parent, 8, "scratch")),
"vyre-primitives::reduce::workgroup_sum_u32"
);
assert_eq!(
region_generator(&child_max_f32_stage(parent, 8, "scratch")),
"vyre-primitives::reduce::workgroup_max_f32"
);
assert_eq!(
region_generator(&child_range_count_stage(parent, "hist", "sum", 1, 4)),
"vyre-primitives::reduce::range_counts_u32"
);
assert_eq!(
region_generator(&child_any_stage(parent, "changed", "any", 8)),
"vyre-primitives::reduce::workgroup_any_u32"
);
assert_eq!(
region_generator(&prefixed_child_any_stage(
parent, "changed", "any", 8, "any_i"
)),
"vyre-primitives::reduce::workgroup_any_u32"
);
}
#[test]
fn body_builders_emit_non_empty_composable_ir() {
assert_ne!(range_count_accumulator_body("hist", "sum", 0, 8).len(), 0);
assert_ne!(any_accumulator_body("changed", "any", 8).len(), 0);
assert_ne!(
prefixed_any_accumulator_body("changed", "any", 8, "changed_i").len(),
0
);
}
#[test]
fn prefix_pipeline_exposes_fused_and_resident_passes() {
let small = dispatch_multi_block_prefix_sum("input", "output", 64);
assert!(!small.buffers.is_empty());
let large = dispatch_multi_block_prefix_sum("input", "output", BLOCK_LANES + 17);
assert!(!large.buffers.is_empty());
let pass_a = prefix_pass_a("input", "partials", "totals", BLOCK_LANES + 1);
assert_eq!(
program_generator(&pass_a),
"vyre-primitives::reduce::multi_block_prefix_scan::pass_a"
);
let pass_c = prefix_pass_c("partials", "totals_scanned", "output", BLOCK_LANES + 1);
assert_eq!(
program_generator(&pass_c),
"vyre-primitives::reduce::multi_block_prefix_scan::pass_c"
);
}
#[test]
fn cpu_reference_wrappers_match_primitive_contracts() {
assert_eq!(reference_sum_f32(&[1.25, -2.0, 5.5]), 4.75);
assert_eq!(reference_sum_u32(&[1, 2, 3, 4]), 10);
assert_eq!(reference_max_f32(&[-7.0, 3.5, 2.0]), 3.5);
assert_eq!(reference_range_count(&[9, 2, 3, 5, 11], 1, 4), 10);
assert_eq!(reference_workgroup_any(&[0, 2, 4, 0]), 6);
assert_eq!(reference_prefix_sum(&[1, 2, 3, 4]), vec![1, 3, 6, 10]);
assert_eq!(reference_radix_sort(&[3, 1, 4, 2], 32), vec![1, 2, 3, 4]);
}
}