Skip to main content

morok_macros/
lib.rs

1//! Proc-macros for morok.
2//!
3//! This crate provides:
4//! - `#[derive(PatternEnum)]` for generating pattern matching infrastructure from Op enum
5//! - `patterns!` macro for declarative pattern rewrite rules
6
7use proc_macro::TokenStream;
8use syn::{DeriveInput, parse_macro_input};
9
10mod pattern_enum;
11mod patterns;
12
13/// Derive macro for generating pattern matching infrastructure from an Op enum.
14///
15/// This macro analyzes your `Op` enum and generates:
16/// - `OpKey` enum for O(1) pattern dispatch
17/// - `OpKey::from_op()` method to extract the key from an `Op`
18/// - `pattern_metadata` module with variant information
19///
20/// # Usage
21///
22/// ```ignore
23/// #[derive(PatternEnum)]
24/// #[pattern(grouped = [Unary, Binary, Ternary])]
25/// pub enum Op {
26///     Const(ConstValue),
27///     Unary(UnaryOp, Arc<UOp>),
28///     Binary(BinaryOp, Arc<UOp>, Arc<UOp>),
29///     #[pattern(skip)]
30///     Invalid,
31/// }
32/// ```
33///
34/// # Attributes
35///
36/// ## Enum-level
37///
38/// - `#[pattern(grouped = [Variant1, Variant2, ...])]` - Marks variants where the first
39///   field is a sub-enum discriminant. For example, `Binary(BinaryOp, ...)` has `BinaryOp`
40///   as a sub-discriminant, so `OpKey::Binary(BinaryOp::Add)` differs from `OpKey::Binary(BinaryOp::Mul)`.
41///
42/// ## Variant-level
43///
44/// - `#[pattern(skip)]` - Skip pattern generation for this variant (e.g., `Invalid`).
45///
46/// # Field Type Detection
47///
48/// The macro automatically classifies field types:
49/// - `Arc<UOp>` → child operand (fixed arity)
50/// - `SmallVec<[Arc<UOp>; N]>` or `Vec<Arc<UOp>>` → variadic children
51/// - `Option<Arc<UOp>>` → optional child
52/// - Other types → filter/metadata (e.g., `DType`, `DeviceSpec`)
53///
54/// # Generated Items
55///
56/// ```ignore
57/// mod pattern_derived {
58///     // Discriminant enum for O(1) dispatch
59///     pub enum OpKey {
60///         Const,
61///         Unary(UnaryOp),
62///         Binary(BinaryOp),
63///         // ...
64///     }
65///
66///     impl OpKey {
67///         pub fn from_op(op: &Op) -> Self { ... }
68///     }
69///
70///     pub mod pattern_metadata {
71///         pub const BINARY_OPS: &[&str] = &["Add", "Mul", ...];
72///         // ...
73///     }
74/// }
75/// ```
76#[proc_macro_derive(PatternEnum, attributes(pattern))]
77pub fn derive_pattern_enum(input: TokenStream) -> TokenStream {
78    let input = parse_macro_input!(input as DeriveInput);
79    match pattern_enum::generate(&input) {
80        Ok(tokens) => tokens.into(),
81        Err(e) => e.to_compile_error().into(),
82    }
83}
84
85/// Proc-macro for declarative pattern rewrite rules.
86///
87/// Generates a [`SimplifiedPatternMatcher`] from a list of pattern rewrite rules.
88/// Patterns are compiled to efficient Rust code with O(1) dispatch via `OpKey`.
89///
90/// # Syntax Overview
91///
92/// ```text
93/// patterns! {
94///     // Basic rule: pattern ~> rewrite (or => for fallible)
95///     Add(x, @zero) ~> x,
96///
97///     // With guard clause
98///     Mul(x, y) if is_power_of_two(y) => { ... },
99///
100///     // For-loop to apply same pattern to multiple ops
101///     for op in binary [Add, Mul, Sub] {
102///         op(x, @zero) ~> x,
103///     }
104/// }
105/// ```
106///
107/// # Arrow Types
108///
109/// - `~>` **Infallible**: Closure returns `Arc<UOp>` directly
110/// - `=>` **Fallible**: Closure returns `Option<Arc<UOp>>`
111///
112/// # Pattern Syntax
113///
114/// ## Operation Patterns
115///
116/// ```text
117/// Add(x, y)           // Tuple-style: match by position
118/// Cast { src, dtype } // Struct-style: match by field name
119/// ```
120///
121/// ## Special Constants
122///
123/// - `@zero` - Matches constant zero (any numeric type)
124/// - `@one` - Matches constant one (any numeric type)
125/// - `@const(cv)` - Matches any constant, binds value to `cv: &ConstValue`
126/// - `_c@const(cv)` - Underscore prefix: don't bind the UOp, only the value
127///
128/// ## Duplicate Variables (Auto ptr_eq)
129///
130/// Same variable name appearing multiple times generates `Arc::ptr_eq` checks:
131///
132/// ```text
133/// Add(x, x) ~> ...    // Matches when both children are the same node
134/// Where(x, x, x) ~> ...  // All three must be ptr_eq
135/// ```
136///
137/// ## Commutative Matching
138///
139/// Square brackets enable commutative matching (tries both orderings):
140///
141/// ```text
142/// Add[x, @zero] ~> x  // Matches Add(x, 0) or Add(0, x)
143/// ```
144///
145/// ## Alternative Patterns
146///
147/// Match any of several patterns:
148///
149/// ```text
150/// (Add | Sub)(x, @zero) ~> x  // Matches Add(x, 0) or Sub(x, 0)
151/// ```
152///
153/// ## Binding Patterns
154///
155/// Bind a name to a subpattern:
156///
157/// ```text
158/// result@Add(x, y) => { ... use result, x, y ... }
159/// ```
160///
161/// # For-Loops
162///
163/// Apply the same pattern template to multiple operations:
164///
165/// ```text
166/// for op in unary [Neg, Not, Sqrt] {
167///     op(x) if is_const(x) => { fold_unary(op, x) }
168/// }
169///
170/// for op in binary [Add, Mul, Sub] {
171///     op(x, @zero) ~> x,
172/// }
173/// ```
174///
175/// # Context Types
176///
177/// Declare a context type to pass mutable state through patterns:
178///
179/// ```text
180/// patterns! {
181///     @context MyContext;
182///
183///     Add(x, y) => |ctx, x, y| {
184///         ctx.record_match();
185///         Some(x.clone())
186///     }
187/// }
188/// ```
189///
190/// # Generated Code
191///
192/// This macro generates a `SimplifiedPatternMatcher` with:
193/// - Compile-time validation of all operation names
194/// - O(1) dispatch via OpKey hashmap
195/// - Inline pattern matching (no runtime pattern interpretation)
196/// - Automatic `Arc::ptr_eq` checks for duplicate variables
197///
198/// [`SimplifiedPatternMatcher`]: morok_ir::pattern::SimplifiedPatternMatcher
199#[proc_macro]
200pub fn patterns(input: TokenStream) -> TokenStream {
201    let pattern_list = parse_macro_input!(input as patterns::PatternList);
202
203    match patterns::generate_simplified_pattern_matcher(&pattern_list) {
204        Ok(tokens) => tokens.into(),
205        Err(e) => e.to_compile_error().into(),
206    }
207}
208
209/// Like `patterns!` but wraps the matcher in `LazyLock` for zero-cost reuse.
210///
211/// Returns `&'static SimplifiedPatternMatcher<C>` instead of an owned matcher.
212/// The matcher is constructed only once on first call and cached globally.
213///
214/// Use this for stateless `pm_*()` functions that are called repeatedly
215/// (e.g., once per kernel). Avoids re-constructing closures and hashmaps
216/// on every call.
217///
218/// # Example
219///
220/// ```ignore
221/// pub fn pm_render() -> &'static TypedPatternMatcher {
222///     cached_patterns! { ... }
223/// }
224/// ```
225#[proc_macro]
226pub fn cached_patterns(input: TokenStream) -> TokenStream {
227    let pattern_list = parse_macro_input!(input as patterns::PatternList);
228
229    match patterns::generate_cached_pattern_matcher(&pattern_list) {
230        Ok(tokens) => tokens.into(),
231        Err(e) => e.to_compile_error().into(),
232    }
233}