use rustc_hash::{FxHashMap, FxHashSet};
use crate::ir::{BufferAccess, BufferDecl, DataType, Ident, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
const MAX_WORKGROUP_PROMOTION_BYTES: u64 = 16 * 1024;
fn element_bytes(element: &DataType) -> Option<u64> {
element.size_bytes().map(|bytes| bytes as u64)
}
fn fits_workgroup_budget(buf: &BufferDecl) -> bool {
let Some(element_bytes) = element_bytes(&buf.element()) else {
return false;
};
let Some(bytes) = u64::from(buf.count()).checked_mul(element_bytes) else {
return false;
};
bytes > 0 && bytes <= MAX_WORKGROUP_PROMOTION_BYTES
}
#[derive(Debug, Default)]
#[vyre_pass(
name = "decode_scan_fuse",
requires = [],
invalidates = ["buffer_layout", "fusion"]
)]
pub struct DecodeScanFuse;
impl DecodeScanFuse {
#[must_use]
fn analyze_impl(program: &Program) -> PassAnalysis {
if count_opportunities(program) == 0 {
PassAnalysis::SKIP
} else {
PassAnalysis::RUN
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let before = fingerprint_program(&program);
let optimized = run(program);
PassResult {
changed: fingerprint_program(&optimized) != before,
program: optimized,
}
}
}
#[must_use]
pub fn run(program: Program) -> Program {
let promotable: FxHashSet<Ident> = program
.buffers
.iter()
.filter(|b| {
b.access() == BufferAccess::ReadWrite
&& b.count() > 0
&& !b.is_pipeline_live_out()
&& fits_workgroup_budget(b)
})
.map(|b| Ident::from(b.name()))
.collect();
if promotable.is_empty() {
return program;
}
let new_buffers: Vec<BufferDecl> = program
.buffers
.iter()
.map(|b| {
if promotable.contains(&Ident::from(b.name())) {
BufferDecl::workgroup(b.name(), b.count(), b.element())
} else {
b.clone()
}
})
.collect();
let entry = std::sync::Arc::try_unwrap(program.entry).unwrap_or_else(|arc| (*arc).clone());
Program::wrapped(new_buffers, program.workgroup_size, entry)
}
#[must_use]
pub fn count_opportunities(program: &Program) -> usize {
program
.buffers
.iter()
.filter(|b| {
b.access() == BufferAccess::ReadWrite
&& b.count() > 0
&& !b.is_pipeline_live_out()
&& fits_workgroup_budget(b)
})
.count()
}
#[must_use]
pub fn candidate_handoffs(program: &Program) -> FxHashMap<Ident, u32> {
let mut out = FxHashMap::default();
for buf in program.buffers.iter() {
if buf.access() == BufferAccess::ReadWrite && buf.count() > 0 && !buf.is_pipeline_live_out()
{
out.insert(Ident::from(buf.name()), buf.count());
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Program};
fn decoder_like() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(64),
BufferDecl::storage("decoded", 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(128),
],
[64, 1, 1],
vec![],
)
}
#[test]
fn run_promotes_readwrite_handoff_to_workgroup() {
let p = decoder_like();
let before_bufs = p.buffers.len();
let after = run(p);
assert_eq!(after.buffers.len(), before_bufs);
let decoded = after
.buffers
.iter()
.find(|b| b.name() == "decoded")
.unwrap();
assert_eq!(decoded.access(), BufferAccess::Workgroup);
}
#[test]
fn run_leaves_read_only_buffers_alone() {
let p = decoder_like();
let after = run(p);
let input = after.buffers.iter().find(|b| b.name() == "input").unwrap();
assert_eq!(input.access(), BufferAccess::ReadOnly);
}
#[test]
fn run_preserves_pipeline_live_out_buffer() {
let p = Program::wrapped(
vec![
BufferDecl::storage("result", 0, BufferAccess::ReadWrite, DataType::U32)
.with_count(16)
.with_pipeline_live_out(true),
],
[64, 1, 1],
vec![],
);
let after = run(p);
let r = after.buffers.iter().find(|b| b.name() == "result").unwrap();
assert_eq!(r.access(), BufferAccess::ReadWrite);
assert!(r.is_pipeline_live_out());
}
#[test]
fn run_is_identity_when_no_candidates() {
let p = Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
],
[64, 1, 1],
vec![],
);
let after = run(p);
assert_eq!(after.buffers.len(), 1);
assert_eq!(after.buffers[0].access(), BufferAccess::ReadOnly);
}
#[test]
fn run_skips_runtime_sized_buffers() {
let p = Program::wrapped(
vec![BufferDecl::storage(
"dynamic",
0,
BufferAccess::ReadWrite,
DataType::U32,
)],
[64, 1, 1],
vec![],
);
let after = run(p);
let b = after
.buffers
.iter()
.find(|b| b.name() == "dynamic")
.unwrap();
assert_eq!(b.access(), BufferAccess::ReadWrite);
}
#[test]
fn count_opportunities_finds_one_candidate() {
assert_eq!(count_opportunities(&decoder_like()), 1);
}
#[test]
fn run_leaves_oversize_handoff_in_storage() {
let p = Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(64),
BufferDecl::storage("decoded", 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(4097),
],
[64, 1, 1],
vec![],
);
assert_eq!(count_opportunities(&p), 0);
let after = run(p);
let decoded = after
.buffers
.iter()
.find(|b| b.name() == "decoded")
.unwrap();
assert_eq!(
decoded.access(),
BufferAccess::ReadWrite,
"oversize handoff must not be promoted; would exceed 16 KiB shared-memory floor"
);
}
#[test]
fn run_promotes_at_workgroup_byte_ceiling() {
let p = Program::wrapped(
vec![
BufferDecl::storage("decoded", 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(4096),
],
[64, 1, 1],
vec![],
);
let after = run(p);
let decoded = after
.buffers
.iter()
.find(|b| b.name() == "decoded")
.unwrap();
assert_eq!(decoded.access(), BufferAccess::Workgroup);
}
#[test]
fn count_opportunities_zero_on_read_only_program() {
let p = Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
],
[64, 1, 1],
vec![],
);
assert_eq!(count_opportunities(&p), 0);
}
#[test]
fn candidate_handoffs_exposes_name_and_count() {
let p = decoder_like();
let cands = candidate_handoffs(&p);
assert_eq!(cands.get(&Ident::from("decoded")).copied(), Some(128));
assert!(!cands.contains_key(&Ident::from("input")));
}
#[test]
fn multiple_candidates_all_surface() {
let p = Program::wrapped(
vec![
BufferDecl::storage("a", 0, BufferAccess::ReadWrite, DataType::U32).with_count(32),
BufferDecl::storage("b", 1, BufferAccess::ReadWrite, DataType::U32).with_count(64),
BufferDecl::storage("c", 2, BufferAccess::ReadOnly, DataType::U32).with_count(16),
],
[64, 1, 1],
vec![],
);
let cands = candidate_handoffs(&p);
assert_eq!(cands.len(), 2);
assert_eq!(cands.get(&Ident::from("a")).copied(), Some(32));
assert_eq!(cands.get(&Ident::from("b")).copied(), Some(64));
}
}