midenc_hir/patterns/
rewriter.rs

1use alloc::{boxed::Box, format, rc::Rc};
2use core::ops::{Deref, DerefMut};
3
4use smallvec::SmallVec;
5
6use crate::{
7    patterns::Pattern, BlockRef, Builder, Context, InsertionGuard, Listener, ListenerType,
8    OpBuilder, OpOperandImpl, OperationRef, PostOrderBlockIter, ProgramPoint, RegionRef, Report,
9    SourceSpan, Usable, ValueRef,
10};
11
12/// A [Rewriter] is a [Builder] extended with additional functionality that is of primary use when
13/// rewriting the IR after it is initially constructed. It is the basis on which the pattern
14/// rewriter infrastructure is built.
15pub trait Rewriter: Builder + RewriterListener {
16    /// Returns true if this rewriter has a listener attached.
17    ///
18    /// When no listener is present, fast paths can be taken when rewriting the IR, whereas a
19    /// listener requires breaking mutations up into individual actions so that the listener can
20    /// be made aware of all of them, in the order they occur.
21    fn has_listener(&self) -> bool;
22
23    /// Replace the results of the given operation with the specified list of values (replacements).
24    ///
25    /// The result types of the given op and the replacements must match. The original op is erased.
26    fn replace_op_with_values(&mut self, op: OperationRef, values: &[Option<ValueRef>]) {
27        assert_eq!(op.borrow().num_results(), values.len());
28
29        // Replace all result uses, notifies listener of the modifications
30        self.replace_all_op_uses_with_values(op, values);
31
32        // Erase the op and notify the listener
33        self.erase_op(op);
34    }
35
36    /// Replace the results of the given operation with the specified replacement op.
37    ///
38    /// The result types of the two ops must match. The original op is erased.
39    fn replace_op(&mut self, op: OperationRef, new_op: OperationRef) {
40        assert_eq!(op.borrow().num_results(), new_op.borrow().num_results());
41
42        // Replace all result uses, notifies listener of the modifications
43        self.replace_all_op_uses_with(op, new_op);
44
45        // Erase the op and notify the listener
46        self.erase_op(op);
47    }
48
49    /// This method erases an operation that is known to have no uses.
50    fn erase_op(&mut self, mut op: OperationRef) {
51        assert!(!op.borrow().is_used(), "expected op to have no uses");
52
53        // If no listener is attached, the op can be dropped all at once.
54        if !self.has_listener() {
55            op.borrow_mut().erase();
56            return;
57        }
58
59        // Helper function that erases a single operation
60        fn erase_single_op<R: ?Sized + RewriterListener>(
61            mut operation: OperationRef,
62            rewrite_listener: &mut R,
63        ) {
64            let op = operation.borrow();
65            if cfg!(debug_assertions) {
66                // All nested ops should have been erased already
67                assert!(op.regions().iter().all(|r| r.is_empty()), "expected empty regions");
68                // All users should have been erased already if the op is in a region with SSA dominance
69                if op.is_used() {
70                    if let Some(region) = op.parent_region() {
71                        assert!(
72                            region.borrow().may_be_graph_region(),
73                            "expected that op has no uses"
74                        );
75                    }
76                }
77            }
78
79            rewrite_listener.notify_operation_erased(operation);
80
81            // Explicitly drop all uses in case the op is in a graph region
82            drop(op);
83            let mut op = operation.borrow_mut();
84            op.drop_all_uses();
85            op.erase();
86        }
87
88        // Nested ops must be erased one-by-one, so that listeners have a consistent view of the
89        // IR every time a notification is triggered. Users must be erased before definitions, i.e.
90        // in post-order, reverse dominance.
91        fn erase_tree<R: ?Sized + Rewriter>(op: OperationRef, rewriter: &mut R) {
92            // Erase nested ops
93            let mut next_region = op.borrow().regions().front().as_pointer();
94            while let Some(region) = next_region.take() {
95                next_region = region.next();
96                // Erase all blocks in the right order. Successors should be erased before
97                // predecessors because successor blocks may use values defined in predecessor
98                // blocks. A post-order traversal of blocks within a region visits successors before
99                // predecessors. Repeat the traversal until the region is empty. (The block graph
100                // could be disconnected.)
101                let mut erased_blocks = SmallVec::<[BlockRef; 4]>::default();
102                let mut region_entry = region.borrow().entry_block_ref();
103                while let Some(entry) = region_entry.take() {
104                    erased_blocks.clear();
105                    for block in PostOrderBlockIter::new(entry) {
106                        let mut next_op = block.borrow().body().front().as_pointer();
107                        while let Some(op) = next_op.take() {
108                            next_op = op.next();
109                            erase_tree(op, rewriter);
110                        }
111                        erased_blocks.push(block);
112                    }
113                    for mut block in erased_blocks.drain(..) {
114                        // Explicitly drop all uses in case there is a cycle in the block
115                        // graph.
116                        for arg in block.borrow_mut().arguments_mut() {
117                            arg.borrow_mut().uses_mut().clear();
118                        }
119                        block.borrow_mut().drop_all_uses();
120                        rewriter.erase_block(block);
121                    }
122
123                    region_entry = region.borrow().entry_block_ref();
124                }
125            }
126            erase_single_op(op, rewriter);
127        }
128
129        erase_tree(op, self);
130    }
131
132    /// This method erases all operations in a block.
133    fn erase_block(&mut self, block: BlockRef) {
134        assert!(!block.borrow().is_used(), "expected 'block' to be unused");
135
136        let mut next_op = block.borrow().body().back().as_pointer();
137        while let Some(op) = next_op.take() {
138            next_op = op.prev();
139            assert!(!op.borrow().is_used(), "expected 'op' to be unused");
140            self.erase_op(op);
141        }
142
143        // Notify the listener that the block is about to be removed.
144        self.notify_block_erased(block);
145
146        // Remove block from parent region
147        let mut region = block.parent().expect("expected 'block' to have a parent region");
148        let mut region_mut = region.borrow_mut();
149        let mut cursor = unsafe { region_mut.body_mut().cursor_mut_from_ptr(block) };
150        cursor.remove();
151    }
152
153    /// Move the blocks that belong to `region` before the given insertion point in another region,
154    /// `ip`. The two regions must be different. The caller is responsible for creating or
155    /// updating the operation transferring flow of control to the region, and passing it the
156    /// correct block arguments.
157    fn inline_region_before(&mut self, mut region: RegionRef, mut ip: RegionRef) {
158        assert!(!RegionRef::ptr_eq(&region, &ip), "cannot inline a region into itself");
159        let region_body = region.borrow_mut().body_mut().take();
160        if !self.has_listener() {
161            let mut parent_region = ip.borrow_mut();
162            let parent_body = parent_region.body_mut();
163            let mut cursor = parent_body.front_mut();
164            cursor.splice_before(region_body);
165        } else {
166            // Move blocks from beginning of the region one-by-one
167            let ip = ip.borrow().entry_block_ref().unwrap();
168            for block in region_body {
169                self.move_block_before(block, ip);
170            }
171        }
172    }
173
174    /// Inline the operations of block `src` before the given insertion point in `dest`.
175    ///
176    /// If the insertion point is `None`, the block will be inlined at the end of the target block.
177    ///
178    /// The source block will be deleted and must have no uses. The `args` values, if provided, are
179    /// used to replace the block arguments of `src`, with `None` used to signal that an argument
180    /// should be ignored.
181    ///
182    /// If the source block is inserted at the end of the dest block, the dest block must have no
183    /// successors. Similarly, if the source block is inserted somewhere in the middle (or
184    /// beginning) of the dest block, the source block must have no successors. Otherwise, the
185    /// resulting IR would have unreachable operations.
186    fn inline_block_before(
187        &mut self,
188        mut src: BlockRef,
189        mut dest: BlockRef,
190        ip: Option<OperationRef>,
191        args: &[Option<ValueRef>],
192    ) {
193        assert!(
194            args.len() == src.borrow().num_arguments(),
195            "incorrect # of argument replacement values"
196        );
197
198        // The source block will be deleted, so it should not have any users (i.e., there should be
199        // no predecessors).
200        assert!(!src.borrow().has_predecessors(), "expected 'src' to have no predecessors");
201
202        // Ensure insertion point belongs to destination block if present
203        let insert_at_block_end = if let Some(ip) = ip {
204            let ip_block = ip.parent().expect("expected 'ip' to belong to a block");
205            assert_eq!(ip_block, dest, "invalid insertion point: must be an op in 'dest'");
206            ip.next().is_none()
207        } else {
208            true
209        };
210
211        if insert_at_block_end {
212            // The source block will be inserted at the end of the dest block, so the
213            // dest block should have no successors. Otherwise, the inserted operations
214            // will be unreachable.
215            assert!(!dest.borrow().has_successors(), "expected 'dest' to have no successors");
216        } else {
217            // The source block will be inserted in the middle of the dest block, so
218            // the source block should have no successors. Otherwise, the remainder of
219            // the dest block would be unreachable.
220            assert!(!src.borrow().has_successors(), "expected 'src' to have no successors");
221        }
222
223        // Replace all of the successor arguments with the provided values.
224        for (arg, replacement) in src.borrow().arguments().iter().copied().zip(args.iter().copied())
225        {
226            if let Some(replacement) = replacement {
227                self.replace_all_uses_of_value_with(arg.upcast(), replacement);
228            }
229        }
230
231        // Move operations from the source block to the dest block and erase the source block.
232        if self.has_listener() {
233            let mut src_mut = src.borrow_mut();
234            let mut src_cursor = src_mut.body_mut().front_mut();
235            while let Some(op) = src_cursor.remove() {
236                if insert_at_block_end {
237                    self.insert_op_at_end(op, dest);
238                } else {
239                    self.insert_op_before(op, ip.unwrap());
240                }
241            }
242        } else {
243            // Fast path: If no listener is attached, move all operations at once.
244            let mut dest_block = dest.borrow_mut();
245            if let Some(ip) = ip {
246                dest_block.splice_block_before(&mut src.borrow_mut(), ip);
247            } else {
248                dest_block.splice_block(&mut src.borrow_mut());
249            }
250        }
251
252        // Erase the source block.
253        assert!(src.borrow().body().is_empty(), "expected 'src' to be empty");
254        self.erase_block(src);
255    }
256
257    /// Inline the operations of block `src` into the end of block `dest`. The source block will be
258    /// deleted and must have no uses. The `args` values, if present, are used to replace the block
259    /// arguments of `src`, where any `None` values are ignored.
260    ///
261    /// The dest block must have no successors. Otherwise, the resulting IR will have unreachable
262    /// operations.
263    fn merge_blocks(&mut self, src: BlockRef, dest: BlockRef, args: &[Option<ValueRef>]) {
264        let ip = dest.borrow().body().back().as_pointer();
265        self.inline_block_before(src, dest, ip, args);
266    }
267
268    /// Split the operations starting at `ip` (inclusive) out of the given block into a new block,
269    /// and return it.
270    fn split_block(&mut self, mut block: BlockRef, ip: OperationRef) -> BlockRef {
271        // Fast path: if no listener is attached, split the block directly
272        if !self.has_listener() {
273            return block.borrow_mut().split_block(ip);
274        }
275
276        assert_eq!(
277            block,
278            ip.parent().expect("expected 'ip' to be attached to a block"),
279            "expected 'ip' to be in 'block'"
280        );
281
282        let region =
283            block.parent().expect("cannot split a block which is not attached to a region");
284
285        // `create_block` sets the insertion point to the start of the new block
286        let mut guard = InsertionGuard::new(self);
287        let new_block = guard.create_block(region, Some(block), &[]);
288
289        // If `ip` points to the end of the block, no ops should be moved
290        if ip.next().is_none() {
291            return new_block;
292        }
293
294        // Move ops one-by-one from the end of `block` to the start of `new_block`.
295        // Stop when the operation pointed to by `ip` has been moved.
296        let mut block_mut = block.borrow_mut();
297        let mut cursor = block_mut.body_mut().back_mut();
298        let ip = new_block.borrow().body().front().as_pointer().unwrap();
299        while let Some(op) = cursor.remove() {
300            let is_last_move = OperationRef::ptr_eq(&op, &ip);
301            guard.insert_op_before(op, ip);
302            if is_last_move {
303                break;
304            }
305        }
306
307        new_block
308    }
309
310    /// Unlink this block and insert it right before `ip`.
311    fn move_block_before(&mut self, mut block: BlockRef, ip: BlockRef) {
312        let current_region = block.parent();
313        if current_region.is_none() {
314            block.borrow_mut().insert_before(ip);
315        } else {
316            block.borrow_mut().move_before(ip);
317        }
318        self.notify_block_inserted(block, current_region, Some(ip));
319    }
320
321    /// Unlink this operation from its current block and insert it right before `ip`, which
322    /// may be in the same or another block in the same function.
323    fn move_op_before(&mut self, mut op: OperationRef, ip: OperationRef) {
324        let prev = ProgramPoint::before(op);
325        op.borrow_mut().move_to(ProgramPoint::before(ip));
326        self.notify_operation_inserted(op, prev);
327    }
328
329    /// Unlink this operation from its current block and insert it right after `ip`, which may be
330    /// in the same or another block in the same function.
331    fn move_op_after(&mut self, mut op: OperationRef, ip: OperationRef) {
332        let prev = ProgramPoint::before(op);
333        op.borrow_mut().move_to(ProgramPoint::after(ip));
334        self.notify_operation_inserted(op, prev);
335    }
336
337    /// Unlink this operation from its current block and insert it at the end of `ip`.
338    fn move_op_to_end(&mut self, mut op: OperationRef, ip: BlockRef) {
339        let prev = ProgramPoint::before(op);
340        op.borrow_mut().move_to(ProgramPoint::at_end_of(ip));
341        self.notify_operation_inserted(op, prev);
342    }
343
344    /// Insert an unlinked operation right before `ip`
345    fn insert_op_before(&mut self, mut op: OperationRef, ip: OperationRef) {
346        let prev = ProgramPoint::before(op);
347        op.borrow_mut().as_operation_ref().insert_before(ip);
348        self.notify_operation_inserted(op, prev);
349    }
350
351    /// Insert an unlinked operation right after `ip`
352    fn insert_op_after(&mut self, mut op: OperationRef, ip: OperationRef) {
353        let prev = ProgramPoint::before(op);
354        op.borrow_mut().as_operation_ref().insert_after(ip);
355        self.notify_operation_inserted(op, prev);
356    }
357
358    /// Insert an unlinked operation at the end of `ip`
359    fn insert_op_at_end(&mut self, op: OperationRef, ip: BlockRef) {
360        let prev = ProgramPoint::before(op);
361        op.insert_at_end(ip);
362        self.notify_operation_inserted(op, prev);
363    }
364
365    /// Find uses of `from` and replace them with `to`.
366    ///
367    /// Notifies the listener about every in-place op modification (for every use that was replaced).
368    fn replace_all_uses_of_value_with(&mut self, mut from: ValueRef, mut to: ValueRef) {
369        let mut from_val = from.borrow_mut();
370        let from_uses = from_val.uses_mut();
371        let mut cursor = from_uses.front_mut();
372        while let Some(mut operand) = cursor.remove() {
373            let op = operand.borrow().owner;
374            self.notify_operation_modification_started(&op);
375            operand.borrow_mut().value = Some(to);
376            to.borrow_mut().insert_use(operand);
377            self.notify_operation_modified(op);
378        }
379    }
380
381    /// Find uses of `from` and replace them with `to`.
382    ///
383    /// Notifies the listener about every in-place op modification (for every use that was replaced).
384    fn replace_all_uses_of_block_with(&mut self, mut from: BlockRef, mut to: BlockRef) {
385        let mut from_block = from.borrow_mut();
386        let from_uses = from_block.uses_mut();
387        let mut cursor = from_uses.front_mut();
388        while let Some(operand) = cursor.remove() {
389            let op = operand.borrow().owner;
390            self.notify_operation_modification_started(&op);
391            to.borrow_mut().insert_use(operand);
392            self.notify_operation_modified(op);
393        }
394    }
395
396    /// Find uses of `from` and replace them with `to`.
397    ///
398    /// Notifies the listener about every in-place op modification (for every use that was replaced).
399    fn replace_all_uses_with(&mut self, from: &[ValueRef], to: &[Option<ValueRef>]) {
400        assert_eq!(from.len(), to.len(), "incorrect number of replacements");
401        for (from, to) in from.iter().cloned().zip(to.iter().cloned()) {
402            if let Some(to) = to {
403                self.replace_all_uses_of_value_with(from, to);
404            }
405        }
406    }
407
408    /// Find uses of `from` and replace them with `to`.
409    ///
410    /// Notifies the listener about every in-place modification (for every use that was replaced),
411    /// and that the `from` operation is about to be replaced.
412    fn replace_all_op_uses_with_values(&mut self, from: OperationRef, to: &[Option<ValueRef>]) {
413        self.notify_operation_replaced_with_values(from, to);
414
415        let results = from
416            .borrow()
417            .results()
418            .all()
419            .iter()
420            .copied()
421            .map(|result| result as ValueRef)
422            .collect::<SmallVec<[ValueRef; 2]>>();
423
424        self.replace_all_uses_with(&results, to);
425    }
426
427    /// Find uses of `from` and replace them with `to`.
428    ///
429    /// Notifies the listener about every in-place modification (for every use that was replaced),
430    /// and that the `from` operation is about to be replaced.
431    fn replace_all_op_uses_with(&mut self, from: OperationRef, to: OperationRef) {
432        self.notify_operation_replaced(from, to);
433
434        let from_results = from
435            .borrow()
436            .results()
437            .all()
438            .iter()
439            .copied()
440            .map(|result| result as ValueRef)
441            .collect::<SmallVec<[ValueRef; 2]>>();
442
443        let to_results = to
444            .borrow()
445            .results()
446            .all()
447            .iter()
448            .copied()
449            .map(|result| Some(result as ValueRef))
450            .collect::<SmallVec<[Option<ValueRef>; 2]>>();
451
452        self.replace_all_uses_with(&from_results, &to_results);
453    }
454
455    /// Find uses of `from` within `block` and replace them with `to`.
456    ///
457    /// Notifies the listener about every in-place op modification (for every use that was replaced).
458    ///
459    /// Returns true if all uses were replaced, otherwise false.
460    fn replace_op_uses_within_block(
461        &mut self,
462        from: OperationRef,
463        to: &[ValueRef],
464        block: BlockRef,
465    ) -> bool {
466        let parent_op = block.grandparent();
467        self.maybe_replace_op_uses_with(from, to, |operand| {
468            !parent_op
469                .as_ref()
470                .is_some_and(|op| op.borrow().is_proper_ancestor_of(&operand.owner.borrow()))
471        })
472    }
473
474    /// Find uses of `from` and replace them with `to`, except if the user is in `exceptions`.
475    ///
476    /// Notifies the listener about every in-place op modification (for every use that was replaced).
477    fn replace_all_uses_except(
478        &mut self,
479        from: ValueRef,
480        to: ValueRef,
481        exceptions: &[OperationRef],
482    ) {
483        self.maybe_replace_uses_of_value_with(from, to, |operand| {
484            !exceptions.contains(&operand.owner)
485        });
486    }
487}
488
489/// An extension trait for [Rewriter] implementations.
490///
491/// This trait contains functionality that is not object safe, and would prevent using [Rewriter] as
492/// a trait object. It is automatically implemented for all [Rewriter] impls.
493pub trait RewriterExt: Rewriter {
494    /// This is a utility function that wraps an in-place modification of an operation, such that
495    /// the rewriter is guaranteed to be notified when the modifications start and stop.
496    fn modify_op_in_place(&mut self, op: OperationRef) -> InPlaceModificationGuard<'_, Self> {
497        InPlaceModificationGuard::new(self, op)
498    }
499
500    /// Find uses of `from` and replace them with `to`, if `should_replace` returns true.
501    ///
502    /// Notifies the listener about every in-place op modification (for every use that was replaced).
503    ///
504    /// Returns true if all uses were replaced, otherwise false.
505    fn maybe_replace_uses_of_value_with<P>(
506        &mut self,
507        mut from: ValueRef,
508        mut to: ValueRef,
509        should_replace: P,
510    ) -> bool
511    where
512        P: Fn(&OpOperandImpl) -> bool,
513    {
514        let mut all_replaced = true;
515        let mut from = from.borrow_mut();
516        let from_uses = from.uses_mut();
517        let mut cursor = from_uses.front_mut();
518        while let Some(user) = cursor.as_pointer() {
519            if should_replace(&user.borrow()) {
520                let owner = user.borrow().owner;
521                self.notify_operation_modification_started(&owner);
522                let operand = cursor.remove().unwrap();
523                to.borrow_mut().insert_use(operand);
524                self.notify_operation_modified(owner);
525            } else {
526                all_replaced = false;
527                cursor.move_next();
528            }
529        }
530        all_replaced
531    }
532
533    /// Find uses of `from` and replace them with `to`, if `should_replace` returns true.
534    ///
535    /// Notifies the listener about every in-place op modification (for every use that was replaced).
536    ///
537    /// Returns true if all uses were replaced, otherwise false.
538    fn maybe_replace_uses_with<P>(
539        &mut self,
540        from: &[ValueRef],
541        to: &[ValueRef],
542        should_replace: P,
543    ) -> bool
544    where
545        P: Fn(&OpOperandImpl) -> bool,
546    {
547        assert_eq!(from.len(), to.len(), "incorrect number of replacements");
548        let mut all_replaced = true;
549        for (from, to) in from.iter().cloned().zip(to.iter().cloned()) {
550            all_replaced &= self.maybe_replace_uses_of_value_with(from, to, &should_replace);
551        }
552        all_replaced
553    }
554
555    /// Find uses of `from` and replace them with `to`, if `should_replace` returns true.
556    ///
557    /// Notifies the listener about every in-place op modification (for every use that was replaced).
558    ///
559    /// Returns true if all uses were replaced, otherwise false.
560    fn maybe_replace_op_uses_with<P>(
561        &mut self,
562        from: OperationRef,
563        to: &[ValueRef],
564        should_replace: P,
565    ) -> bool
566    where
567        P: Fn(&OpOperandImpl) -> bool,
568    {
569        let results = SmallVec::<[ValueRef; 2]>::from_iter(
570            from.borrow().results.all().iter().cloned().map(|result| result as ValueRef),
571        );
572        self.maybe_replace_uses_with(&results, to, should_replace)
573    }
574}
575
576impl<R: ?Sized + Rewriter> RewriterExt for R {}
577
578#[allow(unused_variables)]
579pub trait RewriterListener: Listener {
580    /// Notify the listener that the specified block is about to be erased.
581    ///
582    /// At this point, the block has zero uses.
583    fn notify_block_erased(&self, block: BlockRef) {}
584
585    /// Notify the listener that an in-place modification of the specified operation has started
586    fn notify_operation_modification_started(&self, op: &OperationRef) {}
587
588    /// Notify the listener that an in-place modification of the specified operation was canceled
589    fn notify_operation_modification_canceled(&self, op: &OperationRef) {}
590
591    /// Notify the listener that the specified operation was modified in-place.
592    fn notify_operation_modified(&self, op: OperationRef) {}
593
594    /// Notify the listener that all uses of the specified operation's results are about to be
595    /// replaced with the results of another operation. This is called before the uses of the old
596    /// operation have been changed.
597    ///
598    /// By default, this function calls the "operation replaced with values" notification.
599    fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
600        let replacement = replacement.borrow();
601        let values = replacement
602            .results()
603            .all()
604            .iter()
605            .cloned()
606            .map(|result| Some(result as ValueRef))
607            .collect::<SmallVec<[Option<ValueRef>; 2]>>();
608        self.notify_operation_replaced_with_values(op, &values);
609    }
610
611    /// Notify the listener that all uses of the specified operation's results are about to be
612    /// replaced with the given range of values, potentially produced by other operations. This is
613    /// called before the uses of the operation have been changed.
614    fn notify_operation_replaced_with_values(
615        &self,
616        op: OperationRef,
617        replacement: &[Option<ValueRef>],
618    ) {
619    }
620
621    /// Notify the listener that the specified operation is about to be erased. At this point, the
622    /// operation has zero uses.
623    ///
624    /// NOTE: This notification is not triggered when unlinking an operation.
625    fn notify_operation_erased(&self, op: OperationRef) {}
626
627    /// Notify the listener that the specified pattern is about to be applied at the specified root
628    /// operation.
629    fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {}
630
631    /// Notify the listener that a pattern application finished with the specified status.
632    ///
633    /// `true` indicates that the pattern was applied successfully. `false` indicates that the
634    /// pattern could not be applied. The pattern may have communicated the reason for the failure
635    /// with `notify_match_failure`
636    fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {}
637
638    /// Notify the listener that the pattern failed to match, and provide a diagnostic explaining
639    /// the reason why the failure occurred.
640    fn notify_match_failure(&self, span: SourceSpan, reason: Report) {}
641}
642
643impl<L: RewriterListener> RewriterListener for Option<L> {
644    fn notify_block_erased(&self, block: BlockRef) {
645        if let Some(listener) = self.as_ref() {
646            listener.notify_block_erased(block);
647        }
648    }
649
650    fn notify_operation_modification_started(&self, op: &OperationRef) {
651        if let Some(listener) = self.as_ref() {
652            listener.notify_operation_modification_started(op);
653        }
654    }
655
656    fn notify_operation_modification_canceled(&self, op: &OperationRef) {
657        if let Some(listener) = self.as_ref() {
658            listener.notify_operation_modification_canceled(op);
659        }
660    }
661
662    fn notify_operation_modified(&self, op: OperationRef) {
663        if let Some(listener) = self.as_ref() {
664            listener.notify_operation_modified(op);
665        }
666    }
667
668    fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
669        if let Some(listener) = self.as_ref() {
670            listener.notify_operation_replaced(op, replacement);
671        }
672    }
673
674    fn notify_operation_replaced_with_values(
675        &self,
676        op: OperationRef,
677        replacement: &[Option<ValueRef>],
678    ) {
679        if let Some(listener) = self.as_ref() {
680            listener.notify_operation_replaced_with_values(op, replacement);
681        }
682    }
683
684    fn notify_operation_erased(&self, op: OperationRef) {
685        if let Some(listener) = self.as_ref() {
686            listener.notify_operation_erased(op);
687        }
688    }
689
690    fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
691        if let Some(listener) = self.as_ref() {
692            listener.notify_pattern_begin(pattern, op);
693        }
694    }
695
696    fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
697        if let Some(listener) = self.as_ref() {
698            listener.notify_pattern_end(pattern, success);
699        }
700    }
701
702    fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
703        if let Some(listener) = self.as_ref() {
704            listener.notify_match_failure(span, reason);
705        }
706    }
707}
708
709impl<L: ?Sized + RewriterListener> RewriterListener for Box<L> {
710    fn notify_block_erased(&self, block: BlockRef) {
711        (**self).notify_block_erased(block);
712    }
713
714    fn notify_operation_modification_started(&self, op: &OperationRef) {
715        (**self).notify_operation_modification_started(op);
716    }
717
718    fn notify_operation_modification_canceled(&self, op: &OperationRef) {
719        (**self).notify_operation_modification_canceled(op);
720    }
721
722    fn notify_operation_modified(&self, op: OperationRef) {
723        (**self).notify_operation_modified(op);
724    }
725
726    fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
727        (**self).notify_operation_replaced(op, replacement);
728    }
729
730    fn notify_operation_replaced_with_values(
731        &self,
732        op: OperationRef,
733        replacement: &[Option<ValueRef>],
734    ) {
735        (**self).notify_operation_replaced_with_values(op, replacement);
736    }
737
738    fn notify_operation_erased(&self, op: OperationRef) {
739        (**self).notify_operation_erased(op)
740    }
741
742    fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
743        (**self).notify_pattern_begin(pattern, op);
744    }
745
746    fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
747        (**self).notify_pattern_end(pattern, success);
748    }
749
750    fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
751        (**self).notify_match_failure(span, reason);
752    }
753}
754
755impl<L: ?Sized + RewriterListener> RewriterListener for Rc<L> {
756    fn notify_block_erased(&self, block: BlockRef) {
757        (**self).notify_block_erased(block);
758    }
759
760    fn notify_operation_modification_started(&self, op: &OperationRef) {
761        (**self).notify_operation_modification_started(op);
762    }
763
764    fn notify_operation_modification_canceled(&self, op: &OperationRef) {
765        (**self).notify_operation_modification_canceled(op);
766    }
767
768    fn notify_operation_modified(&self, op: OperationRef) {
769        (**self).notify_operation_modified(op);
770    }
771
772    fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
773        (**self).notify_operation_replaced(op, replacement);
774    }
775
776    fn notify_operation_replaced_with_values(
777        &self,
778        op: OperationRef,
779        replacement: &[Option<ValueRef>],
780    ) {
781        (**self).notify_operation_replaced_with_values(op, replacement);
782    }
783
784    fn notify_operation_erased(&self, op: OperationRef) {
785        (**self).notify_operation_erased(op)
786    }
787
788    fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
789        (**self).notify_pattern_begin(pattern, op);
790    }
791
792    fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
793        (**self).notify_pattern_end(pattern, success);
794    }
795
796    fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
797        (**self).notify_match_failure(span, reason);
798    }
799}
800
801/// A listener of kind `Rewriter` that does nothing
802pub struct NoopRewriterListener;
803impl Listener for NoopRewriterListener {
804    #[inline]
805    fn kind(&self) -> ListenerType {
806        ListenerType::Rewriter
807    }
808
809    #[inline(always)]
810    fn notify_operation_inserted(&self, _op: OperationRef, _prev: ProgramPoint) {}
811
812    #[inline(always)]
813    fn notify_block_inserted(
814        &self,
815        _block: BlockRef,
816        _prev: Option<RegionRef>,
817        _ip: Option<BlockRef>,
818    ) {
819    }
820}
821impl RewriterListener for NoopRewriterListener {
822    fn notify_operation_replaced(&self, _op: OperationRef, _replacement: OperationRef) {}
823}
824
825pub struct ForwardingListener<Base, Derived> {
826    base: Base,
827    derived: Derived,
828}
829impl<Base, Derived> ForwardingListener<Base, Derived> {
830    pub fn new(base: Base, derived: Derived) -> Self {
831        Self { base, derived }
832    }
833}
834impl<Base: Listener, Derived: Listener> Listener for ForwardingListener<Base, Derived> {
835    fn kind(&self) -> ListenerType {
836        self.derived.kind()
837    }
838
839    fn notify_block_inserted(
840        &self,
841        block: BlockRef,
842        prev: Option<RegionRef>,
843        ip: Option<BlockRef>,
844    ) {
845        self.base.notify_block_inserted(block, prev, ip);
846        self.derived.notify_block_inserted(block, prev, ip);
847    }
848
849    fn notify_operation_inserted(&self, op: OperationRef, prev: ProgramPoint) {
850        self.base.notify_operation_inserted(op, prev);
851        self.derived.notify_operation_inserted(op, prev);
852    }
853}
854impl<Base: RewriterListener, Derived: RewriterListener> RewriterListener
855    for ForwardingListener<Base, Derived>
856{
857    fn notify_block_erased(&self, block: BlockRef) {
858        self.base.notify_block_erased(block);
859        self.derived.notify_block_erased(block);
860    }
861
862    fn notify_operation_modification_started(&self, op: &OperationRef) {
863        self.base.notify_operation_modification_started(op);
864        self.derived.notify_operation_modification_started(op);
865    }
866
867    fn notify_operation_modification_canceled(&self, op: &OperationRef) {
868        self.base.notify_operation_modification_canceled(op);
869        self.derived.notify_operation_modification_canceled(op);
870    }
871
872    fn notify_operation_modified(&self, op: OperationRef) {
873        self.base.notify_operation_modified(op);
874        self.derived.notify_operation_modified(op);
875    }
876
877    fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
878        self.base.notify_operation_replaced(op, replacement);
879        self.derived.notify_operation_replaced(op, replacement);
880    }
881
882    fn notify_operation_replaced_with_values(
883        &self,
884        op: OperationRef,
885        replacement: &[Option<ValueRef>],
886    ) {
887        self.base.notify_operation_replaced_with_values(op, replacement);
888        self.derived.notify_operation_replaced_with_values(op, replacement);
889    }
890
891    fn notify_operation_erased(&self, op: OperationRef) {
892        self.base.notify_operation_erased(op);
893        self.derived.notify_operation_erased(op);
894    }
895
896    fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
897        self.base.notify_pattern_begin(pattern, op);
898        self.derived.notify_pattern_begin(pattern, op);
899    }
900
901    fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
902        self.base.notify_pattern_end(pattern, success);
903        self.derived.notify_pattern_end(pattern, success);
904    }
905
906    fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
907        let err = Report::msg(format!("{reason}"));
908        self.base.notify_match_failure(span, reason);
909        self.derived.notify_match_failure(span, err);
910    }
911}
912
913/// Wraps an in-place modification of an [Operation] to ensure the rewriter is properly notified
914/// about the progress and outcome of the in-place notification.
915///
916/// This is a minor efficiency win, as it avoids creating a new operation, and removing the old one,
917/// but also often allows simpler code in the client.
918pub struct InPlaceModificationGuard<'a, R: ?Sized + Rewriter> {
919    rewriter: &'a mut R,
920    op: OperationRef,
921    canceled: bool,
922}
923impl<'a, R> InPlaceModificationGuard<'a, R>
924where
925    R: ?Sized + Rewriter,
926{
927    pub fn new(rewriter: &'a mut R, op: OperationRef) -> Self {
928        rewriter.notify_operation_modification_started(&op);
929        Self {
930            rewriter,
931            op,
932            canceled: false,
933        }
934    }
935
936    #[inline]
937    pub fn rewriter(&mut self) -> &mut R {
938        self.rewriter
939    }
940
941    #[inline]
942    pub fn op(&self) -> &OperationRef {
943        &self.op
944    }
945
946    /// Cancels the pending in-place modification.
947    pub fn cancel(mut self) {
948        self.canceled = true;
949    }
950
951    /// Signals the end of an in-place modification of the current operation.
952    pub fn finalize(self) {}
953}
954impl<R: ?Sized + Rewriter> core::ops::Deref for InPlaceModificationGuard<'_, R> {
955    type Target = R;
956
957    #[inline(always)]
958    fn deref(&self) -> &Self::Target {
959        self.rewriter
960    }
961}
962impl<R: ?Sized + Rewriter> core::ops::DerefMut for InPlaceModificationGuard<'_, R> {
963    #[inline(always)]
964    fn deref_mut(&mut self) -> &mut Self::Target {
965        self.rewriter
966    }
967}
968impl<R: ?Sized + Rewriter> Drop for InPlaceModificationGuard<'_, R> {
969    fn drop(&mut self) {
970        if self.canceled {
971            self.rewriter.notify_operation_modification_canceled(&self.op);
972        } else {
973            self.rewriter.notify_operation_modified(self.op);
974        }
975    }
976}
977
978/// A special type of `RewriterBase` that coordinates the application of a rewrite pattern on the
979/// current IR being matched, providing a way to keep track of any mutations made.
980///
981/// This type should be used to perform all necessary IR mutations within a rewrite pattern, as
982/// the pattern driver may be tracking various state that would be invalidated when a mutation takes
983/// place.
984pub struct PatternRewriter<L = NoopRewriterListener> {
985    rewriter: RewriterImpl<L>,
986    recoverable: bool,
987}
988
989impl PatternRewriter {
990    pub fn new(context: Rc<Context>) -> Self {
991        let rewriter = RewriterImpl::new(context);
992        Self {
993            rewriter,
994            recoverable: false,
995        }
996    }
997
998    pub fn from_builder(builder: OpBuilder) -> Self {
999        let (context, _, ip) = builder.into_parts();
1000        let mut rewriter = RewriterImpl::new(context);
1001        rewriter.restore_insertion_point(ip);
1002        Self {
1003            rewriter,
1004            recoverable: false,
1005        }
1006    }
1007}
1008
1009impl<L: RewriterListener> PatternRewriter<L> {
1010    pub fn new_with_listener(context: Rc<Context>, listener: L) -> Self {
1011        let rewriter = RewriterImpl::<NoopRewriterListener>::new(context).with_listener(listener);
1012        Self {
1013            rewriter,
1014            recoverable: false,
1015        }
1016    }
1017
1018    #[inline]
1019    pub const fn can_recover_from_rewrite_failure(&self) -> bool {
1020        self.recoverable
1021    }
1022}
1023impl<L> Deref for PatternRewriter<L> {
1024    type Target = RewriterImpl<L>;
1025
1026    #[inline(always)]
1027    fn deref(&self) -> &Self::Target {
1028        &self.rewriter
1029    }
1030}
1031impl<L> DerefMut for PatternRewriter<L> {
1032    #[inline(always)]
1033    fn deref_mut(&mut self) -> &mut Self::Target {
1034        &mut self.rewriter
1035    }
1036}
1037
1038pub struct RewriterImpl<L = NoopRewriterListener> {
1039    context: Rc<Context>,
1040    listener: Option<L>,
1041    ip: ProgramPoint,
1042}
1043
1044impl<L> RewriterImpl<L> {
1045    pub fn new(context: Rc<Context>) -> Self {
1046        Self {
1047            context,
1048            listener: None,
1049            ip: ProgramPoint::default(),
1050        }
1051    }
1052
1053    pub fn with_listener<L2>(self, listener: L2) -> RewriterImpl<L2>
1054    where
1055        L2: Listener,
1056    {
1057        RewriterImpl {
1058            context: self.context,
1059            listener: Some(listener),
1060            ip: self.ip,
1061        }
1062    }
1063}
1064
1065impl<L: RewriterListener> From<OpBuilder<L>> for RewriterImpl<L> {
1066    #[inline]
1067    fn from(builder: OpBuilder<L>) -> Self {
1068        let (context, listener, ip) = builder.into_parts();
1069        Self {
1070            context,
1071            listener,
1072            ip,
1073        }
1074    }
1075}
1076
1077impl<L: Listener> Builder for RewriterImpl<L> {
1078    #[inline(always)]
1079    fn context(&self) -> &Context {
1080        &self.context
1081    }
1082
1083    #[inline(always)]
1084    fn context_rc(&self) -> Rc<Context> {
1085        self.context.clone()
1086    }
1087
1088    #[inline(always)]
1089    fn insertion_point(&self) -> &ProgramPoint {
1090        &self.ip
1091    }
1092
1093    #[inline(always)]
1094    fn clear_insertion_point(&mut self) -> ProgramPoint {
1095        let ip = self.ip;
1096        self.ip = ProgramPoint::Invalid;
1097        ip
1098    }
1099
1100    #[inline(always)]
1101    fn restore_insertion_point(&mut self, ip: ProgramPoint) {
1102        self.ip = ip;
1103    }
1104
1105    #[inline(always)]
1106    fn set_insertion_point(&mut self, ip: ProgramPoint) {
1107        self.ip = ip;
1108    }
1109}
1110
1111impl<L: RewriterListener> Rewriter for RewriterImpl<L> {
1112    #[inline(always)]
1113    fn has_listener(&self) -> bool {
1114        self.listener.is_some()
1115    }
1116}
1117
1118impl<L: Listener> Listener for RewriterImpl<L> {
1119    fn kind(&self) -> ListenerType {
1120        ListenerType::Rewriter
1121    }
1122
1123    fn notify_operation_inserted(&self, op: OperationRef, prev: ProgramPoint) {
1124        if let Some(listener) = self.listener.as_ref() {
1125            listener.notify_operation_inserted(op, prev);
1126        }
1127    }
1128
1129    fn notify_block_inserted(
1130        &self,
1131        block: BlockRef,
1132        prev: Option<RegionRef>,
1133        ip: Option<BlockRef>,
1134    ) {
1135        if let Some(listener) = self.listener.as_ref() {
1136            listener.notify_block_inserted(block, prev, ip);
1137        }
1138    }
1139}
1140
1141impl<L: RewriterListener> RewriterListener for RewriterImpl<L> {
1142    fn notify_block_erased(&self, block: BlockRef) {
1143        if let Some(listener) = self.listener.as_ref() {
1144            listener.notify_block_erased(block);
1145        }
1146    }
1147
1148    fn notify_operation_modification_started(&self, op: &OperationRef) {
1149        if let Some(listener) = self.listener.as_ref() {
1150            listener.notify_operation_modification_started(op);
1151        }
1152    }
1153
1154    fn notify_operation_modification_canceled(&self, op: &OperationRef) {
1155        if let Some(listener) = self.listener.as_ref() {
1156            listener.notify_operation_modification_canceled(op);
1157        }
1158    }
1159
1160    fn notify_operation_modified(&self, op: OperationRef) {
1161        if let Some(listener) = self.listener.as_ref() {
1162            listener.notify_operation_modified(op);
1163        }
1164    }
1165
1166    fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
1167        if self.listener.is_some() {
1168            let replacement = replacement.borrow();
1169            let values = replacement
1170                .results()
1171                .all()
1172                .iter()
1173                .cloned()
1174                .map(|result| Some(result.upcast()))
1175                .collect::<SmallVec<[Option<ValueRef>; 2]>>();
1176            self.notify_operation_replaced_with_values(op, &values);
1177        }
1178    }
1179
1180    fn notify_operation_replaced_with_values(
1181        &self,
1182        op: OperationRef,
1183        replacement: &[Option<ValueRef>],
1184    ) {
1185        if let Some(listener) = self.listener.as_ref() {
1186            listener.notify_operation_replaced_with_values(op, replacement);
1187        }
1188    }
1189
1190    fn notify_operation_erased(&self, op: OperationRef) {
1191        if let Some(listener) = self.listener.as_ref() {
1192            listener.notify_operation_erased(op);
1193        }
1194    }
1195
1196    fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
1197        if let Some(listener) = self.listener.as_ref() {
1198            listener.notify_pattern_begin(pattern, op);
1199        }
1200    }
1201
1202    fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
1203        if let Some(listener) = self.listener.as_ref() {
1204            listener.notify_pattern_end(pattern, success);
1205        }
1206    }
1207
1208    fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
1209        if let Some(listener) = self.listener.as_ref() {
1210            listener.notify_match_failure(span, reason);
1211        }
1212    }
1213}