use crate::ir::{Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "empty_block_collapse",
requires = [],
invalidates = []
)]
pub struct EmptyBlockCollapsePass;
impl EmptyBlockCollapsePass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program.entry().iter().any(|n| {
node_map::any_descendant(
n,
&mut |child| matches!(child, Node::Block(b) if b.is_empty()),
)
}) {
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry = drop_empty_blocks(
program
.into_entry_vec()
.into_iter()
.map(|n| collapse_node(n, &mut changed))
.collect(),
&mut changed,
);
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn collapse_node(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| collapse_node(child, changed));
node_map::map_body(recursed, &mut |body| drop_empty_blocks(body, changed))
}
fn drop_empty_blocks(body: Vec<Node>, changed: &mut bool) -> Vec<Node> {
let mut out = Vec::with_capacity(body.len());
for node in body {
match &node {
Node::Block(inner) if inner.is_empty() => {
*changed = true;
}
_ => out.push(node),
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn count_empty_blocks(node: &Node) -> usize {
let mut count = 0;
match node {
Node::Block(body) => {
if body.is_empty() {
count += 1;
}
for child in body {
count += count_empty_blocks(child);
}
}
Node::If {
then, otherwise, ..
} => {
for n in then {
count += if matches!(n, Node::Block(b) if b.is_empty()) {
1
} else {
0
};
count += count_empty_blocks(n);
}
for n in otherwise {
count += if matches!(n, Node::Block(b) if b.is_empty()) {
1
} else {
0
};
count += count_empty_blocks(n);
}
}
Node::Loop { body, .. } => {
for n in body {
count += if matches!(n, Node::Block(b) if b.is_empty()) {
1
} else {
0
};
count += count_empty_blocks(n);
}
}
Node::Region { body, .. } => {
for n in body.iter() {
count += if matches!(n, Node::Block(b) if b.is_empty()) {
1
} else {
0
};
count += count_empty_blocks(n);
}
}
_ => {}
}
count
}
#[test]
fn drops_empty_block_from_top_level_sequence() {
let entry = vec![
Node::store("buf", Expr::u32(0), Expr::u32(7)),
Node::Block(Vec::new()),
Node::store("buf", Expr::u32(1), Expr::u32(8)),
];
let program = program_with_entry(entry);
let result = EmptyBlockCollapsePass::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_empty_blocks).sum();
assert_eq!(
total, 0,
"no empty Blocks must remain after collapse; got {total}"
);
}
#[test]
fn drops_multiple_empty_blocks_in_sequence() {
let entry = vec![
Node::Block(Vec::new()),
Node::Block(Vec::new()),
Node::Block(Vec::new()),
Node::store("buf", Expr::u32(0), Expr::u32(7)),
];
let program = program_with_entry(entry);
let result = EmptyBlockCollapsePass::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_empty_blocks).sum();
assert_eq!(total, 0);
}
#[test]
fn keeps_non_empty_blocks() {
let entry = vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(7),
)])];
let program = program_with_entry(entry);
let result = EmptyBlockCollapsePass::transform(program);
assert!(
!result.changed,
"Block with content must not be touched by empty_block_collapse"
);
}
#[test]
fn drops_empty_block_inside_if_branch() {
let entry = vec![Node::if_then(
Expr::bool(true),
vec![
Node::store("buf", Expr::u32(0), Expr::u32(7)),
Node::Block(Vec::new()),
],
)];
let program = program_with_entry(entry);
let result = EmptyBlockCollapsePass::transform(program);
assert!(result.changed, "must recurse into If branches");
let total: usize = result.program.entry().iter().map(count_empty_blocks).sum();
assert_eq!(total, 0);
}
#[test]
fn analyze_skips_program_with_no_empty_blocks() {
let entry = vec![Node::store("buf", Expr::u32(0), Expr::u32(7))];
let program = program_with_entry(entry);
assert_eq!(
EmptyBlockCollapsePass::analyze(&program),
PassAnalysis::SKIP
);
}
#[test]
fn analyze_runs_when_empty_block_present() {
let entry = vec![Node::Block(Vec::new())];
let program = program_with_entry(entry);
assert_eq!(EmptyBlockCollapsePass::analyze(&program), PassAnalysis::RUN);
}
#[test]
fn nested_empty_block_collapses() {
let entry = vec![Node::Block(vec![Node::Block(Vec::new())])];
let program = program_with_entry(entry);
let result = EmptyBlockCollapsePass::transform(program);
assert!(result.changed, "nested empty Block must trigger collapse");
let total: usize = result.program.entry().iter().map(count_empty_blocks).sum();
assert_eq!(
total, 0,
"nested empty block must collapse; got {total} empty blocks"
);
}
#[test]
fn block_with_store_is_preserved() {
let entry = vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(7),
)])];
let program = program_with_entry(entry);
let result = EmptyBlockCollapsePass::transform(program);
assert!(
!result.changed,
"Block with Store content must not be removed by empty_block_collapse"
);
}
#[test]
fn adversarial_empty_region_inside_block() {
let entry = vec![Node::Block(vec![Node::Region {
body: std::sync::Arc::new(vec![]),
generator: "test".into(),
source_region: None,
}])];
let program = program_with_entry(entry);
let result = EmptyBlockCollapsePass::transform(program);
assert!(
!result.changed,
"empty Region is not an empty Block, should not collapse"
);
}
#[test]
fn adversarial_three_levels_nested_empty_blocks() {
let entry = vec![Node::Block(vec![Node::Block(vec![
Node::Block(Vec::new()),
])])];
let program = program_with_entry(entry);
let result = EmptyBlockCollapsePass::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_empty_blocks).sum();
assert_eq!(total, 0, "all 3 levels must collapse bottom-up");
}
#[test]
fn adversarial_empty_block_sibling_alongside_store() {
let entry = vec![
Node::Block(Vec::new()),
Node::store("buf", Expr::u32(0), Expr::u32(7)),
Node::Block(Vec::new()),
];
let program = program_with_entry(entry);
let result = EmptyBlockCollapsePass::transform(program);
assert!(result.changed);
let total: usize = result.program.entry().iter().map(count_empty_blocks).sum();
assert_eq!(total, 0);
assert!(!result.program.entry().is_empty());
}
}