use rustc_hash::FxHashSet;
use crate::execution_plan::fusion::fuse_programs_vec;
use crate::ir::Program;
#[must_use]
pub fn fuse_cse(mut programs: Vec<Program>) -> Option<Program> {
if programs.is_empty() {
return Some(Program::empty());
}
if programs.len() == 1 {
return programs.pop();
}
let wg0 = programs[0].workgroup_size;
if programs.iter().any(|p| p.workgroup_size != wg0) {
return None;
}
fuse_programs_vec(programs).ok()
}
#[must_use]
pub fn cse_savings(before: &[Program], after: &Program) -> usize {
let total_before: usize = before.iter().map(|p| p.buffers.len()).sum();
let total_after = after.buffers.len();
total_before.saturating_sub(total_after)
}
#[must_use]
pub fn unique_buffer_count(after: &Program) -> usize {
let mut seen: FxHashSet<&str> = FxHashSet::default();
for buf in after.buffers.iter() {
seen.insert(buf.name());
}
seen.len()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn rule(name: &str, extra_buf: Option<&str>) -> Program {
let mut buffers = vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32).with_count(32),
BufferDecl::storage("output_slots", 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(32),
];
if let Some(b) = extra_buf {
buffers.push(
BufferDecl::storage(b, 2, BufferAccess::ReadOnly, DataType::U32).with_count(32),
);
}
let d = Expr::InvocationId { axis: 0 };
Program::wrapped(
buffers,
[64, 1, 1],
vec![Node::let_bind(format!("{name}_val"), d)],
)
}
#[test]
fn empty_returns_empty_program() {
let fused = fuse_cse(Vec::new())
.expect("Fix: empty input is ok; restore this invariant before continuing.");
assert_eq!(fused.buffers.len(), 0);
}
#[test]
fn single_program_is_returned_unchanged() {
let p = rule("r1", None);
let before_bufs = p.buffers.len();
let fused = fuse_cse(vec![p]).unwrap();
assert_eq!(fused.buffers.len(), before_bufs);
}
#[test]
fn two_rules_with_shared_input_collapse_buffers() {
let r1 = rule("r1", None);
let r2 = rule("r2", None);
let before: Vec<_> = vec![r1.clone(), r2.clone()];
let fused = fuse_cse(vec![r1, r2]).unwrap();
assert_eq!(fused.buffers.len(), 2);
assert_eq!(cse_savings(&before, &fused), 2);
assert_eq!(unique_buffer_count(&fused), 2);
}
#[test]
fn unique_buffers_are_preserved() {
let r1 = rule("r1", Some("r1_private"));
let r2 = rule("r2", Some("r2_private"));
let before: Vec<_> = vec![r1.clone(), r2.clone()];
let fused = fuse_cse(vec![r1, r2]).unwrap();
assert_eq!(fused.buffers.len(), 4);
assert_eq!(cse_savings(&before, &fused), 2);
}
#[test]
fn conflicting_workgroup_sizes_are_rejected() {
let r1 = Program::wrapped(
vec![BufferDecl::storage("a", 0, BufferAccess::ReadOnly, DataType::U32).with_count(1)],
[64, 1, 1],
vec![Node::let_bind("x", Expr::u32(0))],
);
let r2 = Program::wrapped(
vec![BufferDecl::storage("a", 0, BufferAccess::ReadOnly, DataType::U32).with_count(1)],
[128, 1, 1],
vec![Node::let_bind("x", Expr::u32(0))],
);
assert!(fuse_cse(vec![r1, r2]).is_none());
}
#[test]
fn savings_are_zero_when_no_shared_buffers() {
let r1 = Program::wrapped(
vec![
BufferDecl::storage("r1_a", 0, BufferAccess::ReadOnly, DataType::U32).with_count(1),
],
[64, 1, 1],
vec![Node::let_bind("x", Expr::u32(0))],
);
let r2 = Program::wrapped(
vec![
BufferDecl::storage("r2_a", 0, BufferAccess::ReadOnly, DataType::U32).with_count(1),
],
[64, 1, 1],
vec![Node::let_bind("x", Expr::u32(0))],
);
let before: Vec<_> = vec![r1.clone(), r2.clone()];
let fused = fuse_cse(vec![r1, r2]).unwrap();
assert_eq!(cse_savings(&before, &fused), 0);
}
#[test]
fn savings_scale_with_rule_count() {
let rules: Vec<Program> = (0..5).map(|_| rule("r", None)).collect();
let before = rules.clone();
let fused = fuse_cse(rules).unwrap();
assert_eq!(cse_savings(&before, &fused), 8);
}
}