use crate::ir::{BinOp, Expr, Ident, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "loop_peel",
requires = ["const_fold"],
invalidates = []
)]
pub struct LoopPeelPass;
impl LoopPeelPass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_peelable_loop))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.flat_map(|node| peel_node(node, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn peel_node(node: Node, changed: &mut bool) -> Vec<Node> {
let recursed = node_map::map_children(node, &mut |child| {
let peeled = peel_node(child, changed);
if peeled.len() == 1 {
peeled.into_iter().next().unwrap_or(Node::Block(Vec::new()))
} else {
Node::Block(peeled)
}
});
if let Node::Loop {
ref var,
ref from,
ref to,
ref body,
} = recursed
{
if let Some((peeled_body, rest_body)) = try_peel(var, from, to, body) {
*changed = true;
let remaining = Node::Loop {
var: var.clone(),
from: Expr::u32(1),
to: to.clone(),
body: rest_body,
};
return vec![Node::Block(peeled_body), remaining];
}
}
vec![recursed]
}
fn try_peel(var: &Ident, from: &Expr, to: &Expr, body: &[Node]) -> Option<(Vec<Node>, Vec<Node>)> {
let Expr::LitU32(0) = from else { return None };
let Expr::LitU32(n) = to else { return None };
if *n <= 1 {
return None;
}
let first = body.first()?;
let Node::If {
cond,
then,
otherwise,
} = first
else {
return None;
};
if !otherwise.is_empty() {
return None;
}
let Expr::BinOp {
op: BinOp::Eq,
left,
right,
} = cond
else {
return None;
};
let matches_var = match (left.as_ref(), right.as_ref()) {
(Expr::Var(name), Expr::LitU32(0)) if name == var => true,
(Expr::LitU32(0), Expr::Var(name)) if name == var => true,
_ => false,
};
if !matches_var {
return None;
}
if assigns_to_name(then, var) {
return None;
}
let peeled_body = then.clone();
let rest_body: Vec<Node> = body[1..].to_vec();
Some((peeled_body, rest_body))
}
fn assigns_to_name(nodes: &[Node], name: &Ident) -> bool {
for node in nodes {
match node {
Node::Assign {
name: assign_name, ..
} if assign_name == name => return true,
Node::If {
then, otherwise, ..
} => {
if assigns_to_name(then, name) || assigns_to_name(otherwise, name) {
return true;
}
}
Node::Loop { body, .. } | Node::Block(body) => {
if assigns_to_name(body, name) {
return true;
}
}
Node::Region { body, .. } => {
if assigns_to_name(body, name) {
return true;
}
}
_ => {}
}
}
false
}
fn is_peelable_loop(node: &Node) -> bool {
if let Node::Loop {
var,
from,
to,
body,
} = node
{
try_peel(var, from, to, body).is_some()
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Ident, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn count_loops(node: &Node) -> usize {
match node {
Node::Loop { body, .. } => 1 + body.iter().map(count_loops).sum::<usize>(),
Node::If {
then, otherwise, ..
} => {
then.iter().map(count_loops).sum::<usize>()
+ otherwise.iter().map(count_loops).sum::<usize>()
}
Node::Block(body) => body.iter().map(count_loops).sum(),
Node::Region { body, .. } => body.iter().map(count_loops).sum(),
_ => 0,
}
}
#[test]
fn peel_fires_for_guarded_first_iteration() {
let guard = Node::If {
cond: Expr::eq(Expr::var("i"), Expr::u32(0)),
then: vec![Node::store("buf", Expr::u32(0), Expr::u32(99))],
otherwise: vec![],
};
let rest = Node::store("buf", Expr::var("i"), Expr::u32(7));
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(10),
body: vec![guard, rest],
}];
let program = program_with_entry(entry);
let result = LoopPeelPass::transform(program);
assert!(result.changed, "peeling must fire");
let loops: usize = result.program.entry().iter().map(count_loops).sum();
assert!(loops >= 1, "remaining loop must exist");
}
#[test]
fn peel_skips_when_from_is_not_zero() {
let guard = Node::If {
cond: Expr::eq(Expr::var("i"), Expr::u32(0)),
then: vec![Node::store("buf", Expr::u32(0), Expr::u32(99))],
otherwise: vec![],
};
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(1), to: Expr::u32(10),
body: vec![guard],
}];
let program = program_with_entry(entry);
let result = LoopPeelPass::transform(program);
assert!(!result.changed, "peeling must not fire when from != 0");
}
#[test]
fn peel_skips_when_to_is_not_literal() {
let guard = Node::If {
cond: Expr::eq(Expr::var("i"), Expr::u32(0)),
then: vec![Node::store("buf", Expr::u32(0), Expr::u32(99))],
otherwise: vec![],
};
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::var("n"), body: vec![guard],
}];
let program = program_with_entry(entry);
let result = LoopPeelPass::transform(program);
assert!(!result.changed, "peeling must not fire when to is Var");
}
#[test]
fn peel_skips_when_first_node_is_not_matching_if() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(10),
body: vec![Node::store("buf", Expr::var("i"), Expr::u32(7))],
}];
let program = program_with_entry(entry);
let result = LoopPeelPass::transform(program);
assert!(!result.changed, "peeling must not fire without matching If");
}
#[test]
fn peel_skips_when_peeled_body_assigns_loop_var() {
let guard = Node::If {
cond: Expr::eq(Expr::var("i"), Expr::u32(0)),
then: vec![Node::assign("i", Expr::u32(42))], otherwise: vec![],
};
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(10),
body: vec![guard],
}];
let program = program_with_entry(entry);
let result = LoopPeelPass::transform(program);
assert!(
!result.changed,
"peeling must not fire when peeled body assigns to loop var"
);
}
}