ghostflow_core/
fusion.rs

1//! Kernel fusion engine for optimizing computation graphs
2//!
3//! This module provides automatic fusion of operations to reduce memory bandwidth
4//! and improve performance.
5
6use std::collections::HashMap;
7
8/// Fusion pattern for combining operations
9#[derive(Debug, Clone, PartialEq)]
10pub enum FusionPattern {
11    /// Element-wise operations (add, mul, relu, etc.)
12    ElementWise(Vec<String>),
13    /// Reduction operations (sum, mean, max, etc.)
14    Reduction(String),
15    /// Matrix operations (matmul, conv, etc.)
16    MatrixOp(String),
17    /// Custom fusion pattern
18    Custom(String, Vec<String>),
19}
20
21/// Computational graph node
22#[derive(Debug, Clone)]
23pub struct GraphNode {
24    pub id: usize,
25    pub op_type: String,
26    pub inputs: Vec<usize>,
27    pub outputs: Vec<usize>,
28    pub fusible: bool,
29}
30
31/// Computational graph for fusion analysis
32#[derive(Debug, Clone)]
33pub struct ComputeGraph {
34    nodes: Vec<GraphNode>,
35    next_id: usize,
36}
37
38impl ComputeGraph {
39    /// Create a new empty compute graph
40    pub fn new() -> Self {
41        Self {
42            nodes: Vec::new(),
43            next_id: 0,
44        }
45    }
46
47    /// Add a node to the graph
48    pub fn add_node(&mut self, op_type: String, inputs: Vec<usize>, fusible: bool) -> usize {
49        let id = self.next_id;
50        self.next_id += 1;
51
52        // Update outputs of input nodes first
53        for &input_id in &inputs {
54            if let Some(node) = self.nodes.iter_mut().find(|n| n.id == input_id) {
55                node.outputs.push(id);
56            }
57        }
58
59        // Then add the new node
60        self.nodes.push(GraphNode {
61            id,
62            op_type,
63            inputs,
64            outputs: Vec::new(),
65            fusible,
66        });
67
68        id
69    }
70
71    /// Get all nodes
72    pub fn nodes(&self) -> &[GraphNode] {
73        &self.nodes
74    }
75
76    /// Get a node by ID
77    pub fn get_node(&self, id: usize) -> Option<&GraphNode> {
78        self.nodes.iter().find(|n| n.id == id)
79    }
80}
81
82impl Default for ComputeGraph {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88/// Fusion engine for optimizing computation graphs
89pub struct FusionEngine {
90    patterns: Vec<FusionPattern>,
91    fused_ops: HashMap<String, Vec<String>>,
92}
93
94impl FusionEngine {
95    /// Create a new fusion engine
96    pub fn new() -> Self {
97        let mut engine = Self {
98            patterns: Vec::new(),
99            fused_ops: HashMap::new(),
100        };
101
102        // Register default fusion patterns
103        engine.register_default_patterns();
104        engine
105    }
106
107    /// Register default fusion patterns
108    fn register_default_patterns(&mut self) {
109        // Conv + BatchNorm + ReLU
110        self.add_pattern(FusionPattern::Custom(
111            "ConvBNReLU".to_string(),
112            vec!["Conv2d".to_string(), "BatchNorm".to_string(), "ReLU".to_string()],
113        ));
114
115        // Linear + ReLU
116        self.add_pattern(FusionPattern::Custom(
117            "LinearReLU".to_string(),
118            vec!["Linear".to_string(), "ReLU".to_string()],
119        ));
120
121        // MatMul + Add (GEMM)
122        self.add_pattern(FusionPattern::Custom(
123            "GEMM".to_string(),
124            vec!["MatMul".to_string(), "Add".to_string()],
125        ));
126
127        // Add + ReLU
128        self.add_pattern(FusionPattern::Custom(
129            "AddReLU".to_string(),
130            vec!["Add".to_string(), "ReLU".to_string()],
131        ));
132
133        // Mul + Add (FMA - Fused Multiply-Add)
134        self.add_pattern(FusionPattern::Custom(
135            "FMA".to_string(),
136            vec!["Mul".to_string(), "Add".to_string()],
137        ));
138
139        // BatchNorm + ReLU
140        self.add_pattern(FusionPattern::Custom(
141            "BNReLU".to_string(),
142            vec!["BatchNorm".to_string(), "ReLU".to_string()],
143        ));
144    }
145
146    /// Add a fusion pattern
147    pub fn add_pattern(&mut self, pattern: FusionPattern) {
148        self.patterns.push(pattern);
149    }
150
151    /// Analyze a compute graph and find fusion opportunities
152    pub fn analyze(&mut self, graph: &ComputeGraph) -> Vec<FusionOpportunity> {
153        let mut opportunities = Vec::new();
154
155        // Check each pattern against the graph
156        for pattern in &self.patterns {
157            if let FusionPattern::Custom(name, ops) = pattern {
158                opportunities.extend(self.find_pattern_matches(graph, name, ops));
159            }
160        }
161
162        opportunities
163    }
164
165    /// Find matches for a specific pattern in the graph
166    fn find_pattern_matches(
167        &self,
168        graph: &ComputeGraph,
169        pattern_name: &str,
170        ops: &[String],
171    ) -> Vec<FusionOpportunity> {
172        let mut matches = Vec::new();
173
174        // Simple pattern matching: look for consecutive operations
175        for i in 0..graph.nodes().len() {
176            if self.matches_pattern_at(graph, i, ops) {
177                let node_ids: Vec<usize> = (i..i + ops.len()).collect();
178                matches.push(FusionOpportunity {
179                    pattern_name: pattern_name.to_string(),
180                    nodes: node_ids,
181                    estimated_speedup: self.estimate_speedup(ops),
182                });
183            }
184        }
185
186        matches
187    }
188
189    /// Check if a pattern matches at a specific position
190    fn matches_pattern_at(&self, graph: &ComputeGraph, start: usize, ops: &[String]) -> bool {
191        if start + ops.len() > graph.nodes().len() {
192            return false;
193        }
194
195        for (i, op) in ops.iter().enumerate() {
196            if let Some(node) = graph.get_node(start + i) {
197                if &node.op_type != op || !node.fusible {
198                    return false;
199                }
200            } else {
201                return false;
202            }
203        }
204
205        true
206    }
207
208    /// Estimate speedup from fusing operations
209    fn estimate_speedup(&self, ops: &[String]) -> f32 {
210        // Simple heuristic: more ops fused = better speedup
211        match ops.len() {
212            2 => 1.3,  // 30% speedup
213            3 => 1.5,  // 50% speedup
214            4 => 1.7,  // 70% speedup
215            _ => 1.2,  // 20% speedup
216        }
217    }
218
219    /// Apply fusion to a graph
220    pub fn fuse(&mut self, graph: &mut ComputeGraph, opportunities: &[FusionOpportunity]) {
221        for opp in opportunities {
222            self.fused_ops.insert(
223                opp.pattern_name.clone(),
224                opp.nodes.iter().map(|&id| {
225                    graph.get_node(id).map(|n| n.op_type.clone()).unwrap_or_default()
226                }).collect(),
227            );
228        }
229    }
230
231    /// Get fused operations
232    pub fn get_fused_ops(&self) -> &HashMap<String, Vec<String>> {
233        &self.fused_ops
234    }
235}
236
237impl Default for FusionEngine {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243/// Fusion opportunity found in the graph
244#[derive(Debug, Clone)]
245pub struct FusionOpportunity {
246    pub pattern_name: String,
247    pub nodes: Vec<usize>,
248    pub estimated_speedup: f32,
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_compute_graph() {
257        let mut graph = ComputeGraph::new();
258        
259        let n1 = graph.add_node("Input".to_string(), vec![], false);
260        let n2 = graph.add_node("Conv2d".to_string(), vec![n1], true);
261        let n3 = graph.add_node("ReLU".to_string(), vec![n2], true);
262        
263        assert_eq!(graph.nodes().len(), 3);
264        assert_eq!(graph.get_node(n2).unwrap().op_type, "Conv2d");
265    }
266
267    #[test]
268    fn test_fusion_engine() {
269        let mut engine = FusionEngine::new();
270        let mut graph = ComputeGraph::new();
271        
272        let n1 = graph.add_node("Input".to_string(), vec![], false);
273        let n2 = graph.add_node("Linear".to_string(), vec![n1], true);
274        let n3 = graph.add_node("ReLU".to_string(), vec![n2], true);
275        
276        let opportunities = engine.analyze(&graph);
277        
278        // Should find LinearReLU fusion
279        assert!(!opportunities.is_empty());
280    }
281
282    #[test]
283    fn test_fusion_patterns() {
284        let engine = FusionEngine::new();
285        
286        // Should have default patterns registered
287        assert!(!engine.patterns.is_empty());
288    }
289}