1use std::fmt;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum FusionPattern {
17 MatMulBias,
19 MatMulBiasRelu,
21 MatMulBiasGelu,
23 ConvBatchNorm,
25 ConvBatchNormRelu,
27 ElementwiseChain,
29 Softmax,
31 LayerNorm,
33 GeluApprox,
35 AddRelu,
37 MulAdd,
39}
40
41impl FusionPattern {
42 pub fn num_ops(&self) -> usize {
44 match self {
45 FusionPattern::MatMulBias | FusionPattern::AddRelu | FusionPattern::MulAdd => 2,
46 FusionPattern::MatMulBiasRelu | FusionPattern::MatMulBiasGelu |
47 FusionPattern::ConvBatchNorm | FusionPattern::Softmax => 3,
48 FusionPattern::ConvBatchNormRelu | FusionPattern::LayerNorm => 4,
49 FusionPattern::GeluApprox => 5,
50 FusionPattern::ElementwiseChain => 2, }
52 }
53
54 pub fn estimated_speedup(&self) -> f32 {
56 match self {
57 FusionPattern::ElementwiseChain => 2.0,
59 FusionPattern::AddRelu => 1.8,
60 FusionPattern::MulAdd => 1.5,
61
62 FusionPattern::MatMulBiasRelu | FusionPattern::MatMulBiasGelu => 1.3,
64 FusionPattern::MatMulBias => 1.2,
65
66 FusionPattern::ConvBatchNormRelu => 1.4,
68 FusionPattern::ConvBatchNorm => 1.3,
69 FusionPattern::Softmax | FusionPattern::LayerNorm | FusionPattern::GeluApprox => 1.2,
70 }
71 }
72
73 pub fn is_memory_bound(&self) -> bool {
75 matches!(
76 self,
77 FusionPattern::ElementwiseChain | FusionPattern::AddRelu | FusionPattern::MulAdd
78 )
79 }
80}
81
82impl fmt::Display for FusionPattern {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 match self {
85 FusionPattern::MatMulBias => write!(f, "MatMul+Bias"),
86 FusionPattern::MatMulBiasRelu => write!(f, "MatMul+Bias+ReLU"),
87 FusionPattern::MatMulBiasGelu => write!(f, "MatMul+Bias+GELU"),
88 FusionPattern::ConvBatchNorm => write!(f, "Conv+BatchNorm"),
89 FusionPattern::ConvBatchNormRelu => write!(f, "Conv+BatchNorm+ReLU"),
90 FusionPattern::ElementwiseChain => write!(f, "Elementwise Chain"),
91 FusionPattern::Softmax => write!(f, "Softmax"),
92 FusionPattern::LayerNorm => write!(f, "LayerNorm"),
93 FusionPattern::GeluApprox => write!(f, "GELU Approximation"),
94 FusionPattern::AddRelu => write!(f, "Add+ReLU"),
95 FusionPattern::MulAdd => write!(f, "Mul+Add (FMA)"),
96 }
97 }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum OpType {
107 MatMul,
109 Add,
111 Sub,
113 Mul,
115 Div,
117 Relu,
119 Gelu,
121 Sigmoid,
123 Tanh,
125 Softmax,
127 Conv,
129 BatchNorm,
131 LayerNorm,
133 Exp,
135 Log,
137 Sqrt,
139 Pow,
141 Reduce,
143 Unknown,
145}
146
147pub fn detect_patterns(ops: &[OpType]) -> Vec<(FusionPattern, usize, usize)> {
159 let mut patterns = Vec::new();
160 let n = ops.len();
161
162 let mut i = 0;
163 while i < n {
164 if i + 2 < n {
168 if ops[i] == OpType::MatMul && ops[i + 1] == OpType::Add && ops[i + 2] == OpType::Relu {
169 patterns.push((FusionPattern::MatMulBiasRelu, i, i + 3));
170 i += 3;
171 continue;
172 }
173 if ops[i] == OpType::MatMul && ops[i + 1] == OpType::Add && ops[i + 2] == OpType::Gelu {
174 patterns.push((FusionPattern::MatMulBiasGelu, i, i + 3));
175 i += 3;
176 continue;
177 }
178 if ops[i] == OpType::Conv && ops[i + 1] == OpType::BatchNorm && ops[i + 2] == OpType::Relu {
179 patterns.push((FusionPattern::ConvBatchNormRelu, i, i + 3));
180 i += 3;
181 continue;
182 }
183 }
184
185 if i + 1 < n {
187 if ops[i] == OpType::MatMul && ops[i + 1] == OpType::Add {
188 patterns.push((FusionPattern::MatMulBias, i, i + 2));
189 i += 2;
190 continue;
191 }
192 if ops[i] == OpType::Conv && ops[i + 1] == OpType::BatchNorm {
193 patterns.push((FusionPattern::ConvBatchNorm, i, i + 2));
194 i += 2;
195 continue;
196 }
197 if ops[i] == OpType::Add && ops[i + 1] == OpType::Relu {
198 patterns.push((FusionPattern::AddRelu, i, i + 2));
199 i += 2;
200 continue;
201 }
202 if ops[i] == OpType::Mul && ops[i + 1] == OpType::Add {
203 patterns.push((FusionPattern::MulAdd, i, i + 2));
204 i += 2;
205 continue;
206 }
207 }
208
209 if is_elementwise_op(ops[i]) {
211 let start = i;
212 while i < n && is_elementwise_op(ops[i]) {
213 i += 1;
214 }
215 if i - start > 1 {
216 patterns.push((FusionPattern::ElementwiseChain, start, i));
217 }
218 continue;
219 }
220
221 i += 1;
222 }
223
224 patterns
225}
226
227fn is_elementwise_op(op: OpType) -> bool {
229 matches!(
230 op,
231 OpType::Add | OpType::Sub | OpType::Mul | OpType::Div |
232 OpType::Relu | OpType::Sigmoid | OpType::Tanh |
233 OpType::Exp | OpType::Log | OpType::Sqrt
234 )
235}
236
237#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_detect_matmul_bias_relu() {
247 let ops = vec![OpType::MatMul, OpType::Add, OpType::Relu];
248 let patterns = detect_patterns(&ops);
249
250 assert_eq!(patterns.len(), 1);
251 assert_eq!(patterns[0].0, FusionPattern::MatMulBiasRelu);
252 }
253
254 #[test]
255 fn test_detect_matmul_bias() {
256 let ops = vec![OpType::MatMul, OpType::Add];
257 let patterns = detect_patterns(&ops);
258
259 assert_eq!(patterns.len(), 1);
260 assert_eq!(patterns[0].0, FusionPattern::MatMulBias);
261 }
262
263 #[test]
264 fn test_detect_elementwise_chain() {
265 let ops = vec![OpType::Add, OpType::Mul, OpType::Relu];
266 let patterns = detect_patterns(&ops);
267
268 assert_eq!(patterns.len(), 1);
269 assert_eq!(patterns[0].0, FusionPattern::ElementwiseChain);
270 }
271
272 #[test]
273 fn test_pattern_speedup() {
274 assert!(FusionPattern::ElementwiseChain.estimated_speedup() > 1.5);
275 assert!(FusionPattern::MatMulBiasRelu.estimated_speedup() > 1.0);
276 }
277
278 #[test]
279 fn test_detect_add_relu() {
280 let ops = vec![OpType::Add, OpType::Relu];
281 let patterns = detect_patterns(&ops);
282
283 assert!(!patterns.is_empty());
285 }
286
287 #[test]
288 fn test_pattern_display() {
289 assert_eq!(format!("{}", FusionPattern::MatMulBiasRelu), "MatMul+Bias+ReLU");
290 assert_eq!(format!("{}", FusionPattern::Softmax), "Softmax");
291 }
292}