use crate::ir::{BinOp, Expr, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[derive(Debug, Default)]
#[vyre_pass(
name = "loop_redundant_bound_check_elide",
requires = [],
invalidates = []
)]
pub struct LoopRedundantBoundCheckElidePass;
impl LoopRedundantBoundCheckElidePass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program.entry().iter().any(node_has_redundant_guard) {
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry = program
.into_entry_vec()
.into_iter()
.map(|n| elide_in_node(n, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn lit_u32_value(expr: &Expr) -> Option<u32> {
match expr {
Expr::LitU32(n) => Some(*n),
_ => None,
}
}
fn cond_matches_loop_var_lt_lit(cond: &Expr, var: &str) -> Option<u32> {
if let Expr::BinOp { op, left, right } = cond {
if matches!(op, BinOp::Lt) {
if let (Expr::Var(v), Some(rhs_lit)) = (left.as_ref(), lit_u32_value(right)) {
if v.as_str() == var {
return Some(rhs_lit);
}
}
}
}
None
}
fn elide_in_sequence(
body: Vec<Node>,
loop_ctx: Option<(&str, u32)>,
changed: &mut bool,
) -> Vec<Node> {
body.into_iter()
.map(|n| elide_in_node_with_ctx(n, loop_ctx, changed))
.collect()
}
fn elide_in_node_with_ctx(node: Node, loop_ctx: Option<(&str, u32)>, changed: &mut bool) -> Node {
match node {
Node::If {
cond,
then,
otherwise,
} => {
if otherwise.is_empty() {
if let Some((loop_var, loop_to)) = loop_ctx {
if let Some(rhs_lit) = cond_matches_loop_var_lt_lit(&cond, loop_var) {
if rhs_lit == loop_to {
*changed = true;
let then_elided = elide_in_sequence(then, loop_ctx, changed);
return Node::Block(then_elided);
}
}
}
}
Node::If {
cond,
then: elide_in_sequence(then, loop_ctx, changed),
otherwise: elide_in_sequence(otherwise, loop_ctx, changed),
}
}
Node::Loop {
var,
from,
to,
body,
} => {
let new_ctx = lit_u32_value(&to).map(|to_lit| (var.as_str(), to_lit));
let body = elide_in_sequence(body, new_ctx, changed);
Node::Loop {
var,
from,
to,
body,
}
}
Node::Block(body) => Node::Block(elide_in_sequence(body, loop_ctx, changed)),
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match std::sync::Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
let body_vec = elide_in_sequence(body_vec, None, changed);
Node::Region {
generator,
source_region,
body: std::sync::Arc::new(body_vec),
}
}
other => other,
}
}
fn elide_in_node(node: Node, changed: &mut bool) -> Node {
elide_in_node_with_ctx(node, None, changed)
}
fn node_has_redundant_guard(node: &Node) -> bool {
has_redundant_guard_with_ctx(node, None)
}
fn has_redundant_guard_with_ctx(node: &Node, loop_ctx: Option<(&str, u32)>) -> bool {
match node {
Node::If {
cond,
then,
otherwise,
} => {
if otherwise.is_empty() {
if let Some((loop_var, loop_to)) = loop_ctx {
if let Some(rhs_lit) = cond_matches_loop_var_lt_lit(cond, loop_var) {
if rhs_lit == loop_to {
return true;
}
}
}
}
then.iter()
.any(|n| has_redundant_guard_with_ctx(n, loop_ctx))
|| otherwise
.iter()
.any(|n| has_redundant_guard_with_ctx(n, loop_ctx))
}
Node::Loop { var, to, body, .. } => {
let new_ctx = lit_u32_value(to).map(|to_lit| (var.as_str(), to_lit));
body.iter()
.any(|n| has_redundant_guard_with_ctx(n, new_ctx))
}
Node::Block(body) => body
.iter()
.any(|n| has_redundant_guard_with_ctx(n, loop_ctx)),
Node::Region { body, .. } => body.iter().any(|n| has_redundant_guard_with_ctx(n, None)),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::model::expr::Ident;
use crate::ir::{BufferAccess, BufferDecl, DataType};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn loop_with_body(var: &str, to: u32, body: Vec<Node>) -> Node {
Node::Loop {
var: Ident::from(var),
from: Expr::u32(0),
to: Expr::u32(to),
body,
}
}
fn count_redundant_if_guards(node: &Node, loop_ctx: Option<(&str, u32)>) -> usize {
let mut count = 0;
match node {
Node::If {
cond,
then,
otherwise,
} => {
if otherwise.is_empty() {
if let Some((loop_var, loop_to)) = loop_ctx {
if let Some(rhs_lit) = cond_matches_loop_var_lt_lit(cond, loop_var) {
if rhs_lit == loop_to {
count += 1;
}
}
}
}
for n in then {
count += count_redundant_if_guards(n, loop_ctx);
}
for n in otherwise {
count += count_redundant_if_guards(n, loop_ctx);
}
}
Node::Loop { var, to, body, .. } => {
let new_ctx = lit_u32_value(to).map(|to_lit| (var.as_str(), to_lit));
for n in body {
count += count_redundant_if_guards(n, new_ctx);
}
}
Node::Block(body) => {
for n in body {
count += count_redundant_if_guards(n, loop_ctx);
}
}
Node::Region { body, .. } => {
for n in body.iter() {
count += count_redundant_if_guards(n, None);
}
}
_ => {}
}
count
}
#[test]
fn skip_analysis_on_program_without_loop() {
let entry = vec![Node::store("buf", Expr::u32(0), Expr::u32(1))];
let program = program_with_entry(entry);
assert_eq!(
LoopRedundantBoundCheckElidePass::analyze(&program),
PassAnalysis::SKIP
);
}
#[test]
fn skip_analysis_on_loop_without_redundant_guard() {
let entry = vec![loop_with_body(
"i",
10,
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)];
let program = program_with_entry(entry);
assert_eq!(
LoopRedundantBoundCheckElidePass::analyze(&program),
PassAnalysis::SKIP
);
}
#[test]
fn run_analysis_on_loop_with_redundant_guard() {
let entry = vec![loop_with_body(
"i",
10,
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(10)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
assert_eq!(
LoopRedundantBoundCheckElidePass::analyze(&program),
PassAnalysis::RUN
);
}
#[test]
fn transform_elides_simple_redundant_guard() {
let entry = vec![loop_with_body(
"i",
10,
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(10)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
let result = LoopRedundantBoundCheckElidePass::transform(program);
assert!(result.changed);
let total: usize = result
.program
.entry()
.iter()
.map(|n| count_redundant_if_guards(n, None))
.sum();
assert_eq!(
total, 0,
"no redundant guards must remain after the elision pass"
);
}
#[test]
fn transform_does_not_elide_when_lit_does_not_match_loop_to() {
let entry = vec![loop_with_body(
"i",
10,
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(8)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
let result = LoopRedundantBoundCheckElidePass::transform(program);
assert!(
!result.changed,
"non-matching literal must not trigger elision"
);
}
#[test]
fn transform_does_not_elide_when_var_name_does_not_match() {
let entry = vec![loop_with_body(
"i",
10,
vec![Node::if_then(
Expr::lt(Expr::var("j"), Expr::u32(10)),
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
let result = LoopRedundantBoundCheckElidePass::transform(program);
assert!(
!result.changed,
"different var name must not trigger elision"
);
}
#[test]
fn transform_does_not_elide_when_loop_to_is_not_literal() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::var("n"),
body: vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(10)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
}];
let program = program_with_entry(entry);
let result = LoopRedundantBoundCheckElidePass::transform(program);
assert!(
!result.changed,
"non-literal loop bound must not trigger elision"
);
}
#[test]
fn transform_does_not_elide_when_else_branch_is_nonempty() {
let entry = vec![loop_with_body(
"i",
10,
vec![Node::if_then_else(
Expr::lt(Expr::var("i"), Expr::u32(10)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
vec![Node::store("buf", Expr::var("i"), Expr::u32(0))],
)],
)];
let program = program_with_entry(entry);
let result = LoopRedundantBoundCheckElidePass::transform(program);
assert!(
!result.changed,
"if-then-else (with else body) must not be elided"
);
}
#[test]
fn transform_handles_nested_loops_independently() {
let inner = loop_with_body(
"j",
5,
vec![Node::if_then(
Expr::lt(Expr::var("j"), Expr::u32(5)),
vec![Node::store("buf", Expr::var("j"), Expr::u32(7))],
)],
);
let entry = vec![loop_with_body("i", 10, vec![inner])];
let program = program_with_entry(entry);
let result = LoopRedundantBoundCheckElidePass::transform(program);
assert!(result.changed, "inner-loop guard must be elided");
let total: usize = result
.program
.entry()
.iter()
.map(|n| count_redundant_if_guards(n, None))
.sum();
assert_eq!(total, 0);
}
#[test]
fn transform_does_not_elide_inner_guard_against_outer_loop_var() {
let inner = Node::Loop {
var: Ident::from("j"),
from: Expr::u32(0),
to: Expr::u32(5),
body: vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(10)),
vec![Node::store("buf", Expr::var("j"), Expr::u32(7))],
)],
};
let entry = vec![loop_with_body("i", 10, vec![inner])];
let program = program_with_entry(entry);
let result = LoopRedundantBoundCheckElidePass::transform(program);
assert!(
!result.changed,
"inner-loop scope must shadow outer; if-guard against outer var stays"
);
}
#[test]
fn transform_is_idempotent() {
let entry = vec![loop_with_body(
"i",
10,
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(10)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
let once = LoopRedundantBoundCheckElidePass::transform(program);
let twice = LoopRedundantBoundCheckElidePass::transform(once.program.clone());
assert!(once.changed);
assert!(!twice.changed, "second run must report no change");
}
#[test]
fn transform_handles_empty_program() {
let program = Program::wrapped(vec![buf()], [1, 1, 1], vec![]);
let result = LoopRedundantBoundCheckElidePass::transform(program);
assert!(!result.changed);
}
#[test]
fn transform_does_not_drop_then_body_during_elision() {
let entry = vec![loop_with_body(
"i",
10,
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(10)),
vec![
Node::store("buf", Expr::var("i"), Expr::u32(7)),
Node::store("buf", Expr::u32(0), Expr::u32(8)),
],
)],
)];
let program = program_with_entry(entry);
let result = LoopRedundantBoundCheckElidePass::transform(program);
assert!(result.changed);
fn count_stores(node: &Node) -> usize {
let mut count = 0;
match node {
Node::Store { .. } => count += 1,
Node::Loop { body, .. } => {
for n in body {
count += count_stores(n);
}
}
Node::Block(body) => {
for n in body {
count += count_stores(n);
}
}
Node::Region { body, .. } => {
for n in body.iter() {
count += count_stores(n);
}
}
Node::If {
then, otherwise, ..
} => {
for n in then {
count += count_stores(n);
}
for n in otherwise {
count += count_stores(n);
}
}
_ => {}
}
count
}
let total: usize = result.program.entry().iter().map(count_stores).sum();
assert_eq!(
total, 2,
"both stores from the elided then-body must still be present"
);
}
#[test]
fn fingerprint_returns_stable_value() {
let entry = vec![loop_with_body(
"i",
10,
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(10)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
let program = program_with_entry(entry);
let fp1 = LoopRedundantBoundCheckElidePass::fingerprint(&program);
let fp2 = LoopRedundantBoundCheckElidePass::fingerprint(&program);
assert_eq!(fp1, fp2);
}
#[test]
fn cond_matches_loop_var_lt_lit_extracts_correct_literal() {
let cond = Expr::lt(Expr::var("i"), Expr::u32(42));
assert_eq!(cond_matches_loop_var_lt_lit(&cond, "i"), Some(42));
}
#[test]
fn cond_matches_loop_var_lt_lit_rejects_wrong_var_name() {
let cond = Expr::lt(Expr::var("j"), Expr::u32(42));
assert_eq!(cond_matches_loop_var_lt_lit(&cond, "i"), None);
}
#[test]
fn cond_matches_loop_var_lt_lit_rejects_non_lt_op() {
let cond = Expr::eq(Expr::var("i"), Expr::u32(42));
assert_eq!(cond_matches_loop_var_lt_lit(&cond, "i"), None);
}
}