use rustc_hash::{FxHashMap, FxHashSet};
use crate::execution_plan::SchedulingPolicy;
use crate::ir::{BufferAccess, BufferDecl, Ident, Node, Program};
use super::alpha_rename::push_alpha_renamed_arm_entry_node;
use super::collectors::collect_buffer_targets;
use super::divergence::has_divergent_invocation_gated_store;
use super::helpers::{fallback_composition_key, upgrade_buffer_access};
use super::{FusionError, FusionOverDispatchError, FusionSelfAliasingError};
pub fn fuse_programs(programs: &[Program]) -> Result<Program, FusionError> {
match programs.len() {
0 => Ok(Program::empty()),
1 => Ok(programs[0].clone()),
_ => fuse_programs_multi(programs),
}
}
#[inline]
#[must_use]
pub fn fuse_programs_vec(mut programs: Vec<Program>) -> Result<Program, FusionError> {
match programs.len() {
0 => Ok(Program::empty()),
1 => {
let Some(program) = programs.pop() else {
return Ok(Program::empty());
};
Ok(program)
}
_ => fuse_programs_multi(programs.as_slice()),
}
}
fn fuse_programs_multi(programs: &[Program]) -> Result<Program, FusionError> {
reject_non_composable_self_fusion(programs)?;
let mut merged_buffers: Vec<BufferDecl> = Vec::new();
let mut name_to_index: FxHashMap<Ident, usize> = FxHashMap::default();
let mut next_binding = 0_u32;
let mut read_arms_per_buffer: FxHashMap<Ident, Vec<usize>> = FxHashMap::default();
let mut write_arms_per_buffer: FxHashMap<Ident, Vec<usize>> = FxHashMap::default();
let mut barrier_after_arm: FxHashSet<usize> = FxHashSet::default();
let mut divergent_store_arms: FxHashSet<usize> = FxHashSet::default();
let mut fused_workgroup = [1u32, 1, 1];
let mut max_arm_threads: u64 = 1;
let mut arm_entries: Vec<Vec<Node>> = Vec::with_capacity(programs.len());
for (arm_idx, prog) in programs.iter().enumerate() {
let entry = prog.entry();
let mut segment = Vec::with_capacity(entry.len());
let mut atomic_targets: FxHashSet<Ident> = FxHashSet::default();
let mut load_targets: FxHashSet<Ident> = FxHashSet::default();
let mut store_targets: FxHashSet<Ident> = FxHashSet::default();
let mut divergent_store_seen = false;
for node in entry {
push_alpha_renamed_arm_entry_node(&mut segment, node, arm_idx);
collect_buffer_targets(
node,
&mut load_targets,
&mut store_targets,
&mut atomic_targets,
);
if has_divergent_invocation_gated_store(node, false) {
divergent_store_seen = true;
}
}
if divergent_store_seen {
divergent_store_arms.insert(arm_idx);
}
arm_entries.push(segment);
let mut arm_reads: FxHashSet<Ident> = FxHashSet::default();
let mut arm_explicit_writes: FxHashSet<Ident> = FxHashSet::default();
classify_and_merge_arm_buffers(
prog,
&mut arm_reads,
&mut arm_explicit_writes,
&mut merged_buffers,
&mut name_to_index,
&mut next_binding,
);
for target in &load_targets {
arm_reads.insert(target.clone());
}
for target in &store_targets {
arm_explicit_writes.insert(target.clone());
}
let mut arm_writes = arm_explicit_writes.clone();
for target in &atomic_targets {
if !arm_reads.contains(target) && !arm_explicit_writes.contains(target) {
arm_writes.insert(target.clone());
}
}
for write_buf in &arm_writes {
if let Some(read_arms) = read_arms_per_buffer.get(write_buf) {
for &read_arm in read_arms {
barrier_after_arm.insert(read_arm);
}
}
}
for read_buf in &arm_reads {
if let Some(write_arms) = write_arms_per_buffer.get(read_buf) {
for &write_arm in write_arms {
barrier_after_arm.insert(write_arm);
}
}
}
for read_buf in &arm_reads {
read_arms_per_buffer
.entry(read_buf.clone())
.or_default()
.push(arm_idx);
}
for write_buf in &arm_writes {
write_arms_per_buffer
.entry(write_buf.clone())
.or_default()
.push(arm_idx);
}
let wg = prog.workgroup_size();
fused_workgroup[0] = fused_workgroup[0].max(wg[0]);
fused_workgroup[1] = fused_workgroup[1].max(wg[1]);
fused_workgroup[2] = fused_workgroup[2].max(wg[2]);
let arm_threads = u64::from(wg[0]) * u64::from(wg[1]) * u64::from(wg[2]);
max_arm_threads = max_arm_threads.max(arm_threads);
}
let combined_entry = flatten_arm_entries(
arm_entries,
&barrier_after_arm,
&divergent_store_arms,
programs.len(),
);
reject_overdispatch(fused_workgroup, max_arm_threads)?;
Ok(Program::wrapped(
merged_buffers,
fused_workgroup,
combined_entry,
))
}
fn classify_and_merge_arm_buffers(
prog: &Program,
arm_reads: &mut FxHashSet<Ident>,
arm_explicit_writes: &mut FxHashSet<Ident>,
merged_buffers: &mut Vec<BufferDecl>,
name_to_index: &mut FxHashMap<Ident, usize>,
next_binding: &mut u32,
) {
for buf in prog.buffers() {
let name = Ident::from(buf.name());
match buf.access() {
BufferAccess::ReadOnly | BufferAccess::Uniform => {
arm_reads.insert(name.clone());
}
BufferAccess::ReadWrite => {
arm_explicit_writes.insert(name.clone());
}
_ => {}
}
if let Some(&idx) = name_to_index.get(&name) {
let existing = &mut merged_buffers[idx];
let access = buf.access();
upgrade_buffer_access(existing, &access);
if buf.count > existing.count {
existing.count = buf.count;
}
if buf.is_output() {
existing.is_output = true;
existing.pipeline_live_out = true;
}
} else {
let mut merged = buf.clone();
if merged.access() != BufferAccess::Workgroup {
merged.binding = *next_binding;
*next_binding += 1;
}
name_to_index.insert(Ident::from(merged.name()), merged_buffers.len());
merged_buffers.push(merged);
}
}
}
fn reject_non_composable_self_fusion(programs: &[Program]) -> Result<(), FusionError> {
let mut seen_op_ids: FxHashMap<String, bool> = FxHashMap::default();
for prog in programs {
let key = prog
.entry_op_id()
.map_or_else(|| fallback_composition_key(prog), ToString::to_string);
let is_non_comp = prog.is_non_composable_with_self();
match seen_op_ids.get_mut(&key) {
Some(has_non_comp) if *has_non_comp || is_non_comp => {
return Err(FusionError::SelfAliasing(FusionSelfAliasingError {
op_id: key,
fix: "rename the second parser's workgroup buffer or split into two separate dispatches",
}));
}
Some(_) => {}
None => {
seen_op_ids.insert(key, is_non_comp);
}
}
}
Ok(())
}
fn flatten_arm_entries(
arm_entries: Vec<Vec<Node>>,
barrier_after_arm: &FxHashSet<usize>,
divergent_store_arms: &FxHashSet<usize>,
program_count: usize,
) -> Vec<Node> {
let total_nodes: usize = arm_entries.iter().map(Vec::len).sum();
let mut combined_entry = Vec::with_capacity(total_nodes + program_count);
for (arm_idx, segment) in arm_entries.into_iter().enumerate() {
combined_entry.push(Node::Block(segment));
if barrier_after_arm.contains(&arm_idx) {
let ordering = if divergent_store_arms.contains(&arm_idx) {
crate::memory_model::MemoryOrdering::GridSync
} else {
crate::memory_model::MemoryOrdering::SeqCst
};
combined_entry.push(Node::barrier_with_ordering(ordering));
}
}
combined_entry
}
fn reject_overdispatch(fused_workgroup: [u32; 3], max_arm_threads: u64) -> Result<(), FusionError> {
let fused_threads = u64::from(fused_workgroup[0])
* u64::from(fused_workgroup[1])
* u64::from(fused_workgroup[2]);
let policy = SchedulingPolicy::standard();
if policy.allow_fused_threads(fused_threads, max_arm_threads) {
return Ok(());
}
Err(FusionError::OverDispatch(FusionOverDispatchError {
max_arm_threads,
fused_threads,
fix: "split the batch or use per-arm dispatch; axis-wise max exceeds the shared over-dispatch policy",
}))
}