use crate::specialization::SpecMap;
#[derive(Debug, Clone)]
pub struct DispatchShape {
pub id: &'static str,
pub workgroup_size: [u32; 3],
pub shared_memory_bytes: u32,
pub inputs: Vec<&'static str>,
pub outputs: Vec<&'static str>,
pub specs: SpecMap,
}
#[derive(Debug, Clone, Copy)]
pub struct FusionCaps {
pub max_shared_memory_bytes: u32,
pub max_invocations_per_workgroup: u32,
}
impl Default for FusionCaps {
fn default() -> Self {
Self {
max_shared_memory_bytes: 16 * 1024,
max_invocations_per_workgroup: 256,
}
}
}
impl FusionCaps {
#[must_use]
pub const fn high_end() -> Self {
Self {
max_shared_memory_bytes: 128 * 1024,
max_invocations_per_workgroup: 1024,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum FusionDecision {
Accept,
WorkgroupSizeMismatch {
upstream: [u32; 3],
downstream: [u32; 3],
},
SharedMemoryBudget {
needed: u32,
cap: u32,
},
OutputConsumedElsewhere,
NoPipelineDependency,
}
impl std::fmt::Display for FusionDecision {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Accept => f.write_str("accept"),
Self::WorkgroupSizeMismatch {
upstream,
downstream,
} => write!(
f,
"workgroup-size-mismatch:{:?}->{:?}",
upstream, downstream
),
Self::SharedMemoryBudget { needed, cap } => {
write!(f, "shared-memory-budget:{needed}/{cap}")
}
Self::OutputConsumedElsewhere => f.write_str("output-consumed-elsewhere"),
Self::NoPipelineDependency => f.write_str("no-pipeline-dependency"),
}
}
}
pub struct FusionPass;
impl FusionPass {
#[must_use]
pub fn decide(
upstream: &DispatchShape,
downstream: &DispatchShape,
caps: FusionCaps,
other_consumers: &[&str],
) -> FusionDecision {
if upstream.workgroup_size != downstream.workgroup_size {
return FusionDecision::WorkgroupSizeMismatch {
upstream: upstream.workgroup_size,
downstream: downstream.workgroup_size,
};
}
let invocations = upstream.workgroup_size[0]
.saturating_mul(upstream.workgroup_size[1])
.saturating_mul(upstream.workgroup_size[2]);
if invocations > caps.max_invocations_per_workgroup {
return FusionDecision::WorkgroupSizeMismatch {
upstream: upstream.workgroup_size,
downstream: downstream.workgroup_size,
};
}
let needed = upstream
.shared_memory_bytes
.saturating_add(downstream.shared_memory_bytes);
if needed > caps.max_shared_memory_bytes {
return FusionDecision::SharedMemoryBudget {
needed,
cap: caps.max_shared_memory_bytes,
};
}
let mut has_pipeline_dependency = false;
for output in &upstream.outputs {
if !downstream.inputs.iter().any(|input| input == output) {
continue;
}
has_pipeline_dependency = true;
if other_consumers.iter().any(|consumer| consumer == output) {
return FusionDecision::OutputConsumedElsewhere;
}
}
if !has_pipeline_dependency {
return FusionDecision::NoPipelineDependency;
}
FusionDecision::Accept
}
}
#[cfg(test)]
mod tests {
use super::*;
fn dispatch(
id: &'static str,
inputs: &[&'static str],
outputs: &[&'static str],
) -> DispatchShape {
DispatchShape {
id,
workgroup_size: [64, 1, 1],
shared_memory_bytes: 1024,
inputs: inputs.to_vec(),
outputs: outputs.to_vec(),
specs: SpecMap::new(),
}
}
#[test]
fn straight_producer_consumer_fuses() {
let up = dispatch("load", &["in"], &["stage"]);
let down = dispatch("xor", &["stage"], &["out"]);
assert_eq!(
FusionPass::decide(&up, &down, FusionCaps::high_end(), &[]),
FusionDecision::Accept
);
}
#[test]
fn third_consumer_rejects() {
let up = dispatch("a", &[], &["x"]);
let down = dispatch("b", &["x"], &[]);
assert_eq!(
FusionPass::decide(&up, &down, FusionCaps::high_end(), &["x"]),
FusionDecision::OutputConsumedElsewhere
);
}
#[test]
fn fusion_decision_formats_human_string() {
assert_eq!(
format!(
"{}",
FusionDecision::WorkgroupSizeMismatch {
upstream: [8, 4, 1],
downstream: [4, 4, 1]
}
),
"workgroup-size-mismatch:[8, 4, 1]->[4, 4, 1]"
);
assert_eq!(
format!(
"{}",
FusionDecision::SharedMemoryBudget {
needed: 4096,
cap: 2048
}
),
"shared-memory-budget:4096/2048"
);
assert_eq!(
format!("{}", FusionDecision::Accept),
"accept"
);
}
}