use rustc_hash::{FxHashMap, FxHashSet};
use crate::ir::{BufferAccess, BufferDecl, Ident, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[derive(Debug, Default)]
#[vyre_pass(
name = "decode_scan_fuse",
requires = [],
invalidates = ["buffer_layout", "fusion"]
)]
pub struct DecodeScanFuse;
impl DecodeScanFuse {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if count_opportunities(program) == 0 {
PassAnalysis::SKIP
} else {
PassAnalysis::RUN
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let before = fingerprint_program(&program);
let optimized = run(program);
PassResult {
changed: fingerprint_program(&optimized) != before,
program: optimized,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
#[must_use]
pub fn run(program: Program) -> Program {
let promotable: FxHashSet<Ident> = program
.buffers
.iter()
.filter(|b| {
b.access() == BufferAccess::ReadWrite && b.count() > 0 && !b.is_pipeline_live_out()
})
.map(|b| Ident::from(b.name()))
.collect();
if promotable.is_empty() {
return program;
}
let new_buffers: Vec<BufferDecl> = program
.buffers
.iter()
.map(|b| {
if promotable.contains(&Ident::from(b.name())) {
BufferDecl::workgroup(b.name(), b.count(), b.element())
} else {
b.clone()
}
})
.collect();
let entry = std::sync::Arc::try_unwrap(program.entry).unwrap_or_else(|arc| (*arc).clone());
Program::wrapped(new_buffers, program.workgroup_size, entry)
}
#[must_use]
pub fn count_opportunities(program: &Program) -> usize {
program
.buffers
.iter()
.filter(|b| {
b.access() == BufferAccess::ReadWrite && b.count() > 0 && !b.is_pipeline_live_out()
})
.count()
}
#[must_use]
pub fn candidate_handoffs(program: &Program) -> FxHashMap<Ident, u32> {
let mut out = FxHashMap::default();
for buf in program.buffers.iter() {
if buf.access() == BufferAccess::ReadWrite && buf.count() > 0 && !buf.is_pipeline_live_out()
{
out.insert(Ident::from(buf.name()), buf.count());
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Program};
fn decoder_like() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(64),
BufferDecl::storage("decoded", 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(128),
],
[64, 1, 1],
vec![],
)
}
#[test]
fn run_promotes_readwrite_handoff_to_workgroup() {
let p = decoder_like();
let before_bufs = p.buffers.len();
let after = run(p);
assert_eq!(after.buffers.len(), before_bufs);
let decoded = after
.buffers
.iter()
.find(|b| b.name() == "decoded")
.unwrap();
assert_eq!(decoded.access(), BufferAccess::Workgroup);
}
#[test]
fn run_leaves_read_only_buffers_alone() {
let p = decoder_like();
let after = run(p);
let input = after.buffers.iter().find(|b| b.name() == "input").unwrap();
assert_eq!(input.access(), BufferAccess::ReadOnly);
}
#[test]
fn run_preserves_pipeline_live_out_buffer() {
let p = Program::wrapped(
vec![
BufferDecl::storage("result", 0, BufferAccess::ReadWrite, DataType::U32)
.with_count(16)
.with_pipeline_live_out(true),
],
[64, 1, 1],
vec![],
);
let after = run(p);
let r = after.buffers.iter().find(|b| b.name() == "result").unwrap();
assert_eq!(r.access(), BufferAccess::ReadWrite);
assert!(r.is_pipeline_live_out());
}
#[test]
fn run_is_identity_when_no_candidates() {
let p = Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
],
[64, 1, 1],
vec![],
);
let after = run(p);
assert_eq!(after.buffers.len(), 1);
assert_eq!(after.buffers[0].access(), BufferAccess::ReadOnly);
}
#[test]
fn run_skips_runtime_sized_buffers() {
let p = Program::wrapped(
vec![BufferDecl::storage(
"dynamic",
0,
BufferAccess::ReadWrite,
DataType::U32,
)],
[64, 1, 1],
vec![],
);
let after = run(p);
let b = after
.buffers
.iter()
.find(|b| b.name() == "dynamic")
.unwrap();
assert_eq!(b.access(), BufferAccess::ReadWrite);
}
#[test]
fn count_opportunities_finds_one_candidate() {
assert_eq!(count_opportunities(&decoder_like()), 1);
}
#[test]
fn count_opportunities_zero_on_read_only_program() {
let p = Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
],
[64, 1, 1],
vec![],
);
assert_eq!(count_opportunities(&p), 0);
}
#[test]
fn candidate_handoffs_exposes_name_and_count() {
let p = decoder_like();
let cands = candidate_handoffs(&p);
assert_eq!(cands.get(&Ident::from("decoded")).copied(), Some(128));
assert!(!cands.contains_key(&Ident::from("input")));
}
#[test]
fn multiple_candidates_all_surface() {
let p = Program::wrapped(
vec![
BufferDecl::storage("a", 0, BufferAccess::ReadWrite, DataType::U32).with_count(32),
BufferDecl::storage("b", 1, BufferAccess::ReadWrite, DataType::U32).with_count(64),
BufferDecl::storage("c", 2, BufferAccess::ReadOnly, DataType::U32).with_count(16),
],
[64, 1, 1],
vec![],
);
let cands = candidate_handoffs(&p);
assert_eq!(cands.len(), 2);
assert_eq!(cands.get(&Ident::from("a")).copied(), Some(32));
assert_eq!(cands.get(&Ident::from("b")).copied(), Some(64));
}
}