morok_ir/rewrite/engine.rs
1//! Graph rewrite engine implementation.
2//!
3//! # Algorithm
4//!
5//! Stack-based 3-stage DFS traversal with waitlist for dependency resolution:
6//!
7//! - **Stage 0 (PushChildren)**: Apply `bpm` patterns (if present), then push children
8//! - `bpm` patterns see **ORIGINAL** children
9//! - Used for bottom-up patterns that need to transform before descent
10//!
11//! - **Stage 1 (ApplyPatterns)**: Reconstruct with optimized children, then apply `pm` patterns
12//! - `pm` patterns see **OPTIMIZED** children
13//! - This is the default mode - patterns run after children are processed
14//!
15//! - **Stage 2 (Link)**: Link original node to final result
16//!
17//! # API
18//!
19//! - `graph_rewrite(pm, root, ctx)` - Default: patterns see optimized children (Stage 1)
20//! - `graph_rewrite_bottom_up(bpm, root, ctx)` - Patterns see original children (Stage 0)
21//!
22//! # Pattern Context
23//!
24//! Context is passed at rewrite-time through `graph_rewrite()`, not captured in
25//! closures. This provides compile-time type safety without `Rc<RefCell<>>`
26//! boilerplate.
27//!
28//! ## Example
29//!
30//! ```ignore
31//! use morok_ir::pattern::SimplifiedPatternMatcher;
32//!
33//! // Create context
34//! let mut ctx = KernelContext::new();
35//!
36//! // Create matcher using the patterns! macro
37//! let matcher = patterns! {
38//! Add(x, @zero) ~> |x| x.clone(),
39//! Mul(x, @one) ~> |x| x.clone(),
40//! };
41//!
42//! // Pass context at rewrite time - patterns see OPTIMIZED children
43//! let result = graph_rewrite(&matcher, root, &mut ctx);
44//! ```
45//!
46//! Patterns that don't need context use `()` as the context type:
47//!
48//! ```ignore
49//! let matcher = patterns! {
50//! Add(x, @zero) ~> |x| x.clone(),
51//! };
52//! let result = graph_rewrite(&matcher, root, &mut ());
53//! ```
54
55use crate::{UOp, UOpKey};
56use std::collections::{HashMap, HashSet};
57use std::sync::Arc;
58
59use crate::pattern::{Matcher, RewriteResult};
60
61/// Maximum stack size before we consider the rewrite to be in an infinite loop.
62const REWRITE_STACK_LIMIT: usize = 500_000;
63
64/// Stack entry for the 3-stage rewrite algorithm.
65///
66/// - `n`: the original node (used as key in `replace` dict)
67/// - `stage`: 0 (PushChildren), 1 (ApplyPatterns), or 2 (Link)
68/// - `new_n`: the working copy (may differ from `n` after bpm rewrites or reconstruction)
69#[derive(Clone)]
70struct Entry {
71 n: Arc<UOp>,
72 stage: u8,
73 new_n: Arc<UOp>,
74}
75
76/// Internal rewrite engine.
77///
78/// Generic over matcher types and context type for compile-time type-safe matching.
79/// Supports separate `pm` (top-down) and `bpm` (bottom-up) matchers.
80struct RewriteEngine<'a, PM, BPM, C>
81where
82 PM: Matcher<C>,
83 BPM: Matcher<C>,
84{
85 /// Top-down pattern matcher: applied in Stage 1 (ApplyPatterns).
86 /// Patterns see OPTIMIZED children.
87 pm: Option<&'a PM>,
88
89 /// Bottom-up pattern matcher: applied in Stage 0 (PushChildren).
90 /// Patterns see ORIGINAL children.
91 bpm: Option<&'a BPM>,
92
93 /// Mutable reference to context passed through to patterns.
94 ctx: &'a mut C,
95
96 /// Results cache: maps original node → optimized result.
97 replace: HashMap<UOpKey, Arc<UOp>>,
98
99 /// BPM result cache: prevents re-running pattern matching on nodes already seen.
100 bpm_cache: HashMap<UOpKey, Option<Arc<UOp>>>,
101}
102
103impl<'a, PM, BPM, C> RewriteEngine<'a, PM, BPM, C>
104where
105 PM: Matcher<C>,
106 BPM: Matcher<C>,
107{
108 fn new(pm: Option<&'a PM>, bpm: Option<&'a BPM>, ctx: &'a mut C) -> Self {
109 Self { pm, bpm, ctx, replace: HashMap::new(), bpm_cache: HashMap::new() }
110 }
111
112 /// Single-shot top-down pattern application.
113 /// No cache needed: pm_rewrite is called at most once per UOp due to the
114 /// replace dict check in the main loop.
115 #[inline]
116 fn pm_rewrite(&mut self, x: &Arc<UOp>) -> Option<Arc<UOp>> {
117 let pm = self.pm.as_ref()?;
118 match pm.rewrite(x, self.ctx) {
119 RewriteResult::Rewritten(new_node) => {
120 debug_assert!(
121 !Arc::ptr_eq(&new_node, x),
122 "PM pattern returned Rewritten but produced the same node (id={}). \
123 This causes infinite loops. Return NoMatch instead.\nOp: {:?}",
124 x.id,
125 x.op().as_ref(),
126 );
127 Some(new_node)
128 }
129 RewriteResult::Gate(_) | RewriteResult::NoMatch => None,
130 }
131 }
132
133 /// Cached bottom-up pattern application.
134 /// Cache prevents re-running patterns on nodes already seen during fixed-point.
135 /// Gate results are NOT cached — Gate is an exception that bypasses the cache.
136 #[inline]
137 fn cached_bpm_rewrite(&mut self, x: &Arc<UOp>) -> Result<Option<Arc<UOp>>, Arc<UOp>> {
138 let key = UOpKey(x.clone());
139 if let Some(cached) = self.bpm_cache.get(&key) {
140 return match cached {
141 Some(node) => Ok(Some(node.clone())),
142 None => Ok(None),
143 };
144 }
145 let bpm = self.bpm.as_ref().unwrap();
146 match bpm.rewrite(x, self.ctx) {
147 RewriteResult::Rewritten(new_node) => {
148 debug_assert!(
149 !Arc::ptr_eq(&new_node, x),
150 "BPM pattern returned Rewritten but produced the same node (id={}). \
151 This causes infinite loops. Return NoMatch instead.\nOp: {:?}",
152 x.id,
153 x.op().as_ref(),
154 );
155 self.bpm_cache.insert(key, Some(new_node.clone()));
156 Ok(Some(new_node))
157 }
158 RewriteResult::Gate(gate_node) => Err(gate_node),
159 RewriteResult::NoMatch => {
160 self.bpm_cache.insert(key, None);
161 Ok(None)
162 }
163 }
164 }
165
166 /// Record a result in the replace map, with provenance tracking.
167 #[inline]
168 fn record_replace(&mut self, original: &Arc<UOp>, result: Arc<UOp>) {
169 if !Arc::ptr_eq(original, &result) {
170 use crate::provenance::{PROVENANCE_TRACKER, PassName};
171 PROVENANCE_TRACKER.with(|tracker| {
172 tracker.borrow_mut().record_transform(result.id, original.id, PassName::RewritePattern);
173 });
174 }
175 self.replace.insert(UOpKey(original.clone()), result);
176 }
177
178 /// Main rewrite method — stack-based 3-stage traversal.
179 #[allow(clippy::mutable_key_type)]
180 fn rewrite(&mut self, root: Arc<UOp>) -> Arc<UOp> {
181 let mut stack: Vec<Entry> = vec![Entry { n: root.clone(), stage: 0, new_n: root.clone() }];
182
183 // All UOps either on the stack or in self.replace — don't have to be placed again.
184 let mut on_stack: HashSet<UOpKey> = HashSet::new();
185 on_stack.insert(UOpKey(root.clone()));
186
187 // UOps waiting on a dependency to be in self.replace.
188 let mut waitlist: HashMap<UOpKey, Vec<Entry>> = HashMap::new();
189
190 while let Some(Entry { n, stage, new_n }) = stack.pop() {
191 if stack.len() > REWRITE_STACK_LIMIT {
192 panic!(
193 "infinite loop in graph_rewrite (stack too big: {}). results cached: {}",
194 stack.len(),
195 self.replace.len(),
196 );
197 }
198
199 let n_key = UOpKey(n.clone());
200
201 if self.replace.contains_key(&n_key) {
202 continue;
203 }
204
205 if stage == 0 {
206 // Stage 0: PushChildren
207 let mut working = new_n;
208
209 if self.bpm.is_some() {
210 // Apply bpm rewrite rules until a fixed point is reached.
211 let mut seen: HashSet<UOpKey> = HashSet::new();
212 let mut gated = false;
213 loop {
214 let working_key = UOpKey(working.clone());
215 if seen.contains(&working_key) {
216 panic!(
217 "infinite loop in fixed_point_rewrite: node {:?} (id={}) seen twice",
218 working.op().as_ref(),
219 working.id
220 );
221 }
222 seen.insert(working_key);
223 match self.cached_bpm_rewrite(&working) {
224 Ok(Some(rewritten)) => {
225 working = rewritten;
226 }
227 Ok(None) => break,
228 Err(gate_node) => {
229 // Gate: done with this node, don't descend into children
230 self.record_replace(&n, gate_node);
231 if let Some(entries) = waitlist.remove(&n_key) {
232 stack.extend(entries);
233 }
234 gated = true;
235 break;
236 }
237 }
238 }
239 if gated {
240 continue;
241 }
242 }
243
244 stack.push(Entry { n: n.clone(), stage: 1, new_n: working.clone() });
245
246 let sources = working.op().sources();
247 for child in sources.iter().rev() {
248 let child_key = UOpKey(child.clone());
249 if on_stack.contains(&child_key) {
250 continue;
251 }
252 stack.push(Entry { n: child.clone(), stage: 0, new_n: child.clone() });
253 on_stack.insert(child_key);
254 }
255 } else if stage == 1 {
256 // Stage 1: ApplyPatterns
257 let sources = new_n.op().sources();
258
259 let mut tmp: Vec<Arc<UOp>> = Vec::with_capacity(sources.len());
260 let mut waiting = false;
261
262 for src in &sources {
263 let src_key = UOpKey(src.clone());
264 if let Some(rx) = self.replace.get(&src_key) {
265 tmp.push(rx.clone());
266 } else {
267 // Source not ready: register in waitlist
268 waitlist.entry(src_key).or_default().push(Entry {
269 n: n.clone(),
270 stage: 1,
271 new_n: new_n.clone(),
272 });
273 waiting = true;
274 break;
275 }
276 }
277
278 if waiting {
279 continue;
280 }
281
282 // All sources ready — reconstruct if any changed
283 let sources_changed = tmp.iter().zip(sources.iter()).any(|(a, b)| !Arc::ptr_eq(a, b));
284
285 // Hash consing may collapse reconstruction to same node even when
286 // sources logically changed. Detect this and treat as unchanged.
287 let node = if sources_changed {
288 let reconstructed = new_n.with_sources(tmp);
289 if Arc::ptr_eq(&reconstructed, &new_n) { new_n.clone() } else { reconstructed }
290 } else {
291 new_n.clone()
292 };
293
294 if Arc::ptr_eq(&node, &new_n) {
295 // Sources effectively unchanged: try pm rewrite
296 if let Some(new_src_n) = self.pm_rewrite(&new_n) {
297 stack.push(Entry { n: n.clone(), stage: 2, new_n: new_src_n.clone() });
298 stack.push(Entry { n: new_src_n.clone(), stage: 0, new_n: new_src_n });
299 } else {
300 // No pm match — done with this node
301 self.record_replace(&n, new_n);
302 if let Some(entries) = waitlist.remove(&n_key) {
303 stack.extend(entries);
304 }
305 }
306 } else {
307 // Reconstruction produced a new node — process it, then link back
308 stack.push(Entry { n: n.clone(), stage: 2, new_n: node.clone() });
309 stack.push(Entry { n: node.clone(), stage: 0, new_n: node });
310 }
311 } else {
312 // Stage 2: Link
313 let new_n_key = UOpKey(new_n.clone());
314
315 if let Some(replaced_new_n) = self.replace.get(&new_n_key).cloned() {
316 self.record_replace(&n, replaced_new_n);
317 if let Some(entries) = waitlist.remove(&n_key) {
318 stack.extend(entries);
319 }
320 } else {
321 // Not ready: register in waitlist
322 waitlist.entry(new_n_key).or_default().push(Entry { n, stage: 2, new_n });
323 }
324 }
325 }
326
327 self.replace.get(&UOpKey(root.clone())).cloned().unwrap_or(root)
328 }
329}
330
331/// Marker type for "no matcher" in generic contexts.
332pub struct NoMatcher;
333
334impl<C> Matcher<C> for NoMatcher {
335 fn rewrite(&self, _node: &Arc<UOp>, _ctx: &mut C) -> RewriteResult {
336 RewriteResult::NoMatch
337 }
338}
339
340/// Apply graph rewriting. Patterns see **OPTIMIZED** children (Stage 1).
341pub fn graph_rewrite<M: Matcher<C>, C>(matcher: &M, root: Arc<UOp>, ctx: &mut C) -> Arc<UOp> {
342 RewriteEngine::new(Some(matcher), None::<&NoMatcher>, ctx).rewrite(root)
343}
344
345/// Apply graph rewriting with bottom-up pattern application.
346/// Patterns see **ORIGINAL** children (Stage 0).
347pub fn graph_rewrite_bottom_up<M: Matcher<C>, C>(matcher: &M, root: Arc<UOp>, ctx: &mut C) -> Arc<UOp> {
348 RewriteEngine::new(None::<&NoMatcher>, Some(matcher), ctx).rewrite(root)
349}
350
351/// Apply graph rewriting with both top-down and bottom-up patterns.
352/// - `bpm` patterns see ORIGINAL children (Stage 0)
353/// - `pm` patterns see OPTIMIZED children (Stage 1)
354pub fn graph_rewrite_with_bpm<PM, BPM, C>(pm: &PM, bpm: &BPM, root: Arc<UOp>, ctx: &mut C) -> Arc<UOp>
355where
356 PM: Matcher<C>,
357 BPM: Matcher<C>,
358{
359 RewriteEngine::new(Some(pm), Some(bpm), ctx).rewrite(root)
360}