Skip to main content

axonml_fusion/
patterns.rs

1//! Fusion Pattern Detection
2//!
3//! Detects common patterns in computational graphs that can be fused.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::fmt;
9
10// =============================================================================
11// Fusion Patterns
12// =============================================================================
13
14/// Common fusion patterns.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum FusionPattern {
17    /// MatMul followed by bias addition.
18    MatMulBias,
19    /// MatMul followed by bias addition and ReLU.
20    MatMulBiasRelu,
21    /// MatMul followed by bias addition and GELU.
22    MatMulBiasGelu,
23    /// Convolution followed by batch normalization.
24    ConvBatchNorm,
25    /// Convolution followed by batch normalization and ReLU.
26    ConvBatchNormRelu,
27    /// Multiple elementwise operations.
28    ElementwiseChain,
29    /// Softmax pattern (exp, sum, div).
30    Softmax,
31    /// Layer normalization pattern.
32    LayerNorm,
33    /// GELU approximation pattern.
34    GeluApprox,
35    /// Add followed by ReLU.
36    AddRelu,
37    /// Multiply followed by add (FMA).
38    MulAdd,
39}
40
41impl FusionPattern {
42    /// Returns the number of operations fused in this pattern.
43    pub fn num_ops(&self) -> usize {
44        match self {
45            FusionPattern::MatMulBias | FusionPattern::AddRelu | FusionPattern::MulAdd => 2,
46            FusionPattern::MatMulBiasRelu | FusionPattern::MatMulBiasGelu |
47            FusionPattern::ConvBatchNorm | FusionPattern::Softmax => 3,
48            FusionPattern::ConvBatchNormRelu | FusionPattern::LayerNorm => 4,
49            FusionPattern::GeluApprox => 5,
50            FusionPattern::ElementwiseChain => 2, // Variable, default 2
51        }
52    }
53
54    /// Returns estimated speedup from this fusion.
55    pub fn estimated_speedup(&self) -> f32 {
56        match self {
57            // Memory-bound patterns benefit most
58            FusionPattern::ElementwiseChain => 2.0,
59            FusionPattern::AddRelu => 1.8,
60            FusionPattern::MulAdd => 1.5,
61
62            // Compute-bound patterns still benefit
63            FusionPattern::MatMulBiasRelu | FusionPattern::MatMulBiasGelu => 1.3,
64            FusionPattern::MatMulBias => 1.2,
65
66            // Complex patterns
67            FusionPattern::ConvBatchNormRelu => 1.4,
68            FusionPattern::ConvBatchNorm => 1.3,
69            FusionPattern::Softmax | FusionPattern::LayerNorm | FusionPattern::GeluApprox => 1.2,
70        }
71    }
72
73    /// Returns whether this pattern is memory-bound (vs compute-bound).
74    pub fn is_memory_bound(&self) -> bool {
75        matches!(
76            self,
77            FusionPattern::ElementwiseChain | FusionPattern::AddRelu | FusionPattern::MulAdd
78        )
79    }
80}
81
82impl fmt::Display for FusionPattern {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            FusionPattern::MatMulBias => write!(f, "MatMul+Bias"),
86            FusionPattern::MatMulBiasRelu => write!(f, "MatMul+Bias+ReLU"),
87            FusionPattern::MatMulBiasGelu => write!(f, "MatMul+Bias+GELU"),
88            FusionPattern::ConvBatchNorm => write!(f, "Conv+BatchNorm"),
89            FusionPattern::ConvBatchNormRelu => write!(f, "Conv+BatchNorm+ReLU"),
90            FusionPattern::ElementwiseChain => write!(f, "Elementwise Chain"),
91            FusionPattern::Softmax => write!(f, "Softmax"),
92            FusionPattern::LayerNorm => write!(f, "LayerNorm"),
93            FusionPattern::GeluApprox => write!(f, "GELU Approximation"),
94            FusionPattern::AddRelu => write!(f, "Add+ReLU"),
95            FusionPattern::MulAdd => write!(f, "Mul+Add (FMA)"),
96        }
97    }
98}
99
100// =============================================================================
101// Operation Type for Pattern Matching
102// =============================================================================
103
104/// Operation types for pattern matching.
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum OpType {
107    /// Matrix multiplication.
108    MatMul,
109    /// Addition.
110    Add,
111    /// Subtraction.
112    Sub,
113    /// Multiplication.
114    Mul,
115    /// Division.
116    Div,
117    /// ReLU activation.
118    Relu,
119    /// GELU activation.
120    Gelu,
121    /// Sigmoid activation.
122    Sigmoid,
123    /// Tanh activation.
124    Tanh,
125    /// Softmax.
126    Softmax,
127    /// Convolution.
128    Conv,
129    /// Batch normalization.
130    BatchNorm,
131    /// Layer normalization.
132    LayerNorm,
133    /// Exponential.
134    Exp,
135    /// Logarithm.
136    Log,
137    /// Square root.
138    Sqrt,
139    /// Power.
140    Pow,
141    /// Reduction (sum, mean, max).
142    Reduce,
143    /// Unknown operation.
144    Unknown,
145}
146
147// =============================================================================
148// Pattern Detection
149// =============================================================================
150
151/// Detects fusion patterns in a sequence of operations.
152///
153/// # Arguments
154/// * `ops` - Sequence of operation types
155///
156/// # Returns
157/// List of detected patterns with their positions
158pub fn detect_patterns(ops: &[OpType]) -> Vec<(FusionPattern, usize, usize)> {
159    let mut patterns = Vec::new();
160    let n = ops.len();
161
162    let mut i = 0;
163    while i < n {
164        // Try to match longer patterns first
165
166        // MatMul + Add + ReLU (length 3)
167        if i + 2 < n {
168            if ops[i] == OpType::MatMul && ops[i + 1] == OpType::Add && ops[i + 2] == OpType::Relu {
169                patterns.push((FusionPattern::MatMulBiasRelu, i, i + 3));
170                i += 3;
171                continue;
172            }
173            if ops[i] == OpType::MatMul && ops[i + 1] == OpType::Add && ops[i + 2] == OpType::Gelu {
174                patterns.push((FusionPattern::MatMulBiasGelu, i, i + 3));
175                i += 3;
176                continue;
177            }
178            if ops[i] == OpType::Conv && ops[i + 1] == OpType::BatchNorm && ops[i + 2] == OpType::Relu {
179                patterns.push((FusionPattern::ConvBatchNormRelu, i, i + 3));
180                i += 3;
181                continue;
182            }
183        }
184
185        // MatMul + Add (length 2)
186        if i + 1 < n {
187            if ops[i] == OpType::MatMul && ops[i + 1] == OpType::Add {
188                patterns.push((FusionPattern::MatMulBias, i, i + 2));
189                i += 2;
190                continue;
191            }
192            if ops[i] == OpType::Conv && ops[i + 1] == OpType::BatchNorm {
193                patterns.push((FusionPattern::ConvBatchNorm, i, i + 2));
194                i += 2;
195                continue;
196            }
197            if ops[i] == OpType::Add && ops[i + 1] == OpType::Relu {
198                patterns.push((FusionPattern::AddRelu, i, i + 2));
199                i += 2;
200                continue;
201            }
202            if ops[i] == OpType::Mul && ops[i + 1] == OpType::Add {
203                patterns.push((FusionPattern::MulAdd, i, i + 2));
204                i += 2;
205                continue;
206            }
207        }
208
209        // Elementwise chain detection
210        if is_elementwise_op(ops[i]) {
211            let start = i;
212            while i < n && is_elementwise_op(ops[i]) {
213                i += 1;
214            }
215            if i - start > 1 {
216                patterns.push((FusionPattern::ElementwiseChain, start, i));
217            }
218            continue;
219        }
220
221        i += 1;
222    }
223
224    patterns
225}
226
227/// Checks if an operation is elementwise.
228fn is_elementwise_op(op: OpType) -> bool {
229    matches!(
230        op,
231        OpType::Add | OpType::Sub | OpType::Mul | OpType::Div |
232        OpType::Relu | OpType::Sigmoid | OpType::Tanh |
233        OpType::Exp | OpType::Log | OpType::Sqrt
234    )
235}
236
237// =============================================================================
238// Tests
239// =============================================================================
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_detect_matmul_bias_relu() {
247        let ops = vec![OpType::MatMul, OpType::Add, OpType::Relu];
248        let patterns = detect_patterns(&ops);
249
250        assert_eq!(patterns.len(), 1);
251        assert_eq!(patterns[0].0, FusionPattern::MatMulBiasRelu);
252    }
253
254    #[test]
255    fn test_detect_matmul_bias() {
256        let ops = vec![OpType::MatMul, OpType::Add];
257        let patterns = detect_patterns(&ops);
258
259        assert_eq!(patterns.len(), 1);
260        assert_eq!(patterns[0].0, FusionPattern::MatMulBias);
261    }
262
263    #[test]
264    fn test_detect_elementwise_chain() {
265        let ops = vec![OpType::Add, OpType::Mul, OpType::Relu];
266        let patterns = detect_patterns(&ops);
267
268        assert_eq!(patterns.len(), 1);
269        assert_eq!(patterns[0].0, FusionPattern::ElementwiseChain);
270    }
271
272    #[test]
273    fn test_pattern_speedup() {
274        assert!(FusionPattern::ElementwiseChain.estimated_speedup() > 1.5);
275        assert!(FusionPattern::MatMulBiasRelu.estimated_speedup() > 1.0);
276    }
277
278    #[test]
279    fn test_detect_add_relu() {
280        let ops = vec![OpType::Add, OpType::Relu];
281        let patterns = detect_patterns(&ops);
282
283        // Note: This will be detected as elementwise chain first
284        assert!(!patterns.is_empty());
285    }
286
287    #[test]
288    fn test_pattern_display() {
289        assert_eq!(format!("{}", FusionPattern::MatMulBiasRelu), "MatMul+Bias+ReLU");
290        assert_eq!(format!("{}", FusionPattern::Softmax), "Softmax");
291    }
292}