use vyre_foundation::execution_plan::fusion::fuse_programs;
use vyre_foundation::ir::{BufferDecl, DataType, Program};
#[derive(Debug, thiserror::Error)]
pub enum DecodeScanFuseError {
#[error(
"Fix: handoff buffer {handoff:?} does not appear in the decoder or scanner Program's \
buffer list. Add a `BufferDecl::storage({handoff:?}, ..., DataType::U32)` to both \
Programs before calling `fuse_decode_scan`."
)]
HandoffBufferMissing {
handoff: String,
},
#[error(
"Fix: fuse_decode_scan(handoff_byte_count = 0) is rejected on buffer {handoff:?}. \
Pass the decoder's peak output-bytes-per-workgroup."
)]
ZeroHandoff {
handoff: String,
},
#[error(
"Fix: kernel-level fusion failed - run the autotune pass to normalise workgroup \
sizes and rename any self-aliasing buffers before calling `fuse_decode_scan`. \
Inner: {0}"
)]
Fusion(#[from] vyre_foundation::execution_plan::fusion::FusionError),
}
pub fn fuse_decode_scan(
decoder: Program,
scanner: Program,
handoff_buf: &str,
handoff_byte_count: u32,
) -> Result<Program, DecodeScanFuseError> {
if handoff_byte_count == 0 {
return Err(DecodeScanFuseError::ZeroHandoff {
handoff: handoff_buf.to_string(),
});
}
let decoder_has = decoder.buffers.iter().any(|b| b.name() == handoff_buf);
let scanner_has = scanner.buffers.iter().any(|b| b.name() == handoff_buf);
if !decoder_has && !scanner_has {
return Err(DecodeScanFuseError::HandoffBufferMissing {
handoff: handoff_buf.to_string(),
});
}
let fused = fuse_programs(&[decoder, scanner])?;
Ok(promote_to_workgroup(fused, handoff_buf, handoff_byte_count))
}
fn promote_to_workgroup(program: Program, handoff_buf: &str, count: u32) -> Program {
let mut new_buffers: Vec<BufferDecl> = program
.buffers
.iter()
.filter(|b| b.name() != handoff_buf)
.cloned()
.collect();
new_buffers.push(BufferDecl::workgroup(handoff_buf, count, DataType::U32));
let entry = std::sync::Arc::try_unwrap(program.entry).unwrap_or_else(|arc| (*arc).clone());
Program::wrapped(new_buffers, program.workgroup_size, entry)
}
#[must_use]
pub fn dram_bytes_saved(handoff_byte_count: u32, invocations: u32) -> u64 {
2_u64 * u64::from(handoff_byte_count) * u64::from(invocations)
}
#[cfg(test)]
mod tests {
use super::*;
use vyre_foundation::ir::{BufferAccess, Expr, MemoryKind, Node};
fn decoder_with_handoff(handoff: &str) -> Program {
Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(64),
BufferDecl::storage(handoff, 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(64),
],
[64, 1, 1],
vec![Node::store(
handoff,
Expr::InvocationId { axis: 0 },
Expr::u32(0xAA),
)],
)
}
fn scanner_with_handoff(handoff: &str) -> Program {
Program::wrapped(
vec![
BufferDecl::storage(handoff, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(64),
BufferDecl::storage("matches", 2, BufferAccess::ReadWrite, DataType::U32)
.with_count(64),
],
[64, 1, 1],
vec![Node::let_bind(
"byte",
Expr::load(handoff, Expr::InvocationId { axis: 0 }),
)],
)
}
#[test]
fn missing_handoff_buffer_errors_with_actionable_fix() {
let decoder = Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
],
[64, 1, 1],
vec![],
);
let scanner = Program::wrapped(
vec![
BufferDecl::storage("matches", 0, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[64, 1, 1],
vec![],
);
let err = fuse_decode_scan(decoder, scanner, "decoded", 64).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("Fix:"));
assert!(msg.contains("decoded"));
}
#[test]
fn zero_handoff_byte_count_returns_structured_error() {
let decoder = decoder_with_handoff("decoded");
let scanner = scanner_with_handoff("decoded");
let err = fuse_decode_scan(decoder, scanner, "decoded", 0).unwrap_err();
assert!(matches!(err, DecodeScanFuseError::ZeroHandoff { .. }));
assert!(err.to_string().contains("Fix:"));
}
#[test]
fn fused_program_promotes_handoff_to_workgroup_memory() {
let decoder = decoder_with_handoff("decoded");
let scanner = scanner_with_handoff("decoded");
let fused = fuse_decode_scan(decoder, scanner, "decoded", 128).unwrap();
let handoff = fused.buffers.iter().find(|b| b.name() == "decoded").expect(
"Fix: handoff buffer survives fusion; restore this invariant before continuing.",
);
assert_eq!(handoff.access(), BufferAccess::Workgroup);
assert_eq!(handoff.kind(), MemoryKind::Shared);
assert_eq!(handoff.count(), 128);
}
#[test]
fn non_handoff_buffers_stay_as_declared() {
let decoder = decoder_with_handoff("decoded");
let scanner = scanner_with_handoff("decoded");
let fused = fuse_decode_scan(decoder, scanner, "decoded", 128).unwrap();
let input = fused.buffers.iter().find(|b| b.name() == "input").unwrap();
assert_eq!(input.access(), BufferAccess::ReadOnly);
let matches = fused
.buffers
.iter()
.find(|b| b.name() == "matches")
.unwrap();
assert_eq!(matches.access(), BufferAccess::ReadWrite);
}
#[test]
fn fused_body_contains_both_decoder_and_scanner_nodes() {
let decoder = decoder_with_handoff("decoded");
let scanner = scanner_with_handoff("decoded");
let fused = fuse_decode_scan(decoder, scanner, "decoded", 64).unwrap();
assert_eq!(
fused.entry.len(),
1,
"wrapped entry must be a single root Region"
);
let body = match &fused.entry[0] {
vyre::ir::Node::Region { body, .. } => body.as_ref(),
other => panic!("Fix: fused entry root must be a Region, got {other:?}"),
};
assert!(
body.len() >= 2,
"fused root region body should contain both arms, got {} nodes",
body.len()
);
}
#[test]
fn dram_bytes_saved_scales_with_invocations() {
assert_eq!(dram_bytes_saved(0, 1_000_000), 0);
assert_eq!(dram_bytes_saved(1024, 1000), 2 * 1024 * 1000);
assert_eq!(dram_bytes_saved(1, u32::MAX), 2 * u64::from(u32::MAX));
}
#[test]
fn handoff_present_in_only_scanner_still_fuses() {
let decoder = Program::wrapped(
vec![
BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
],
[64, 1, 1],
vec![],
);
let scanner = scanner_with_handoff("decoded");
let fused = fuse_decode_scan(decoder, scanner, "decoded", 64).unwrap();
let handoff = fused
.buffers
.iter()
.find(|b| b.name() == "decoded")
.unwrap();
assert_eq!(handoff.access(), BufferAccess::Workgroup);
}
}