use crate::ir::{Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "tail_duplication",
requires = [],
invalidates = []
)]
pub struct TailDuplicationPass;
impl TailDuplicationPass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_tail_candidate))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.flat_map(|node| hoist_tail(node, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn hoist_tail(node: Node, changed: &mut bool) -> Vec<Node> {
let recursed = node_map::map_children(node, &mut |child| {
let hoisted = hoist_tail(child, changed);
if hoisted.len() == 1 {
hoisted
.into_iter()
.next()
.unwrap_or(Node::Block(Vec::new()))
} else {
Node::Block(hoisted)
}
});
if let Node::If {
cond,
then,
otherwise,
} = recursed
{
if let Some((new_then, new_otherwise, tail)) = try_extract_tail(&then, &otherwise) {
*changed = true;
let new_if = Node::If {
cond,
then: new_then,
otherwise: new_otherwise,
};
return vec![new_if, tail];
}
return vec![Node::If {
cond,
then,
otherwise,
}];
}
vec![recursed]
}
fn try_extract_tail(then: &[Node], otherwise: &[Node]) -> Option<(Vec<Node>, Vec<Node>, Node)> {
if then.is_empty() || otherwise.is_empty() {
return None;
}
let then_tail = then.last()?;
let otherwise_tail = otherwise.last()?;
if then_tail != otherwise_tail {
return None;
}
if !node_is_observably_free(then_tail) {
return None;
}
let tail = then_tail.clone();
let new_then = then[..then.len() - 1].to_vec();
let new_otherwise = otherwise[..otherwise.len() - 1].to_vec();
Some((new_then, new_otherwise, tail))
}
fn node_is_observably_free(node: &Node) -> bool {
match node {
Node::Let { .. } => true,
Node::Block(body) => body.iter().all(node_is_observably_free),
Node::Store { .. }
| Node::Assign { .. }
| Node::If { .. }
| Node::Loop { .. }
| Node::Region { .. }
| Node::Return
| Node::Barrier { .. }
| Node::IndirectDispatch { .. }
| Node::AsyncLoad { .. }
| Node::AsyncStore { .. }
| Node::AsyncWait { .. }
| Node::Trap { .. }
| Node::Resume { .. }
| Node::Opaque(_) => false,
}
}
fn is_tail_candidate(node: &Node) -> bool {
if let Node::If {
then, otherwise, ..
} = node
{
try_extract_tail(then, otherwise).is_some()
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
#[test]
fn hoists_common_let_tail() {
let common = Node::let_bind("x", Expr::u32(42));
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
common.clone(),
],
otherwise: vec![Node::store("buf", Expr::u32(0), Expr::u32(2)), common],
}];
let program = program_with_entry(entry);
let result = TailDuplicationPass::transform(program);
assert!(result.changed, "common tail must be hoisted");
}
#[test]
fn keeps_when_tails_differ() {
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![Node::let_bind("x", Expr::u32(1))],
otherwise: vec![Node::let_bind("x", Expr::u32(2))],
}];
let program = program_with_entry(entry);
let result = TailDuplicationPass::transform(program);
assert!(!result.changed, "must not hoist when tails are different");
}
#[test]
fn keeps_when_tail_has_side_effects() {
let common = Node::store("buf", Expr::u32(0), Expr::u32(7));
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![Node::let_bind("x", Expr::u32(1)), common.clone()],
otherwise: vec![Node::let_bind("x", Expr::u32(2)), common],
}];
let program = program_with_entry(entry);
let result = TailDuplicationPass::transform(program);
assert!(
!result.changed,
"must not hoist tail with Store (side effects)"
);
}
#[test]
fn keeps_when_tail_is_loop() {
let common = Node::Loop {
var: crate::ir::Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(5),
body: vec![],
};
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![Node::let_bind("x", Expr::u32(1)), common.clone()],
otherwise: vec![Node::let_bind("x", Expr::u32(2)), common],
}];
let program = program_with_entry(entry);
let result = TailDuplicationPass::transform(program);
assert!(!result.changed, "must not hoist Loop as tail");
}
#[test]
fn analyze_skips_program_with_no_tail_candidates() {
let entry = vec![Node::store("buf", Expr::u32(0), Expr::u32(7))];
let program = program_with_entry(entry);
assert_eq!(TailDuplicationPass::analyze(&program), PassAnalysis::SKIP);
}
#[test]
fn analyze_runs_when_tail_candidate_present() {
let common = Node::let_bind("x", Expr::u32(42));
let entry = vec![Node::If {
cond: Expr::var("c"),
then: vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
common.clone(),
],
otherwise: vec![Node::store("buf", Expr::u32(0), Expr::u32(2)), common],
}];
let program = program_with_entry(entry);
assert_eq!(TailDuplicationPass::analyze(&program), PassAnalysis::RUN);
}
}