midenc_hir/patterns/
driver.rs

1use alloc::{rc::Rc, vec::Vec};
2use core::cell::RefCell;
3
4use smallvec::SmallVec;
5
6use super::{
7    ForwardingListener, FrozenRewritePatternSet, PatternApplicator, PatternRewriter, Rewriter,
8    RewriterListener,
9};
10use crate::{
11    adt::SmallSet,
12    patterns::{PatternApplicationError, RewritePattern},
13    traits::{ConstantLike, Foldable, IsolatedFromAbove},
14    AttrPrinter, BlockRef, Builder, Context, Forward, InsertionGuard, Listener, OpFoldResult,
15    OperationFolder, OperationRef, ProgramPoint, RawWalk, Region, RegionRef, Report, SourceSpan,
16    Spanned, Value, ValueRef, WalkResult,
17};
18
19/// Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the
20/// highest benefit patterns in a greedy worklist driven manner until a fixpoint is reached.
21///
22/// The greedy rewrite may prematurely stop after a maximum number of iterations, which can be
23/// configured using [GreedyRewriteConfig].
24///
25/// This function also performs folding and simple dead-code elimination before attempting to match
26/// any of the provided patterns.
27///
28/// A region scope can be set using [GreedyRewriteConfig]. By default, the scope is set to the
29/// specified region. Only in-scope ops are added to the worklist and only in-scope ops are allowed
30/// to be modified by the patterns.
31///
32/// Returns `Ok(changed)` if the iterative process converged (i.e., fixpoint was reached) and no
33/// more patterns can be matched within the region. The `changed` flag is set to `true` if the IR
34/// was modified at all.
35///
36/// NOTE: This function does not apply patterns to the region's parent operation.
37pub fn apply_patterns_and_fold_region_greedily(
38    region: RegionRef,
39    patterns: Rc<FrozenRewritePatternSet>,
40    mut config: GreedyRewriteConfig,
41) -> Result<bool, bool> {
42    // The top-level operation must be known to be isolated from above to prevent performing
43    // canonicalizations on operations defined at or above the region containing 'op'.
44    let context = {
45        let parent_op = region.parent().unwrap().borrow();
46        assert!(
47            parent_op.implements::<dyn IsolatedFromAbove>(),
48            "patterns can only be applied to operations which are isolated from above"
49        );
50        parent_op.context_rc()
51    };
52
53    // Set scope if not specified
54    if config.scope.is_none() {
55        config.scope = Some(region);
56    }
57
58    let mut driver = RegionPatternRewriteDriver::new(context, patterns, config, region);
59    let converged = driver.simplify();
60    if converged.is_err() {
61        if let Some(max_iterations) = driver.driver.config.max_iterations {
62            log::trace!(target: "pattern-rewrite-driver", "pattern rewrite did not converge after scanning {max_iterations} times");
63        } else {
64            log::trace!(target: "pattern-rewrite-driver", "pattern rewrite did not converge");
65        }
66    }
67    converged
68}
69
70/// Rewrite ops nested under the given operation, which must be isolated from above, by repeatedly
71/// applying the highest benefit patterns in a greedy worklist driven manner until a fixpoint is
72/// reached.
73///
74/// The greedy rewrite may prematurely stop after a maximum number of iterations, which can be
75/// configured using [GreedyRewriteConfig].
76///
77/// Also performs folding and simple dead-code elimination before attempting to match any of the
78/// provided patterns.
79///
80/// This overload runs a separate greedy rewrite for each region of the specified op. A region
81/// scope can be set in the configuration parameter. By default, the scope is set to the region of
82/// the current greedy rewrite. Only in-scope ops are added to the worklist and only in-scope ops
83/// and the specified op itself are allowed to be modified by the patterns.
84///
85/// NOTE: The specified op may be modified, but it may not be removed by the patterns.
86///
87/// Returns `Ok(changed)` if the iterative process converged (i.e., fixpoint was reached) and no
88/// more patterns can be matched within the region. The `changed` flag is set to `true` if the IR
89/// was modified at all.
90///
91/// NOTE: This function does not apply patterns to the given operation itself.
92pub fn apply_patterns_and_fold_greedily(
93    op: OperationRef,
94    patterns: Rc<FrozenRewritePatternSet>,
95    config: GreedyRewriteConfig,
96) -> Result<bool, bool> {
97    let mut any_region_changed = false;
98    let mut failed = false;
99    let op = op.borrow();
100    let mut cursor = op.regions().front();
101    while let Some(region) = cursor.as_pointer() {
102        cursor.move_next();
103        match apply_patterns_and_fold_region_greedily(region, patterns.clone(), config.clone()) {
104            Ok(region_changed) => {
105                any_region_changed |= region_changed;
106            }
107            Err(region_changed) => {
108                any_region_changed |= region_changed;
109                failed = true;
110            }
111        }
112    }
113
114    if failed {
115        Err(any_region_changed)
116    } else {
117        Ok(any_region_changed)
118    }
119}
120
121#[derive(Debug, Copy, Clone, PartialEq, Eq)]
122#[repr(u8)]
123pub enum ApplyPatternsAndFoldEffect {
124    /// No effect, the IR remains unchanged
125    None,
126    /// The IR was modified
127    Changed,
128    /// The input IR was erased
129    Erased,
130}
131
132pub type ApplyPatternsAndFoldResult =
133    Result<ApplyPatternsAndFoldEffect, ApplyPatternsAndFoldEffect>;
134
135/// Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy
136/// worklist driven manner until a fixpoint is reached.
137///
138/// The greedy rewrite may prematurely stop after a maximum number of iterations, which can be
139/// configured using [GreedyRewriteConfig].
140///
141/// This function also performs folding and simple dead-code elimination before attempting to match
142/// any of the provided patterns.
143///
144/// Newly created ops and other pre-existing ops that use results of rewritten ops or supply
145/// operands to such ops are also processed, unless such ops are excluded via `config.restrict`.
146/// Any other ops remain unmodified (i.e., regardless of restrictions).
147///
148/// In addition to op restrictions, a region scope can be specified. Only ops within the scope are
149/// simplified. This is similar to [apply_patterns_and_fold_greedily], where only ops within the
150/// given region/op are simplified by default. If no scope is specified, it is assumed to be the
151/// first common enclosing region of the given ops.
152///
153/// Note that ops in `ops` could be erased as result of folding, becoming dead, or via pattern
154/// rewrites. If more far reaching simplification is desired, [apply_patterns_and_fold_greedily]
155/// should be used.
156///
157/// Returns `Ok(effect)` if the iterative process converged (i.e., fixpoint was reached) and no more
158/// patterns can be matched. `effect` is set to `Changed` if the IR was modified, but at least one
159/// operation was not erased. It is set to `Erased` if all of the input ops were erased.
160pub fn apply_patterns_and_fold(
161    ops: &[OperationRef],
162    patterns: Rc<FrozenRewritePatternSet>,
163    mut config: GreedyRewriteConfig,
164) -> ApplyPatternsAndFoldResult {
165    if ops.is_empty() {
166        return Ok(ApplyPatternsAndFoldEffect::None);
167    }
168
169    // Determine scope of rewrite
170    if let Some(scope) = config.scope.as_ref() {
171        // If a scope was provided, make sure that all ops are in scope.
172        let all_ops_in_scope = ops.iter().all(|op| scope.borrow().find_ancestor_op(*op).is_some());
173        assert!(all_ops_in_scope, "ops must be within the specified scope");
174    } else {
175        // Compute scope if none was provided. The scope will remain `None` if there is a top-level
176        // op among `ops`.
177        config.scope = Region::find_common_ancestor(ops);
178    }
179
180    // Start the pattern driver
181    let max_rewrites = config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX);
182    let context = ops[0].borrow().context_rc();
183    let mut driver = MultiOpPatternRewriteDriver::new(context, patterns, config, ops);
184    let converged = driver.simplify(ops);
185    let changed = match converged.as_ref() {
186        Ok(changed) | Err(changed) => *changed,
187    };
188    let erased = driver.inner.surviving_ops.borrow().is_empty();
189    let effect = if erased {
190        ApplyPatternsAndFoldEffect::Erased
191    } else if changed {
192        ApplyPatternsAndFoldEffect::Changed
193    } else {
194        ApplyPatternsAndFoldEffect::None
195    };
196    if converged.is_ok() {
197        Ok(effect)
198    } else {
199        log::trace!(target: "pattern-rewrite-driver", "pattern rewrite did not converge after {max_rewrites} rewrites");
200        Err(effect)
201    }
202}
203
204/// This enum indicates which ops are put on the worklist during a greedy pattern rewrite
205#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
206pub enum GreedyRewriteStrictness {
207    /// No restrictions on which ops are processed.
208    #[default]
209    Any,
210    /// Only pre-existing and newly created ops are processed.
211    ///
212    /// Pre-existing ops are those that were on the worklist at the very beginning.
213    ExistingAndNew,
214    /// Only pre-existing ops are processed.
215    ///
216    /// Pre-existing ops are those that were on the worklist at the very beginning.
217    Existing,
218}
219
220/// This enum indicates the level of simplification to be applied to regions during a greedy
221/// pattern rewrite.
222#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
223pub enum RegionSimplificationLevel {
224    /// Disable simplification.
225    None,
226    /// Perform basic simplifications (e.g. dead argument elimination)
227    #[default]
228    Normal,
229    /// Perform additional complex/expensive simplifications (e.g. block merging)
230    Aggressive,
231}
232
233/// Configuration for [GreedyPatternRewriteDriver]
234#[derive(Clone)]
235pub struct GreedyRewriteConfig {
236    listener: Option<Rc<dyn RewriterListener>>,
237    /// If set, only ops within the given region are added to the worklist.
238    ///
239    /// If no scope is specified, and no specific region is given when starting the greedy rewrite,
240    /// then the closest enclosing region of the initial list of operations is used.
241    scope: Option<RegionRef>,
242    /// If set, specifies the maximum number of times the rewriter will iterate between applying
243    /// patterns and simplifying regions.
244    ///
245    /// NOTE: Only applicable when simplifying entire regions.
246    max_iterations: Option<core::num::NonZeroU32>,
247    /// If set, specifies the maximum number of rewrites within an iteration.
248    max_rewrites: Option<core::num::NonZeroU32>,
249    /// Perform control flow optimizations to the region tree after applying all patterns.
250    ///
251    /// NOTE: Only applicable when simplifying entire regions.
252    region_simplification: RegionSimplificationLevel,
253    /// The restrictions to apply, if any, to operations added to the worklist during the rewrite.
254    restrict: GreedyRewriteStrictness,
255    /// This flag specifies the order of initial traversal that populates the rewriter worklist.
256    ///
257    /// When true, operations are visited top-down, which is generally more efficient in terms of
258    /// compilation time.
259    ///
260    /// When false, the initial traversal of the region tree is bottom up on each block, which may
261    /// match larger patterns when given an ambiguous pattern set.
262    ///
263    /// NOTE: Only applicable when simplifying entire regions.
264    use_top_down_traversal: bool,
265}
266impl Default for GreedyRewriteConfig {
267    fn default() -> Self {
268        Self {
269            listener: None,
270            scope: None,
271            max_iterations: core::num::NonZeroU32::new(10),
272            max_rewrites: None,
273            region_simplification: Default::default(),
274            restrict: Default::default(),
275            use_top_down_traversal: false,
276        }
277    }
278}
279impl GreedyRewriteConfig {
280    pub fn new_with_listener(listener: impl RewriterListener + 'static) -> Self {
281        Self {
282            listener: Some(Rc::new(listener)),
283            ..Default::default()
284        }
285    }
286
287    /// Scope rewrites to operations within `region`
288    pub fn with_scope(&mut self, region: RegionRef) -> &mut Self {
289        self.scope = Some(region);
290        self
291    }
292
293    /// Set the maximum number of times the rewriter will iterate between applying patterns and
294    /// simplifying regions.
295    ///
296    /// If `0` is given, the number of iterations is unlimited.
297    ///
298    /// NOTE: Only applicable when simplifying entire regions.
299    pub fn with_max_iterations(&mut self, max: u32) -> &mut Self {
300        self.max_iterations = core::num::NonZeroU32::new(max);
301        self
302    }
303
304    /// Set the maximum number of rewrites per iteration.
305    ///
306    /// If `0` is given, the number of rewrites is unlimited.
307    ///
308    /// NOTE: Only applicable when simplifying entire regions.
309    pub fn with_max_rewrites(&mut self, max: u32) -> &mut Self {
310        self.max_rewrites = core::num::NonZeroU32::new(max);
311        self
312    }
313
314    /// Set the level of control flow optimizations to apply to the region tree.
315    ///
316    /// NOTE: Only applicable when simplifying entire regions.
317    pub fn with_region_simplification_level(
318        &mut self,
319        level: RegionSimplificationLevel,
320    ) -> &mut Self {
321        self.region_simplification = level;
322        self
323    }
324
325    /// Set the level of restriction to apply to operations added to the worklist during the rewrite.
326    pub fn with_restrictions(&mut self, level: GreedyRewriteStrictness) -> &mut Self {
327        self.restrict = level;
328        self
329    }
330
331    /// Specify whether or not to use a top-down traversal when initially adding operations to the
332    /// worklist.
333    pub fn with_top_down_traversal(&mut self, yes: bool) -> &mut Self {
334        self.use_top_down_traversal = yes;
335        self
336    }
337
338    #[inline]
339    pub fn scope(&self) -> Option<RegionRef> {
340        self.scope
341    }
342
343    #[inline]
344    pub fn max_iterations(&self) -> Option<core::num::NonZeroU32> {
345        self.max_iterations
346    }
347
348    #[inline]
349    pub fn max_rewrites(&self) -> Option<core::num::NonZeroU32> {
350        self.max_rewrites
351    }
352
353    #[inline]
354    pub fn region_simplification_level(&self) -> RegionSimplificationLevel {
355        self.region_simplification
356    }
357
358    #[inline]
359    pub fn strictness(&self) -> GreedyRewriteStrictness {
360        self.restrict
361    }
362
363    #[inline]
364    pub fn use_top_down_traversal(&self) -> bool {
365        self.use_top_down_traversal
366    }
367}
368
369pub struct GreedyPatternRewriteDriver {
370    context: Rc<Context>,
371    worklist: RefCell<Worklist>,
372    config: GreedyRewriteConfig,
373    /// Not maintained when `config.restrict` is `GreedyRewriteStrictness::Any`
374    filtered_ops: RefCell<SmallSet<OperationRef, 8>>,
375    matcher: RefCell<PatternApplicator>,
376}
377
378impl GreedyPatternRewriteDriver {
379    pub fn new(
380        context: Rc<Context>,
381        patterns: Rc<FrozenRewritePatternSet>,
382        config: GreedyRewriteConfig,
383    ) -> Self {
384        // Apply a simple cost model based solely on pattern benefit
385        let mut matcher = PatternApplicator::new(patterns);
386        matcher.apply_default_cost_model();
387
388        Self {
389            context,
390            worklist: Default::default(),
391            config,
392            filtered_ops: Default::default(),
393            matcher: RefCell::new(matcher),
394        }
395    }
396}
397
398/// Worklist Managment
399impl GreedyPatternRewriteDriver {
400    /// Add the given operation to the worklist
401    pub fn add_single_op_to_worklist(&self, op: OperationRef) {
402        if matches!(self.config.restrict, GreedyRewriteStrictness::Any)
403            || self.filtered_ops.borrow().contains(&op)
404        {
405            log::trace!(target: "pattern-rewrite-driver", "adding single op '{}' to worklist", op.name());
406            self.worklist.borrow_mut().push(op);
407        } else {
408            log::trace!(
409                target: "pattern-rewrite-driver", "skipped adding single op '{}' to worklist due to strictness level",
410                op.name()
411            );
412        }
413    }
414
415    /// Add the given operation, and its ancestors, to the worklist
416    pub fn add_to_worklist(&self, op: OperationRef) {
417        // Gather potential ancestors while looking for a `scope` parent region
418        let mut ancestors = SmallVec::<[OperationRef; 8]>::default();
419        let mut op = Some(op);
420        while let Some(ancestor_op) = op.take() {
421            let region = ancestor_op.grandparent();
422            if self.config.scope.as_ref() == region.as_ref() {
423                ancestors.push(ancestor_op);
424                for op in ancestors {
425                    self.add_single_op_to_worklist(op);
426                }
427                return;
428            } else {
429                log::trace!(target: "pattern-rewrite-driver", "gathering ancestors of '{}' for worklist", ancestor_op.name());
430                ancestors.push(ancestor_op);
431            }
432            if let Some(region) = region {
433                op = region.parent();
434            } else {
435                log::trace!(target: "pattern-rewrite-driver", "reached top level op while searching for ancestors");
436            }
437        }
438    }
439
440    /// Process operations until the worklist is empty, or `config.max_rewrites` is reached.
441    ///
442    /// Returns true if the IR was changed.
443    pub fn process_worklist(self: Rc<Self>) -> bool {
444        log::debug!(target: "pattern-rewrite-driver", "starting processing of greedy pattern rewrite driver worklist");
445        let mut rewriter =
446            PatternRewriter::new_with_listener(self.context.clone(), Rc::clone(&self));
447
448        let mut changed = false;
449        let mut num_rewrites = 0u32;
450        while self.config.max_rewrites.is_none_or(|max| num_rewrites < max.get()) {
451            let Some(op) = self.worklist.borrow_mut().pop() else {
452                // Worklist is empty, we've converged
453                log::debug!(target: "pattern-rewrite-driver", "processing worklist complete, rewrites have converged");
454                return changed;
455            };
456
457            if self.process_worklist_item(&mut rewriter, op) {
458                changed = true;
459                num_rewrites += 1;
460            }
461        }
462
463        log::debug!(
464            target: "pattern-rewrite-driver", "processing worklist was canceled after {} rewrites without converging (reached max \
465             rewrite limit)",
466            self.config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX)
467        );
468
469        changed
470    }
471
472    /// Process a single operation from the worklist.
473    ///
474    /// Returns true if the IR was changed.
475    fn process_worklist_item(
476        &self,
477        rewriter: &mut PatternRewriter<Rc<Self>>,
478        mut op_ref: OperationRef,
479    ) -> bool {
480        let op = op_ref.borrow();
481
482        log::trace!(target: "pattern-rewrite-driver", "processing operation '{op}'");
483
484        // If the operation is trivially dead - remove it.
485        if op.is_trivially_dead() {
486            drop(op);
487            rewriter.erase_op(op_ref);
488            log::trace!(target: "pattern-rewrite-driver", "processing complete: operation is trivially dead");
489            return true;
490        }
491
492        // Try to fold this op, unless it is a constant op, as that would lead to an infinite
493        // folding loop, since the folded result would be immediately materialized as a constant
494        // op, and then revisited.
495        if !op.implements::<dyn ConstantLike>() {
496            // Re-borrow mutably since we're going to try and rewrite `op` now
497            drop(op);
498            let op = op_ref.borrow_mut();
499
500            let mut results = SmallVec::<[OpFoldResult; 1]>::default();
501            log::trace!(target: "pattern-rewrite-driver", "attempting to fold operation..");
502            if op.fold(&mut results).is_ok() {
503                if results.is_empty() {
504                    // Op was modified in-place
505                    self.notify_operation_modified(op_ref);
506                    log::trace!(
507                        target: "pattern-rewrite-driver",
508                        "operation was succesfully folded/modified in-place"
509                    );
510                    return true;
511                } else {
512                    log::trace!(
513                        target: "pattern-rewrite-driver",
514                        "operation was succesfully folded away, to be replaced with: {}",
515                        results.as_slice().print(&crate::OpPrintingFlags::default(), op.context())
516                    );
517                }
518
519                // Op results can be replaced with `results`
520                assert_eq!(
521                    results.len(),
522                    op.num_results(),
523                    "folder produced incorrect number of results"
524                );
525                let mut rewriter = InsertionGuard::new(&mut **rewriter);
526                rewriter.set_insertion_point(ProgramPoint::before(op_ref));
527
528                log::trace!(target: "pattern-rewrite-driver", "replacing op with fold results..");
529                let mut replacements = SmallVec::<[Option<ValueRef>; 2]>::default();
530                let mut materialization_succeeded = true;
531                for (fold_result, result_ty) in results
532                    .into_iter()
533                    .zip(op.results().all().iter().map(|r| r.borrow().ty().clone()))
534                {
535                    match fold_result {
536                        OpFoldResult::Value(value) => {
537                            assert_eq!(
538                                value.borrow().ty(),
539                                &result_ty,
540                                "folder produced value of incorrect type"
541                            );
542                            replacements.push(Some(value));
543                        }
544                        OpFoldResult::Attribute(attr) => {
545                            // Materialize attributes as SSA values using a constant op
546                            let span = op.span();
547                            log::trace!(
548                                target: "pattern-rewrite-driver",
549                                "materializing constant for value '{}' and type '{result_ty}'",
550                                attr.print(&crate::OpPrintingFlags::default(), op.context())
551                            );
552                            let constant_op = op.dialect().materialize_constant(
553                                &mut *rewriter,
554                                attr,
555                                &result_ty,
556                                span,
557                            );
558                            match constant_op {
559                                None => {
560                                    log::trace!(
561                                        target: "pattern-rewrite-driver",
562                                        "materialization failed: cleaning up any materialized ops \
563                                         for {} previous results",
564                                        replacements.len()
565                                    );
566                                    // If materialization fails, clean up any operations generated for the previous results
567                                    let mut replacement_ops =
568                                        SmallVec::<[OperationRef; 2]>::default();
569                                    for replacement in replacements.iter().filter_map(|repl| *repl)
570                                    {
571                                        let replacement = replacement.borrow();
572                                        assert!(
573                                            !replacement.is_used(),
574                                            "folder reused existing op for one result, but \
575                                             constant materialization failed for another result"
576                                        );
577                                        let replacement_op = replacement.get_defining_op().unwrap();
578                                        if replacement_ops.contains(&replacement_op) {
579                                            continue;
580                                        }
581                                        replacement_ops.push(replacement_op);
582                                    }
583                                    for replacement_op in replacement_ops {
584                                        rewriter.erase_op(replacement_op);
585                                    }
586                                    materialization_succeeded = false;
587                                    break;
588                                }
589                                Some(constant_op) => {
590                                    let const_op = constant_op.borrow();
591                                    assert!(
592                                        const_op.implements::<dyn ConstantLike>(),
593                                        "materialize_constant produced op that does not implement \
594                                         ConstantLike"
595                                    );
596                                    let result: ValueRef = const_op.results().all()[0].upcast();
597                                    assert_eq!(
598                                        result.borrow().ty(),
599                                        &result_ty,
600                                        "materialize_constant produced incorrect result type"
601                                    );
602                                    log::trace!(
603                                        target: "pattern-rewrite-driver",
604                                        "successfully materialized constant as {}",
605                                        result.borrow().id()
606                                    );
607                                    replacements.push(Some(result));
608                                }
609                            }
610                        }
611                    }
612                }
613
614                if materialization_succeeded {
615                    log::trace!(
616                        target: "pattern-rewrite-driver",
617                        "materialization of fold results was successful, performing replacement.."
618                    );
619                    drop(op);
620                    rewriter.replace_op_with_values(op_ref, &replacements);
621                    log::trace!(
622                        target: "pattern-rewrite-driver",
623                        "fold succeeded: operation was replaced with materialized constants"
624                    );
625                    return true;
626                } else {
627                    log::trace!(
628                        target: "pattern-rewrite-driver",
629                        "materialization of fold results failed, proceeding without folding"
630                    );
631                }
632            }
633        } else {
634            log::trace!(target: "pattern-rewrite-driver", "operation could not be folded");
635        }
636
637        // Try to match one of the patterns.
638        //
639        // The rewriter is automatically notified of any necessary changes, so there is nothing
640        // else to do here.
641        // TODO(pauls): if self.config.listener.is_some() {
642        //
643        // We need to trigger `notify_pattern_begin` in `can_apply`, and `notify_pattern_end`
644        // in `on_failure` and `on_success`, but we can't have multiple mutable aliases of
645        // the listener captured by these closures.
646        //
647        // This is another aspect of the listener infra that needs to be handled
648        log::trace!(target: "pattern-rewrite-driver", "attempting to match and rewrite one of the input patterns..");
649        let result = if let Some(listener) = self.config.listener.as_deref() {
650            let op_name = op_ref.borrow().name();
651            let can_apply = |pattern: &dyn RewritePattern| {
652                log::trace!(target: "pattern-rewrite-driver", "applying pattern {} to op {}", pattern.name(), &op_name);
653                listener.notify_pattern_begin(pattern, op_ref);
654                true
655            };
656            let on_failure = |pattern: &dyn RewritePattern| {
657                log::trace!(target: "pattern-rewrite-driver", "pattern failed to match");
658                listener.notify_pattern_end(pattern, false);
659            };
660            let on_success = |pattern: &dyn RewritePattern| {
661                log::trace!(target: "pattern-rewrite-driver", "pattern applied successfully");
662                listener.notify_pattern_end(pattern, true);
663                Ok(())
664            };
665            self.matcher.borrow_mut().match_and_rewrite(
666                op_ref,
667                &mut **rewriter,
668                can_apply,
669                on_failure,
670                on_success,
671            )
672        } else {
673            self.matcher.borrow_mut().match_and_rewrite(
674                op_ref,
675                &mut **rewriter,
676                |_| true,
677                |_| {},
678                |_| Ok(()),
679            )
680        };
681
682        match result {
683            Ok(_) => {
684                log::trace!(target: "pattern-rewrite-driver", "processing complete: pattern matched and operation was rewritten");
685                true
686            }
687            Err(PatternApplicationError::NoMatchesFound) => {
688                log::debug!(target: "pattern-rewrite-driver", "processing complete: exhausted all patterns without finding a match");
689                false
690            }
691            Err(PatternApplicationError::Report(report)) => {
692                log::debug!(
693                    target: "pattern-rewrite-driver", "processing complete: error occurred during match and rewrite: {report}"
694                );
695                false
696            }
697        }
698    }
699
700    /// Look over the operands of the provided op for any defining operations that should be re-
701    /// added to the worklist. This function should be called when an operation is modified or
702    /// removed, as it may trigger further simplifications.
703    fn add_operands_to_worklist(&self, op: OperationRef) {
704        let current_op = op.borrow();
705        for operand in current_op.operands().all() {
706            // If this operand currently has at most 2 users, add its defining op to the worklist.
707            // After the op is deleted, then the operand will have at most 1 user left. If it has
708            // 0 users left, it can be deleted as well, and if it has 1 user left, there may be
709            // further canonicalization opportunities.
710            let operand = operand.borrow();
711            let Some(def_op) = operand.value().get_defining_op() else {
712                continue;
713            };
714
715            let mut other_user = None;
716            let mut has_more_than_two_uses = false;
717            for user in operand.value().iter_uses() {
718                if user.owner == op || other_user.as_ref().is_some_and(|ou| ou == &user.owner) {
719                    continue;
720                }
721                if other_user.is_none() {
722                    other_user = Some(user.owner);
723                    continue;
724                }
725                has_more_than_two_uses = true;
726                break;
727            }
728            if !has_more_than_two_uses {
729                self.add_to_worklist(def_op);
730            }
731        }
732    }
733}
734
735/// Notifications
736impl Listener for GreedyPatternRewriteDriver {
737    fn kind(&self) -> crate::ListenerType {
738        crate::ListenerType::Rewriter
739    }
740
741    /// Notify the driver that the given block was inserted
742    fn notify_block_inserted(
743        &self,
744        block: crate::BlockRef,
745        prev: Option<RegionRef>,
746        ip: Option<crate::BlockRef>,
747    ) {
748        if let Some(listener) = self.config.listener.as_deref() {
749            listener.notify_block_inserted(block, prev, ip);
750        }
751    }
752
753    /// Notify the driver that the specified operation was inserted.
754    ///
755    /// Update the worklist as needed: the operation is enqueued depending on scope and strictness
756    fn notify_operation_inserted(&self, op: OperationRef, prev: ProgramPoint) {
757        if let Some(listener) = self.config.listener.as_deref() {
758            listener.notify_operation_inserted(op, prev);
759        }
760        if matches!(self.config.restrict, GreedyRewriteStrictness::ExistingAndNew) {
761            self.filtered_ops.borrow_mut().insert(op);
762        }
763        self.add_to_worklist(op);
764    }
765}
766impl RewriterListener for GreedyPatternRewriteDriver {
767    /// Notify the driver that the given block is about to be removed.
768    fn notify_block_erased(&self, block: BlockRef) {
769        if let Some(listener) = self.config.listener.as_deref() {
770            listener.notify_block_erased(block);
771        }
772    }
773
774    /// Notify the driver that the sepcified operation may have been modified in-place. The
775    /// operation is added to the worklist.
776    fn notify_operation_modified(&self, op: OperationRef) {
777        if let Some(listener) = self.config.listener.as_deref() {
778            listener.notify_operation_modified(op);
779        }
780        self.add_to_worklist(op);
781    }
782
783    /// Notify the driver that the specified operation was removed.
784    ///
785    /// Update the worklist as needed: the operation and its children are removed from the worklist
786    fn notify_operation_erased(&self, op: OperationRef) {
787        // Only ops that are within the configured scope are added to the worklist of the greedy
788        // pattern rewriter.
789        //
790        // A greedy pattern rewrite is not allowed to erase the parent op of the scope region, as
791        // that would break the worklist handling and some sanity checks.
792        if let Some(scope) = self.config.scope.as_ref() {
793            assert!(
794                scope.parent().is_some_and(|parent_op| parent_op != op),
795                "scope region must not be erased during greedy pattern rewrite"
796            );
797        }
798
799        if let Some(listener) = self.config.listener.as_deref() {
800            listener.notify_operation_erased(op);
801        }
802
803        self.add_operands_to_worklist(op);
804        self.worklist.borrow_mut().remove(&op);
805
806        if self.config.restrict != GreedyRewriteStrictness::Any {
807            self.filtered_ops.borrow_mut().remove(&op);
808        }
809    }
810
811    /// Notify the driver that the specified operation was replaced.
812    ///
813    /// Update the worklist as needed: new users are enqueued
814    fn notify_operation_replaced_with_values(
815        &self,
816        op: OperationRef,
817        replacement: &[Option<ValueRef>],
818    ) {
819        if let Some(listener) = self.config.listener.as_deref() {
820            listener.notify_operation_replaced_with_values(op, replacement);
821        }
822    }
823
824    fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
825        if let Some(listener) = self.config.listener.as_deref() {
826            listener.notify_match_failure(span, reason);
827        }
828    }
829}
830
831pub struct RegionPatternRewriteDriver {
832    driver: Rc<GreedyPatternRewriteDriver>,
833    region: RegionRef,
834}
835impl RegionPatternRewriteDriver {
836    pub fn new(
837        context: Rc<Context>,
838        patterns: Rc<FrozenRewritePatternSet>,
839        config: GreedyRewriteConfig,
840        region: RegionRef,
841    ) -> Self {
842        let mut driver = GreedyPatternRewriteDriver::new(context, patterns, config);
843        // Populate strict mode ops, if applicable
844        if driver.config.restrict != GreedyRewriteStrictness::Any {
845            let filtered_ops = driver.filtered_ops.get_mut();
846            region.raw_postwalk_all::<Forward, _>(|op| {
847                filtered_ops.insert(op);
848            });
849        }
850        Self {
851            driver: Rc::new(driver),
852            region,
853        }
854    }
855
856    /// Simplify ops inside `self.region`, and simplify the region itself.
857    ///
858    /// Returns `Ok(changed)` if the transformation converged, with `changed` indicating whether or
859    /// not the IR was changed. Otherwise, `Err(changed)` is returned.
860    pub fn simplify(&mut self) -> Result<bool, bool> {
861        use crate::matchers::Matcher;
862
863        let mut continue_rewrites = false;
864        let mut iteration = 0;
865
866        while self.driver.config.max_iterations.is_none_or(|max| iteration < max.get()) {
867            log::trace!(target: "pattern-rewrite-driver", "starting iteration {iteration} of region pattern rewrite driver");
868            iteration += 1;
869
870            // New iteration: start with an empty worklist
871            self.driver.worklist.borrow_mut().clear();
872
873            // `OperationFolder` CSE's constant ops (and may move them into parents regions to
874            // enable more aggressive CSE'ing).
875            let context = self.driver.context.clone();
876            let mut folder = OperationFolder::new(context, Rc::clone(&self.driver));
877            let mut insert_known_constant = |op: OperationRef| {
878                // Check for existing constants when populating the worklist. This avoids
879                // accidentally reversing the constant order during processing.
880                let operation = op.borrow();
881                if let Some(const_value) = crate::matchers::constant().matches(&operation) {
882                    drop(operation);
883                    if !folder.insert_known_constant(op, Some(const_value)) {
884                        return true;
885                    }
886                }
887                false
888            };
889
890            if !self.driver.config.use_top_down_traversal {
891                // Add operations to the worklist in postorder.
892                log::trace!(target: "pattern-rewrite-driver", "adding operations in postorder");
893                self.region.raw_postwalk_all::<Forward, _>(|op| {
894                    if !insert_known_constant(op) {
895                        self.driver.add_to_worklist(op);
896                    }
897                });
898            } else {
899                // Add all nested operations to the worklist in preorder.
900                log::trace!(target: "pattern-rewrite-driver", "adding operations in preorder");
901                self.region
902                    .raw_prewalk::<Forward, _, _>(|op| {
903                        if !insert_known_constant(op) {
904                            self.driver.add_to_worklist(op);
905                            WalkResult::<Report>::Continue(())
906                        } else {
907                            WalkResult::Skip
908                        }
909                    })
910                    .into_result()
911                    .expect("unexpected error occurred while walking region");
912
913                // Reverse the list so our loop processes them in-order
914                self.driver.worklist.borrow_mut().reverse();
915            }
916
917            continue_rewrites = self.driver.clone().process_worklist();
918            log::trace!(
919                target: "pattern-rewrite-driver", "processing of worklist for this iteration has completed, \
920                 changed={continue_rewrites}"
921            );
922
923            // After applying patterns, make sure that the CFG of each of the regions is kept up to
924            // date.
925            if self.driver.config.region_simplification != RegionSimplificationLevel::None {
926                let mut rewriter = PatternRewriter::new_with_listener(
927                    self.driver.context.clone(),
928                    Rc::clone(&self.driver),
929                );
930                continue_rewrites |= Region::simplify_all(
931                    &[self.region],
932                    &mut *rewriter,
933                    self.driver.config.region_simplification,
934                )
935                .is_ok();
936            } else {
937                log::debug!(target: "pattern-rewrite-driver", "region simplification was disabled, skipping simplification rewrites");
938            }
939
940            if !continue_rewrites {
941                log::trace!(target: "pattern-rewrite-driver", "region pattern rewrites have converged");
942                break;
943            }
944        }
945
946        // If `continue_rewrites` is false, then the rewrite converged, i.e. the IR wasn't changed
947        // in the last iteration.
948        if !continue_rewrites {
949            Ok(iteration > 1)
950        } else {
951            Err(iteration > 1)
952        }
953    }
954}
955
956pub struct MultiOpPatternRewriteDriver {
957    driver: Rc<GreedyPatternRewriteDriver>,
958    inner: Rc<MultiOpPatternRewriteDriverImpl>,
959}
960
961struct MultiOpPatternRewriteDriverImpl {
962    surviving_ops: RefCell<SmallSet<OperationRef, 8>>,
963}
964
965impl MultiOpPatternRewriteDriver {
966    pub fn new(
967        context: Rc<Context>,
968        patterns: Rc<FrozenRewritePatternSet>,
969        mut config: GreedyRewriteConfig,
970        ops: &[OperationRef],
971    ) -> Self {
972        let surviving_ops = SmallSet::from_iter(ops.iter().copied());
973        let inner = Rc::new(MultiOpPatternRewriteDriverImpl {
974            surviving_ops: RefCell::new(surviving_ops),
975        });
976        let listener = Rc::new(ForwardingListener::new(config.listener.take(), Rc::clone(&inner)));
977        config.listener = Some(listener);
978
979        let mut driver = GreedyPatternRewriteDriver::new(context.clone(), patterns, config);
980        if driver.config.restrict != GreedyRewriteStrictness::Any {
981            driver.filtered_ops.get_mut().extend(ops.iter().cloned());
982        }
983
984        Self {
985            driver: Rc::new(driver),
986            inner,
987        }
988    }
989
990    pub fn simplify(&mut self, ops: &[OperationRef]) -> Result<bool, bool> {
991        // Populate the initial worklist
992        for op in ops.iter().copied() {
993            self.driver.add_single_op_to_worklist(op);
994        }
995
996        // Process ops on the worklist
997        let changed = self.driver.clone().process_worklist();
998        if self.driver.worklist.borrow().is_empty() {
999            Ok(changed)
1000        } else {
1001            Err(changed)
1002        }
1003    }
1004}
1005
1006impl Listener for MultiOpPatternRewriteDriverImpl {
1007    fn kind(&self) -> crate::ListenerType {
1008        crate::ListenerType::Rewriter
1009    }
1010}
1011impl RewriterListener for MultiOpPatternRewriteDriverImpl {
1012    fn notify_operation_erased(&self, op: OperationRef) {
1013        self.surviving_ops.borrow_mut().remove(&op);
1014    }
1015}
1016
1017#[derive(Default)]
1018struct Worklist(Vec<OperationRef>);
1019impl Worklist {
1020    /// Clear all operations from the worklist
1021    #[inline]
1022    pub fn clear(&mut self) {
1023        self.0.clear()
1024    }
1025
1026    /// Returns true if the worklist is empty
1027    #[inline(always)]
1028    pub fn is_empty(&self) -> bool {
1029        self.0.is_empty()
1030    }
1031
1032    /// Push an operation to the end of the worklist, unless it is already in the worklist.
1033    pub fn push(&mut self, op: OperationRef) {
1034        if self.0.contains(&op) {
1035            return;
1036        }
1037        self.0.push(op);
1038    }
1039
1040    /// Pop the next operation from the worklist
1041    #[inline]
1042    pub fn pop(&mut self) -> Option<OperationRef> {
1043        self.0.pop()
1044    }
1045
1046    /// Remove `op` from the worklist
1047    pub fn remove(&mut self, op: &OperationRef) {
1048        if let Some(index) = self.0.iter().position(|o| o == op) {
1049            self.0.remove(index);
1050        }
1051    }
1052
1053    /// Reverse the worklist
1054    pub fn reverse(&mut self) {
1055        self.0.reverse();
1056    }
1057}