1use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq)]
10pub enum FusionPattern {
11 ElementWise(Vec<String>),
13 Reduction(String),
15 MatrixOp(String),
17 Custom(String, Vec<String>),
19}
20
21#[derive(Debug, Clone)]
23pub struct GraphNode {
24 pub id: usize,
25 pub op_type: String,
26 pub inputs: Vec<usize>,
27 pub outputs: Vec<usize>,
28 pub fusible: bool,
29}
30
31#[derive(Debug, Clone)]
33pub struct ComputeGraph {
34 nodes: Vec<GraphNode>,
35 next_id: usize,
36}
37
38impl ComputeGraph {
39 pub fn new() -> Self {
41 Self {
42 nodes: Vec::new(),
43 next_id: 0,
44 }
45 }
46
47 pub fn add_node(&mut self, op_type: String, inputs: Vec<usize>, fusible: bool) -> usize {
49 let id = self.next_id;
50 self.next_id += 1;
51
52 for &input_id in &inputs {
54 if let Some(node) = self.nodes.iter_mut().find(|n| n.id == input_id) {
55 node.outputs.push(id);
56 }
57 }
58
59 self.nodes.push(GraphNode {
61 id,
62 op_type,
63 inputs,
64 outputs: Vec::new(),
65 fusible,
66 });
67
68 id
69 }
70
71 pub fn nodes(&self) -> &[GraphNode] {
73 &self.nodes
74 }
75
76 pub fn get_node(&self, id: usize) -> Option<&GraphNode> {
78 self.nodes.iter().find(|n| n.id == id)
79 }
80}
81
82impl Default for ComputeGraph {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88pub struct FusionEngine {
90 patterns: Vec<FusionPattern>,
91 fused_ops: HashMap<String, Vec<String>>,
92}
93
94impl FusionEngine {
95 pub fn new() -> Self {
97 let mut engine = Self {
98 patterns: Vec::new(),
99 fused_ops: HashMap::new(),
100 };
101
102 engine.register_default_patterns();
104 engine
105 }
106
107 fn register_default_patterns(&mut self) {
109 self.add_pattern(FusionPattern::Custom(
111 "ConvBNReLU".to_string(),
112 vec!["Conv2d".to_string(), "BatchNorm".to_string(), "ReLU".to_string()],
113 ));
114
115 self.add_pattern(FusionPattern::Custom(
117 "LinearReLU".to_string(),
118 vec!["Linear".to_string(), "ReLU".to_string()],
119 ));
120
121 self.add_pattern(FusionPattern::Custom(
123 "GEMM".to_string(),
124 vec!["MatMul".to_string(), "Add".to_string()],
125 ));
126
127 self.add_pattern(FusionPattern::Custom(
129 "AddReLU".to_string(),
130 vec!["Add".to_string(), "ReLU".to_string()],
131 ));
132
133 self.add_pattern(FusionPattern::Custom(
135 "FMA".to_string(),
136 vec!["Mul".to_string(), "Add".to_string()],
137 ));
138
139 self.add_pattern(FusionPattern::Custom(
141 "BNReLU".to_string(),
142 vec!["BatchNorm".to_string(), "ReLU".to_string()],
143 ));
144 }
145
146 pub fn add_pattern(&mut self, pattern: FusionPattern) {
148 self.patterns.push(pattern);
149 }
150
151 pub fn analyze(&mut self, graph: &ComputeGraph) -> Vec<FusionOpportunity> {
153 let mut opportunities = Vec::new();
154
155 for pattern in &self.patterns {
157 if let FusionPattern::Custom(name, ops) = pattern {
158 opportunities.extend(self.find_pattern_matches(graph, name, ops));
159 }
160 }
161
162 opportunities
163 }
164
165 fn find_pattern_matches(
167 &self,
168 graph: &ComputeGraph,
169 pattern_name: &str,
170 ops: &[String],
171 ) -> Vec<FusionOpportunity> {
172 let mut matches = Vec::new();
173
174 for i in 0..graph.nodes().len() {
176 if self.matches_pattern_at(graph, i, ops) {
177 let node_ids: Vec<usize> = (i..i + ops.len()).collect();
178 matches.push(FusionOpportunity {
179 pattern_name: pattern_name.to_string(),
180 nodes: node_ids,
181 estimated_speedup: self.estimate_speedup(ops),
182 });
183 }
184 }
185
186 matches
187 }
188
189 fn matches_pattern_at(&self, graph: &ComputeGraph, start: usize, ops: &[String]) -> bool {
191 if start + ops.len() > graph.nodes().len() {
192 return false;
193 }
194
195 for (i, op) in ops.iter().enumerate() {
196 if let Some(node) = graph.get_node(start + i) {
197 if &node.op_type != op || !node.fusible {
198 return false;
199 }
200 } else {
201 return false;
202 }
203 }
204
205 true
206 }
207
208 fn estimate_speedup(&self, ops: &[String]) -> f32 {
210 match ops.len() {
212 2 => 1.3, 3 => 1.5, 4 => 1.7, _ => 1.2, }
217 }
218
219 pub fn fuse(&mut self, graph: &mut ComputeGraph, opportunities: &[FusionOpportunity]) {
221 for opp in opportunities {
222 self.fused_ops.insert(
223 opp.pattern_name.clone(),
224 opp.nodes.iter().map(|&id| {
225 graph.get_node(id).map(|n| n.op_type.clone()).unwrap_or_default()
226 }).collect(),
227 );
228 }
229 }
230
231 pub fn get_fused_ops(&self) -> &HashMap<String, Vec<String>> {
233 &self.fused_ops
234 }
235}
236
237impl Default for FusionEngine {
238 fn default() -> Self {
239 Self::new()
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct FusionOpportunity {
246 pub pattern_name: String,
247 pub nodes: Vec<usize>,
248 pub estimated_speedup: f32,
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_compute_graph() {
257 let mut graph = ComputeGraph::new();
258
259 let n1 = graph.add_node("Input".to_string(), vec![], false);
260 let n2 = graph.add_node("Conv2d".to_string(), vec![n1], true);
261 let n3 = graph.add_node("ReLU".to_string(), vec![n2], true);
262
263 assert_eq!(graph.nodes().len(), 3);
264 assert_eq!(graph.get_node(n2).unwrap().op_type, "Conv2d");
265 }
266
267 #[test]
268 fn test_fusion_engine() {
269 let mut engine = FusionEngine::new();
270 let mut graph = ComputeGraph::new();
271
272 let n1 = graph.add_node("Input".to_string(), vec![], false);
273 let n2 = graph.add_node("Linear".to_string(), vec![n1], true);
274 let n3 = graph.add_node("ReLU".to_string(), vec![n2], true);
275
276 let opportunities = engine.analyze(&graph);
277
278 assert!(!opportunities.is_empty());
280 }
281
282 #[test]
283 fn test_fusion_patterns() {
284 let engine = FusionEngine::new();
285
286 assert!(!engine.patterns.is_empty());
288 }
289}