Skip to main content

morok_ir/rewrite/
engine.rs

1//! Graph rewrite engine implementation.
2//!
3//! # Algorithm
4//!
5//! Stack-based 3-stage DFS traversal with waitlist for dependency resolution:
6//!
7//! - **Stage 0 (PushChildren)**: Apply `bpm` patterns (if present), then push children
8//!   - `bpm` patterns see **ORIGINAL** children
9//!   - Used for bottom-up patterns that need to transform before descent
10//!
11//! - **Stage 1 (ApplyPatterns)**: Reconstruct with optimized children, then apply `pm` patterns
12//!   - `pm` patterns see **OPTIMIZED** children
13//!   - This is the default mode - patterns run after children are processed
14//!
15//! - **Stage 2 (Link)**: Link original node to final result
16//!
17//! # API
18//!
19//! - `graph_rewrite(pm, root, ctx)` - Default: patterns see optimized children (Stage 1)
20//! - `graph_rewrite_bottom_up(bpm, root, ctx)` - Patterns see original children (Stage 0)
21//!
22//! # Pattern Context
23//!
24//! Context is passed at rewrite-time through `graph_rewrite()`, not captured in
25//! closures. This provides compile-time type safety without `Rc<RefCell<>>`
26//! boilerplate.
27//!
28//! ## Example
29//!
30//! ```ignore
31//! use morok_ir::pattern::SimplifiedPatternMatcher;
32//!
33//! // Create context
34//! let mut ctx = KernelContext::new();
35//!
36//! // Create matcher using the patterns! macro
37//! let matcher = patterns! {
38//!     Add(x, @zero) ~> |x| x.clone(),
39//!     Mul(x, @one) ~> |x| x.clone(),
40//! };
41//!
42//! // Pass context at rewrite time - patterns see OPTIMIZED children
43//! let result = graph_rewrite(&matcher, root, &mut ctx);
44//! ```
45//!
46//! Patterns that don't need context use `()` as the context type:
47//!
48//! ```ignore
49//! let matcher = patterns! {
50//!     Add(x, @zero) ~> |x| x.clone(),
51//! };
52//! let result = graph_rewrite(&matcher, root, &mut ());
53//! ```
54
55use crate::{UOp, UOpKey};
56use std::collections::{HashMap, HashSet};
57use std::sync::Arc;
58
59use crate::pattern::{Matcher, RewriteResult};
60
61/// Maximum stack size before we consider the rewrite to be in an infinite loop.
62const REWRITE_STACK_LIMIT: usize = 500_000;
63
64/// Stack entry for the 3-stage rewrite algorithm.
65///
66/// - `n`: the original node (used as key in `replace` dict)
67/// - `stage`: 0 (PushChildren), 1 (ApplyPatterns), or 2 (Link)
68/// - `new_n`: the working copy (may differ from `n` after bpm rewrites or reconstruction)
69#[derive(Clone)]
70struct Entry {
71    n: Arc<UOp>,
72    stage: u8,
73    new_n: Arc<UOp>,
74}
75
76/// Internal rewrite engine.
77///
78/// Generic over matcher types and context type for compile-time type-safe matching.
79/// Supports separate `pm` (top-down) and `bpm` (bottom-up) matchers.
80struct RewriteEngine<'a, PM, BPM, C>
81where
82    PM: Matcher<C>,
83    BPM: Matcher<C>,
84{
85    /// Top-down pattern matcher: applied in Stage 1 (ApplyPatterns).
86    /// Patterns see OPTIMIZED children.
87    pm: Option<&'a PM>,
88
89    /// Bottom-up pattern matcher: applied in Stage 0 (PushChildren).
90    /// Patterns see ORIGINAL children.
91    bpm: Option<&'a BPM>,
92
93    /// Mutable reference to context passed through to patterns.
94    ctx: &'a mut C,
95
96    /// Results cache: maps original node → optimized result.
97    replace: HashMap<UOpKey, Arc<UOp>>,
98
99    /// BPM result cache: prevents re-running pattern matching on nodes already seen.
100    bpm_cache: HashMap<UOpKey, Option<Arc<UOp>>>,
101}
102
103impl<'a, PM, BPM, C> RewriteEngine<'a, PM, BPM, C>
104where
105    PM: Matcher<C>,
106    BPM: Matcher<C>,
107{
108    fn new(pm: Option<&'a PM>, bpm: Option<&'a BPM>, ctx: &'a mut C) -> Self {
109        Self { pm, bpm, ctx, replace: HashMap::new(), bpm_cache: HashMap::new() }
110    }
111
112    /// Single-shot top-down pattern application.
113    /// No cache needed: pm_rewrite is called at most once per UOp due to the
114    /// replace dict check in the main loop.
115    #[inline]
116    fn pm_rewrite(&mut self, x: &Arc<UOp>) -> Option<Arc<UOp>> {
117        let pm = self.pm.as_ref()?;
118        match pm.rewrite(x, self.ctx) {
119            RewriteResult::Rewritten(new_node) => {
120                debug_assert!(
121                    !Arc::ptr_eq(&new_node, x),
122                    "PM pattern returned Rewritten but produced the same node (id={}). \
123                     This causes infinite loops. Return NoMatch instead.\nOp: {:?}",
124                    x.id,
125                    x.op().as_ref(),
126                );
127                Some(new_node)
128            }
129            RewriteResult::Gate(_) | RewriteResult::NoMatch => None,
130        }
131    }
132
133    /// Cached bottom-up pattern application.
134    /// Cache prevents re-running patterns on nodes already seen during fixed-point.
135    /// Gate results are NOT cached — Gate is an exception that bypasses the cache.
136    #[inline]
137    fn cached_bpm_rewrite(&mut self, x: &Arc<UOp>) -> Result<Option<Arc<UOp>>, Arc<UOp>> {
138        let key = UOpKey(x.clone());
139        if let Some(cached) = self.bpm_cache.get(&key) {
140            return match cached {
141                Some(node) => Ok(Some(node.clone())),
142                None => Ok(None),
143            };
144        }
145        let bpm = self.bpm.as_ref().unwrap();
146        match bpm.rewrite(x, self.ctx) {
147            RewriteResult::Rewritten(new_node) => {
148                debug_assert!(
149                    !Arc::ptr_eq(&new_node, x),
150                    "BPM pattern returned Rewritten but produced the same node (id={}). \
151                     This causes infinite loops. Return NoMatch instead.\nOp: {:?}",
152                    x.id,
153                    x.op().as_ref(),
154                );
155                self.bpm_cache.insert(key, Some(new_node.clone()));
156                Ok(Some(new_node))
157            }
158            RewriteResult::Gate(gate_node) => Err(gate_node),
159            RewriteResult::NoMatch => {
160                self.bpm_cache.insert(key, None);
161                Ok(None)
162            }
163        }
164    }
165
166    /// Record a result in the replace map, with provenance tracking.
167    #[inline]
168    fn record_replace(&mut self, original: &Arc<UOp>, result: Arc<UOp>) {
169        if !Arc::ptr_eq(original, &result) {
170            use crate::provenance::{PROVENANCE_TRACKER, PassName};
171            PROVENANCE_TRACKER.with(|tracker| {
172                tracker.borrow_mut().record_transform(result.id, original.id, PassName::RewritePattern);
173            });
174        }
175        self.replace.insert(UOpKey(original.clone()), result);
176    }
177
178    /// Main rewrite method — stack-based 3-stage traversal.
179    #[allow(clippy::mutable_key_type)]
180    fn rewrite(&mut self, root: Arc<UOp>) -> Arc<UOp> {
181        let mut stack: Vec<Entry> = vec![Entry { n: root.clone(), stage: 0, new_n: root.clone() }];
182
183        // All UOps either on the stack or in self.replace — don't have to be placed again.
184        let mut on_stack: HashSet<UOpKey> = HashSet::new();
185        on_stack.insert(UOpKey(root.clone()));
186
187        // UOps waiting on a dependency to be in self.replace.
188        let mut waitlist: HashMap<UOpKey, Vec<Entry>> = HashMap::new();
189
190        while let Some(Entry { n, stage, new_n }) = stack.pop() {
191            if stack.len() > REWRITE_STACK_LIMIT {
192                panic!(
193                    "infinite loop in graph_rewrite (stack too big: {}). results cached: {}",
194                    stack.len(),
195                    self.replace.len(),
196                );
197            }
198
199            let n_key = UOpKey(n.clone());
200
201            if self.replace.contains_key(&n_key) {
202                continue;
203            }
204
205            if stage == 0 {
206                // Stage 0: PushChildren
207                let mut working = new_n;
208
209                if self.bpm.is_some() {
210                    // Apply bpm rewrite rules until a fixed point is reached.
211                    let mut seen: HashSet<UOpKey> = HashSet::new();
212                    let mut gated = false;
213                    loop {
214                        let working_key = UOpKey(working.clone());
215                        if seen.contains(&working_key) {
216                            panic!(
217                                "infinite loop in fixed_point_rewrite: node {:?} (id={}) seen twice",
218                                working.op().as_ref(),
219                                working.id
220                            );
221                        }
222                        seen.insert(working_key);
223                        match self.cached_bpm_rewrite(&working) {
224                            Ok(Some(rewritten)) => {
225                                working = rewritten;
226                            }
227                            Ok(None) => break,
228                            Err(gate_node) => {
229                                // Gate: done with this node, don't descend into children
230                                self.record_replace(&n, gate_node);
231                                if let Some(entries) = waitlist.remove(&n_key) {
232                                    stack.extend(entries);
233                                }
234                                gated = true;
235                                break;
236                            }
237                        }
238                    }
239                    if gated {
240                        continue;
241                    }
242                }
243
244                stack.push(Entry { n: n.clone(), stage: 1, new_n: working.clone() });
245
246                let sources = working.op().sources();
247                for child in sources.iter().rev() {
248                    let child_key = UOpKey(child.clone());
249                    if on_stack.contains(&child_key) {
250                        continue;
251                    }
252                    stack.push(Entry { n: child.clone(), stage: 0, new_n: child.clone() });
253                    on_stack.insert(child_key);
254                }
255            } else if stage == 1 {
256                // Stage 1: ApplyPatterns
257                let sources = new_n.op().sources();
258
259                let mut tmp: Vec<Arc<UOp>> = Vec::with_capacity(sources.len());
260                let mut waiting = false;
261
262                for src in &sources {
263                    let src_key = UOpKey(src.clone());
264                    if let Some(rx) = self.replace.get(&src_key) {
265                        tmp.push(rx.clone());
266                    } else {
267                        // Source not ready: register in waitlist
268                        waitlist.entry(src_key).or_default().push(Entry {
269                            n: n.clone(),
270                            stage: 1,
271                            new_n: new_n.clone(),
272                        });
273                        waiting = true;
274                        break;
275                    }
276                }
277
278                if waiting {
279                    continue;
280                }
281
282                // All sources ready — reconstruct if any changed
283                let sources_changed = tmp.iter().zip(sources.iter()).any(|(a, b)| !Arc::ptr_eq(a, b));
284
285                // Hash consing may collapse reconstruction to same node even when
286                // sources logically changed. Detect this and treat as unchanged.
287                let node = if sources_changed {
288                    let reconstructed = new_n.with_sources(tmp);
289                    if Arc::ptr_eq(&reconstructed, &new_n) { new_n.clone() } else { reconstructed }
290                } else {
291                    new_n.clone()
292                };
293
294                if Arc::ptr_eq(&node, &new_n) {
295                    // Sources effectively unchanged: try pm rewrite
296                    if let Some(new_src_n) = self.pm_rewrite(&new_n) {
297                        stack.push(Entry { n: n.clone(), stage: 2, new_n: new_src_n.clone() });
298                        stack.push(Entry { n: new_src_n.clone(), stage: 0, new_n: new_src_n });
299                    } else {
300                        // No pm match — done with this node
301                        self.record_replace(&n, new_n);
302                        if let Some(entries) = waitlist.remove(&n_key) {
303                            stack.extend(entries);
304                        }
305                    }
306                } else {
307                    // Reconstruction produced a new node — process it, then link back
308                    stack.push(Entry { n: n.clone(), stage: 2, new_n: node.clone() });
309                    stack.push(Entry { n: node.clone(), stage: 0, new_n: node });
310                }
311            } else {
312                // Stage 2: Link
313                let new_n_key = UOpKey(new_n.clone());
314
315                if let Some(replaced_new_n) = self.replace.get(&new_n_key).cloned() {
316                    self.record_replace(&n, replaced_new_n);
317                    if let Some(entries) = waitlist.remove(&n_key) {
318                        stack.extend(entries);
319                    }
320                } else {
321                    // Not ready: register in waitlist
322                    waitlist.entry(new_n_key).or_default().push(Entry { n, stage: 2, new_n });
323                }
324            }
325        }
326
327        self.replace.get(&UOpKey(root.clone())).cloned().unwrap_or(root)
328    }
329}
330
331/// Marker type for "no matcher" in generic contexts.
332pub struct NoMatcher;
333
334impl<C> Matcher<C> for NoMatcher {
335    fn rewrite(&self, _node: &Arc<UOp>, _ctx: &mut C) -> RewriteResult {
336        RewriteResult::NoMatch
337    }
338}
339
340/// Apply graph rewriting. Patterns see **OPTIMIZED** children (Stage 1).
341pub fn graph_rewrite<M: Matcher<C>, C>(matcher: &M, root: Arc<UOp>, ctx: &mut C) -> Arc<UOp> {
342    RewriteEngine::new(Some(matcher), None::<&NoMatcher>, ctx).rewrite(root)
343}
344
345/// Apply graph rewriting with bottom-up pattern application.
346/// Patterns see **ORIGINAL** children (Stage 0).
347pub fn graph_rewrite_bottom_up<M: Matcher<C>, C>(matcher: &M, root: Arc<UOp>, ctx: &mut C) -> Arc<UOp> {
348    RewriteEngine::new(None::<&NoMatcher>, Some(matcher), ctx).rewrite(root)
349}
350
351/// Apply graph rewriting with both top-down and bottom-up patterns.
352/// - `bpm` patterns see ORIGINAL children (Stage 0)
353/// - `pm` patterns see OPTIMIZED children (Stage 1)
354pub fn graph_rewrite_with_bpm<PM, BPM, C>(pm: &PM, bpm: &BPM, root: Arc<UOp>, ctx: &mut C) -> Arc<UOp>
355where
356    PM: Matcher<C>,
357    BPM: Matcher<C>,
358{
359    RewriteEngine::new(Some(pm), Some(bpm), ctx).rewrite(root)
360}