use crate::ir::{BinOp, Expr, Ident, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
use rustc_hash::FxHashSet;
#[derive(Debug, Default)]
#[vyre_pass(
name = "loop_lower_bound_normalize",
requires = ["const_fold"],
invalidates = ["loop_unroll", "loop_strip_mine"]
)]
pub struct LoopLowerBoundNormalize;
impl LoopLowerBoundNormalize {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_normalizable_loop))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.map(|n| recurse(n, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn recurse(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| recurse(child, changed));
match recursed {
Node::Loop {
var,
from,
to,
body,
} => {
let (lo, hi) = match (&from, &to) {
(Expr::LitU32(lo), Expr::LitU32(hi)) if *lo > 0 && *hi >= *lo => (*lo, *hi),
_ => {
return Node::Loop {
var,
from,
to,
body,
};
}
};
if body_rebinds_var(&body, &var) {
return Node::Loop {
var,
from,
to,
body,
};
}
let fresh = freshen(&var, &body);
let offset = Expr::u32(lo);
let new_body: Vec<Node> = body
.into_iter()
.map(|n| substitute_var_in_node(n, &var, &fresh, &offset))
.collect();
*changed = true;
Node::Loop {
var: fresh,
from: Expr::u32(0),
to: Expr::u32(hi - lo),
body: new_body,
}
}
other => other,
}
}
fn is_normalizable_loop(node: &Node) -> bool {
if let Node::Loop {
var,
from,
to,
body,
} = node
{
match (from, to) {
(Expr::LitU32(lo), Expr::LitU32(hi)) if *lo > 0 && *hi >= *lo => {}
_ => return false,
}
!body_rebinds_var(body, var)
} else {
false
}
}
fn body_rebinds_var(body: &[Node], var: &Ident) -> bool {
fn check(node: &Node, var: &Ident) -> bool {
match node {
Node::Assign { name, .. } => name == var,
Node::Let { name, .. } => name == var,
Node::Loop {
var: inner, body, ..
} => {
if inner == var {
return true;
}
body.iter().any(|n| check(n, var))
}
Node::If {
then, otherwise, ..
} => then.iter().any(|n| check(n, var)) || otherwise.iter().any(|n| check(n, var)),
Node::Block(body) => body.iter().any(|n| check(n, var)),
Node::Region { body, .. } => body.iter().any(|n| check(n, var)),
_ => false,
}
}
body.iter().any(|n| check(n, var))
}
fn freshen(base: &Ident, body: &[Node]) -> Ident {
let mut used: FxHashSet<Ident> = FxHashSet::default();
collect_all_names(body, &mut used);
used.insert(base.clone());
let mut counter = 0u32;
loop {
let candidate = Ident::from(format!("{}__norm_{counter}", base.as_str()));
if !used.contains(&candidate) {
return candidate;
}
counter += 1;
}
}
fn collect_all_names(body: &[Node], out: &mut FxHashSet<Ident>) {
for node in body {
match node {
Node::Let { name, .. } | Node::Assign { name, .. } => {
out.insert(name.clone());
}
Node::Loop { var, body, .. } => {
out.insert(var.clone());
collect_all_names(body, out);
}
Node::If {
then, otherwise, ..
} => {
collect_all_names(then, out);
collect_all_names(otherwise, out);
}
Node::Block(body) => collect_all_names(body, out),
Node::Region { body, .. } => collect_all_names(body, out),
_ => {}
}
}
}
fn substitute_var_in_node(node: Node, from: &Ident, to: &Ident, offset: &Expr) -> Node {
match node {
Node::Let { name, value } => Node::Let {
name,
value: substitute_var_in_expr(value, from, to, offset),
},
Node::Assign { name, value } => Node::Assign {
name,
value: substitute_var_in_expr(value, from, to, offset),
},
Node::Store {
buffer,
index,
value,
} => Node::Store {
buffer,
index: substitute_var_in_expr(index, from, to, offset),
value: substitute_var_in_expr(value, from, to, offset),
},
Node::If {
cond,
then,
otherwise,
} => Node::If {
cond: substitute_var_in_expr(cond, from, to, offset),
then: then
.into_iter()
.map(|n| substitute_var_in_node(n, from, to, offset))
.collect(),
otherwise: otherwise
.into_iter()
.map(|n| substitute_var_in_node(n, from, to, offset))
.collect(),
},
Node::Loop {
var,
from: lo,
to: hi,
body,
} => Node::Loop {
var,
from: substitute_var_in_expr(lo, from, to, offset),
to: substitute_var_in_expr(hi, from, to, offset),
body: body
.into_iter()
.map(|n| substitute_var_in_node(n, from, to, offset))
.collect(),
},
Node::Block(body) => Node::Block(
body.into_iter()
.map(|n| substitute_var_in_node(n, from, to, offset))
.collect(),
),
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match std::sync::Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
Node::Region {
generator,
source_region,
body: std::sync::Arc::new(
body_vec
.into_iter()
.map(|n| substitute_var_in_node(n, from, to, offset))
.collect(),
),
}
}
Node::AsyncLoad {
source,
destination,
offset: o,
size,
tag,
} => Node::AsyncLoad {
source,
destination,
tag,
offset: Box::new(substitute_var_in_expr(*o, from, to, offset)),
size: Box::new(substitute_var_in_expr(*size, from, to, offset)),
},
Node::AsyncStore {
source,
destination,
offset: o,
size,
tag,
} => Node::AsyncStore {
source,
destination,
tag,
offset: Box::new(substitute_var_in_expr(*o, from, to, offset)),
size: Box::new(substitute_var_in_expr(*size, from, to, offset)),
},
Node::Trap { address, tag } => Node::Trap {
address: Box::new(substitute_var_in_expr(*address, from, to, offset)),
tag,
},
other => other,
}
}
fn substitute_var_in_expr(expr: Expr, from: &Ident, to: &Ident, offset: &Expr) -> Expr {
match expr {
Expr::Var(ref name) if name == from => Expr::BinOp {
op: BinOp::Add,
left: Box::new(Expr::Var(to.clone())),
right: Box::new(offset.clone()),
},
Expr::Var(_)
| Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::Opaque(_) => expr,
Expr::Load { buffer, index } => Expr::Load {
buffer,
index: Box::new(substitute_var_in_expr(*index, from, to, offset)),
},
Expr::BinOp { op, left, right } => Expr::BinOp {
op,
left: Box::new(substitute_var_in_expr(*left, from, to, offset)),
right: Box::new(substitute_var_in_expr(*right, from, to, offset)),
},
Expr::UnOp { op, operand } => Expr::UnOp {
op,
operand: Box::new(substitute_var_in_expr(*operand, from, to, offset)),
},
Expr::Call { op_id, args } => Expr::Call {
op_id,
args: args
.into_iter()
.map(|a| substitute_var_in_expr(a, from, to, offset))
.collect(),
},
Expr::Select {
cond,
true_val,
false_val,
} => Expr::Select {
cond: Box::new(substitute_var_in_expr(*cond, from, to, offset)),
true_val: Box::new(substitute_var_in_expr(*true_val, from, to, offset)),
false_val: Box::new(substitute_var_in_expr(*false_val, from, to, offset)),
},
Expr::Cast { target, value } => Expr::Cast {
target,
value: Box::new(substitute_var_in_expr(*value, from, to, offset)),
},
Expr::Fma { a, b, c } => Expr::Fma {
a: Box::new(substitute_var_in_expr(*a, from, to, offset)),
b: Box::new(substitute_var_in_expr(*b, from, to, offset)),
c: Box::new(substitute_var_in_expr(*c, from, to, offset)),
},
Expr::Atomic {
op,
buffer,
index,
expected,
value,
ordering,
} => Expr::Atomic {
op,
buffer,
index: Box::new(substitute_var_in_expr(*index, from, to, offset)),
expected: expected.map(|e| Box::new(substitute_var_in_expr(*e, from, to, offset))),
value: Box::new(substitute_var_in_expr(*value, from, to, offset)),
ordering,
},
Expr::SubgroupBallot { cond } => Expr::SubgroupBallot {
cond: Box::new(substitute_var_in_expr(*cond, from, to, offset)),
},
Expr::SubgroupShuffle { value, lane } => Expr::SubgroupShuffle {
value: Box::new(substitute_var_in_expr(*value, from, to, offset)),
lane: Box::new(substitute_var_in_expr(*lane, from, to, offset)),
},
Expr::SubgroupAdd { value } => Expr::SubgroupAdd {
value: Box::new(substitute_var_in_expr(*value, from, to, offset)),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(16)
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn find_loop(nodes: &[Node]) -> Option<&Node> {
for n in nodes {
if matches!(n, Node::Loop { .. }) {
return Some(n);
}
match n {
Node::Block(body) => {
if let Some(found) = find_loop(body) {
return Some(found);
}
}
Node::Region { body, .. } => {
if let Some(found) = find_loop(body.as_ref()) {
return Some(found);
}
}
Node::If {
then, otherwise, ..
} => {
if let Some(found) = find_loop(then) {
return Some(found);
}
if let Some(found) = find_loop(otherwise) {
return Some(found);
}
}
_ => {}
}
}
None
}
#[test]
fn rewrites_positive_lower_bound_to_zero() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(4),
to: Expr::u32(12),
body: vec![Node::store("buf", Expr::var("i"), Expr::u32(1))],
}];
let result = LoopLowerBoundNormalize::transform(program(entry));
assert!(result.changed, "loop with from=4 must normalize");
let loop_node = find_loop(result.program.entry()).expect("loop present");
match loop_node {
Node::Loop {
var,
from,
to,
body,
} => {
assert_ne!(var.as_str(), "i", "var must be freshened");
assert_eq!(*from, Expr::LitU32(0), "from must be 0");
assert_eq!(*to, Expr::LitU32(8), "to must be original (12) - lower (4)");
match &body[0] {
Node::Store { index, .. } => match index {
Expr::BinOp { op, left, right } => {
assert_eq!(*op, BinOp::Add);
assert!(
matches!(left.as_ref(), Expr::Var(name) if name.as_str() == var.as_str())
);
assert_eq!(*right.as_ref(), Expr::LitU32(4));
}
other => panic!("expected Var(i') + 4, got {other:?}"),
},
other => panic!("expected Store, got {other:?}"),
}
}
other => panic!("expected Loop, got {other:?}"),
}
}
#[test]
fn keeps_loop_with_zero_lower_bound() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(8),
body: vec![Node::store("buf", Expr::var("i"), Expr::u32(1))],
}];
let result = LoopLowerBoundNormalize::transform(program(entry));
assert!(!result.changed, "from=0 is already canonical");
}
#[test]
fn keeps_loop_with_runtime_lower_bound() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::var("k"),
to: Expr::u32(10),
body: vec![Node::store("buf", Expr::var("i"), Expr::u32(1))],
}];
let result = LoopLowerBoundNormalize::transform(program(entry));
assert!(!result.changed, "runtime from must skip");
}
#[test]
fn keeps_loop_with_runtime_upper_bound() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(2),
to: Expr::var("n"),
body: vec![Node::store("buf", Expr::var("i"), Expr::u32(1))],
}];
let result = LoopLowerBoundNormalize::transform(program(entry));
assert!(!result.changed, "runtime to must skip");
}
#[test]
fn keeps_loop_with_inverted_bounds() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(10),
to: Expr::u32(4),
body: vec![Node::store("buf", Expr::var("i"), Expr::u32(1))],
}];
let result = LoopLowerBoundNormalize::transform(program(entry));
assert!(
!result.changed,
"inverted bounds must be left for trip-zero pass"
);
}
#[test]
fn keeps_loop_when_body_assigns_loop_var() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(2),
to: Expr::u32(10),
body: vec![
Node::Assign {
name: Ident::from("i"),
value: Expr::u32(99),
},
Node::store("buf", Expr::var("i"), Expr::u32(1)),
],
}];
let result = LoopLowerBoundNormalize::transform(program(entry));
assert!(!result.changed, "Assign to loop var must block rewrite");
}
#[test]
fn keeps_loop_when_nested_loop_shadows_var() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(2),
to: Expr::u32(10),
body: vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![],
}],
}];
let result = LoopLowerBoundNormalize::transform(program(entry));
assert!(!result.changed, "shadowing nested Loop must block rewrite");
}
#[test]
fn normalizes_nested_loop_independently() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![Node::Loop {
var: Ident::from("j"),
from: Expr::u32(5),
to: Expr::u32(10),
body: vec![Node::store("buf", Expr::var("j"), Expr::u32(1))],
}],
}];
let result = LoopLowerBoundNormalize::transform(program(entry));
assert!(result.changed, "inner loop must normalize");
}
#[test]
fn analyze_skips_program_with_only_canonical_loops() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(8),
body: vec![],
}];
match LoopLowerBoundNormalize::analyze(&program(entry)) {
PassAnalysis::SKIP => {}
other => panic!("expected SKIP, got {other:?}"),
}
}
}