use crate::graph::{Node, OpKind};
use std::collections::{HashMap, HashSet};
pub fn fuse_conv_add_relu(nodes: Vec<Node>) -> Vec<Node> {
if nodes.len() < 3 {
return nodes;
}
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut consumer_count: HashMap<String, usize> = HashMap::new();
for node in &nodes {
for inp in &node.inputs {
if !inp.is_empty() {
*consumer_count.entry(inp.clone()).or_insert(0) += 1;
}
}
}
let mut skip: HashSet<usize> = HashSet::new();
let mut replacements: HashMap<usize, Node> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if skip.contains(&i) {
continue;
}
if !matches!(node.op, OpKind::Relu) {
continue;
}
if node.inputs.is_empty() {
continue;
}
let relu_input = &node.inputs[0];
if consumer_count.get(relu_input).copied().unwrap_or(0) != 1 {
continue;
}
let add_idx = match producer.get(relu_input) {
Some(&idx) => idx,
None => continue,
};
if skip.contains(&add_idx) {
continue;
}
if !matches!(nodes[add_idx].op, OpKind::Add) {
continue;
}
if nodes[add_idx].inputs.len() < 2 {
continue;
}
let add_inp0 = &nodes[add_idx].inputs[0];
let add_inp1 = &nodes[add_idx].inputs[1];
let (conv_idx, residual_name) = {
let try_find_conv =
|conv_candidate: &str, residual_candidate: &str| -> Option<(usize, String)> {
if consumer_count.get(conv_candidate).copied().unwrap_or(0) != 1 {
return None;
}
let idx = producer.get(conv_candidate)?;
if skip.contains(idx) {
return None;
}
if !matches!(nodes[*idx].op, OpKind::Conv) {
return None;
}
Some((*idx, residual_candidate.to_string()))
};
match try_find_conv(add_inp0, add_inp1).or_else(|| try_find_conv(add_inp1, add_inp0)) {
Some(result) => result,
None => continue,
}
};
let conv_node = &nodes[conv_idx];
let mut fused_inputs = Vec::with_capacity(4);
if conv_node.inputs.is_empty() {
continue;
}
fused_inputs.push(conv_node.inputs[0].clone());
if conv_node.inputs.len() < 2 {
continue;
}
fused_inputs.push(conv_node.inputs[1].clone());
fused_inputs.push(conv_node.inputs.get(2).cloned().unwrap_or_default());
fused_inputs.push(residual_name);
let mut fused_attrs = conv_node.attrs.clone();
fused_attrs
.strings
.insert("fused_ops".to_string(), "Add,Relu".to_string());
let fused = Node {
op: OpKind::ConvAddRelu,
name: format!("{}_fused_conv_add_relu", conv_node.name),
inputs: fused_inputs,
outputs: node.outputs.clone(),
attrs: fused_attrs,
};
replacements.insert(conv_idx, fused);
skip.insert(add_idx);
skip.insert(i);
}
nodes
.into_iter()
.enumerate()
.filter(|(i, _)| !skip.contains(i))
.map(|(i, n)| replacements.remove(&i).unwrap_or(n))
.collect()
}