Skip to main content

morok_ir/pattern/
simplified.rs

1//! High-performance pattern matcher with OpKey-based O(1) dispatch.
2//!
3//! # Architecture
4//!
5//! `SimplifiedPatternMatcher` uses a two-tier dispatch strategy:
6//!
7//! 1. **Indexed patterns**: Stored in a `HashMap<OpKey, Vec<Closure>>` for O(1) lookup
8//! 2. **Wildcard patterns**: Tried after indexed patterns for ops without specific patterns
9//!
10//! The `patterns!` macro generates closures that use native Rust `match` expressions,
11//! avoiding runtime pattern interpretation overhead.
12//!
13//! # Performance
14//!
15//! - O(1) dispatch to relevant patterns via `OpKey`
16//! - Only patterns matching the input's `OpKey` are tried
17//! - Wildcard patterns act as fallback for unmatched ops
18//! - 5-10x faster than linear pattern scanning
19//!
20//! # Usage with patterns! macro
21//!
22//! ```ignore
23//! use morok_macros::patterns;
24//!
25//! let matcher = patterns! {
26//!     Add(x, @zero) ~> x,              // Indexed under OpKey::Binary(BinaryOp::Add)
27//!     Mul(x, @one) ~> x,               // Indexed under OpKey::Binary(BinaryOp::Mul)
28//!     x if is_const(x) => fold(x),     // Wildcard - tried for all ops
29//! };
30//! ```
31
32use std::collections::HashMap;
33use std::sync::Arc;
34
35use crate::UOp;
36use crate::op::pattern_derived::OpKey;
37
38use super::RewriteResult;
39
40/// Closure type for pattern matching + rewriting.
41///
42/// Takes a UOp and mutable context, returns a RewriteResult.
43/// Uses `Arc` instead of `Box` to enable `Clone` on `SimplifiedPatternMatcher`,
44/// which is needed for caching combined matchers via `LazyLock`.
45pub type PatternClosure<C> = Arc<dyn Fn(&Arc<UOp>, &mut C) -> RewriteResult + Send + Sync>;
46
47/// High-performance pattern matcher with O(1) OpKey-based dispatch.
48///
49/// # Design
50///
51/// Instead of a single list of patterns that must be linearly scanned,
52/// patterns are indexed by their `OpKey` in a `HashMap`. When matching:
53///
54/// 1. Extract `OpKey` from the input UOp
55/// 2. Look up patterns for that key (O(1) HashMap lookup)
56/// 3. Try only those patterns (typically 1-3 per key)
57/// 4. Fall back to wildcard patterns if no match
58///
59/// # Type Parameter
60///
61/// - `C`: Context type passed to all pattern closures. Use `()` for stateless matching.
62///
63/// # Example
64///
65/// Typically used via the `patterns!` macro:
66///
67/// ```ignore
68/// use morok_macros::patterns;
69///
70/// let matcher = patterns! {
71///     Add(x, @zero) ~> x,
72///     Mul(x, @one) ~> x,
73/// };
74///
75/// // Use with graph_rewrite
76/// let result = graph_rewrite(&ast, &matcher, &mut ());
77/// ```
78///
79/// Manual construction (rarely needed):
80///
81/// ```ignore
82/// let mut matcher = SimplifiedPatternMatcher::<()>::new();
83/// matcher.add(
84///     &[OpKey::Binary(BinaryOp::Add)],
85///     |uop, _ctx| {
86///         let Op::Binary(BinaryOp::Add, left, right) = uop.op() else {
87///             return RewriteResult::NoMatch;
88///         };
89///         if is_zero(right) { RewriteResult::Rewritten(left.clone()) }
90///         else { RewriteResult::NoMatch }
91///     }
92/// );
93/// ```
94pub struct SimplifiedPatternMatcher<C = ()> {
95    /// Patterns indexed by OpKey - tried first for O(1) dispatch
96    indexed: HashMap<OpKey, Vec<PatternClosure<C>>>,
97    /// Wildcard patterns - tried after indexed patterns
98    wildcards: Vec<PatternClosure<C>>,
99}
100
101impl<C> SimplifiedPatternMatcher<C> {
102    /// Create a new empty pattern matcher.
103    pub fn new() -> Self {
104        Self { indexed: HashMap::new(), wildcards: Vec::new() }
105    }
106
107    /// Add pattern for specific OpKey(s).
108    ///
109    /// If `keys` is empty, the pattern is treated as a wildcard and will be
110    /// tried for every UOp after all indexed patterns have been tried.
111    pub fn add<F>(&mut self, keys: &[OpKey], closure: F)
112    where
113        F: Fn(&Arc<UOp>, &mut C) -> RewriteResult + Send + Sync + 'static,
114    {
115        if keys.is_empty() {
116            // No keys = wildcard pattern
117            self.wildcards.push(Arc::new(closure));
118        } else if keys.len() == 1 {
119            // Single key - store directly
120            self.indexed.entry(keys[0].clone()).or_default().push(Arc::new(closure));
121        } else {
122            // Multiple keys - share the closure via Arc clone
123            let shared: PatternClosure<C> = Arc::new(closure);
124            for key in keys {
125                self.indexed.entry(key.clone()).or_default().push(Arc::clone(&shared));
126            }
127        }
128    }
129
130    /// Add wildcard pattern (matches any op).
131    ///
132    /// Wildcard patterns are tried after all indexed patterns have been tried.
133    pub fn add_wildcard<F>(&mut self, closure: F)
134    where
135        F: Fn(&Arc<UOp>, &mut C) -> RewriteResult + Send + Sync + 'static,
136    {
137        self.wildcards.push(Arc::new(closure));
138    }
139
140    /// Number of registered patterns.
141    pub fn len(&self) -> usize {
142        self.indexed.values().map(|v| v.len()).sum::<usize>() + self.wildcards.len()
143    }
144
145    /// Check if no patterns are registered.
146    pub fn is_empty(&self) -> bool {
147        self.indexed.is_empty() && self.wildcards.is_empty()
148    }
149
150    /// Number of wildcard patterns (tried for every op).
151    pub fn wildcard_count(&self) -> usize {
152        self.wildcards.len()
153    }
154
155    /// Number of indexed buckets (unique OpKeys with patterns).
156    pub fn indexed_count(&self) -> usize {
157        self.indexed.len()
158    }
159
160    /// Attempt to rewrite a UOp using registered patterns.
161    ///
162    /// This is an inherent method that provides the same functionality as
163    /// `Matcher::rewrite()` without requiring the trait to be in scope.
164    ///
165    /// # Tracing
166    ///
167    /// Enable debug-level tracing to see pattern matching activity:
168    /// ```bash
169    /// RUST_LOG=morok_ir::pattern=debug cargo run
170    /// ```
171    pub fn rewrite(&self, uop: &Arc<UOp>, ctx: &mut C) -> RewriteResult {
172        let key = OpKey::from_op(uop.op());
173
174        // Try patterns indexed by this OpKey
175        if let Some(patterns) = self.indexed.get(&key) {
176            let pattern_count = patterns.len();
177            tracing::trace!(op_key = ?key, pattern_count, "trying indexed patterns");
178
179            for (idx, closure) in patterns.iter().enumerate() {
180                let result = closure(uop, ctx);
181                if !matches!(result, RewriteResult::NoMatch) {
182                    tracing::debug!(op_key = ?key, pattern_idx = idx, "pattern matched");
183                    return result;
184                }
185            }
186        }
187
188        // Try wildcard patterns
189        if !self.wildcards.is_empty() {
190            tracing::trace!(wildcard_count = self.wildcards.len(), "trying wildcard patterns");
191
192            for (idx, closure) in self.wildcards.iter().enumerate() {
193                let result = closure(uop, ctx);
194                if !matches!(result, RewriteResult::NoMatch) {
195                    tracing::debug!(wildcard_idx = idx, "wildcard pattern matched");
196                    return result;
197                }
198            }
199        }
200
201        RewriteResult::NoMatch
202    }
203}
204
205impl<C> Clone for SimplifiedPatternMatcher<C> {
206    fn clone(&self) -> Self {
207        Self { indexed: self.indexed.clone(), wildcards: self.wildcards.clone() }
208    }
209}
210
211impl<C> Default for SimplifiedPatternMatcher<C> {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217impl SimplifiedPatternMatcher<()> {
218    /// Lift a context-free matcher into any context type.
219    ///
220    /// Since `()` patterns ignore the context parameter, they can safely run
221    /// under any `D`. Each closure is re-wrapped to discard `&mut D` and pass
222    /// `&mut ()` to the original. This enables combining context-free matchers
223    /// with context-dependent ones via `+`:
224    ///
225    /// ```ignore
226    /// let mega = symbolic().with_context::<PcontigConfig>()
227    ///     + buffer_removal_with_pcontig(); // TypedPatternMatcher<PcontigConfig>
228    /// ```
229    pub fn with_context<D: 'static + Send + Sync>(&self) -> SimplifiedPatternMatcher<D> {
230        let mut result = SimplifiedPatternMatcher::<D>::new();
231        for (key, closures) in &self.indexed {
232            for closure in closures {
233                let closure = Arc::clone(closure);
234                result
235                    .indexed
236                    .entry(key.clone())
237                    .or_default()
238                    .push(Arc::new(move |uop: &Arc<UOp>, _ctx: &mut D| closure(uop, &mut ())));
239            }
240        }
241        for closure in &self.wildcards {
242            let closure = Arc::clone(closure);
243            result.wildcards.push(Arc::new(move |uop: &Arc<UOp>, _ctx: &mut D| closure(uop, &mut ())));
244        }
245        result
246    }
247}
248
249// Implement Matcher trait for graph_rewrite compatibility
250impl<C> super::Matcher<C> for SimplifiedPatternMatcher<C> {
251    fn rewrite(&self, uop: &Arc<UOp>, ctx: &mut C) -> RewriteResult {
252        // Delegate to inherent method
253        SimplifiedPatternMatcher::rewrite(self, uop, ctx)
254    }
255}
256
257// Implement Add<Self> for composition (matcher1 + matcher2)
258impl<C> std::ops::Add for SimplifiedPatternMatcher<C> {
259    type Output = Self;
260
261    /// Combine two matchers. Patterns from `rhs` are appended.
262    fn add(mut self, rhs: Self) -> Self::Output {
263        // Merge indexed patterns
264        for (key, patterns) in rhs.indexed {
265            self.indexed.entry(key).or_default().extend(patterns);
266        }
267        // Merge wildcards
268        self.wildcards.extend(rhs.wildcards);
269        self
270    }
271}
272
273// Implement Add for references — clones both sides then combines.
274// Enables `pm_a() + pm_b()` when both return `&'static TypedPatternMatcher`.
275impl<C> std::ops::Add for &SimplifiedPatternMatcher<C> {
276    type Output = SimplifiedPatternMatcher<C>;
277
278    fn add(self, rhs: Self) -> Self::Output {
279        self.clone() + rhs.clone()
280    }
281}
282
283impl<C> std::ops::Add<&SimplifiedPatternMatcher<C>> for SimplifiedPatternMatcher<C> {
284    type Output = SimplifiedPatternMatcher<C>;
285
286    fn add(self, rhs: &SimplifiedPatternMatcher<C>) -> Self::Output {
287        self + rhs.clone()
288    }
289}
290
291impl<C> std::ops::Add<SimplifiedPatternMatcher<C>> for &SimplifiedPatternMatcher<C> {
292    type Output = SimplifiedPatternMatcher<C>;
293
294    fn add(self, rhs: SimplifiedPatternMatcher<C>) -> Self::Output {
295        self.clone() + rhs
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use crate::types::BinaryOp;
303    use crate::{ConstValue, Op, UOp};
304    use morok_dtype::DType;
305
306    fn const_int(v: i64) -> Arc<UOp> {
307        UOp::const_(DType::Int32, ConstValue::Int(v))
308    }
309
310    fn binary(op: BinaryOp, lhs: Arc<UOp>, rhs: Arc<UOp>) -> Arc<UOp> {
311        // Use UOp::new to create binary ops directly for tests
312        UOp::new(Op::Binary(op, lhs, rhs), DType::Int32)
313    }
314
315    #[test]
316    fn test_empty_matcher() {
317        let matcher = SimplifiedPatternMatcher::<()>::new();
318        assert!(matcher.is_empty());
319        assert_eq!(matcher.len(), 0);
320    }
321
322    #[test]
323    fn test_add_indexed_pattern() {
324        let mut matcher = SimplifiedPatternMatcher::<()>::new();
325
326        matcher.add(&[OpKey::Binary(BinaryOp::Add)], |_uop, _ctx| RewriteResult::NoMatch);
327
328        assert_eq!(matcher.len(), 1);
329        assert!(!matcher.is_empty());
330    }
331
332    #[test]
333    fn test_add_wildcard_pattern() {
334        let mut matcher = SimplifiedPatternMatcher::<()>::new();
335
336        matcher.add_wildcard(|_uop, _ctx| RewriteResult::NoMatch);
337
338        assert_eq!(matcher.len(), 1);
339        assert_eq!(matcher.wildcards.len(), 1);
340    }
341
342    #[test]
343    fn test_combine_matchers() {
344        let mut m1 = SimplifiedPatternMatcher::<()>::new();
345        m1.add(&[OpKey::Binary(BinaryOp::Add)], |_, _| RewriteResult::NoMatch);
346
347        let mut m2 = SimplifiedPatternMatcher::<()>::new();
348        m2.add(&[OpKey::Binary(BinaryOp::Mul)], |_, _| RewriteResult::NoMatch);
349
350        let combined = m1 + m2;
351        assert_eq!(combined.len(), 2);
352    }
353
354    #[test]
355    fn test_rewrite_basic() {
356        let mut matcher = SimplifiedPatternMatcher::<()>::new();
357
358        // Pattern: Add(x, 0) -> x
359        matcher.add(&[OpKey::Binary(BinaryOp::Add)], |uop, _ctx| {
360            let Op::Binary(BinaryOp::Add, left, right) = uop.op() else {
361                return RewriteResult::NoMatch;
362            };
363            // Check if right is zero
364            if let Op::Const(cv) = right.op()
365                && cv.0.is_zero()
366            {
367                return RewriteResult::Rewritten(left.clone());
368            }
369            // Check if left is zero (commutative)
370            if let Op::Const(cv) = left.op()
371                && cv.0.is_zero()
372            {
373                return RewriteResult::Rewritten(right.clone());
374            }
375            RewriteResult::NoMatch
376        });
377
378        // Test: 5 + 0 -> 5
379        let five = const_int(5);
380        let zero = const_int(0);
381        let expr = binary(BinaryOp::Add, five.clone(), zero);
382
383        let result = matcher.rewrite(&expr, &mut ());
384        assert!(matches!(result, RewriteResult::Rewritten(ref r) if Arc::ptr_eq(r, &five)));
385
386        // Test: 0 + 5 -> 5
387        let expr2 = binary(BinaryOp::Add, const_int(0), five.clone());
388        let result2 = matcher.rewrite(&expr2, &mut ());
389        assert!(matches!(result2, RewriteResult::Rewritten(ref r) if Arc::ptr_eq(r, &five)));
390
391        // Test: 3 + 4 -> NoMatch
392        let expr3 = binary(BinaryOp::Add, const_int(3), const_int(4));
393        let result3 = matcher.rewrite(&expr3, &mut ());
394        assert!(matches!(result3, RewriteResult::NoMatch));
395    }
396
397    #[test]
398    fn test_wildcard_after_indexed() {
399        let mut matcher = SimplifiedPatternMatcher::<()>::new();
400
401        // Indexed pattern that doesn't match
402        matcher.add(&[OpKey::Binary(BinaryOp::Add)], |_uop, _ctx| RewriteResult::NoMatch);
403
404        // Wildcard that matches everything
405        matcher.add_wildcard(|uop, _ctx| RewriteResult::Rewritten(uop.clone()));
406
407        let expr = binary(BinaryOp::Add, const_int(1), const_int(2));
408
409        // Should fall through to wildcard
410        let result = matcher.rewrite(&expr, &mut ());
411        assert!(matches!(result, RewriteResult::Rewritten(_)));
412    }
413}