use crate::ir::{Node, Program};
use crate::memory_model::MemoryOrdering;
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[derive(Debug, Default)]
#[vyre_pass(
name = "barrier_coalesce",
requires = [],
invalidates = []
)]
pub struct BarrierCoalescePass;
impl BarrierCoalescePass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program.entry().iter().any(has_consecutive_barriers) {
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry = coalesce_nodes(program.into_entry_vec(), &mut changed);
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn coalesce_node(node: Node, changed: &mut bool) -> Node {
match node {
Node::If {
cond,
then,
otherwise,
} => {
let then = coalesce_nodes(then, changed);
let otherwise = coalesce_nodes(otherwise, changed);
Node::If {
cond,
then,
otherwise,
}
}
Node::Loop {
var,
from,
to,
body,
} => {
let body = coalesce_nodes(body, changed);
Node::Loop {
var,
from,
to,
body,
}
}
Node::Block(body) => Node::Block(coalesce_nodes(body, changed)),
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match std::sync::Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
let body_vec = coalesce_nodes(body_vec, changed);
Node::Region {
generator,
source_region,
body: std::sync::Arc::new(body_vec),
}
}
other => other,
}
}
fn coalesce_nodes(body: Vec<Node>, changed: &mut bool) -> Vec<Node> {
let mut out = Vec::with_capacity(body.len());
for node in body {
push_coalesced(&mut out, coalesce_node(node, changed), changed);
}
out
}
fn push_coalesced(out: &mut Vec<Node>, node: Node, changed: &mut bool) {
match (out.last(), &node) {
(Some(Node::Barrier { ordering: prev }), Node::Barrier { ordering: curr }) => {
let joined = join_ordering(*prev, *curr);
let new_last = Node::Barrier { ordering: joined };
*out.last_mut().expect("non-empty by match arm") = new_last;
*changed = true;
}
_ => out.push(node),
}
}
fn has_consecutive_barriers(node: &Node) -> bool {
match node {
Node::If {
then, otherwise, ..
} => {
sequence_has_consecutive_barriers(then)
|| sequence_has_consecutive_barriers(otherwise)
|| then.iter().any(has_consecutive_barriers)
|| otherwise.iter().any(has_consecutive_barriers)
}
Node::Loop { body, .. } => {
sequence_has_consecutive_barriers(body) || body.iter().any(has_consecutive_barriers)
}
Node::Block(body) => {
sequence_has_consecutive_barriers(body) || body.iter().any(has_consecutive_barriers)
}
Node::Region { body, .. } => {
sequence_has_consecutive_barriers(body) || body.iter().any(has_consecutive_barriers)
}
_ => false,
}
}
fn sequence_has_consecutive_barriers(body: &[Node]) -> bool {
body.windows(2).any(|pair| {
matches!(
(&pair[0], &pair[1]),
(Node::Barrier { .. }, Node::Barrier { .. })
)
})
}
fn join_ordering(a: MemoryOrdering, b: MemoryOrdering) -> MemoryOrdering {
use MemoryOrdering::{AcqRel, Acquire, GridSync, Relaxed, Release, SeqCst};
match (a, b) {
(GridSync, _) | (_, GridSync) => GridSync,
(SeqCst, _) | (_, SeqCst) => SeqCst,
(AcqRel, _) | (_, AcqRel) => AcqRel,
(Acquire, Release) | (Release, Acquire) => AcqRel,
(Acquire, Acquire) => Acquire,
(Release, Release) => Release,
(Relaxed, x) | (x, Relaxed) => x,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn count_barriers(node: &Node) -> usize {
match node {
Node::Barrier { .. } => 1,
Node::If {
then, otherwise, ..
} => {
then.iter().map(count_barriers).sum::<usize>()
+ otherwise.iter().map(count_barriers).sum::<usize>()
}
Node::Loop { body, .. } | Node::Block(body) => body.iter().map(count_barriers).sum(),
Node::Region { body, .. } => body.iter().map(count_barriers).sum(),
_ => 0,
}
}
fn first_barrier_ordering(node: &Node) -> Option<MemoryOrdering> {
match node {
Node::Barrier { ordering } => Some(*ordering),
Node::If {
then, otherwise, ..
} => then
.iter()
.find_map(first_barrier_ordering)
.or_else(|| otherwise.iter().find_map(first_barrier_ordering)),
Node::Loop { body, .. } | Node::Block(body) => {
body.iter().find_map(first_barrier_ordering)
}
Node::Region { body, .. } => body.iter().find_map(first_barrier_ordering),
_ => None,
}
}
#[test]
fn join_workgroup_acqrel_workgroup_acqrel_yields_acqrel() {
assert_eq!(
join_ordering(MemoryOrdering::AcqRel, MemoryOrdering::AcqRel),
MemoryOrdering::AcqRel
);
}
#[test]
fn join_acquire_release_yields_acqrel() {
assert_eq!(
join_ordering(MemoryOrdering::Acquire, MemoryOrdering::Release),
MemoryOrdering::AcqRel,
"acquire ⊔ release must escalate to AcqRel"
);
}
#[test]
fn join_grid_sync_dominates_everything() {
for other in [
MemoryOrdering::Relaxed,
MemoryOrdering::Acquire,
MemoryOrdering::Release,
MemoryOrdering::AcqRel,
MemoryOrdering::SeqCst,
] {
assert_eq!(
join_ordering(MemoryOrdering::GridSync, other),
MemoryOrdering::GridSync,
"GridSync ⊔ {other:?} must stay GridSync — losing GridSync would silently \
downgrade cross-block synchronization"
);
assert_eq!(
join_ordering(other, MemoryOrdering::GridSync),
MemoryOrdering::GridSync
);
}
}
#[test]
fn coalesces_two_seqcst_barriers_into_one() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(7)),
Node::Barrier {
ordering: MemoryOrdering::SeqCst,
},
Node::Barrier {
ordering: MemoryOrdering::SeqCst,
},
Node::store("buf", Expr::u32(1), Expr::u32(8)),
];
let program = program_with_entry(entry);
let result = BarrierCoalescePass::transform(program);
assert!(
result.changed,
"consecutive barriers must mark the program as changed"
);
let barrier_count: usize = result.program.entry().iter().map(count_barriers).sum();
assert_eq!(
barrier_count, 1,
"two consecutive SeqCst barriers must coalesce into one; got {barrier_count}"
);
}
#[test]
fn coalesces_three_consecutive_barriers_to_one_with_join() {
let entry = vec![
Node::Barrier {
ordering: MemoryOrdering::Acquire,
},
Node::Barrier {
ordering: MemoryOrdering::Release,
},
Node::Barrier {
ordering: MemoryOrdering::AcqRel,
},
];
let program = program_with_entry(entry);
let result = BarrierCoalescePass::transform(program);
assert!(result.changed);
let total_barriers: usize = result.program.entry().iter().map(count_barriers).sum();
assert_eq!(
total_barriers, 1,
"three consecutive barriers must coalesce to one; got {total_barriers}"
);
let ordering = result
.program
.entry()
.iter()
.find_map(first_barrier_ordering)
.expect("a barrier must exist after coalesce");
assert_eq!(
ordering,
MemoryOrdering::AcqRel,
"Acquire ⊔ Release ⊔ AcqRel must collapse to AcqRel"
);
}
#[test]
fn does_not_coalesce_barriers_separated_by_a_store() {
let entry = vec![
Node::Barrier {
ordering: MemoryOrdering::SeqCst,
},
Node::store("buf", Expr::u32(0), Expr::u32(7)),
Node::Barrier {
ordering: MemoryOrdering::SeqCst,
},
];
let program = program_with_entry(entry);
let result = BarrierCoalescePass::transform(program);
assert!(
!result.changed,
"barriers separated by a store must NOT coalesce"
);
}
#[test]
fn analyze_skips_program_with_no_consecutive_barriers() {
let entry = vec![
Node::Barrier {
ordering: MemoryOrdering::SeqCst,
},
Node::store("buf", Expr::u32(0), Expr::u32(7)),
Node::Barrier {
ordering: MemoryOrdering::SeqCst,
},
];
let program = program_with_entry(entry);
assert_eq!(
BarrierCoalescePass::analyze(&program),
PassAnalysis::SKIP,
"analyze must SKIP when no consecutive barriers exist; otherwise the optimizer \
pays a full transform pass for nothing"
);
}
#[test]
fn coalesces_grid_sync_with_workgroup_to_grid_sync() {
let entry = vec![
Node::Barrier {
ordering: MemoryOrdering::GridSync,
},
Node::Barrier {
ordering: MemoryOrdering::SeqCst,
},
];
let program = program_with_entry(entry);
let result = BarrierCoalescePass::transform(program);
assert!(result.changed);
let ordering = result
.program
.entry()
.iter()
.find_map(first_barrier_ordering)
.expect("a barrier must exist");
assert_eq!(
ordering,
MemoryOrdering::GridSync,
"GridSync ⊔ SeqCst must stay GridSync"
);
}
#[test]
fn coalesces_inside_if_then_branch() {
let entry = vec![Node::if_then(
Expr::bool(true),
vec![
Node::Barrier {
ordering: MemoryOrdering::SeqCst,
},
Node::Barrier {
ordering: MemoryOrdering::SeqCst,
},
],
)];
let program = program_with_entry(entry);
let result = BarrierCoalescePass::transform(program);
assert!(result.changed, "coalesce must recurse into If branches");
}
}