morok_ir/pattern/mod.rs
1//! Pattern matching infrastructure for UOp graphs.
2//!
3//! This module provides pattern matching using `SimplifiedPatternMatcher`, which
4//! uses closures that do inline Rust pattern matching. The `patterns!` macro
5//! generates closures with native `match` expressions for O(1) OpKey dispatch.
6
7pub mod helpers;
8pub mod simplified;
9
10use crate::UOp;
11use std::sync::Arc;
12
13// =============================================================================
14// RewriteResult - Result of pattern matching
15// =============================================================================
16
17/// Result of applying a pattern rewrite.
18#[derive(Debug, Clone)]
19pub enum RewriteResult {
20 /// Pattern didn't match or rewrite function declined to rewrite
21 NoMatch,
22 /// Pattern matched and returned a replacement UOp
23 Rewritten(Arc<UOp>),
24 /// Pattern matched and indicates bottom-up gate (Tinygrad's BottomUpGate)
25 /// This signals that children should be processed before proceeding
26 Gate(Arc<UOp>),
27}
28
29// =============================================================================
30// Pattern Exports
31// =============================================================================
32
33pub use helpers::{const_matches, is_any_const, is_neg_one, is_nonzero, is_one, is_zero, try_const};
34pub use simplified::{PatternClosure, SimplifiedPatternMatcher};
35
36/// Type alias for backwards compatibility.
37pub type TypedPatternMatcher<C = ()> = SimplifiedPatternMatcher<C>;
38
39// =============================================================================
40// Matcher Trait - Unified interface for pattern matchers
41// =============================================================================
42
43/// Trait for pattern matchers used by the rewrite engine.
44///
45/// This trait provides a unified interface for pattern matching,
46/// allowing the rewrite engine to work with different matcher implementations.
47pub trait Matcher<C> {
48 /// Attempt to rewrite a UOp using registered patterns.
49 fn rewrite(&self, uop: &Arc<UOp>, ctx: &mut C) -> RewriteResult;
50}
51
52impl<C, M: Matcher<C> + ?Sized> Matcher<C> for &M {
53 fn rewrite(&self, uop: &Arc<UOp>, ctx: &mut C) -> RewriteResult {
54 (**self).rewrite(uop, ctx)
55 }
56}