use crate::graph::{Attributes, Node, OpKind};
use crate::tensor::Tensor;
use std::collections::{HashMap, HashSet};
pub fn fuse_add_matmul_to_gemm(
nodes: Vec<Node>,
weights: &mut HashMap<String, Tensor>,
) -> Vec<Node> {
if nodes.len() < 2 {
return nodes;
}
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut consumer_count: HashMap<String, usize> = HashMap::new();
for node in &nodes {
for inp in &node.inputs {
if !inp.is_empty() {
*consumer_count.entry(inp.clone()).or_insert(0) += 1;
}
}
}
let mut skip: HashSet<usize> = HashSet::new();
let mut replacements: HashMap<usize, Node> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if skip.contains(&i) {
continue;
}
if !matches!(node.op, OpKind::MatMul) {
continue;
}
if node.inputs.len() < 2 {
continue;
}
let add_out = &node.inputs[0];
let w_name = &node.inputs[1];
let w_tensor = match weights.get(w_name) {
Some(t) => t.clone(),
None => continue,
};
if w_tensor.shape.len() != 2 {
continue;
}
if consumer_count.get(add_out).copied().unwrap_or(0) != 1 {
continue;
}
let add_idx = match producer.get(add_out) {
Some(&idx) => idx,
None => continue,
};
if skip.contains(&add_idx) {
continue;
}
if !matches!(nodes[add_idx].op, OpKind::Add) {
continue;
}
if nodes[add_idx].inputs.len() < 2 {
continue;
}
let (x_name, bias_name) = {
let inp0 = &nodes[add_idx].inputs[0];
let inp1 = &nodes[add_idx].inputs[1];
if let Some(b) = weights.get(inp1) {
if b.ndim() == 1 {
(inp0.clone(), inp1.clone())
} else {
continue;
}
} else if let Some(b) = weights.get(inp0) {
if b.ndim() == 1 {
(inp1.clone(), inp0.clone())
} else {
continue;
}
} else {
continue;
}
};
let bias = match weights.get(&bias_name) {
Some(t) => t.clone(),
None => continue,
};
let k = w_tensor.shape[0];
let n = w_tensor.shape[1];
if bias.shape.len() != 1 || bias.shape[0] != k {
continue;
}
let mut fused_bias_data = vec![0.0f32; n];
for (j, fused_val) in fused_bias_data.iter_mut().enumerate() {
let mut sum = 0.0f32;
for ki in 0..k {
sum += bias.data[ki] * w_tensor.data[ki * n + j];
}
*fused_val = sum;
}
let fused_bias_name = format!("{}_fused_add_matmul_bias", nodes[add_idx].name);
weights.insert(
fused_bias_name.clone(),
Tensor::new(fused_bias_data, vec![n]),
);
let mut attrs = Attributes::default();
attrs.floats.insert("alpha".to_string(), 1.0);
attrs.floats.insert("beta".to_string(), 1.0);
attrs.ints.insert("transA".to_string(), 0);
attrs.ints.insert("transB".to_string(), 0);
let fused = Node {
op: OpKind::Gemm,
name: format!("{}_fused_add_matmul_gemm", nodes[add_idx].name),
inputs: vec![x_name, w_name.clone(), fused_bias_name],
outputs: node.outputs.clone(),
attrs,
};
replacements.insert(add_idx, fused);
skip.insert(i);
}
nodes
.into_iter()
.enumerate()
.filter(|(i, _)| !skip.contains(i))
.map(|(i, n)| replacements.remove(&i).unwrap_or(n))
.collect()
}