Skip to main content

ronn_graph/passes/
fusion.rs

1use super::{OptimizationPass, PassStats};
2use crate::error::Result;
3use ronn_core::ModelGraph;
4use tracing::debug;
5
6/// Node fusion pass - combines compatible operations
7/// Examples: Conv+BatchNorm+ReLU, MatMul+Add (bias)
8pub struct NodeFusionPass;
9
10impl OptimizationPass for NodeFusionPass {
11    fn name(&self) -> &str {
12        "NodeFusion"
13    }
14
15    fn run(&self, graph: &mut ModelGraph) -> Result<PassStats> {
16        let mut stats = PassStats::default();
17
18        // Look for fusion patterns
19        stats.nodes_fused += self.fuse_conv_bn_relu(graph)?;
20        stats.nodes_fused += self.fuse_matmul_add(graph)?;
21
22        debug!(
23            "Node fusion pass completed: {} nodes fused",
24            stats.nodes_fused
25        );
26
27        Ok(stats)
28    }
29}
30
31impl NodeFusionPass {
32    /// Fuse Conv + BatchNorm + ReLU into a single operation
33    fn fuse_conv_bn_relu(&self, graph: &mut ModelGraph) -> Result<usize> {
34        let mut fused_count = 0;
35
36        // Pattern: Conv -> BatchNorm -> ReLU
37        for node in graph.nodes() {
38            if node.op_type == "Conv" {
39                // Check if followed by BatchNorm
40                if let Some(bn_node) =
41                    Self::find_successor(graph, &node.id.to_string(), "BatchNormalization")
42                {
43                    // Check if BatchNorm is followed by ReLU
44                    if let Some(_relu_node) = Self::find_successor(graph, &bn_node, "Relu") {
45                        debug!("Found Conv+BN+ReLU pattern at node: {}", node.id);
46                        // Fuse these three nodes into one
47                        // This would create a fused op with combined parameters
48                        fused_count += 1;
49                    }
50                }
51            }
52        }
53
54        Ok(fused_count)
55    }
56
57    /// Fuse MatMul + Add (bias) into a single operation
58    fn fuse_matmul_add(&self, graph: &mut ModelGraph) -> Result<usize> {
59        let mut fused_count = 0;
60
61        // Pattern: MatMul -> Add (where Add is adding a bias vector)
62        for node in graph.nodes() {
63            if node.op_type == "MatMul" {
64                // Check if followed by Add
65                if let Some(_add_node) = Self::find_successor(graph, &node.id.to_string(), "Add") {
66                    debug!("Found MatMul+Add pattern at node: {}", node.id);
67                    // Fuse into MatMul with bias
68                    fused_count += 1;
69                }
70            }
71        }
72
73        Ok(fused_count)
74    }
75
76    /// Find a successor node with the given op type
77    fn find_successor(graph: &ModelGraph, node_id: &str, op_type: &str) -> Option<String> {
78        // Get the node's outputs
79        // Find nodes that consume those outputs
80        // Check if any match the op_type
81        // Simplified implementation - would require proper graph traversal
82        None
83    }
84}