use crate::ir::{BinOp, Expr, Ident, Node, Program};
use crate::optimizer::program_soa::ProgramFacts;
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[derive(Debug, Default)]
#[vyre_pass(
name = "loop_software_pipeline",
requires = ["const_fold"],
invalidates = ["loop_unroll", "loop_strip_mine"]
)]
pub struct LoopSoftwarePipeline;
impl LoopSoftwarePipeline {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
let facts = ProgramFacts::build(program);
if program
.entry()
.iter()
.any(|n| node_has_pipelinable_loop(n, &facts))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let facts = ProgramFacts::build(&program);
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = program
.into_entry_vec()
.into_iter()
.flat_map(|n| rewrite_node(n, &facts, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn rewrite_node(node: Node, facts: &ProgramFacts, changed: &mut bool) -> Vec<Node> {
match node {
Node::Loop {
var,
from,
to,
body,
} => {
if let Some(plan) = analyse_pipelinable(&var, &from, &to, &body, facts) {
*changed = true;
return apply_pipeline(plan);
}
vec![Node::Loop {
var,
from,
to,
body: body
.into_iter()
.flat_map(|n| rewrite_node(n, facts, changed))
.collect(),
}]
}
Node::If {
cond,
then,
otherwise,
} => vec![Node::If {
cond,
then: then
.into_iter()
.flat_map(|n| rewrite_node(n, facts, changed))
.collect(),
otherwise: otherwise
.into_iter()
.flat_map(|n| rewrite_node(n, facts, changed))
.collect(),
}],
Node::Block(body) => vec![Node::Block(
body.into_iter()
.flat_map(|n| rewrite_node(n, facts, changed))
.collect(),
)],
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match std::sync::Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
vec![Node::Region {
generator,
source_region,
body: std::sync::Arc::new(
body_vec
.into_iter()
.flat_map(|n| rewrite_node(n, facts, changed))
.collect(),
),
}]
}
other => vec![other],
}
}
struct PipelinePlan {
loop_var: Ident,
lo: u32,
hi: u32,
pipe_name: Ident,
next_name: Ident,
let_name: Ident,
buf_in: Ident,
buf_out: Ident,
store_value_template: Expr,
}
fn analyse_pipelinable(
var: &Ident,
from: &Expr,
to: &Expr,
body: &[Node],
facts: &ProgramFacts,
) -> Option<PipelinePlan> {
let (lo, hi) = match (from, to) {
(Expr::LitU32(lo), Expr::LitU32(hi)) if *hi >= lo + 2 => (*lo, *hi),
_ => return None,
};
if body.len() != 2 {
return None;
}
let (let_name, buf_in) = match &body[0] {
Node::Let { name, value } => match value {
Expr::Load { buffer, index } => match index.as_ref() {
Expr::Var(idx_var) if idx_var == var => (name.clone(), buffer.clone()),
_ => return None,
},
_ => return None,
},
_ => return None,
};
let (buf_out, store_index, store_value) = match &body[1] {
Node::Store {
buffer,
index,
value,
} => match index {
Expr::Var(idx_var) if idx_var == var => (buffer.clone(), index.clone(), value.clone()),
_ => return None,
},
_ => return None,
};
let _ = store_index;
if buf_in == buf_out {
return None;
}
if !facts.buffers_provably_distinct(buf_in.as_str(), buf_out.as_str()) {
return None;
}
if !expr_reads_only(&store_value, &let_name) {
return None;
}
let pipe_name = Ident::from(format!("__sp_{}_pipe", let_name.as_str()));
let next_name = Ident::from(format!("__sp_{}_next", let_name.as_str()));
Some(PipelinePlan {
loop_var: var.clone(),
lo,
hi,
pipe_name,
next_name,
let_name,
buf_in,
buf_out,
store_value_template: store_value,
})
}
fn apply_pipeline(plan: PipelinePlan) -> Vec<Node> {
let prologue = Node::let_bind(
plan.pipe_name.clone(),
Expr::Load {
buffer: plan.buf_in.clone(),
index: Box::new(Expr::u32(plan.lo)),
},
);
let prefetch = Node::let_bind(
plan.next_name.clone(),
Expr::Load {
buffer: plan.buf_in.clone(),
index: Box::new(Expr::BinOp {
op: BinOp::Add,
left: Box::new(Expr::Var(plan.loop_var.clone())),
right: Box::new(Expr::u32(1)),
}),
},
);
let pipe_value = substitute_var(
plan.store_value_template.clone(),
&plan.let_name,
&plan.pipe_name,
);
let store_current = Node::Store {
buffer: plan.buf_out.clone(),
index: Expr::Var(plan.loop_var.clone()),
value: pipe_value.clone(),
};
let shuffle = Node::Assign {
name: plan.pipe_name.clone(),
value: Expr::Var(plan.next_name.clone()),
};
let steady = Node::Loop {
var: plan.loop_var.clone(),
from: Expr::u32(plan.lo),
to: Expr::u32(plan.hi - 1),
body: vec![prefetch, store_current, shuffle],
};
let epilogue = Node::Store {
buffer: plan.buf_out.clone(),
index: Expr::u32(plan.hi - 1),
value: pipe_value,
};
vec![prologue, steady, epilogue]
}
fn expr_reads_only(expr: &Expr, name: &Ident) -> bool {
let mut reads_name = false;
let observable = expr_visit_check(expr, name, &mut reads_name);
observable && reads_name
}
fn expr_visit_check(expr: &Expr, name: &Ident, reads_name: &mut bool) -> bool {
match expr {
Expr::Var(n) => {
if n == name {
*reads_name = true;
}
true
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. } => true,
Expr::BinOp { left, right, .. } => {
expr_visit_check(left, name, reads_name) && expr_visit_check(right, name, reads_name)
}
Expr::UnOp { operand, .. } => expr_visit_check(operand, name, reads_name),
Expr::Select {
cond,
true_val,
false_val,
} => {
expr_visit_check(cond, name, reads_name)
&& expr_visit_check(true_val, name, reads_name)
&& expr_visit_check(false_val, name, reads_name)
}
Expr::Cast { value, .. } => expr_visit_check(value, name, reads_name),
Expr::Fma { a, b, c } => {
expr_visit_check(a, name, reads_name)
&& expr_visit_check(b, name, reads_name)
&& expr_visit_check(c, name, reads_name)
}
Expr::Load { .. }
| Expr::Atomic { .. }
| Expr::Call { .. }
| Expr::Opaque(_)
| Expr::SubgroupBallot { .. }
| Expr::SubgroupShuffle { .. }
| Expr::SubgroupAdd { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => false,
}
}
fn substitute_var(expr: Expr, from: &Ident, to: &Ident) -> Expr {
match expr {
Expr::Var(ref n) if n == from => Expr::Var(to.clone()),
Expr::Load { buffer, index } => Expr::Load {
buffer,
index: Box::new(substitute_var(*index, from, to)),
},
Expr::BinOp { op, left, right } => Expr::BinOp {
op,
left: Box::new(substitute_var(*left, from, to)),
right: Box::new(substitute_var(*right, from, to)),
},
Expr::UnOp { op, operand } => Expr::UnOp {
op,
operand: Box::new(substitute_var(*operand, from, to)),
},
Expr::Call { op_id, args } => Expr::Call {
op_id,
args: args
.into_iter()
.map(|a| substitute_var(a, from, to))
.collect(),
},
Expr::Select {
cond,
true_val,
false_val,
} => Expr::Select {
cond: Box::new(substitute_var(*cond, from, to)),
true_val: Box::new(substitute_var(*true_val, from, to)),
false_val: Box::new(substitute_var(*false_val, from, to)),
},
Expr::Cast { target, value } => Expr::Cast {
target,
value: Box::new(substitute_var(*value, from, to)),
},
Expr::Fma { a, b, c } => Expr::Fma {
a: Box::new(substitute_var(*a, from, to)),
b: Box::new(substitute_var(*b, from, to)),
c: Box::new(substitute_var(*c, from, to)),
},
other => other,
}
}
fn node_has_pipelinable_loop(node: &Node, facts: &ProgramFacts) -> bool {
match node {
Node::Loop {
var,
from,
to,
body,
} => {
analyse_pipelinable(var, from, to, body, facts).is_some()
|| body.iter().any(|n| node_has_pipelinable_loop(n, facts))
}
Node::If {
then, otherwise, ..
} => {
then.iter().any(|n| node_has_pipelinable_loop(n, facts))
|| otherwise
.iter()
.any(|n| node_has_pipelinable_loop(n, facts))
}
Node::Block(body) => body.iter().any(|n| node_has_pipelinable_loop(n, facts)),
Node::Region { body, .. } => body.iter().any(|n| node_has_pipelinable_loop(n, facts)),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Ident, Node};
fn ro(name: &str) -> BufferDecl {
BufferDecl::storage(name, 0, BufferAccess::ReadOnly, DataType::U32).with_count(16)
}
fn rw(name: &str, binding: u32) -> BufferDecl {
BufferDecl::storage(name, binding, BufferAccess::ReadWrite, DataType::U32).with_count(16)
}
fn program(buffers: Vec<BufferDecl>, entry: Vec<Node>) -> Program {
Program::wrapped(buffers, [1, 1, 1], entry)
}
fn pipelinable_loop(lo: u32, hi: u32) -> Vec<Node> {
vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(lo),
to: Expr::u32(hi),
body: vec![
Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("ro"),
index: Box::new(Expr::var("i")),
},
),
Node::store(
"rw",
Expr::var("i"),
Expr::BinOp {
op: BinOp::Add,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(1)),
},
),
],
}]
}
fn count_loops_and_stores(nodes: &[Node]) -> (usize, usize) {
let mut loops = 0;
let mut stores = 0;
for n in nodes {
match n {
Node::Loop { body, .. } => {
loops += 1;
let (l, s) = count_loops_and_stores(body);
loops += l;
stores += s;
}
Node::Store { .. } => stores += 1,
Node::Block(body) => {
let (l, s) = count_loops_and_stores(body);
loops += l;
stores += s;
}
Node::Region { body, .. } => {
let (l, s) = count_loops_and_stores(body.as_ref());
loops += l;
stores += s;
}
Node::If {
then, otherwise, ..
} => {
let (l, s) = count_loops_and_stores(then);
loops += l;
stores += s;
let (l2, s2) = count_loops_and_stores(otherwise);
loops += l2;
stores += s2;
}
_ => {}
}
}
(loops, stores)
}
#[test]
fn pipelines_simple_load_store_loop() {
let entry = pipelinable_loop(0, 8);
let prog = program(vec![ro("ro"), rw("rw", 1)], entry);
let result = LoopSoftwarePipeline::transform(prog);
assert!(result.changed, "Load-then-Store loop must pipeline");
let (loops, stores) = count_loops_and_stores(result.program.entry());
assert_eq!(loops, 1, "exactly one steady-state Loop after pipelining");
assert_eq!(stores, 2);
}
#[test]
fn keeps_loop_with_trip_count_one() {
let entry = pipelinable_loop(0, 1);
let prog = program(vec![ro("ro"), rw("rw", 1)], entry);
let result = LoopSoftwarePipeline::transform(prog);
assert!(!result.changed);
}
#[test]
fn keeps_loop_when_buffers_alias() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(8),
body: vec![
Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("rw"),
index: Box::new(Expr::var("i")),
},
),
Node::store(
"rw",
Expr::var("i"),
Expr::BinOp {
op: BinOp::Add,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::u32(1)),
},
),
],
}];
let prog = program(vec![rw("rw", 1)], entry);
let result = LoopSoftwarePipeline::transform(prog);
assert!(
!result.changed,
"self-aliasing Load+Store must not pipeline"
);
}
#[test]
fn keeps_loop_with_non_var_index() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(8),
body: vec![
Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("ro"),
index: Box::new(Expr::BinOp {
op: BinOp::Add,
left: Box::new(Expr::var("i")),
right: Box::new(Expr::u32(1)),
}),
},
),
Node::store("rw", Expr::var("i"), Expr::var("x")),
],
}];
let prog = program(vec![ro("ro"), rw("rw", 1)], entry);
let result = LoopSoftwarePipeline::transform(prog);
assert!(!result.changed);
}
#[test]
fn keeps_loop_with_three_body_stmts() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(8),
body: vec![
Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("ro"),
index: Box::new(Expr::var("i")),
},
),
Node::let_bind("y", Expr::u32(1)),
Node::store("rw", Expr::var("i"), Expr::var("x")),
],
}];
let prog = program(vec![ro("ro"), rw("rw", 1)], entry);
let result = LoopSoftwarePipeline::transform(prog);
assert!(!result.changed);
}
#[test]
fn keeps_loop_when_store_does_not_use_load() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(8),
body: vec![
Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("ro"),
index: Box::new(Expr::var("i")),
},
),
Node::store("rw", Expr::var("i"), Expr::u32(99)),
],
}];
let prog = program(vec![ro("ro"), rw("rw", 1)], entry);
let result = LoopSoftwarePipeline::transform(prog);
assert!(!result.changed);
}
#[test]
fn keeps_loop_when_store_value_has_other_load() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(8),
body: vec![
Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("ro"),
index: Box::new(Expr::var("i")),
},
),
Node::store(
"rw",
Expr::var("i"),
Expr::BinOp {
op: BinOp::Add,
left: Box::new(Expr::var("x")),
right: Box::new(Expr::Load {
buffer: Ident::from("ro"),
index: Box::new(Expr::u32(0)),
}),
},
),
],
}];
let prog = program(vec![ro("ro"), rw("rw", 1)], entry);
let result = LoopSoftwarePipeline::transform(prog);
assert!(!result.changed);
}
#[test]
fn keeps_loop_with_runtime_bounds() {
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::var("n"),
body: vec![
Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("ro"),
index: Box::new(Expr::var("i")),
},
),
Node::store("rw", Expr::var("i"), Expr::var("x")),
],
}];
let prog = program(vec![ro("ro"), rw("rw", 1)], entry);
let result = LoopSoftwarePipeline::transform(prog);
assert!(!result.changed);
}
#[test]
fn analyze_skips_program_without_pipelinable_loop() {
let entry = vec![Node::store("rw", Expr::u32(0), Expr::u32(1))];
let prog = program(vec![rw("rw", 1)], entry);
match LoopSoftwarePipeline::analyze(&prog) {
PassAnalysis::SKIP => {}
other => panic!("expected SKIP, got {other:?}"),
}
}
}