ronn_graph/passes/
fusion.rs1use super::{OptimizationPass, PassStats};
2use crate::error::Result;
3use ronn_core::ModelGraph;
4use tracing::debug;
5
6pub struct NodeFusionPass;
9
10impl OptimizationPass for NodeFusionPass {
11 fn name(&self) -> &str {
12 "NodeFusion"
13 }
14
15 fn run(&self, graph: &mut ModelGraph) -> Result<PassStats> {
16 let mut stats = PassStats::default();
17
18 stats.nodes_fused += self.fuse_conv_bn_relu(graph)?;
20 stats.nodes_fused += self.fuse_matmul_add(graph)?;
21
22 debug!(
23 "Node fusion pass completed: {} nodes fused",
24 stats.nodes_fused
25 );
26
27 Ok(stats)
28 }
29}
30
31impl NodeFusionPass {
32 fn fuse_conv_bn_relu(&self, graph: &mut ModelGraph) -> Result<usize> {
34 let mut fused_count = 0;
35
36 for node in graph.nodes() {
38 if node.op_type == "Conv" {
39 if let Some(bn_node) =
41 Self::find_successor(graph, &node.id.to_string(), "BatchNormalization")
42 {
43 if let Some(_relu_node) = Self::find_successor(graph, &bn_node, "Relu") {
45 debug!("Found Conv+BN+ReLU pattern at node: {}", node.id);
46 fused_count += 1;
49 }
50 }
51 }
52 }
53
54 Ok(fused_count)
55 }
56
57 fn fuse_matmul_add(&self, graph: &mut ModelGraph) -> Result<usize> {
59 let mut fused_count = 0;
60
61 for node in graph.nodes() {
63 if node.op_type == "MatMul" {
64 if let Some(_add_node) = Self::find_successor(graph, &node.id.to_string(), "Add") {
66 debug!("Found MatMul+Add pattern at node: {}", node.id);
67 fused_count += 1;
69 }
70 }
71 }
72
73 Ok(fused_count)
74 }
75
76 fn find_successor(graph: &ModelGraph, node_id: &str, op_type: &str) -> Option<String> {
78 None
83 }
84}