use crate::ir::{BinOp, Expr, Ident, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "loop_var_range_fold",
requires = ["const_fold"],
invalidates = []
)]
pub struct LoopVarRangeFoldPass;
impl LoopVarRangeFoldPass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut has_foldable_if))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.map(|n| recurse(n, None, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
#[derive(Clone, Copy)]
struct LoopRange<'a> {
var: &'a Ident,
lo: u32,
hi: u32,
}
fn recurse(node: Node, range: Option<LoopRange<'_>>, changed: &mut bool) -> Node {
match node {
Node::Loop {
var,
from,
to,
body,
} => {
let body_range = match (&from, &to) {
(Expr::LitU32(lo), Expr::LitU32(hi)) if !body_rebinds_var(&body, &var) => {
Some((var.clone(), *lo, *hi))
}
_ => None,
};
let new_body: Vec<Node> = if let Some((var_owned, lo, hi)) = body_range {
let inner_range = LoopRange {
var: &var_owned,
lo,
hi,
};
body.into_iter()
.flat_map(|n| {
let folded = recurse(n, Some(inner_range), changed);
flatten_block(folded)
})
.collect()
} else {
body.into_iter()
.flat_map(|n| {
let folded = recurse(n, range, changed);
flatten_block(folded)
})
.collect()
};
Node::Loop {
var,
from,
to,
body: new_body,
}
}
Node::If {
cond,
then,
otherwise,
} => {
if let Some(range) = range {
if let Some(verdict) = condition_verdict(&cond, &range) {
*changed = true;
let new_body = if verdict { then } else { otherwise };
let folded: Vec<Node> = new_body
.into_iter()
.map(|n| recurse(n, Some(range), changed))
.collect();
if folded.len() == 1 {
return folded.into_iter().next().unwrap();
}
return Node::Block(folded);
}
}
Node::If {
cond,
then: then
.into_iter()
.map(|n| recurse(n, range, changed))
.collect(),
otherwise: otherwise
.into_iter()
.map(|n| recurse(n, range, changed))
.collect(),
}
}
Node::Block(body) => Node::Block(
body.into_iter()
.flat_map(|n| {
let folded = recurse(n, range, changed);
flatten_block(folded)
})
.collect(),
),
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match std::sync::Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
Node::Region {
generator,
source_region,
body: std::sync::Arc::new(
body_vec
.into_iter()
.flat_map(|n| {
let folded = recurse(n, range, changed);
flatten_block(folded)
})
.collect(),
),
}
}
other => other,
}
}
fn flatten_block(node: Node) -> Vec<Node> {
match node {
Node::Block(body) => body,
other => vec![other],
}
}
fn condition_verdict(cond: &Expr, range: &LoopRange<'_>) -> Option<bool> {
let Expr::BinOp { op, left, right } = cond else {
return None;
};
let (var_side, lit_side, var_on_left) = match (left.as_ref(), right.as_ref()) {
(Expr::Var(name), Expr::LitU32(lit)) if name == range.var => (name, *lit, true),
(Expr::LitU32(lit), Expr::Var(name)) if name == range.var => (name, *lit, false),
_ => return None,
};
let _ = var_side;
let lo = range.lo;
let hi = range.hi;
if hi <= lo {
return None;
}
let max_inclusive = hi - 1;
Some(match (op, var_on_left) {
(BinOp::Lt, true) => {
if lit_side >= hi {
true
} else if lit_side <= lo {
false
} else {
return None;
}
}
(BinOp::Lt, false) => {
if lit_side >= max_inclusive {
false
} else if lit_side < lo {
true
} else {
return None;
}
}
(BinOp::Le, true) => {
if lit_side >= max_inclusive {
true
} else if lit_side < lo {
false
} else {
return None;
}
}
(BinOp::Le, false) => {
if lit_side <= lo {
true
} else if lit_side > max_inclusive {
false
} else {
return None;
}
}
(BinOp::Gt, true) => {
if lit_side >= max_inclusive {
false
} else if lit_side < lo {
true
} else {
return None;
}
}
(BinOp::Gt, false) => {
if lit_side >= hi {
true
} else if lit_side <= lo {
false
} else {
return None;
}
}
(BinOp::Ge, true) => {
if lit_side <= lo {
true
} else if lit_side > max_inclusive {
false
} else {
return None;
}
}
(BinOp::Ge, false) => {
if lit_side >= max_inclusive {
true
} else if lit_side < lo {
false
} else {
return None;
}
}
(BinOp::Eq, _) => {
if lit_side < lo || lit_side > max_inclusive {
false
} else {
return None;
}
}
(BinOp::Ne, _) => {
if lit_side < lo || lit_side > max_inclusive {
true
} else {
return None;
}
}
_ => return None,
})
}
fn body_rebinds_var(body: &[Node], var: &Ident) -> bool {
body.iter().any(|n| node_rebinds_var(n, var))
}
fn node_rebinds_var(node: &Node, var: &Ident) -> bool {
match node {
Node::Assign { name, .. } => name == var,
Node::Let { name, .. } => name == var,
Node::Loop {
var: inner, body, ..
} => {
if inner == var {
return true;
}
body.iter().any(|n| node_rebinds_var(n, var))
}
Node::If {
then, otherwise, ..
} => {
then.iter().any(|n| node_rebinds_var(n, var))
|| otherwise.iter().any(|n| node_rebinds_var(n, var))
}
Node::Block(body) => body.iter().any(|n| node_rebinds_var(n, var)),
Node::Region { body, .. } => body.iter().any(|n| node_rebinds_var(n, var)),
_ => false,
}
}
fn has_foldable_if(node: &Node) -> bool {
if let Node::Loop {
var,
from,
to,
body,
} = node
{
let (lo, hi) = match (from, to) {
(Expr::LitU32(lo), Expr::LitU32(hi)) if hi > lo => (*lo, *hi),
_ => return false,
};
if body_rebinds_var(body, var) {
return false;
}
let range = LoopRange { var, lo, hi };
body.iter().any(|n| body_has_foldable_if(n, &range))
} else {
false
}
}
fn body_has_foldable_if(node: &Node, range: &LoopRange<'_>) -> bool {
match node {
Node::If { cond, .. } => condition_verdict(cond, range).is_some(),
Node::Block(body) => body.iter().any(|n| body_has_foldable_if(n, range)),
Node::Loop { body, .. } => body.iter().any(|n| body_has_foldable_if(n, range)),
Node::Region { body, .. } => body.iter().any(|n| body_has_foldable_if(n, range)),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Ident, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(8)
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn loop_with_if(
cond: Expr,
then: Vec<Node>,
otherwise: Vec<Node>,
lo: u32,
hi: u32,
) -> Vec<Node> {
vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(lo),
to: Expr::u32(hi),
body: vec![Node::If {
cond,
then,
otherwise,
}],
}]
}
fn store(name: &str, idx: Expr, val: Expr) -> Node {
Node::store(name, idx, val)
}
fn count_ifs(nodes: &[Node]) -> usize {
let mut total = 0;
for n in nodes {
match n {
Node::If {
then, otherwise, ..
} => {
total += 1;
total += count_ifs(then);
total += count_ifs(otherwise);
}
Node::Loop { body, .. } => total += count_ifs(body),
Node::Block(body) => total += count_ifs(body),
Node::Region { body, .. } => total += count_ifs(body),
_ => {}
}
}
total
}
#[test]
fn folds_lt_when_lit_at_least_hi() {
let entry = loop_with_if(
Expr::lt(Expr::var("i"), Expr::u32(8)),
vec![store("buf", Expr::var("i"), Expr::u32(1))],
vec![store("buf", Expr::var("i"), Expr::u32(99))],
0,
8,
);
let result = LoopVarRangeFoldPass::transform(program(entry));
assert!(result.changed, "Lt(i, hi) is always true");
assert_eq!(
count_ifs(result.program.entry()),
0,
"If must be folded out"
);
}
#[test]
fn folds_lt_when_lit_at_most_lo() {
let entry = loop_with_if(
Expr::lt(Expr::var("i"), Expr::u32(0)),
vec![store("buf", Expr::var("i"), Expr::u32(1))],
vec![store("buf", Expr::var("i"), Expr::u32(99))],
0,
8,
);
let result = LoopVarRangeFoldPass::transform(program(entry));
assert!(result.changed, "Lt(i, 0) is always false for i in [0,8)");
assert_eq!(count_ifs(result.program.entry()), 0);
}
#[test]
fn folds_eq_outside_range() {
let entry = loop_with_if(
Expr::eq(Expr::var("i"), Expr::u32(99)),
vec![store("buf", Expr::var("i"), Expr::u32(1))],
vec![store("buf", Expr::var("i"), Expr::u32(2))],
0,
8,
);
let result = LoopVarRangeFoldPass::transform(program(entry));
assert!(result.changed, "Eq(i, 99) is always false for i in [0,8)");
assert_eq!(count_ifs(result.program.entry()), 0);
}
#[test]
fn folds_ne_outside_range() {
let entry = loop_with_if(
Expr::ne(Expr::var("i"), Expr::u32(99)),
vec![store("buf", Expr::var("i"), Expr::u32(1))],
vec![store("buf", Expr::var("i"), Expr::u32(2))],
0,
8,
);
let result = LoopVarRangeFoldPass::transform(program(entry));
assert!(result.changed, "Ne(i, 99) is always true for i in [0,8)");
assert_eq!(count_ifs(result.program.entry()), 0);
}
#[test]
fn keeps_lt_inside_range() {
let entry = loop_with_if(
Expr::lt(Expr::var("i"), Expr::u32(4)),
vec![store("buf", Expr::var("i"), Expr::u32(1))],
vec![store("buf", Expr::var("i"), Expr::u32(2))],
0,
8,
);
let result = LoopVarRangeFoldPass::transform(program(entry));
assert!(!result.changed);
assert_eq!(count_ifs(result.program.entry()), 1);
}
#[test]
fn keeps_var_lt_var() {
let entry = loop_with_if(
Expr::lt(Expr::var("i"), Expr::var("k")),
vec![store("buf", Expr::var("i"), Expr::u32(1))],
vec![store("buf", Expr::var("i"), Expr::u32(2))],
0,
8,
);
let result = LoopVarRangeFoldPass::transform(program(entry));
assert!(!result.changed);
}
#[test]
fn keeps_when_body_assigns_loop_var() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(8),
body: vec![
Node::Assign {
name: Ident::from("i"),
value: Expr::u32(99),
},
Node::If {
cond: Expr::lt(Expr::var("i"), Expr::u32(8)),
then: vec![store("buf", Expr::u32(0), Expr::u32(1))],
otherwise: vec![],
},
],
}];
let result = LoopVarRangeFoldPass::transform(program(entry));
assert!(!result.changed);
}
#[test]
fn keeps_runtime_bound_loop() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::var("n"),
body: vec![Node::If {
cond: Expr::lt(Expr::var("i"), Expr::u32(99)),
then: vec![store("buf", Expr::u32(0), Expr::u32(1))],
otherwise: vec![],
}],
}];
let result = LoopVarRangeFoldPass::transform(program(entry));
assert!(!result.changed);
}
#[test]
fn analyze_skips_program_without_loop() {
let entry = vec![store("buf", Expr::u32(0), Expr::u32(1))];
match LoopVarRangeFoldPass::analyze(&program(entry)) {
PassAnalysis::SKIP => {}
other => panic!("expected SKIP, got {other:?}"),
}
}
#[test]
fn folds_inside_nested_loop() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![Node::Loop {
var: Ident::from("j"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![Node::If {
cond: Expr::lt(Expr::var("j"), Expr::u32(4)),
then: vec![store("buf", Expr::var("j"), Expr::u32(1))],
otherwise: vec![],
}],
}],
}];
let result = LoopVarRangeFoldPass::transform(program(entry));
assert!(
result.changed,
"inner Lt(j, 4) is always true for j in [0,4)"
);
assert_eq!(count_ifs(result.program.entry()), 0);
}
}