use crate::graph::{Attributes, Node, OpKind};
use crate::tensor::Tensor;
use std::collections::{HashMap, HashSet};
pub fn fuse_conv_batchnorm(nodes: Vec<Node>, weights: &mut HashMap<String, Tensor>) -> Vec<Node> {
if nodes.len() < 2 {
return nodes;
}
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut consumer_count: HashMap<String, usize> = HashMap::new();
for node in &nodes {
for inp in &node.inputs {
if !inp.is_empty() {
*consumer_count.entry(inp.clone()).or_insert(0) += 1;
}
}
}
let mut skip: HashSet<usize> = HashSet::new();
let mut replacements: HashMap<usize, Node> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
if skip.contains(&i) {
continue;
}
if !matches!(node.op, OpKind::BatchNorm) {
continue;
}
if node.inputs.len() < 5 {
continue;
}
let conv_tensor = &node.inputs[0];
let bn_scale_name = &node.inputs[1];
let bn_bias_name = &node.inputs[2];
let bn_mean_name = &node.inputs[3];
let bn_var_name = &node.inputs[4];
if consumer_count.get(conv_tensor).copied().unwrap_or(0) != 1 {
continue;
}
let conv_idx = match producer.get(conv_tensor) {
Some(&idx) => idx,
None => continue,
};
if !matches!(nodes[conv_idx].op, OpKind::Conv) {
continue;
}
let bn_scale = match weights.get(bn_scale_name) {
Some(t) => t.clone(),
None => continue,
};
let bn_bias = match weights.get(bn_bias_name) {
Some(t) => t.clone(),
None => continue,
};
let bn_mean = match weights.get(bn_mean_name) {
Some(t) => t.clone(),
None => continue,
};
let bn_var = match weights.get(bn_var_name) {
Some(t) => t.clone(),
None => continue,
};
let epsilon = node.attrs.floats.get("epsilon").copied().unwrap_or(1e-5);
let conv_node = &nodes[conv_idx];
if conv_node.inputs.len() < 2 {
continue;
}
let conv_weight_name = &conv_node.inputs[1];
let conv_bias_name = conv_node.inputs.get(2).cloned();
let conv_weight = match weights.get(conv_weight_name) {
Some(t) => t.clone(),
None => continue,
};
let c_out = bn_scale.data.len();
if c_out == 0 || conv_weight.data.len() % c_out != 0 {
continue;
}
let weight_per_channel: usize = conv_weight.data.len() / c_out;
let mut fused_weight = conv_weight.data.clone();
let mut fused_bias = vec![0.0f32; c_out];
let conv_bias_data = if let Some(ref name) = conv_bias_name {
if let Some(b) = weights.get(name) {
b.data.clone()
} else {
vec![0.0f32; c_out]
}
} else {
vec![0.0f32; c_out]
};
for c in 0..c_out {
let inv_std = 1.0 / (bn_var.data[c] + epsilon).sqrt();
let factor = bn_scale.data[c] * inv_std;
let start = c * weight_per_channel;
for w in &mut fused_weight[start..start + weight_per_channel] {
*w *= factor;
}
fused_bias[c] = (conv_bias_data[c] - bn_mean.data[c]) * factor + bn_bias.data[c];
}
let fused_weight_name = format!("{}_fused_weight", conv_node.name);
let fused_bias_name = format!("{}_fused_bias", conv_node.name);
weights.insert(
fused_weight_name.clone(),
Tensor::new(fused_weight, conv_weight.shape.clone()),
);
weights.insert(
fused_bias_name.clone(),
Tensor::new(fused_bias, vec![c_out]),
);
let fused_inputs = vec![
conv_node.inputs[0].clone(),
fused_weight_name,
fused_bias_name,
];
let fused_conv = Node {
op: OpKind::Conv,
name: format!("{}_fused_convbn", conv_node.name),
inputs: fused_inputs,
outputs: node.outputs.clone(),
attrs: conv_node.attrs.clone(),
};
replacements.insert(conv_idx, fused_conv);
skip.insert(i);
}
nodes
.into_iter()
.enumerate()
.filter(|(i, _)| !skip.contains(i))
.map(|(i, n)| replacements.remove(&i).unwrap_or(n))
.collect()
}
pub fn fold_batch_norm_inference(
nodes: Vec<Node>,
weights: &mut HashMap<String, Tensor>,
) -> Vec<Node> {
if nodes.is_empty() {
return nodes;
}
let mut producer: HashMap<String, usize> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
for out in &node.outputs {
producer.insert(out.clone(), i);
}
}
let mut skip: HashSet<usize> = HashSet::new();
let mut new_nodes: Vec<(usize, Vec<Node>)> = Vec::new();
for (i, node) in nodes.iter().enumerate() {
if skip.contains(&i) {
continue;
}
if !matches!(node.op, OpKind::BatchNorm) {
continue;
}
if node.inputs.len() < 5 {
continue;
}
let x_name = &node.inputs[0];
let bn_scale_name = &node.inputs[1];
let bn_bias_name = &node.inputs[2];
let bn_mean_name = &node.inputs[3];
let bn_var_name = &node.inputs[4];
if let Some(&prev_idx) = producer.get(x_name) {
if matches!(nodes[prev_idx].op, OpKind::Conv) {
continue;
}
}
let bn_scale = match weights.get(bn_scale_name) {
Some(t) => t.clone(),
None => continue,
};
let bn_bias = match weights.get(bn_bias_name) {
Some(t) => t.clone(),
None => continue,
};
let bn_mean = match weights.get(bn_mean_name) {
Some(t) => t.clone(),
None => continue,
};
let bn_var = match weights.get(bn_var_name) {
Some(t) => t.clone(),
None => continue,
};
let epsilon = node.attrs.floats.get("epsilon").copied().unwrap_or(1e-5);
let c_out = bn_scale.data.len();
if c_out == 0
|| bn_bias.data.len() != c_out
|| bn_mean.data.len() != c_out
|| bn_var.data.len() != c_out
{
continue;
}
let mut factor_data = Vec::with_capacity(c_out);
let mut shift_data = Vec::with_capacity(c_out);
for c in 0..c_out {
let inv_std = 1.0 / (bn_var.data[c] + epsilon).sqrt();
let f = bn_scale.data[c] * inv_std;
factor_data.push(f);
shift_data.push(bn_bias.data[c] - bn_mean.data[c] * f);
}
let factor_name = format!("{}_bn_factor", node.name);
let shift_name = format!("{}_bn_shift", node.name);
let mul_out_name = format!("{}_bn_mul_out", node.name);
weights.insert(factor_name.clone(), Tensor::new(factor_data, vec![c_out]));
weights.insert(shift_name.clone(), Tensor::new(shift_data, vec![c_out]));
let mul_node = Node {
op: OpKind::Mul,
name: format!("{}_bn_mul", node.name),
inputs: vec![x_name.clone(), factor_name],
outputs: vec![mul_out_name.clone()],
attrs: Attributes::default(),
};
let add_node = Node {
op: OpKind::Add,
name: format!("{}_bn_add", node.name),
inputs: vec![mul_out_name, shift_name],
outputs: node.outputs.clone(),
attrs: Attributes::default(),
};
skip.insert(i);
new_nodes.push((i, vec![mul_node, add_node]));
}
let mut result = Vec::with_capacity(nodes.len() + new_nodes.len());
let replacement_map: HashMap<usize, Vec<Node>> = new_nodes.into_iter().collect();
for (i, node) in nodes.into_iter().enumerate() {
if let Some(replacement_nodes) = replacement_map.get(&i) {
result.extend(replacement_nodes.iter().cloned());
} else if !skip.contains(&i) {
result.push(node);
}
}
result
}