1#[derive(Clone, Debug)]
8pub struct ComputeGraph {
9 pub nodes: Vec<GraphNode>,
10 pub edges: Vec<(usize, usize)>,
11}
12
13#[derive(Clone, Debug)]
14pub struct GraphNode {
15 pub id: usize,
16 pub op: Operation,
17 pub inputs: Vec<usize>,
18 pub outputs: Vec<usize>,
19}
20
21#[derive(Clone, Debug, PartialEq)]
22pub enum Operation {
23 Conv2d { channels: usize, kernel: (usize, usize) },
24 BatchNorm { channels: usize },
25 ReLU,
26 GELU,
27 MatMul { m: usize, n: usize, k: usize },
28 Add,
29 Mul,
30 Softmax { dim: i32 },
31 LayerNorm,
32 Attention { heads: usize, dim: usize },
33}
34
35#[derive(Clone, Debug)]
37pub enum FusionPattern {
38 ConvBnRelu,
40 MatMulAddActivation,
42 ElementWiseChain,
44 AttentionPattern,
46 LayerNormLinear,
48}
49
50pub struct FusionEngine {
52 patterns: Vec<FusionPattern>,
53 enabled: bool,
54}
55
56impl FusionEngine {
57 pub fn new() -> Self {
59 Self {
60 patterns: vec![
61 FusionPattern::ConvBnRelu,
62 FusionPattern::MatMulAddActivation,
63 FusionPattern::ElementWiseChain,
64 FusionPattern::AttentionPattern,
65 FusionPattern::LayerNormLinear,
66 ],
67 enabled: true,
68 }
69 }
70
71 pub fn optimize(&self, graph: ComputeGraph) -> ComputeGraph {
73 if !self.enabled {
74 return graph;
75 }
76
77 let mut optimized = graph;
78
79 for pattern in &self.patterns {
81 optimized = self.apply_pattern(optimized, pattern);
82 }
83
84 optimized
85 }
86
87 fn apply_pattern(&self, mut graph: ComputeGraph, pattern: &FusionPattern) -> ComputeGraph {
89 match pattern {
90 FusionPattern::ConvBnRelu => self.fuse_conv_bn_relu(&mut graph),
91 FusionPattern::MatMulAddActivation => self.fuse_matmul_add_act(&mut graph),
92 FusionPattern::ElementWiseChain => self.fuse_elementwise_chain(&mut graph),
93 FusionPattern::AttentionPattern => self.fuse_attention(&mut graph),
94 FusionPattern::LayerNormLinear => self.fuse_layernorm_linear(&mut graph),
95 }
96
97 graph
98 }
99
100 pub fn can_fuse(&self, op1: &Operation, op2: &Operation) -> bool {
102 matches!(
103 (op1, op2),
104 (Operation::Conv2d { .. }, Operation::BatchNorm { .. }) |
105 (Operation::BatchNorm { .. }, Operation::ReLU) |
106 (Operation::MatMul { .. }, Operation::Add) |
107 (Operation::Add, Operation::ReLU) |
108 (Operation::Add, Operation::GELU)
109 )
110 }
111
112 fn fuse_conv_bn_relu(&self, graph: &mut ComputeGraph) {
114 let mut fused_indices = Vec::new();
115
116 let mut i = 0;
118 while i + 2 < graph.nodes.len() {
119 let is_pattern = matches!(
120 (&graph.nodes[i].op, &graph.nodes[i+1].op, &graph.nodes[i+2].op),
121 (Operation::Conv2d { .. }, Operation::BatchNorm { .. }, Operation::ReLU)
122 );
123
124 if is_pattern && self.is_sequential(&graph.nodes[i..i+3]) {
125 fused_indices.push(i);
126 i += 3; } else {
128 i += 1;
129 }
130 }
131
132 for &idx in fused_indices.iter().rev() {
134 let fused = GraphNode {
136 id: graph.nodes[idx].id,
137 op: Operation::Conv2d {
138 channels: if let Operation::Conv2d { channels, .. } = graph.nodes[idx].op {
139 channels
140 } else {
141 unreachable!()
142 },
143 kernel: if let Operation::Conv2d { kernel, .. } = graph.nodes[idx].op {
144 kernel
145 } else {
146 unreachable!()
147 },
148 },
149 inputs: graph.nodes[idx].inputs.clone(),
150 outputs: graph.nodes[idx+2].outputs.clone(),
151 };
152
153 graph.nodes[idx] = fused;
155 graph.nodes.remove(idx+1);
156 graph.nodes.remove(idx+1);
157
158 graph.edges.retain(|(from, to)| {
160 !(*from == idx && *to == idx+1) && !(*from == idx+1 && *to == idx+2)
161 });
162 }
163 }
164
165 fn fuse_matmul_add_act(&self, graph: &mut ComputeGraph) {
167 let mut i = 0;
168 while i + 2 < graph.nodes.len() {
169 let is_pattern = matches!(
170 (&graph.nodes[i].op, &graph.nodes[i+1].op, &graph.nodes[i+2].op),
171 (Operation::MatMul { .. }, Operation::Add, Operation::ReLU | Operation::GELU)
172 );
173
174 if is_pattern && self.is_sequential(&graph.nodes[i..i+3]) {
175 }
179
180 i += 1;
181 }
182 }
183
184 fn fuse_elementwise_chain(&self, _graph: &mut ComputeGraph) {
186 }
190
191 fn fuse_attention(&self, _graph: &mut ComputeGraph) {
193 }
198
199 fn fuse_layernorm_linear(&self, _graph: &mut ComputeGraph) {
201 }
205
206 fn is_sequential(&self, nodes: &[GraphNode]) -> bool {
208 for i in 0..nodes.len()-1 {
209 let current_id = nodes[i].id;
212 let next_inputs = &nodes[i+1].inputs;
213
214 if !next_inputs.contains(¤t_id) {
215 return false;
216 }
217 }
218 true
219 }
220
221 #[allow(dead_code)]
223 fn update_edges(&self, graph: &mut ComputeGraph, start: usize, end: usize) {
224 graph.edges.retain(|(from, to)| {
227 !(*from >= start && *from <= end && *to >= start && *to <= end)
228 });
229 }
230}
231
232impl Default for FusionEngine {
233 fn default() -> Self {
234 Self::new()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_conv_bn_relu_fusion() {
244 let graph = ComputeGraph {
245 nodes: vec![
246 GraphNode {
247 id: 0,
248 op: Operation::Conv2d { channels: 64, kernel: (3, 3) },
249 inputs: vec![],
250 outputs: vec![1],
251 },
252 GraphNode {
253 id: 1,
254 op: Operation::BatchNorm { channels: 64 },
255 inputs: vec![0],
256 outputs: vec![2],
257 },
258 GraphNode {
259 id: 2,
260 op: Operation::ReLU,
261 inputs: vec![1],
262 outputs: vec![],
263 },
264 ],
265 edges: vec![(0, 1), (1, 2)],
266 };
267
268 let engine = FusionEngine::new();
269 let optimized = engine.optimize(graph.clone());
270
271 assert_eq!(optimized.nodes.len(), 1, "Should fuse 3 nodes into 1");
273 assert!(matches!(optimized.nodes[0].op, Operation::Conv2d { .. }));
274 assert!(engine.can_fuse(&graph.nodes[0].op, &graph.nodes[1].op));
275 }
276}