use crate::ir::{Expr, Node, Program};
use crate::optimizer::ctx::AdapterCaps;
use std::borrow::Cow;
use std::sync::Arc;
const WORKGROUP_SUM_PREFIX: &str = "vyre-primitives::reduce::workgroup_sum_";
const WORKGROUP_MAX_PREFIX: &str = "vyre-primitives::reduce::workgroup_max_";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ReductionScope {
EveryWorkgroup,
FirstWorkgroup,
}
#[must_use]
pub fn lower_subgroup_reductions(program: Program, caps: &AdapterCaps) -> Program {
if !caps.supports_subgroup_ops || caps.subgroup_size == 0 {
return program;
}
let workgroup_total = program.workgroup_size()[0]
.saturating_mul(program.workgroup_size()[1])
.saturating_mul(program.workgroup_size()[2]);
if workgroup_total > caps.subgroup_size {
return program;
}
match rewrite_nodes(program.entry()) {
Cow::Borrowed(_) => program,
Cow::Owned(entry) => program.with_rewritten_entry(entry),
}
}
fn rewrite_nodes(nodes: &[Node]) -> Cow<'_, [Node]> {
let mut rewritten: Option<Vec<Node>> = None;
for (index, node) in nodes.iter().enumerate() {
match rewrite_node(node) {
Cow::Borrowed(_) if rewritten.is_none() => {}
Cow::Borrowed(borrowed) => {
if let Some(out) = rewritten.as_mut() {
out.extend_from_slice(borrowed);
}
}
Cow::Owned(owned) => {
let out = rewritten.get_or_insert_with(|| nodes[..index].to_vec());
out.extend(owned);
}
}
}
rewritten.map_or(Cow::Borrowed(nodes), Cow::Owned)
}
fn rewrite_node(node: &Node) -> Cow<'_, [Node]> {
match node {
Node::Region {
generator,
source_region,
body,
} => {
let generator_name = generator.as_str();
if let Some(lowered) = try_lower_workgroup_reduction(generator_name, body) {
return Cow::Owned(vec![Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(lowered),
}]);
}
match rewrite_nodes(body) {
Cow::Borrowed(_) => Cow::Borrowed(std::slice::from_ref(node)),
Cow::Owned(new_body) => Cow::Owned(vec![Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(new_body),
}]),
}
}
Node::If {
cond,
then,
otherwise,
} => {
let t = rewrite_nodes(then);
let o = rewrite_nodes(otherwise);
if matches!((&t, &o), (Cow::Borrowed(_), Cow::Borrowed(_))) {
Cow::Borrowed(std::slice::from_ref(node))
} else {
Cow::Owned(vec![Node::if_then_else(
cond.clone(),
t.into_owned(),
o.into_owned(),
)])
}
}
Node::Loop {
var,
from,
to,
body,
} => {
let b = rewrite_nodes(body);
if matches!(b, Cow::Borrowed(_)) {
Cow::Borrowed(std::slice::from_ref(node))
} else {
Cow::Owned(vec![Node::loop_for(
var.clone(),
from.clone(),
to.clone(),
b.into_owned(),
)])
}
}
Node::Block(body) => match rewrite_nodes(body) {
Cow::Borrowed(_) => Cow::Borrowed(std::slice::from_ref(node)),
Cow::Owned(b) => Cow::Owned(vec![Node::block(b)]),
},
_ => Cow::Borrowed(std::slice::from_ref(node)),
}
}
fn try_lower_workgroup_reduction(generator: &str, body: &[Node]) -> Option<Vec<Node>> {
let scratch = extract_scratch_buffer(body)?;
let scope = detect_scope(body)?;
if generator.starts_with(WORKGROUP_SUM_PREFIX) {
Some(subgroup_sum_body(&scratch, scope))
} else if generator.starts_with(WORKGROUP_MAX_PREFIX) {
None
} else {
None
}
}
fn extract_scratch_buffer(body: &[Node]) -> Option<String> {
for node in body {
if let Node::Store { buffer, .. } = node {
return Some(buffer.as_str().to_owned());
}
if let Node::If { then, .. } = node {
for child in then {
if let Node::Store { buffer, .. } = child {
return Some(buffer.as_str().to_owned());
}
if let Node::If {
then: inner_then, ..
} = child
{
for inner in inner_then {
if let Node::Store { buffer, .. } = inner {
return Some(buffer.as_str().to_owned());
}
}
}
}
}
}
None
}
fn detect_scope(body: &[Node]) -> Option<ReductionScope> {
let first = body.first()?;
let Node::If { cond, .. } = first else {
return Some(ReductionScope::EveryWorkgroup);
};
if contains_workgroup_zero_guard(cond) {
Some(ReductionScope::FirstWorkgroup)
} else {
Some(ReductionScope::EveryWorkgroup)
}
}
fn contains_workgroup_zero_guard(expr: &Expr) -> bool {
match expr {
Expr::BinOp {
op: crate::ir::BinOp::And,
left,
right,
} => contains_workgroup_zero_guard(left) || contains_workgroup_zero_guard(right),
Expr::BinOp {
op: crate::ir::BinOp::Eq,
left,
right,
} => {
matches!(left.as_ref(), Expr::WorkgroupId { axis: 0 })
&& matches!(right.as_ref(), Expr::LitU32(0))
|| matches!(right.as_ref(), Expr::WorkgroupId { axis: 0 })
&& matches!(left.as_ref(), Expr::LitU32(0))
}
_ => false,
}
}
fn subgroup_sum_body(scratch: &str, scope: ReductionScope) -> Vec<Node> {
let load_expr = Expr::load(scratch, Expr::var("local"));
let subgroup_expr = Expr::subgroup_add(load_expr);
let store_node = Node::store(scratch, Expr::var("local"), subgroup_expr);
match scope {
ReductionScope::EveryWorkgroup => vec![store_node, Node::barrier()],
ReductionScope::FirstWorkgroup => vec![
Node::if_then(
Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
vec![store_node],
),
Node::barrier(),
],
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
fn caps_with_subgroup(size: u32) -> AdapterCaps {
AdapterCaps {
supports_subgroup_ops: true,
subgroup_size: size,
..AdapterCaps::default()
}
}
fn workgroup_sum_region(scratch: &str, scope: ReductionScope) -> Node {
let body = if scope == ReductionScope::FirstWorkgroup {
vec![
Node::if_then(
Expr::and(
Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
Expr::lt(Expr::var("local"), Expr::u32(2)),
),
vec![Node::Store {
buffer: scratch.into(),
index: Expr::var("local"),
value: Expr::add(
Expr::load(scratch, Expr::var("local")),
Expr::load(scratch, Expr::add(Expr::var("local"), Expr::u32(2))),
),
}],
),
Node::barrier(),
Node::if_then(
Expr::and(
Expr::eq(Expr::WorkgroupId { axis: 0 }, Expr::u32(0)),
Expr::lt(Expr::var("local"), Expr::u32(1)),
),
vec![Node::Store {
buffer: scratch.into(),
index: Expr::var("local"),
value: Expr::add(
Expr::load(scratch, Expr::var("local")),
Expr::load(scratch, Expr::add(Expr::var("local"), Expr::u32(1))),
),
}],
),
Node::barrier(),
]
} else {
vec![
Node::if_then(
Expr::lt(Expr::var("local"), Expr::u32(2)),
vec![Node::Store {
buffer: scratch.into(),
index: Expr::var("local"),
value: Expr::add(
Expr::load(scratch, Expr::var("local")),
Expr::load(scratch, Expr::add(Expr::var("local"), Expr::u32(2))),
),
}],
),
Node::barrier(),
Node::if_then(
Expr::lt(Expr::var("local"), Expr::u32(1)),
vec![Node::Store {
buffer: scratch.into(),
index: Expr::var("local"),
value: Expr::add(
Expr::load(scratch, Expr::var("local")),
Expr::load(scratch, Expr::add(Expr::var("local"), Expr::u32(1))),
),
}],
),
Node::barrier(),
]
};
Node::Region {
generator: "vyre-primitives::reduce::workgroup_sum_f32".into(),
source_region: None,
body: Arc::new(body),
}
}
#[test]
fn no_change_when_subgroup_not_supported() {
let region = workgroup_sum_region("scratch", ReductionScope::EveryWorkgroup);
let program = Program::wrapped(
vec![BufferDecl::workgroup("scratch", 4, DataType::F32)],
[4, 1, 1],
vec![region],
);
let caps = AdapterCaps::default();
let lowered = lower_subgroup_reductions(Clone::clone(&program), &caps);
assert_eq!(lowered, program);
}
#[test]
fn no_change_when_workgroup_larger_than_subgroup() {
let region = workgroup_sum_region("scratch", ReductionScope::EveryWorkgroup);
let program = Program::wrapped(
vec![BufferDecl::workgroup("scratch", 64, DataType::F32)],
[64, 1, 1],
vec![region],
);
let caps = caps_with_subgroup(32);
let lowered = lower_subgroup_reductions(Clone::clone(&program), &caps);
assert_eq!(lowered, program);
}
#[test]
fn lowers_every_workgroup_sum_to_subgroup_add() {
let region = workgroup_sum_region("scratch", ReductionScope::EveryWorkgroup);
let program = Program::wrapped(
vec![BufferDecl::workgroup("scratch", 4, DataType::F32)],
[4, 1, 1],
vec![region],
);
let caps = caps_with_subgroup(32);
let lowered = lower_subgroup_reductions(program, &caps);
let entry = lowered.entry();
assert_eq!(entry.len(), 1);
let Node::Region { body, .. } = &entry[0] else {
panic!("expected Region");
};
assert_eq!(body.len(), 2);
assert!(
matches!(&body[0], Node::Store { buffer, index, value } if
buffer.as_str() == "scratch" &&
matches!(index, Expr::Var(v) if v.as_str() == "local") &&
matches!(value, Expr::SubgroupAdd { .. })
),
"expected subgroup_add store, got {:?}",
body[0]
);
assert!(matches!(&body[1], Node::Barrier { .. }));
}
#[test]
fn lowers_first_workgroup_sum_with_guard() {
let region = workgroup_sum_region("scratch", ReductionScope::FirstWorkgroup);
let program = Program::wrapped(
vec![BufferDecl::workgroup("scratch", 4, DataType::F32)],
[4, 1, 1],
vec![region],
);
let caps = caps_with_subgroup(32);
let lowered = lower_subgroup_reductions(program, &caps);
let entry = lowered.entry();
assert_eq!(entry.len(), 1);
let Node::Region { body, .. } = &entry[0] else {
panic!("expected Region");
};
assert_eq!(body.len(), 2);
let Node::If { cond, then, .. } = &body[0] else {
panic!("expected If guard");
};
assert!(
matches!(cond, Expr::BinOp { op: crate::ir::BinOp::Eq, left, right } if
matches!(left.as_ref(), Expr::WorkgroupId { axis: 0 }) &&
matches!(right.as_ref(), Expr::LitU32(0))
)
);
assert_eq!(then.len(), 1);
assert!(matches!(&then[0], Node::Store { buffer, .. } if buffer.as_str() == "scratch"));
assert!(matches!(&body[1], Node::Barrier { .. }));
}
#[test]
fn non_reduction_regions_are_unchanged() {
let region = Node::Region {
generator: "vyre-libs::math::dot".into(),
source_region: None,
body: Arc::new(vec![Node::store("out", Expr::u32(0), Expr::u32(1))]),
};
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![region],
);
let caps = caps_with_subgroup(32);
let lowered = lower_subgroup_reductions(Clone::clone(&program), &caps);
assert_eq!(lowered, program);
}
#[test]
fn stats_flag_subgroup_ops_after_lowering() {
let region = workgroup_sum_region("scratch", ReductionScope::EveryWorkgroup);
let program = Program::wrapped(
vec![BufferDecl::workgroup("scratch", 4, DataType::F32)],
[4, 1, 1],
vec![region],
);
let caps = caps_with_subgroup(32);
let lowered = lower_subgroup_reductions(program, &caps);
assert!(
lowered.stats().subgroup_ops(),
"lowering must set the subgroup_ops capability bit"
);
}
}