use crate::ir::{Ident, Node, Program};
use crate::optimizer::program_soa::ProgramFacts;
use rustc_hash::{FxHashMap, FxHashSet};
#[derive(Debug, Default)]
pub struct ScratchReusePlan {
arm_recyclable: FxHashMap<Ident, FxHashSet<Ident>>,
non_escaping: FxHashSet<Ident>,
}
impl ScratchReusePlan {
#[must_use]
pub fn build(program: &Program) -> Self {
let facts = ProgramFacts::build(program);
let escaping = facts.escaping_buffers();
let non_escaping: FxHashSet<Ident> = program
.buffers()
.iter()
.filter_map(|b| {
let name = Ident::from(b.name.as_ref());
if escaping.contains_key(&name) {
None
} else {
Some(name)
}
})
.collect();
let mut arm_recyclable: FxHashMap<Ident, FxHashSet<Ident>> = FxHashMap::default();
for region in facts.regions() {
let arm_buffers = collect_buffer_uses(program.entry(), ®ion.node, &facts);
let recyclable: FxHashSet<Ident> = arm_buffers
.into_iter()
.filter(|b| non_escaping.contains(b))
.collect();
if !recyclable.is_empty() {
arm_recyclable
.entry(region.generator.clone())
.or_default()
.extend(recyclable);
}
}
Self {
arm_recyclable,
non_escaping,
}
}
#[must_use]
pub fn recyclable_for(&self, arm_generator: &str) -> &FxHashSet<Ident> {
self.arm_recyclable.get(arm_generator).unwrap_or(&EMPTY_SET)
}
#[must_use]
pub fn non_escaping(&self) -> &FxHashSet<Ident> {
&self.non_escaping
}
#[must_use]
pub fn is_recyclable(&self, name: &str) -> bool {
self.non_escaping.iter().any(|n| n.as_str() == name)
}
#[must_use]
pub fn pair_count(&self) -> usize {
self.arm_recyclable.values().map(|s| s.len()).sum()
}
}
static EMPTY_SET: std::sync::LazyLock<FxHashSet<Ident>> =
std::sync::LazyLock::new(FxHashSet::default);
fn collect_buffer_uses(
_entry: &[Node],
region_node: &crate::optimizer::program_soa::NodeIndex,
facts: &ProgramFacts,
) -> FxHashSet<Ident> {
let mut out: FxHashSet<Ident> = FxHashSet::default();
for (node, name, _) in facts.buffer_refs() {
if facts.is_descendant_of(*node, *region_node) {
out.insert(name.clone());
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf_rw(name: &str, binding: u32) -> BufferDecl {
BufferDecl::storage(name, binding, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn buf_ro(name: &str, binding: u32) -> BufferDecl {
BufferDecl::storage(name, binding, BufferAccess::ReadOnly, DataType::U32).with_count(4)
}
fn buf_output(name: &str, binding: u32) -> BufferDecl {
BufferDecl::output(name, binding, DataType::U32).with_count(4)
}
fn region(gen: &str, body: Vec<Node>) -> Node {
Node::Region {
generator: Ident::from(gen),
source_region: None,
body: std::sync::Arc::new(body),
}
}
#[test]
fn read_only_buffer_appears_as_non_escaping() {
let entry = vec![region(
"arm_a",
vec![Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("input"),
index: Box::new(Expr::u32(0)),
},
)],
)];
let prog = Program::wrapped(vec![buf_ro("input", 0)], [1, 1, 1], entry);
let plan = ScratchReusePlan::build(&prog);
assert!(plan.is_recyclable("input"));
}
#[test]
fn output_buffer_does_not_appear_as_recyclable() {
let entry = vec![region(
"arm_a",
vec![Node::store("out", Expr::u32(0), Expr::u32(7))],
)];
let prog = Program::wrapped(vec![buf_output("out", 0)], [1, 1, 1], entry);
let plan = ScratchReusePlan::build(&prog);
assert!(!plan.is_recyclable("out"));
}
#[test]
fn pair_count_reports_total_pairs() {
let entry = vec![
region(
"arm_a",
vec![Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("scratch_a"),
index: Box::new(Expr::u32(0)),
},
)],
),
region(
"arm_b",
vec![Node::let_bind(
"y",
Expr::Load {
buffer: Ident::from("scratch_b"),
index: Box::new(Expr::u32(0)),
},
)],
),
];
let prog = Program::wrapped(
vec![buf_ro("scratch_a", 0), buf_ro("scratch_b", 1)],
[1, 1, 1],
entry,
);
let plan = ScratchReusePlan::build(&prog);
let arm_a = plan.recyclable_for("arm_a");
assert!(arm_a.iter().any(|n| n.as_str() == "scratch_a"));
assert!(!arm_a.iter().any(|n| n.as_str() == "scratch_b"));
let arm_b = plan.recyclable_for("arm_b");
assert!(arm_b.iter().any(|n| n.as_str() == "scratch_b"));
assert!(!arm_b.iter().any(|n| n.as_str() == "scratch_a"));
assert!(plan.pair_count() >= 2);
}
#[test]
fn non_escaping_excludes_store_targets() {
let entry = vec![region(
"arm",
vec![
Node::store("rw", Expr::u32(0), Expr::u32(7)),
Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("input"),
index: Box::new(Expr::u32(0)),
},
),
],
)];
let prog = Program::wrapped(vec![buf_ro("input", 0), buf_rw("rw", 1)], [1, 1, 1], entry);
let plan = ScratchReusePlan::build(&prog);
assert!(plan.non_escaping().iter().any(|n| n.as_str() == "input"));
assert!(!plan.non_escaping().iter().any(|n| n.as_str() == "rw"));
}
}