use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::{AlgebraicLaw, OpSpec, U32_INPUTS, U32_OUTPUTS};
pub const WORKGROUP_SIZE: u32 = 256;
pub const LAWS: &[AlgebraicLaw] = &[
AlgebraicLaw::Associative,
AlgebraicLaw::Identity { element: 0 },
];
#[derive(Debug, Clone, Copy, Default)]
pub struct PrefixSumInclusiveU32;
impl PrefixSumInclusiveU32 {
pub const SPEC: OpSpec = OpSpec::composition(
"scan.prefix_sum_inclusive",
U32_INPUTS,
U32_OUTPUTS,
LAWS,
Self::program,
);
#[must_use]
pub fn program() -> Program {
let mut entry = Vec::with_capacity(64);
entry.push(Node::let_bind("lid", Expr::local_x()));
entry.push(Node::let_bind("n", Expr::buf_len("input")));
entry.push(Node::let_bind(
"original",
Expr::select(
Expr::lt(Expr::var("lid"), Expr::var("n")),
Expr::load("input", Expr::var("lid")),
Expr::u32(0),
),
));
entry.push(Node::store("shared", Expr::var("lid"), Expr::var("original")));
entry.push(Node::barrier());
for stride in [1u32, 2, 4, 8, 16, 32, 64, 128] {
entry.push(Node::if_then(
Expr::eq(
Expr::rem(Expr::add(Expr::var("lid"), Expr::u32(1)), Expr::u32(stride * 2)),
Expr::u32(0),
),
vec![Node::store(
"shared",
Expr::var("lid"),
Expr::add(
Expr::load("shared", Expr::var("lid")),
Expr::load("shared", Expr::sub(Expr::var("lid"), Expr::u32(stride))),
),
)],
));
entry.push(Node::barrier());
}
entry.push(Node::if_then(
Expr::eq(Expr::var("lid"), Expr::u32(WORKGROUP_SIZE - 1)),
vec![Node::store("shared", Expr::var("lid"), Expr::u32(0))],
));
entry.push(Node::barrier());
let downsweep_pairs: [(u32, u32); 8] = [
(128, 1),
(64, 2),
(32, 4),
(16, 8),
(8, 16),
(4, 32),
(2, 64),
(1, 128),
];
for (offset, d) in downsweep_pairs {
entry.push(Node::if_then(
Expr::lt(Expr::var("lid"), Expr::u32(d)),
vec![
Node::let_bind(
"ai",
Expr::sub(
Expr::mul(
Expr::u32(offset),
Expr::add(Expr::mul(Expr::u32(2), Expr::var("lid")), Expr::u32(1)),
),
Expr::u32(1),
),
),
Node::let_bind(
"bi",
Expr::sub(
Expr::mul(
Expr::u32(offset),
Expr::add(Expr::mul(Expr::u32(2), Expr::var("lid")), Expr::u32(2)),
),
Expr::u32(1),
),
),
Node::let_bind("t", Expr::load("shared", Expr::var("ai"))),
Node::store("shared", Expr::var("ai"), Expr::load("shared", Expr::var("bi"))),
Node::store(
"shared",
Expr::var("bi"),
Expr::add(Expr::load("shared", Expr::var("bi")), Expr::var("t")),
),
],
));
entry.push(Node::barrier());
}
entry.push(Node::let_bind("exclusive", Expr::load("shared", Expr::var("lid"))));
entry.push(Node::if_then(
Expr::lt(Expr::var("lid"), Expr::var("n")),
vec![Node::store(
"out",
Expr::var("lid"),
Expr::add(Expr::var("exclusive"), Expr::var("original")),
)],
));
Program::new(
vec![
BufferDecl::read("input", 0, DataType::U32),
BufferDecl::output("out", 1, DataType::U32),
BufferDecl::workgroup("shared", WORKGROUP_SIZE, DataType::U32),
],
[WORKGROUP_SIZE, 1, 1],
entry,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir;
use crate::lower::wgsl;
#[test]
pub(crate) fn spec_builds_non_empty_program() {
let program = PrefixSumInclusiveU32::program();
assert!(!program.entry().is_empty());
assert_eq!(program.workgroup_size(), [WORKGROUP_SIZE, 1, 1]);
}
#[test]
pub(crate) fn program_validates_cleanly() {
let program = PrefixSumInclusiveU32::program();
let errors = ir::validate(&program);
assert!(errors.is_empty(), "validation failed: {errors:?}");
}
#[test]
pub(crate) fn program_lowers_to_wgsl() {
let program = PrefixSumInclusiveU32::program();
let wgsl = wgsl::lower_anonymous(&program).expect("WGSL lowering must succeed");
assert!(wgsl.contains("workgroupBarrier"));
assert!(wgsl.contains("var<workgroup> shared"));
}
}