use crate::ir::{Expr, Ident, Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
use std::sync::Arc;
pub const DEFAULT_STRIP_MINE_TILE: u32 = 8;
#[derive(Debug, Default)]
#[vyre_pass(
name = "loop_strip_mine",
requires = ["const_fold"],
invalidates = ["loop_unroll", "vectorization"]
)]
pub struct LoopStripMine;
impl LoopStripMine {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|node| node_map::any_descendant(node, &mut is_strip_mine_eligible))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry = program
.into_entry_vec()
.into_iter()
.map(|node| rewrite_node(node, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn rewrite_node(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| rewrite_node(child, changed));
let recursed = node_map::map_body(recursed, &mut |body| {
body.into_iter()
.map(|child| rewrite_node(child, changed))
.collect()
});
strip_mine_if_eligible(recursed, changed)
}
fn strip_mine_if_eligible(node: Node, changed: &mut bool) -> Node {
let Node::Loop {
var,
from,
to,
body,
} = node
else {
return node;
};
let Some((from_lit, to_lit)) = literal_bounds(&from, &to) else {
return Node::Loop {
var,
from,
to,
body,
};
};
let Some(trip_count) = to_lit.checked_sub(from_lit) else {
return Node::Loop {
var,
from,
to,
body,
};
};
if trip_count < DEFAULT_STRIP_MINE_TILE.saturating_mul(2) || body_writes_loop_var(&body, &var) {
return Node::Loop {
var,
from,
to,
body,
};
}
let names = names_in_nodes(&body);
let outer_var = fresh_ident(&var, "tile", &names);
let lane_var = fresh_ident(&var, "lane", &names);
let tile_count = trip_count.div_ceil(DEFAULT_STRIP_MINE_TILE);
let original_index = Expr::add(
Expr::u32(from_lit),
Expr::add(
Expr::mul(
Expr::var(outer_var.as_str()),
Expr::u32(DEFAULT_STRIP_MINE_TILE),
),
Expr::var(lane_var.as_str()),
),
);
let tiled_body = substitute_nodes(&body, &var, &original_index);
let guarded_body = vec![Node::if_then(
Expr::lt(original_index, Expr::u32(to_lit)),
tiled_body,
)];
*changed = true;
Node::loop_for(
outer_var,
Expr::u32(0),
Expr::u32(tile_count),
vec![Node::loop_for(
lane_var,
Expr::u32(0),
Expr::u32(DEFAULT_STRIP_MINE_TILE),
guarded_body,
)],
)
}
fn literal_bounds(from: &Expr, to: &Expr) -> Option<(u32, u32)> {
let from = literal_u32(from)?;
let to = literal_u32(to)?;
Some((from, to))
}
fn literal_u32(expr: &Expr) -> Option<u32> {
match expr {
Expr::LitU32(value) => Some(*value),
Expr::LitI32(value) => u32::try_from(*value).ok(),
_ => None,
}
}
fn is_strip_mine_eligible(node: &Node) -> bool {
let Node::Loop {
var,
from,
to,
body,
} = node
else {
return false;
};
let Some((from, to)) = literal_bounds(from, to) else {
return false;
};
matches!(to.checked_sub(from), Some(n) if n >= DEFAULT_STRIP_MINE_TILE * 2)
&& !body_writes_loop_var(body, var)
}
fn fresh_ident(base: &Ident, suffix: &str, used: &[Ident]) -> Ident {
for ordinal in 0.. {
let candidate = if ordinal == 0 {
format!("{}_{}", base.as_str(), suffix)
} else {
format!("{}_{}_{}", base.as_str(), suffix, ordinal)
};
if used.iter().all(|name| name.as_str() != candidate) {
return Ident::from(candidate.as_str());
}
}
unreachable!("unbounded ordinal search must return before overflow")
}
fn names_in_nodes(nodes: &[Node]) -> Vec<Ident> {
let mut out = Vec::new();
collect_names(nodes, &mut out);
out
}
fn collect_names(nodes: &[Node], out: &mut Vec<Ident>) {
for node in nodes {
match node {
Node::Let { name, value } | Node::Assign { name, value } => {
out.push(name.clone());
collect_names_in_expr(value, out);
}
Node::Store { index, value, .. } => {
collect_names_in_expr(index, out);
collect_names_in_expr(value, out);
}
Node::If {
cond,
then,
otherwise,
} => {
collect_names_in_expr(cond, out);
collect_names(then, out);
collect_names(otherwise, out);
}
Node::Loop {
var,
from,
to,
body,
} => {
out.push(var.clone());
collect_names_in_expr(from, out);
collect_names_in_expr(to, out);
collect_names(body, out);
}
Node::Block(body) => collect_names(body, out),
Node::Region { body, .. } => collect_names(body, out),
Node::AsyncLoad { offset, size, .. } | Node::AsyncStore { offset, size, .. } => {
collect_names_in_expr(offset, out);
collect_names_in_expr(size, out);
}
Node::Trap { address, .. } => collect_names_in_expr(address, out),
Node::IndirectDispatch { .. }
| Node::AsyncWait { .. }
| Node::Resume { .. }
| Node::Return
| Node::Barrier { .. }
| Node::Opaque(_) => {}
}
}
}
fn collect_names_in_expr(expr: &Expr, out: &mut Vec<Ident>) {
match expr {
Expr::Var(name) => out.push(name.clone()),
Expr::Load { index, .. } | Expr::UnOp { operand: index, .. } => {
collect_names_in_expr(index, out);
}
Expr::BinOp { left, right, .. } => {
collect_names_in_expr(left, out);
collect_names_in_expr(right, out);
}
Expr::Call { args, .. } => {
for arg in args {
collect_names_in_expr(arg, out);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
collect_names_in_expr(cond, out);
collect_names_in_expr(true_val, out);
collect_names_in_expr(false_val, out);
}
Expr::Cast { value, .. } => collect_names_in_expr(value, out),
Expr::Fma { a, b, c } => {
collect_names_in_expr(a, out);
collect_names_in_expr(b, out);
collect_names_in_expr(c, out);
}
Expr::Atomic {
index,
expected,
value,
..
} => {
collect_names_in_expr(index, out);
if let Some(expected) = expected {
collect_names_in_expr(expected, out);
}
collect_names_in_expr(value, out);
}
Expr::SubgroupBallot { cond } => collect_names_in_expr(cond, out),
Expr::SubgroupShuffle { value, lane } => {
collect_names_in_expr(value, out);
collect_names_in_expr(lane, out);
}
Expr::SubgroupAdd { value } => collect_names_in_expr(value, out),
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::Opaque(_) => {}
}
}
fn body_writes_loop_var(nodes: &[Node], var: &Ident) -> bool {
nodes.iter().any(|node| match node {
Node::Let { name, .. } | Node::Assign { name, .. } => name == var,
Node::If {
then, otherwise, ..
} => body_writes_loop_var(then, var) || body_writes_loop_var(otherwise, var),
Node::Loop {
var: inner, body, ..
} => inner != var && body_writes_loop_var(body, var),
Node::Block(body) => body_writes_loop_var(body, var),
Node::Region { body, .. } => body_writes_loop_var(body, var),
_ => false,
})
}
fn substitute_nodes(nodes: &[Node], var: &Ident, replacement: &Expr) -> Vec<Node> {
nodes
.iter()
.map(|node| substitute_node(node, var, replacement))
.collect()
}
fn substitute_node(node: &Node, var: &Ident, replacement: &Expr) -> Node {
match node {
Node::Let { name, value } => Node::let_bind(name, substitute_expr(value, var, replacement)),
Node::Assign { name, value } => {
Node::assign(name, substitute_expr(value, var, replacement))
}
Node::Store {
buffer,
index,
value,
} => Node::store(
buffer,
substitute_expr(index, var, replacement),
substitute_expr(value, var, replacement),
),
Node::If {
cond,
then,
otherwise,
} => Node::if_then_else(
substitute_expr(cond, var, replacement),
substitute_nodes(then, var, replacement),
substitute_nodes(otherwise, var, replacement),
),
Node::Loop {
var: inner,
from,
to,
body,
} => {
let from = substitute_expr(from, var, replacement);
let to = substitute_expr(to, var, replacement);
let body = if inner == var {
body.clone()
} else {
substitute_nodes(body, var, replacement)
};
Node::loop_for(inner, from, to, body)
}
Node::Block(body) => Node::block(substitute_nodes(body, var, replacement)),
Node::Region {
generator,
source_region,
body,
} => Node::Region {
generator: generator.clone(),
source_region: source_region.clone(),
body: Arc::new(substitute_nodes(body, var, replacement)),
},
Node::AsyncLoad {
source,
destination,
offset,
size,
tag,
} => Node::AsyncLoad {
source: source.clone(),
destination: destination.clone(),
offset: Box::new(substitute_expr(offset, var, replacement)),
size: Box::new(substitute_expr(size, var, replacement)),
tag: tag.clone(),
},
Node::AsyncStore {
source,
destination,
offset,
size,
tag,
} => Node::AsyncStore {
source: source.clone(),
destination: destination.clone(),
offset: Box::new(substitute_expr(offset, var, replacement)),
size: Box::new(substitute_expr(size, var, replacement)),
tag: tag.clone(),
},
Node::Trap { address, tag } => Node::Trap {
address: Box::new(substitute_expr(address, var, replacement)),
tag: tag.clone(),
},
Node::IndirectDispatch { .. }
| Node::AsyncWait { .. }
| Node::Resume { .. }
| Node::Return
| Node::Barrier { .. }
| Node::Opaque(_) => node.clone(),
}
}
fn substitute_expr(expr: &Expr, var: &Ident, replacement: &Expr) -> Expr {
match expr {
Expr::Var(name) if name == var => replacement.clone(),
Expr::Load { buffer, index } => {
Expr::load(buffer, substitute_expr(index, var, replacement))
}
Expr::BinOp { op, left, right } => Expr::BinOp {
op: *op,
left: Box::new(substitute_expr(left, var, replacement)),
right: Box::new(substitute_expr(right, var, replacement)),
},
Expr::UnOp { op, operand } => Expr::UnOp {
op: op.clone(),
operand: Box::new(substitute_expr(operand, var, replacement)),
},
Expr::Call { op_id, args } => Expr::call(
op_id,
args.iter()
.map(|arg| substitute_expr(arg, var, replacement))
.collect(),
),
Expr::Select {
cond,
true_val,
false_val,
} => Expr::select(
substitute_expr(cond, var, replacement),
substitute_expr(true_val, var, replacement),
substitute_expr(false_val, var, replacement),
),
Expr::Cast { target, value } => {
Expr::cast(target.clone(), substitute_expr(value, var, replacement))
}
Expr::Fma { a, b, c } => Expr::fma(
substitute_expr(a, var, replacement),
substitute_expr(b, var, replacement),
substitute_expr(c, var, replacement),
),
Expr::Atomic {
op,
buffer,
index,
expected,
value,
ordering,
} => Expr::Atomic {
op: *op,
buffer: buffer.clone(),
index: Box::new(substitute_expr(index, var, replacement)),
expected: expected
.as_ref()
.map(|expr| Box::new(substitute_expr(expr, var, replacement))),
value: Box::new(substitute_expr(value, var, replacement)),
ordering: *ordering,
},
Expr::SubgroupBallot { cond } => {
Expr::subgroup_ballot(substitute_expr(cond, var, replacement))
}
Expr::SubgroupShuffle { value, lane } => Expr::subgroup_shuffle(
substitute_expr(value, var, replacement),
substitute_expr(lane, var, replacement),
),
Expr::SubgroupAdd { value } => Expr::subgroup_add(substitute_expr(value, var, replacement)),
_ => expr.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType};
fn buf() -> BufferDecl {
BufferDecl::storage("out", 0, BufferAccess::ReadWrite, DataType::U32).with_count(64)
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
#[test]
fn strip_mines_large_literal_loop() {
let result = LoopStripMine::transform(program(vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(32),
vec![Node::store(
"out",
Expr::var("i"),
Expr::add(Expr::var("i"), Expr::u32(1)),
)],
)]));
assert!(result.changed);
let entry = crate::test_util::region_body(&result.program);
let Node::Loop {
var: outer,
from,
to,
body,
} = &entry[0]
else {
panic!("expected outer loop");
};
assert_eq!(outer.as_str(), "i_tile");
assert_eq!(from, &Expr::u32(0));
assert_eq!(to, &Expr::u32(4));
let Node::Loop {
var: lane,
from: lane_from,
to: lane_to,
body: lane_body,
} = &body[0]
else {
panic!("expected inner lane loop");
};
assert_eq!(lane.as_str(), "i_lane");
assert_eq!(lane_from, &Expr::u32(0));
assert_eq!(lane_to, &Expr::u32(DEFAULT_STRIP_MINE_TILE));
assert!(matches!(&lane_body[0], Node::If { .. }));
}
#[test]
fn preserves_non_zero_lower_bound_in_index_expression() {
let result = LoopStripMine::transform(program(vec![Node::loop_for(
"i",
Expr::u32(16),
Expr::u32(48),
vec![Node::store("out", Expr::var("i"), Expr::var("i"))],
)]));
assert!(result.changed);
let entry = crate::test_util::region_body(&result.program);
let Node::Loop { body, .. } = &entry[0] else {
panic!("expected outer loop");
};
let Node::Loop { body: inner, .. } = &body[0] else {
panic!("expected inner loop");
};
let Node::If { then, .. } = &inner[0] else {
panic!("expected tail guard");
};
let Node::Store { index, value, .. } = &then[0] else {
panic!("expected store in guarded body");
};
let expected = Expr::add(
Expr::u32(16),
Expr::add(
Expr::mul(Expr::var("i_tile"), Expr::u32(DEFAULT_STRIP_MINE_TILE)),
Expr::var("i_lane"),
),
);
assert_eq!(index, &expected);
assert_eq!(value, &expected);
}
#[test]
fn skips_small_loops_that_unroll_directly() {
let result = LoopStripMine::transform(program(vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(DEFAULT_STRIP_MINE_TILE),
vec![Node::store("out", Expr::var("i"), Expr::u32(1))],
)]));
assert!(!result.changed);
}
#[test]
fn skips_runtime_bounds() {
let result = LoopStripMine::transform(program(vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::buf_len("out"),
vec![Node::store("out", Expr::var("i"), Expr::u32(1))],
)]));
assert!(!result.changed);
}
#[test]
fn skips_body_that_rebinds_loop_var() {
let result = LoopStripMine::transform(program(vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(32),
vec![Node::let_bind("i", Expr::u32(7))],
)]));
assert!(!result.changed);
}
#[test]
fn freshens_generated_names_when_body_already_uses_defaults() {
let result = LoopStripMine::transform(program(vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(32),
vec![
Node::let_bind("i_tile", Expr::u32(0)),
Node::let_bind("i_lane", Expr::u32(0)),
Node::store("out", Expr::var("i"), Expr::u32(1)),
],
)]));
assert!(result.changed);
let entry = crate::test_util::region_body(&result.program);
let Node::Loop { var, body, .. } = &entry[0] else {
panic!("expected outer loop");
};
assert_eq!(var.as_str(), "i_tile_1");
let Node::Loop { var: lane, .. } = &body[0] else {
panic!("expected inner loop");
};
assert_eq!(lane.as_str(), "i_lane_1");
}
#[test]
fn analyze_skips_without_large_loop_and_runs_with_large_loop() {
assert_eq!(
LoopStripMine::analyze(&program(vec![Node::store(
"out",
Expr::u32(0),
Expr::u32(1)
)])),
PassAnalysis::SKIP
);
assert_eq!(
LoopStripMine::analyze(&program(vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(32),
vec![Node::store("out", Expr::var("i"), Expr::u32(1))],
)])),
PassAnalysis::RUN
);
}
}