morok_macros/lib.rs
1//! Proc-macros for morok.
2//!
3//! This crate provides:
4//! - `#[derive(PatternEnum)]` for generating pattern matching infrastructure from Op enum
5//! - `patterns!` macro for declarative pattern rewrite rules
6
7use proc_macro::TokenStream;
8use syn::{DeriveInput, parse_macro_input};
9
10mod pattern_enum;
11mod patterns;
12
13/// Derive macro for generating pattern matching infrastructure from an Op enum.
14///
15/// This macro analyzes your `Op` enum and generates:
16/// - `OpKey` enum for O(1) pattern dispatch
17/// - `OpKey::from_op()` method to extract the key from an `Op`
18/// - `pattern_metadata` module with variant information
19///
20/// # Usage
21///
22/// ```ignore
23/// #[derive(PatternEnum)]
24/// #[pattern(grouped = [Unary, Binary, Ternary])]
25/// pub enum Op {
26/// Const(ConstValue),
27/// Unary(UnaryOp, Arc<UOp>),
28/// Binary(BinaryOp, Arc<UOp>, Arc<UOp>),
29/// #[pattern(skip)]
30/// Invalid,
31/// }
32/// ```
33///
34/// # Attributes
35///
36/// ## Enum-level
37///
38/// - `#[pattern(grouped = [Variant1, Variant2, ...])]` - Marks variants where the first
39/// field is a sub-enum discriminant. For example, `Binary(BinaryOp, ...)` has `BinaryOp`
40/// as a sub-discriminant, so `OpKey::Binary(BinaryOp::Add)` differs from `OpKey::Binary(BinaryOp::Mul)`.
41///
42/// ## Variant-level
43///
44/// - `#[pattern(skip)]` - Skip pattern generation for this variant (e.g., `Invalid`).
45///
46/// # Field Type Detection
47///
48/// The macro automatically classifies field types:
49/// - `Arc<UOp>` → child operand (fixed arity)
50/// - `SmallVec<[Arc<UOp>; N]>` or `Vec<Arc<UOp>>` → variadic children
51/// - `Option<Arc<UOp>>` → optional child
52/// - Other types → filter/metadata (e.g., `DType`, `DeviceSpec`)
53///
54/// # Generated Items
55///
56/// ```ignore
57/// mod pattern_derived {
58/// // Discriminant enum for O(1) dispatch
59/// pub enum OpKey {
60/// Const,
61/// Unary(UnaryOp),
62/// Binary(BinaryOp),
63/// // ...
64/// }
65///
66/// impl OpKey {
67/// pub fn from_op(op: &Op) -> Self { ... }
68/// }
69///
70/// pub mod pattern_metadata {
71/// pub const BINARY_OPS: &[&str] = &["Add", "Mul", ...];
72/// // ...
73/// }
74/// }
75/// ```
76#[proc_macro_derive(PatternEnum, attributes(pattern))]
77pub fn derive_pattern_enum(input: TokenStream) -> TokenStream {
78 let input = parse_macro_input!(input as DeriveInput);
79 match pattern_enum::generate(&input) {
80 Ok(tokens) => tokens.into(),
81 Err(e) => e.to_compile_error().into(),
82 }
83}
84
85/// Proc-macro for declarative pattern rewrite rules.
86///
87/// Generates a [`SimplifiedPatternMatcher`] from a list of pattern rewrite rules.
88/// Patterns are compiled to efficient Rust code with O(1) dispatch via `OpKey`.
89///
90/// # Syntax Overview
91///
92/// ```text
93/// patterns! {
94/// // Basic rule: pattern ~> rewrite (or => for fallible)
95/// Add(x, @zero) ~> x,
96///
97/// // With guard clause
98/// Mul(x, y) if is_power_of_two(y) => { ... },
99///
100/// // For-loop to apply same pattern to multiple ops
101/// for op in binary [Add, Mul, Sub] {
102/// op(x, @zero) ~> x,
103/// }
104/// }
105/// ```
106///
107/// # Arrow Types
108///
109/// - `~>` **Infallible**: Closure returns `Arc<UOp>` directly
110/// - `=>` **Fallible**: Closure returns `Option<Arc<UOp>>`
111///
112/// # Pattern Syntax
113///
114/// ## Operation Patterns
115///
116/// ```text
117/// Add(x, y) // Tuple-style: match by position
118/// Cast { src, dtype } // Struct-style: match by field name
119/// ```
120///
121/// ## Special Constants
122///
123/// - `@zero` - Matches constant zero (any numeric type)
124/// - `@one` - Matches constant one (any numeric type)
125/// - `@const(cv)` - Matches any constant, binds value to `cv: &ConstValue`
126/// - `_c@const(cv)` - Underscore prefix: don't bind the UOp, only the value
127///
128/// ## Duplicate Variables (Auto ptr_eq)
129///
130/// Same variable name appearing multiple times generates `Arc::ptr_eq` checks:
131///
132/// ```text
133/// Add(x, x) ~> ... // Matches when both children are the same node
134/// Where(x, x, x) ~> ... // All three must be ptr_eq
135/// ```
136///
137/// ## Commutative Matching
138///
139/// Square brackets enable commutative matching (tries both orderings):
140///
141/// ```text
142/// Add[x, @zero] ~> x // Matches Add(x, 0) or Add(0, x)
143/// ```
144///
145/// ## Alternative Patterns
146///
147/// Match any of several patterns:
148///
149/// ```text
150/// (Add | Sub)(x, @zero) ~> x // Matches Add(x, 0) or Sub(x, 0)
151/// ```
152///
153/// ## Binding Patterns
154///
155/// Bind a name to a subpattern:
156///
157/// ```text
158/// result@Add(x, y) => { ... use result, x, y ... }
159/// ```
160///
161/// # For-Loops
162///
163/// Apply the same pattern template to multiple operations:
164///
165/// ```text
166/// for op in unary [Neg, Not, Sqrt] {
167/// op(x) if is_const(x) => { fold_unary(op, x) }
168/// }
169///
170/// for op in binary [Add, Mul, Sub] {
171/// op(x, @zero) ~> x,
172/// }
173/// ```
174///
175/// # Context Types
176///
177/// Declare a context type to pass mutable state through patterns:
178///
179/// ```text
180/// patterns! {
181/// @context MyContext;
182///
183/// Add(x, y) => |ctx, x, y| {
184/// ctx.record_match();
185/// Some(x.clone())
186/// }
187/// }
188/// ```
189///
190/// # Generated Code
191///
192/// This macro generates a `SimplifiedPatternMatcher` with:
193/// - Compile-time validation of all operation names
194/// - O(1) dispatch via OpKey hashmap
195/// - Inline pattern matching (no runtime pattern interpretation)
196/// - Automatic `Arc::ptr_eq` checks for duplicate variables
197///
198/// [`SimplifiedPatternMatcher`]: morok_ir::pattern::SimplifiedPatternMatcher
199#[proc_macro]
200pub fn patterns(input: TokenStream) -> TokenStream {
201 let pattern_list = parse_macro_input!(input as patterns::PatternList);
202
203 match patterns::generate_simplified_pattern_matcher(&pattern_list) {
204 Ok(tokens) => tokens.into(),
205 Err(e) => e.to_compile_error().into(),
206 }
207}
208
209/// Like `patterns!` but wraps the matcher in `LazyLock` for zero-cost reuse.
210///
211/// Returns `&'static SimplifiedPatternMatcher<C>` instead of an owned matcher.
212/// The matcher is constructed only once on first call and cached globally.
213///
214/// Use this for stateless `pm_*()` functions that are called repeatedly
215/// (e.g., once per kernel). Avoids re-constructing closures and hashmaps
216/// on every call.
217///
218/// # Example
219///
220/// ```ignore
221/// pub fn pm_render() -> &'static TypedPatternMatcher {
222/// cached_patterns! { ... }
223/// }
224/// ```
225#[proc_macro]
226pub fn cached_patterns(input: TokenStream) -> TokenStream {
227 let pattern_list = parse_macro_input!(input as patterns::PatternList);
228
229 match patterns::generate_cached_pattern_matcher(&pattern_list) {
230 Ok(tokens) => tokens.into(),
231 Err(e) => e.to_compile_error().into(),
232 }
233}