use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID_INCLUSIVE_SUM: &str = "vyre-primitives::math::prefix_scan_inclusive_sum";
pub const OP_ID_EXCLUSIVE_SUM: &str = "vyre-primitives::math::prefix_scan_exclusive_sum";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScanKind {
InclusiveSum,
ExclusiveSum,
}
#[must_use]
pub fn prefix_scan(in_buf: &str, out_buf: &str, n: u32, kind: ScanKind) -> Program {
let op_id = match kind {
ScanKind::InclusiveSum => OP_ID_INCLUSIVE_SUM,
ScanKind::ExclusiveSum => OP_ID_EXCLUSIVE_SUM,
};
prefix_scan_with_op_id(in_buf, out_buf, n, kind, op_id)
}
#[must_use]
pub fn prefix_scan_with_op_id(
in_buf: &str,
out_buf: &str,
n: u32,
kind: ScanKind,
op_id: &'static str,
) -> Program {
if n == 0 || n > 1024 {
return crate::invalid_output_program(
op_id,
out_buf,
DataType::U32,
format!("Fix: prefix_scan requires n in 1..=1024, got {n}."),
);
}
let lanes = n.next_power_of_two();
let lane = Expr::InvocationId { axis: 0 };
let scratch_a = format!("__{out_buf}_scan_a");
let scratch_b = format!("__{out_buf}_scan_b");
let mut body: Vec<Node> = Vec::new();
body.push(Node::store(&scratch_a, lane.clone(), Expr::u32(0)));
match kind {
ScanKind::InclusiveSum => body.push(Node::if_then(
Expr::lt(lane.clone(), Expr::u32(n)),
vec![Node::store(
&scratch_a,
lane.clone(),
Expr::load(in_buf, lane.clone()),
)],
)),
ScanKind::ExclusiveSum => body.push(Node::if_then(
Expr::and(
Expr::lt(Expr::u32(0), lane.clone()),
Expr::lt(lane.clone(), Expr::u32(n)),
),
vec![Node::store(
&scratch_a,
lane.clone(),
Expr::load(in_buf, Expr::add(lane.clone(), Expr::u32(u32::MAX))),
)],
)),
}
body.push(Node::Barrier {
ordering: vyre_foundation::MemoryOrdering::SeqCst,
});
let mut stride = 1_u32;
while stride < lanes {
let previous_lane = Expr::add(lane.clone(), Expr::u32(u32::MAX.wrapping_sub(stride - 1)));
body.push(Node::store(
&scratch_b,
lane.clone(),
Expr::load(&scratch_a, lane.clone()),
));
body.push(Node::if_then(
Expr::lt(Expr::u32(stride.saturating_sub(1)), lane.clone()),
vec![Node::store(
&scratch_b,
lane.clone(),
Expr::add(
Expr::load(&scratch_a, lane.clone()),
Expr::load(&scratch_a, previous_lane),
),
)],
));
body.push(Node::Barrier {
ordering: vyre_foundation::MemoryOrdering::SeqCst,
});
body.push(Node::store(
&scratch_a,
lane.clone(),
Expr::load(&scratch_b, lane.clone()),
));
body.push(Node::Barrier {
ordering: vyre_foundation::MemoryOrdering::SeqCst,
});
stride *= 2;
}
body.push(Node::if_then(
Expr::lt(lane.clone(), Expr::u32(n)),
vec![Node::store(
out_buf,
lane.clone(),
Expr::load(&scratch_a, lane.clone()),
)],
));
let buffers = vec![
BufferDecl::storage(in_buf, 0, BufferAccess::ReadOnly, DataType::U32).with_count(n),
BufferDecl::output(out_buf, 1, DataType::U32)
.with_count(n)
.with_output_byte_range(0..(n as usize) * 4),
BufferDecl::workgroup(&scratch_a, lanes, DataType::U32),
BufferDecl::workgroup(&scratch_b, lanes, DataType::U32),
];
Program::wrapped(
buffers,
[lanes, 1, 1],
vec![Node::Region {
generator: Ident::from(op_id),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn prefix_scan_large(in_buf: &str, out_buf: &str, n: u32) -> Program {
prefix_scan_large_with_op_id(in_buf, out_buf, n, OP_ID_INCLUSIVE_SUM)
}
#[must_use]
pub fn prefix_scan_large_with_op_id(
in_buf: &str,
out_buf: &str,
n: u32,
op_id: &'static str,
) -> Program {
let input_decl = if n == 0 {
BufferDecl::storage(in_buf, 0, BufferAccess::ReadOnly, DataType::U32)
} else {
BufferDecl::storage(in_buf, 0, BufferAccess::ReadOnly, DataType::U32).with_count(n)
};
let output_decl = BufferDecl::output(out_buf, 1, DataType::U32)
.with_count(n.max(1))
.with_output_byte_range(0..(n as usize).saturating_mul(4));
let body = if n == 0 {
Vec::new()
} else {
vec![Node::if_then(
Expr::eq(Expr::InvocationId { axis: 0 }, Expr::u32(0)),
vec![
Node::let_bind("acc", Expr::u32(0)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(n),
vec![
Node::assign(
"acc",
Expr::add(Expr::var("acc"), Expr::load(in_buf, Expr::var("i"))),
),
Node::store(out_buf, Expr::var("i"), Expr::var("acc")),
],
),
],
)]
};
Program::wrapped(
vec![input_decl, output_decl],
[1, 1, 1],
vec![Node::Region {
generator: Ident::from(op_id),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn cpu_ref(input: &[u32], kind: ScanKind) -> Vec<u32> {
let mut out = Vec::new();
cpu_ref_into(input, kind, &mut out);
out
}
pub fn cpu_ref_into(input: &[u32], kind: ScanKind, out: &mut Vec<u32>) {
out.clear();
out.reserve(input.len());
let mut acc = 0_u32;
match kind {
ScanKind::InclusiveSum => {
for &x in input {
acc = acc.wrapping_add(x);
out.push(acc);
}
}
ScanKind::ExclusiveSum => {
for &x in input {
out.push(acc);
acc = acc.wrapping_add(x);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn inclusive_cpu_ref_matches_textbook() {
assert_eq!(
cpu_ref(&[1, 2, 3, 4], ScanKind::InclusiveSum),
vec![1, 3, 6, 10],
);
}
#[test]
fn exclusive_cpu_ref_matches_textbook() {
assert_eq!(
cpu_ref(&[1, 2, 3, 4], ScanKind::ExclusiveSum),
vec![0, 1, 3, 6],
);
}
#[test]
fn empty_cpu_ref_returns_empty() {
assert_eq!(cpu_ref(&[], ScanKind::InclusiveSum), Vec::<u32>::new());
assert_eq!(cpu_ref(&[], ScanKind::ExclusiveSum), Vec::<u32>::new());
}
#[test]
fn wrap_on_overflow() {
assert_eq!(
cpu_ref(&[u32::MAX, 1], ScanKind::InclusiveSum),
vec![u32::MAX, 0],
);
}
#[test]
fn cpu_ref_into_reuses_output_buffer() {
let mut out = Vec::with_capacity(16);
let ptr = out.as_ptr();
cpu_ref_into(&[1, 2, 3, 4], ScanKind::ExclusiveSum, &mut out);
assert_eq!(out, vec![0, 1, 3, 6]);
assert_eq!(out.as_ptr(), ptr);
}
#[test]
fn emitted_inclusive_program_has_expected_buffers() {
let p = prefix_scan("in", "out", 32, ScanKind::InclusiveSum);
assert_eq!(p.workgroup_size, [32, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["in", "out", "__out_scan_a", "__out_scan_b"]);
}
#[test]
fn emitted_exclusive_program_has_expected_buffers() {
let p = prefix_scan("in", "out", 64, ScanKind::ExclusiveSum);
assert_eq!(p.workgroup_size, [64, 1, 1]);
}
#[test]
fn non_power_of_two_n_pads_to_next_power_of_two() {
let p = prefix_scan("in", "out", 5, ScanKind::InclusiveSum);
assert_eq!(p.workgroup_size, [8, 1, 1]);
}
#[test]
fn zero_n_traps() {
let p = prefix_scan("in", "out", 0, ScanKind::InclusiveSum);
assert!(p.stats().trap());
}
#[test]
fn over_limit_n_traps() {
let p = prefix_scan("in", "out", 2048, ScanKind::InclusiveSum);
assert!(p.stats().trap());
}
#[test]
fn binary_power_of_two_sizes_accepted() {
for n in &[1_u32, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
let _ = prefix_scan("in", "out", *n, ScanKind::InclusiveSum);
}
}
}