use crate::ir::{BinOp, Expr, Ident, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "loop_bound_tighten",
requires = [],
invalidates = []
)]
pub struct LoopBoundTighten;
impl LoopBoundTighten {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_tighten_eligible))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.map(|n| rewrite_node(n, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn rewrite_node(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| rewrite_node(child, changed));
let recursed = node_map::map_body(recursed, &mut |body| {
body.into_iter().map(|n| rewrite_node(n, changed)).collect()
});
tighten_if_eligible(recursed, changed)
}
fn tighten_if_eligible(node: Node, changed: &mut bool) -> Node {
let Node::Loop {
var,
from,
to,
body,
} = node
else {
return node_unchanged_helper(node);
};
let Some((upper_lit, predicate_lit, real_body)) =
match_tighten_pattern(&var, &from, &to, &body)
else {
return Node::Loop {
var,
from,
to,
body,
};
};
if predicate_lit >= upper_lit {
return Node::Loop {
var,
from,
to,
body,
};
}
*changed = true;
Node::Loop {
var,
from,
to: Expr::u32(predicate_lit),
body: real_body,
}
}
fn node_unchanged_helper(node: Node) -> Node {
node
}
fn match_tighten_pattern(
loop_var: &Ident,
from: &Expr,
to: &Expr,
body: &[Node],
) -> Option<(u32, u32, Vec<Node>)> {
let Expr::LitU32(_) = from else { return None };
let Expr::LitU32(upper) = to else {
return None;
};
if body.len() != 1 {
return None;
}
let Node::If {
cond,
then,
otherwise,
} = &body[0]
else {
return None;
};
if !otherwise.is_empty() {
return None;
}
let Expr::BinOp {
op: BinOp::Lt,
left,
right,
} = cond
else {
return None;
};
let Expr::Var(name) = left.as_ref() else {
return None;
};
if name != loop_var {
return None;
}
let Expr::LitU32(n) = right.as_ref() else {
return None;
};
Some((*upper, *n, then.clone()))
}
fn is_tighten_eligible(node: &Node) -> bool {
let Node::Loop {
var,
from,
to,
body,
} = node
else {
return false;
};
let Some((upper, n, _)) = match_tighten_pattern(var, from, to, body) else {
return false;
};
n < upper
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(8)
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn loop_with_to(entry: &[Node]) -> Option<u32> {
for n in entry {
match n {
Node::Loop { to, .. } => match to {
Expr::LitU32(v) => return Some(*v),
_ => return None,
},
Node::Region { body, .. } => {
if let Some(v) = loop_with_to(body) {
return Some(v);
}
}
_ => {}
}
}
None
}
#[test]
fn tightens_upper_bound_when_inner_predicate_is_smaller() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(64),
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(8)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
let result = LoopBoundTighten::transform(program(entry));
assert!(result.changed);
assert_eq!(
loop_with_to(result.program.entry()),
Some(8),
"loop's upper bound must shrink from 64 to the predicate constant 8"
);
}
#[test]
fn does_not_tighten_when_predicate_meets_or_exceeds_upper() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(64),
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(64)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
let result = LoopBoundTighten::transform(program(entry));
assert!(
!result.changed,
"predicate constant equal to upper bound is not a tighten win"
);
}
#[test]
fn does_not_tighten_when_body_has_unguarded_sibling() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(64),
vec![
Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(8)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
),
Node::store("buf", Expr::var("i"), Expr::u32(0)),
],
)];
let result = LoopBoundTighten::transform(program(entry));
assert!(!result.changed, "unguarded sibling Store blocks tightening");
}
#[test]
fn does_not_tighten_when_inner_if_has_otherwise_arm() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(64),
vec![Node::if_then_else(
Expr::lt(Expr::var("i"), Expr::u32(8)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
vec![Node::store("buf", Expr::var("i"), Expr::u32(0))],
)],
)];
let result = LoopBoundTighten::transform(program(entry));
assert!(
!result.changed,
"else-arm side-effect must keep firing across full range"
);
}
#[test]
fn does_not_tighten_when_predicate_uses_different_var() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(64),
vec![Node::if_then(
Expr::lt(Expr::var("j"), Expr::u32(8)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
let result = LoopBoundTighten::transform(program(entry));
assert!(
!result.changed,
"predicate on a different variable is not a tightener"
);
}
#[test]
fn does_not_tighten_when_loop_bound_is_runtime() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::buf_len("buf"),
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(8)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
let result = LoopBoundTighten::transform(program(entry));
assert!(
!result.changed,
"runtime upper bound needs range facts (A16) to tighten"
);
}
#[test]
fn analyze_skips_program_with_no_eligible_loop() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(8),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)];
assert_eq!(
LoopBoundTighten::analyze(&program(entry)),
PassAnalysis::SKIP
);
}
#[test]
fn analyze_runs_when_loop_is_tighten_eligible() {
let entry = vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(64),
vec![Node::if_then(
Expr::lt(Expr::var("i"), Expr::u32(8)),
vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
)],
)];
assert_eq!(
LoopBoundTighten::analyze(&program(entry)),
PassAnalysis::RUN
);
}
}