use crate::quantize::calibration::QuantizationParams;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub enum NodeType {
Conv2d,
BatchNorm,
ReLU,
HardSwish,
Quantize,
Dequantize,
Input,
Output,
}
#[derive(Debug, Clone)]
pub struct GraphNode {
pub id: usize,
pub node_type: NodeType,
pub inputs: Vec<usize>,
pub outputs: Vec<usize>,
pub params: NodeParams,
}
#[derive(Debug, Clone)]
pub enum NodeParams {
Conv2d {
weights: Vec<f32>,
bias: Option<Vec<f32>>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
},
BatchNorm {
gamma: Vec<f32>,
beta: Vec<f32>,
mean: Vec<f32>,
var: Vec<f32>,
eps: f32,
},
Activation,
Quantize {
scale: f32,
zero_point: i32,
},
Dequantize {
scale: f32,
zero_point: i32,
},
None,
}
#[derive(Debug, Clone)]
pub struct ComputationGraph {
pub nodes: HashMap<usize, GraphNode>,
pub next_id: usize,
}
impl ComputationGraph {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
next_id: 0,
}
}
pub fn add_node(&mut self, node_type: NodeType, params: NodeParams) -> usize {
let id = self.next_id;
self.next_id += 1;
self.nodes.insert(
id,
GraphNode {
id,
node_type,
inputs: Vec::new(),
outputs: Vec::new(),
params,
},
);
id
}
pub fn connect(&mut self, from: usize, to: usize) {
if let Some(from_node) = self.nodes.get_mut(&from) {
from_node.outputs.push(to);
}
if let Some(to_node) = self.nodes.get_mut(&to) {
to_node.inputs.push(from);
}
}
pub fn remove_node(&mut self, id: usize) {
if let Some(node) = self.nodes.remove(&id) {
for &input_id in &node.inputs {
if let Some(input_node) = self.nodes.get_mut(&input_id) {
input_node.outputs.retain(|&x| x != id);
input_node.outputs.extend(&node.outputs);
}
}
for &output_id in &node.outputs {
if let Some(output_node) = self.nodes.get_mut(&output_id) {
output_node.inputs.retain(|&x| x != id);
output_node.inputs.extend(&node.inputs);
}
}
}
}
pub fn get_node(&self, id: usize) -> Option<&GraphNode> {
self.nodes.get(&id)
}
pub fn get_node_mut(&mut self, id: usize) -> Option<&mut GraphNode> {
self.nodes.get_mut(&id)
}
}
impl Default for ComputationGraph {
fn default() -> Self {
Self::new()
}
}
pub fn fuse_batchnorm_to_conv(graph: &mut ComputationGraph) -> usize {
let mut fused_count = 0;
let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
for conv_id in node_ids {
let conv_node = match graph.get_node(conv_id) {
Some(node) if node.node_type == NodeType::Conv2d => node,
_ => continue,
};
let bn_id = match conv_node.outputs.first() {
Some(&id) => id,
None => continue,
};
let bn_node = match graph.get_node(bn_id) {
Some(node) if node.node_type == NodeType::BatchNorm => node,
_ => continue,
};
let (weights, bias, out_channels) = match &conv_node.params {
NodeParams::Conv2d {
weights,
bias,
out_channels,
..
} => (weights.clone(), bias.clone(), *out_channels),
_ => continue,
};
let (gamma, beta, mean, var, eps) = match &bn_node.params {
NodeParams::BatchNorm {
gamma,
beta,
mean,
var,
eps,
} => (gamma, beta, mean, var, *eps),
_ => continue,
};
let mut fused_weights = weights;
let mut fused_bias = bias.unwrap_or_else(|| vec![0.0; out_channels]);
for c in 0..out_channels {
let scale = gamma[c] / (var[c] + eps).sqrt();
let weights_per_channel = fused_weights.len() / out_channels;
for i in 0..weights_per_channel {
fused_weights[c * weights_per_channel + i] *= scale;
}
fused_bias[c] = (fused_bias[c] - mean[c]) * scale + beta[c];
}
if let Some(conv_node) = graph.get_node_mut(conv_id) {
if let NodeParams::Conv2d { weights, bias, .. } = &mut conv_node.params {
*weights = fused_weights;
*bias = Some(fused_bias);
}
}
graph.remove_node(bn_id);
fused_count += 1;
}
fused_count
}
pub fn fuse_zp_to_bias(
graph: &mut ComputationGraph,
quant_params: &HashMap<usize, QuantizationParams>,
) -> usize {
let mut fused_count = 0;
let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
for conv_id in node_ids {
let conv_node = match graph.get_node(conv_id) {
Some(node) if node.node_type == NodeType::Conv2d => node,
_ => continue,
};
let input_id = match conv_node.inputs.first() {
Some(&id) => id,
None => continue,
};
let input_qparams = match quant_params.get(&input_id) {
Some(qp) => qp,
None => continue,
};
let zp_input = input_qparams.zero_point as f32;
let (weights, bias, in_channels, out_channels, kernel_size) = match &conv_node.params {
NodeParams::Conv2d {
weights,
bias,
in_channels,
out_channels,
kernel_size,
} => (weights, bias, *in_channels, *out_channels, *kernel_size),
_ => continue,
};
let mut fused_bias = bias.clone().unwrap_or_else(|| vec![0.0; out_channels]);
let weights_per_channel = kernel_size * kernel_size * in_channels;
for c in 0..out_channels {
let mut weight_sum = 0.0;
for i in 0..weights_per_channel {
weight_sum += weights[c * weights_per_channel + i];
}
fused_bias[c] -= zp_input * weight_sum;
}
if let Some(conv_node) = graph.get_node_mut(conv_id) {
if let NodeParams::Conv2d { bias, .. } = &mut conv_node.params {
*bias = Some(fused_bias);
}
}
fused_count += 1;
}
fused_count
}
pub fn insert_qdq_nodes(
graph: &mut ComputationGraph,
quant_params: &HashMap<usize, QuantizationParams>,
) -> usize {
let mut inserted_count = 0;
let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
for node_id in node_ids {
let (node_type, inputs, outputs) = match graph.get_node(node_id) {
Some(n) => (n.node_type.clone(), n.inputs.clone(), n.outputs.clone()),
None => continue,
};
if matches!(node_type, NodeType::Quantize | NodeType::Dequantize) {
continue;
}
for &input_id in &inputs {
let input_node_type = match graph.get_node(input_id) {
Some(n) => n.node_type.clone(),
None => continue,
};
let needs_quantize = is_quantized_op(&node_type)
&& !is_quantized_op(&input_node_type)
&& quant_params.contains_key(&node_id);
if needs_quantize {
let qparams = &quant_params[&node_id];
let q_id = graph.add_node(
NodeType::Quantize,
NodeParams::Quantize {
scale: qparams.scale,
zero_point: qparams.zero_point,
},
);
graph.nodes.get_mut(&input_id).unwrap().outputs.retain(|&x| x != node_id);
graph.nodes.get_mut(&input_id).unwrap().outputs.push(q_id);
graph.nodes.get_mut(&node_id).unwrap().inputs.retain(|&x| x != input_id);
graph.nodes.get_mut(&node_id).unwrap().inputs.push(q_id);
graph.nodes.get_mut(&q_id).unwrap().inputs.push(input_id);
graph.nodes.get_mut(&q_id).unwrap().outputs.push(node_id);
inserted_count += 1;
}
}
for &output_id in &outputs {
let output_node_type = match graph.get_node(output_id) {
Some(n) => n.node_type.clone(),
None => continue,
};
let needs_dequantize = is_quantized_op(&node_type)
&& !is_quantized_op(&output_node_type)
&& quant_params.contains_key(&node_id);
if needs_dequantize {
let qparams = &quant_params[&node_id];
let dq_id = graph.add_node(
NodeType::Dequantize,
NodeParams::Dequantize {
scale: qparams.scale,
zero_point: qparams.zero_point,
},
);
graph.nodes.get_mut(&node_id).unwrap().outputs.retain(|&x| x != output_id);
graph.nodes.get_mut(&node_id).unwrap().outputs.push(dq_id);
graph.nodes.get_mut(&output_id).unwrap().inputs.retain(|&x| x != node_id);
graph.nodes.get_mut(&output_id).unwrap().inputs.push(dq_id);
graph.nodes.get_mut(&dq_id).unwrap().inputs.push(node_id);
graph.nodes.get_mut(&dq_id).unwrap().outputs.push(output_id);
inserted_count += 1;
}
}
}
inserted_count
}
fn is_quantized_op(node_type: &NodeType) -> bool {
matches!(
node_type,
NodeType::Conv2d | NodeType::Quantize | NodeType::Dequantize
)
}
pub fn fuse_relu(graph: &mut ComputationGraph) -> usize {
let mut fused_count = 0;
let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
for conv_id in node_ids {
let conv_node = match graph.get_node(conv_id) {
Some(node) if node.node_type == NodeType::Conv2d => node,
_ => continue,
};
let relu_id = match conv_node.outputs.first() {
Some(&id) => id,
None => continue,
};
let _relu_node = match graph.get_node(relu_id) {
Some(node) if node.node_type == NodeType::ReLU => node,
_ => continue,
};
graph.remove_node(relu_id);
fused_count += 1;
}
fused_count
}
pub fn fuse_hardswish(graph: &mut ComputationGraph) -> usize {
let mut fused_count = 0;
let node_ids: Vec<usize> = graph.nodes.keys().copied().collect();
for conv_id in node_ids {
let conv_node = match graph.get_node(conv_id) {
Some(node) if node.node_type == NodeType::Conv2d => node,
_ => continue,
};
let hs_id = match conv_node.outputs.first() {
Some(&id) => id,
None => continue,
};
let _hs_node = match graph.get_node(hs_id) {
Some(node) if node.node_type == NodeType::HardSwish => node,
_ => continue,
};
graph.remove_node(hs_id);
fused_count += 1;
}
fused_count
}
pub fn generate_hardswish_lut(scale: f32, zero_point: i32) -> [i8; 256] {
let mut lut = [0i8; 256];
for i in 0..256 {
let q_input = i as i8;
let x = (q_input as i32 - zero_point) as f32 * scale;
let relu6 = ((x + 3.0).max(0.0)).min(6.0);
let hs_output = x * relu6 / 6.0;
let q_output = (hs_output / scale).round() as i32 + zero_point;
lut[i] = q_output.clamp(-128, 127) as i8;
}
lut
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fuse_batchnorm_to_conv() {
let mut graph = ComputationGraph::new();
let conv_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0, 2.0, 3.0, 4.0], bias: Some(vec![0.5, 1.0]),
in_channels: 1,
out_channels: 2,
kernel_size: 1,
},
);
let bn_id = graph.add_node(
NodeType::BatchNorm,
NodeParams::BatchNorm {
gamma: vec![2.0, 3.0],
beta: vec![0.1, 0.2],
mean: vec![0.5, 1.0],
var: vec![1.0, 4.0],
eps: 1e-5,
},
);
graph.connect(conv_id, bn_id);
let fused = fuse_batchnorm_to_conv(&mut graph);
assert_eq!(fused, 1);
assert!(graph.get_node(bn_id).is_none());
let conv_node = graph.get_node(conv_id).unwrap();
if let NodeParams::Conv2d { weights, bias, .. } = &conv_node.params {
assert!((weights[0] - 2.0).abs() < 0.01);
assert!((weights[1] - 4.0).abs() < 0.01);
assert!((weights[2] - 4.5).abs() < 0.01);
assert!((weights[3] - 6.0).abs() < 0.01);
let bias = bias.as_ref().unwrap();
assert!((bias[0] - 0.1).abs() < 0.01);
assert!((bias[1] - 0.2).abs() < 0.01);
} else {
panic!("Expected Conv2d params");
}
}
#[test]
fn test_fuse_zp_to_bias() {
let mut graph = ComputationGraph::new();
let input_id = graph.add_node(NodeType::Input, NodeParams::None);
let conv_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0, 2.0, 3.0, 4.0], bias: Some(vec![1.0, 2.0]),
in_channels: 1,
out_channels: 2,
kernel_size: 1,
},
);
graph.connect(input_id, conv_id);
let mut quant_params = HashMap::new();
quant_params.insert(
input_id,
QuantizationParams {
scale: 0.1,
zero_point: 10,
min_val: -12.8,
max_val: 12.7,
num_bins: 256,
},
);
let fused = fuse_zp_to_bias(&mut graph, &quant_params);
assert_eq!(fused, 1);
let conv_node = graph.get_node(conv_id).unwrap();
if let NodeParams::Conv2d { bias, .. } = &conv_node.params {
let bias = bias.as_ref().unwrap();
assert!((bias[0] - (-29.0)).abs() < 0.01);
assert!((bias[1] - (-68.0)).abs() < 0.01);
} else {
panic!("Expected Conv2d params");
}
}
#[test]
fn test_insert_qdq_nodes() {
let mut graph = ComputationGraph::new();
let input_id = graph.add_node(NodeType::Input, NodeParams::None);
let conv_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0; 4],
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: 2,
},
);
let output_id = graph.add_node(NodeType::Output, NodeParams::None);
graph.connect(input_id, conv_id);
graph.connect(conv_id, output_id);
let mut quant_params = HashMap::new();
quant_params.insert(
conv_id,
QuantizationParams {
scale: 0.1,
zero_point: 0,
min_val: -12.8,
max_val: 12.7,
num_bins: 256,
},
);
let inserted = insert_qdq_nodes(&mut graph, &quant_params);
assert_eq!(inserted, 2);
let conv_node = graph.get_node(conv_id).unwrap();
assert_eq!(conv_node.inputs.len(), 1);
let q_id = conv_node.inputs[0];
let q_node = graph.get_node(q_id).unwrap();
assert_eq!(q_node.node_type, NodeType::Quantize);
assert_eq!(conv_node.outputs.len(), 1);
let dq_id = conv_node.outputs[0];
let dq_node = graph.get_node(dq_id).unwrap();
assert_eq!(dq_node.node_type, NodeType::Dequantize);
}
#[test]
fn test_fuse_relu() {
let mut graph = ComputationGraph::new();
let conv_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0; 4],
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: 2,
},
);
let relu_id = graph.add_node(NodeType::ReLU, NodeParams::Activation);
graph.connect(conv_id, relu_id);
let fused = fuse_relu(&mut graph);
assert_eq!(fused, 1);
assert!(graph.get_node(relu_id).is_none());
let conv_node = graph.get_node(conv_id).unwrap();
assert_eq!(conv_node.outputs, vec![]); }
#[test]
fn test_fuse_hardswish() {
let mut graph = ComputationGraph::new();
let conv_id = graph.add_node(
NodeType::Conv2d,
NodeParams::Conv2d {
weights: vec![1.0; 4],
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: 2,
},
);
let hs_id = graph.add_node(NodeType::HardSwish, NodeParams::Activation);
graph.connect(conv_id, hs_id);
let fused = fuse_hardswish(&mut graph);
assert_eq!(fused, 1);
assert!(graph.get_node(hs_id).is_none());
}
#[test]
fn test_hardswish_lut_generation() {
let scale = 0.1;
let zero_point = 0;
let lut = generate_hardswish_lut(scale, zero_point);
let idx_0 = (0 - zero_point + 128) as usize;
assert_eq!(lut[idx_0], 0);
let idx_neg3 = ((-30 as i32 - zero_point + 128) as usize).min(255);
assert_eq!(lut[idx_neg3], 0);
let idx_pos3 = ((30 as i32 - zero_point + 128) as usize).min(255);
let x_pos3 = (lut[idx_pos3] as i32 - zero_point) as f32 * scale;
assert!((x_pos3 - 3.0).abs() < 0.5); }
#[test]
fn test_graph_construction() {
let mut graph = ComputationGraph::new();
let id1 = graph.add_node(NodeType::Input, NodeParams::None);
let id2 = graph.add_node(NodeType::Conv2d, NodeParams::Conv2d {
weights: vec![1.0; 4],
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: 2,
});
let id3 = graph.add_node(NodeType::Output, NodeParams::None);
graph.connect(id1, id2);
graph.connect(id2, id3);
assert_eq!(graph.nodes.len(), 3);
assert_eq!(graph.get_node(id2).unwrap().inputs, vec![id1]);
assert_eq!(graph.get_node(id2).unwrap().outputs, vec![id3]);
}
#[test]
fn test_remove_node() {
let mut graph = ComputationGraph::new();
let id1 = graph.add_node(NodeType::Input, NodeParams::None);
let id2 = graph.add_node(NodeType::ReLU, NodeParams::Activation);
let id3 = graph.add_node(NodeType::Output, NodeParams::None);
graph.connect(id1, id2);
graph.connect(id2, id3);
graph.remove_node(id2);
assert!(graph.get_node(id2).is_none());
assert_eq!(graph.get_node(id1).unwrap().outputs, vec![id3]);
assert_eq!(graph.get_node(id3).unwrap().inputs, vec![id1]);
}
}