use crate::ir::{Expr, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "loop_trip_zero_eliminate",
requires = ["const_fold"],
invalidates = []
)]
pub struct LoopTripZeroEliminatePass;
impl LoopTripZeroEliminatePass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_empty_loop))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.map(|node| eliminate_node(node, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn eliminate_node(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| eliminate_node(child, changed));
if is_empty_loop(&recursed) {
*changed = true;
Node::Block(Vec::new())
} else {
recursed
}
}
fn is_empty_loop(node: &Node) -> bool {
if let Node::Loop { from, to, .. } = node {
match (from, to) {
(Expr::LitU32(a), Expr::LitU32(b)) => return *a >= *b,
(Expr::LitI32(a), Expr::LitI32(b)) => return *a >= *b,
_ => {}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Ident, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn count_loops(node: &Node) -> usize {
match node {
Node::Loop { body, .. } => 1 + body.iter().map(count_loops).sum::<usize>(),
Node::If {
then, otherwise, ..
} => {
then.iter().map(count_loops).sum::<usize>()
+ otherwise.iter().map(count_loops).sum::<usize>()
}
Node::Block(body) => body.iter().map(count_loops).sum(),
Node::Region { body, .. } => body.iter().map(count_loops).sum(),
_ => 0,
}
}
fn make_loop(from: u32, to: u32, body: Vec<Node>) -> Node {
Node::Loop {
var: Ident::from("i"),
from: Expr::u32(from),
to: Expr::u32(to),
body,
}
}
#[test]
fn empty_range_loop_dropped() {
let entry = vec![make_loop(
5,
3,
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = LoopTripZeroEliminatePass::transform(program);
assert!(result.changed);
let total_loops: usize = result.program.entry().iter().map(count_loops).sum();
assert_eq!(
total_loops, 0,
"empty-range loop must be dropped; got {total_loops} loops remaining"
);
}
#[test]
fn equal_bounds_loop_dropped() {
let entry = vec![make_loop(
5,
5,
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = LoopTripZeroEliminatePass::transform(program);
assert!(result.changed);
let total_loops: usize = result.program.entry().iter().map(count_loops).sum();
assert_eq!(total_loops, 0);
}
#[test]
fn non_empty_range_loop_kept() {
let entry = vec![make_loop(
0,
10,
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = LoopTripZeroEliminatePass::transform(program);
assert!(!result.changed, "non-empty loop must be preserved");
let total_loops: usize = result.program.entry().iter().map(count_loops).sum();
assert_eq!(total_loops, 1);
}
#[test]
fn non_constant_bounds_loop_kept() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::var("start"),
to: Expr::var("stop"),
body: vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
}];
let program = program_with_entry(entry);
let result = LoopTripZeroEliminatePass::transform(program);
assert!(
!result.changed,
"loops with non-literal bounds must be kept because the runtime trip count is unknown"
);
}
#[test]
fn nested_empty_loop_inside_outer_loop_dropped() {
let inner_empty = make_loop(0, 0, vec![Node::store("buf", Expr::u32(0), Expr::u32(7))]);
let outer = make_loop(1, 3, vec![inner_empty]);
let entry = vec![outer];
let program = program_with_entry(entry);
let result = LoopTripZeroEliminatePass::transform(program);
assert!(result.changed);
let total_loops: usize = result.program.entry().iter().map(count_loops).sum();
assert_eq!(
total_loops, 1,
"outer non-empty loop kept; inner empty loop dropped; got {total_loops}"
);
}
#[test]
fn analyze_skips_program_with_no_empty_loops() {
let entry = vec![make_loop(0, 10, vec![])];
let program = program_with_entry(entry);
assert_eq!(
LoopTripZeroEliminatePass::analyze(&program),
PassAnalysis::SKIP,
"analyze must SKIP programs with no compile-time-empty loops"
);
}
#[test]
fn analyze_runs_for_program_with_one_empty_loop() {
let entry = vec![make_loop(5, 3, vec![])];
let program = program_with_entry(entry);
assert_eq!(
LoopTripZeroEliminatePass::analyze(&program),
PassAnalysis::RUN,
"analyze must RUN when at least one compile-time-empty loop exists"
);
}
fn make_loop_i32(from: i32, to: i32, body: Vec<Node>) -> Node {
Node::Loop {
var: Ident::from("i"),
from: Expr::i32(from),
to: Expr::i32(to),
body,
}
}
#[test]
fn i32_swapped_bounds_collapses() {
let entry = vec![make_loop_i32(
5,
3,
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = LoopTripZeroEliminatePass::transform(program);
assert!(
result.changed,
"i32 swapped bounds must trigger elimination"
);
let total_loops: usize = result.program.entry().iter().map(count_loops).sum();
assert_eq!(
total_loops, 0,
"i32 swapped-bounds loop must be dropped; got {total_loops} loops"
);
}
#[test]
fn i32_equal_bounds_collapses() {
let entry = vec![make_loop_i32(
5,
5,
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)];
let program = program_with_entry(entry);
let result = LoopTripZeroEliminatePass::transform(program);
assert!(result.changed, "i32 equal bounds must trigger elimination");
let total_loops: usize = result.program.entry().iter().map(count_loops).sum();
assert_eq!(total_loops, 0);
}
}