use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use vyre_foundation::MemoryOrdering;
use crate::math::prefix_scan::{prefix_scan_with_op_id, ScanKind};
pub const OP_ID_INCLUSIVE_SUM: &str =
"vyre-primitives::reduce::multi_block_prefix_scan_inclusive_sum";
pub const BLOCK_LANES: u32 = 1024;
pub const SOFT_MAX_N: u32 = BLOCK_LANES * BLOCK_LANES;
fn output_byte_range(words: u32, context: &str) -> usize {
usize::try_from(words)
.ok()
.and_then(|count| count.checked_mul(4))
.unwrap_or_else(|| {
panic!(
"{context} words={words} overflows output byte range. Fix: shard the scan before GPU dispatch."
)
})
}
#[must_use]
pub fn multi_block_prefix_scan_sum_u32(input: &str, output: &str, n: u32) -> Program {
if n == 0 {
return Program::empty();
}
if n <= BLOCK_LANES {
return prefix_scan_with_op_id(
input,
output,
n,
ScanKind::InclusiveSum,
OP_ID_INCLUSIVE_SUM,
);
}
let num_blocks = n.div_ceil(BLOCK_LANES);
let partials = format!("__{output}_mbps_partials");
let block_totals = format!("__{output}_mbps_block_totals");
let block_totals_scanned = format!("__{output}_mbps_block_totals_scanned");
let pass_a = pass_a_local_scan(input, &partials, &block_totals, n, num_blocks);
let pass_b = multi_block_prefix_scan_sum_u32(&block_totals, &block_totals_scanned, num_blocks);
let pass_c = pass_c_broadcast_offsets(&partials, &block_totals_scanned, output, n, num_blocks);
match vyre_foundation::execution_plan::fusion::fuse_programs(&[pass_a, pass_b, pass_c]) {
Ok(prog) => prog,
Err(_) => panic!(
"vyre multi_block_prefix_scan fusion failed for n={n}, num_blocks={num_blocks}. Fix: repair grid-sync fusion for the three-pass GPU scan; do not substitute an empty Program."
),
}
}
#[must_use]
pub fn pass_a_local_scan(
input: &str,
partials: &str,
block_totals: &str,
n: u32,
num_blocks: u32,
) -> Program {
let lane = Expr::var("lane");
let block = Expr::var("block");
let global = Expr::var("global");
let scratch_a = format!("__{partials}_pass_a_scratch_a");
let scratch_b = format!("__{partials}_pass_a_scratch_b");
let mut body: Vec<Node> = Vec::new();
body.push(Node::let_bind("lane", Expr::LocalId { axis: 0 }));
body.push(Node::let_bind("block", Expr::WorkgroupId { axis: 0 }));
body.push(Node::let_bind(
"global",
Expr::add(
Expr::mul(block.clone(), Expr::u32(BLOCK_LANES)),
lane.clone(),
),
));
body.push(Node::store(&scratch_a, lane.clone(), Expr::u32(0)));
body.push(Node::if_then(
Expr::lt(global.clone(), Expr::u32(n)),
vec![Node::store(
&scratch_a,
lane.clone(),
Expr::load(input, global.clone()),
)],
));
body.push(Node::Barrier {
ordering: MemoryOrdering::SeqCst,
});
let mut stride = 1_u32;
while stride < BLOCK_LANES {
body.push(Node::store(
&scratch_b,
lane.clone(),
Expr::load(&scratch_a, lane.clone()),
));
let previous_lane = Expr::add(lane.clone(), Expr::u32(0u32.wrapping_sub(stride)));
body.push(Node::if_then(
Expr::lt(Expr::u32(stride - 1), lane.clone()),
vec![Node::store(
&scratch_b,
lane.clone(),
Expr::add(
Expr::load(&scratch_a, lane.clone()),
Expr::load(&scratch_a, previous_lane),
),
)],
));
body.push(Node::Barrier {
ordering: MemoryOrdering::SeqCst,
});
body.push(Node::store(
&scratch_a,
lane.clone(),
Expr::load(&scratch_b, lane.clone()),
));
body.push(Node::Barrier {
ordering: MemoryOrdering::SeqCst,
});
stride *= 2;
}
body.push(Node::if_then(
Expr::lt(global.clone(), Expr::u32(n)),
vec![Node::store(
partials,
global.clone(),
Expr::load(&scratch_a, lane.clone()),
)],
));
body.push(Node::if_then(
Expr::eq(lane.clone(), Expr::u32(BLOCK_LANES - 1)),
vec![Node::store(
block_totals,
block.clone(),
Expr::load(&scratch_a, lane.clone()),
)],
));
let total_partials = num_blocks.checked_mul(BLOCK_LANES).unwrap_or_else(|| {
panic!(
"vyre multi_block_prefix_scan Pass A num_blocks={num_blocks} overflows partial buffer count. Fix: shard the scan before GPU dispatch."
)
});
let total_partial_bytes = output_byte_range(
total_partials,
"vyre multi_block_prefix_scan Pass A partials",
);
let block_total_bytes = output_byte_range(
num_blocks,
"vyre multi_block_prefix_scan Pass A block_totals",
);
let buffers = vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::U32).with_count(n),
BufferDecl::output(partials, 1, DataType::U32)
.with_count(total_partials)
.with_output_byte_range(0..total_partial_bytes),
BufferDecl::storage(block_totals, 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(num_blocks)
.with_pipeline_live_out(true)
.with_output_byte_range(0..block_total_bytes),
BufferDecl::workgroup(&scratch_a, BLOCK_LANES, DataType::U32),
BufferDecl::workgroup(&scratch_b, BLOCK_LANES, DataType::U32),
];
Program::wrapped(
buffers,
[BLOCK_LANES, 1, 1],
vec![Node::Region {
generator: Ident::from("vyre-primitives::reduce::multi_block_prefix_scan::pass_a"),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn pass_c_broadcast_offsets(
partials: &str,
block_totals_scanned: &str,
output: &str,
n: u32,
num_blocks: u32,
) -> Program {
let lane = Expr::var("lane");
let block = Expr::var("block");
let global = Expr::var("global");
let offset = Expr::var("offset");
let body = vec![
Node::let_bind("lane", Expr::LocalId { axis: 0 }),
Node::let_bind("block", Expr::WorkgroupId { axis: 0 }),
Node::let_bind(
"global",
Expr::add(
Expr::mul(block.clone(), Expr::u32(BLOCK_LANES)),
lane.clone(),
),
),
Node::let_bind("offset", Expr::u32(0)),
Node::if_then(
Expr::lt(Expr::u32(0), block.clone()),
vec![Node::assign(
"offset",
Expr::load(
block_totals_scanned,
Expr::add(block.clone(), Expr::u32(0u32.wrapping_sub(1))),
),
)],
),
Node::if_then(
Expr::lt(global.clone(), Expr::u32(n)),
vec![Node::store(
output,
global.clone(),
Expr::add(Expr::load(partials, global.clone()), offset),
)],
),
];
let total_partials = num_blocks.checked_mul(BLOCK_LANES).unwrap_or_else(|| {
panic!(
"vyre multi_block_prefix_scan Pass C num_blocks={num_blocks} overflows partial buffer count. Fix: shard the scan before GPU dispatch."
)
});
let output_bytes = output_byte_range(n, "vyre multi_block_prefix_scan Pass C output");
let buffers = vec![
BufferDecl::storage(partials, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(total_partials),
BufferDecl::storage(
block_totals_scanned,
1,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(num_blocks),
BufferDecl::output(output, 2, DataType::U32)
.with_count(n)
.with_output_byte_range(0..output_bytes),
];
Program::wrapped(
buffers,
[BLOCK_LANES, 1, 1],
vec![Node::Region {
generator: Ident::from("vyre-primitives::reduce::multi_block_prefix_scan::pass_c"),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn cpu_ref(input: &[u32]) -> Vec<u32> {
let mut out = Vec::new();
match try_cpu_ref_into(input, &mut out) {
Ok(()) => out,
Err(error) => {
eprintln!("vyre-primitives multi-block prefix-scan CPU reference failed: {error}");
Vec::new()
}
}
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn cpu_ref_into(input: &[u32], out: &mut Vec<u32>) {
if let Err(error) = try_cpu_ref_into(input, out) {
eprintln!("vyre-primitives multi-block prefix-scan CPU reference failed: {error}");
out.clear();
}
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn try_cpu_ref_into(input: &[u32], out: &mut Vec<u32>) -> Result<(), String> {
if input.len() > out.capacity() {
out.try_reserve_exact(input.len() - out.capacity())
.map_err(|err| {
format!(
"multi-block prefix-scan CPU reference could not reserve {} output words: {err}",
input.len()
)
})?;
}
out.clear();
let mut acc: u32 = 0;
for &x in input {
acc = acc.wrapping_add(x);
out.push(acc);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_ref_matches_simple_inclusive_sum() {
assert_eq!(cpu_ref(&[1, 2, 3, 4]), vec![1, 3, 6, 10]);
assert_eq!(cpu_ref(&[]), Vec::<u32>::new());
assert_eq!(cpu_ref(&[7]), vec![7]);
}
#[test]
fn cpu_ref_into_reuses_output_and_truncates_stale_tail() {
let mut out = Vec::with_capacity(8);
out.extend_from_slice(&[99, 98, 97, 96]);
let capacity = out.capacity();
cpu_ref_into(&[u32::MAX, 1, 2], &mut out);
assert_eq!(out, vec![u32::MAX, 0, 2]);
assert_eq!(out.capacity(), capacity);
cpu_ref_into(&[7], &mut out);
assert_eq!(out, vec![7]);
assert_eq!(out.capacity(), capacity);
}
#[test]
fn try_cpu_ref_into_reuses_output_and_clears_stale_tail() {
let mut out = Vec::with_capacity(8);
out.extend_from_slice(&[99, 98, 97, 96]);
let ptr = out.as_ptr();
try_cpu_ref_into(&[u32::MAX, 1, 2], &mut out).unwrap();
assert_eq!(out, vec![u32::MAX, 0, 2]);
assert_eq!(out.as_ptr(), ptr);
}
#[test]
fn compatibility_wrappers_match_fallible_reference() {
let input = &[u32::MAX, 1, 2];
let mut compat = Vec::with_capacity(8);
let mut fallible = Vec::with_capacity(8);
cpu_ref_into(input, &mut compat);
try_cpu_ref_into(input, &mut fallible)
.expect("Fix: small multi-block prefix-scan CPU reference must reserve");
assert_eq!(cpu_ref(input), fallible);
assert_eq!(compat, fallible);
}
#[test]
fn production_wrappers_have_no_raw_panic_path() {
let production = include_str!("multi_block_prefix_scan.rs")
.split("#[cfg(test)]")
.next()
.expect("Fix: multi_block_prefix_scan.rs must contain production section");
assert!(
!production.contains(".expect(") && !production.contains(".unwrap("),
"Fix: multi-block prefix-scan CPU reference wrappers must not panic in production."
);
}
#[test]
fn small_n_falls_through_to_single_block_path() {
for &n in &[1u32, 2, 64, 1023, 1024] {
let prog = multi_block_prefix_scan_sum_u32("in_buf", "out_buf", n);
let names: Vec<&str> = prog.buffers().iter().map(BufferDecl::name).collect();
assert!(names.contains(&"in_buf"), "n={n} must declare in_buf, got {names:?}");
assert!(names.contains(&"out_buf"), "n={n} must declare out_buf, got {names:?}");
}
}
#[test]
fn large_n_emits_three_pass_chain() {
let prog = multi_block_prefix_scan_sum_u32("in_buf", "out_buf", 2 * BLOCK_LANES);
let names: Vec<&str> = prog.buffers().iter().map(BufferDecl::name).collect();
assert!(
names.contains(&"in_buf"),
"input must be declared, got {names:?}"
);
assert!(
names.contains(&"out_buf"),
"output must be declared, got {names:?}"
);
}
#[test]
fn empty_input_returns_empty_program() {
let prog = multi_block_prefix_scan_sum_u32("in_buf", "out_buf", 0);
assert!(prog.buffers().is_empty());
}
#[test]
fn recursion_handles_million_elements() {
let prog = multi_block_prefix_scan_sum_u32("in_buf", "out_buf", SOFT_MAX_N);
let names: Vec<&str> = prog.buffers().iter().map(BufferDecl::name).collect();
assert!(names.contains(&"in_buf"));
assert!(names.contains(&"out_buf"));
}
}