1use alloc::{rc::Rc, vec::Vec};
2use core::cell::RefCell;
3
4use smallvec::SmallVec;
5
6use super::{
7 ForwardingListener, FrozenRewritePatternSet, PatternApplicator, PatternRewriter, Rewriter,
8 RewriterListener,
9};
10use crate::{
11 adt::SmallSet,
12 patterns::{PatternApplicationError, RewritePattern},
13 traits::{ConstantLike, Foldable, IsolatedFromAbove},
14 AttrPrinter, BlockRef, Builder, Context, Forward, InsertionGuard, Listener, OpFoldResult,
15 OperationFolder, OperationRef, ProgramPoint, RawWalk, Region, RegionRef, Report, SourceSpan,
16 Spanned, Value, ValueRef, WalkResult,
17};
18
19pub fn apply_patterns_and_fold_region_greedily(
38 region: RegionRef,
39 patterns: Rc<FrozenRewritePatternSet>,
40 mut config: GreedyRewriteConfig,
41) -> Result<bool, bool> {
42 let context = {
45 let parent_op = region.parent().unwrap().borrow();
46 assert!(
47 parent_op.implements::<dyn IsolatedFromAbove>(),
48 "patterns can only be applied to operations which are isolated from above"
49 );
50 parent_op.context_rc()
51 };
52
53 if config.scope.is_none() {
55 config.scope = Some(region);
56 }
57
58 let mut driver = RegionPatternRewriteDriver::new(context, patterns, config, region);
59 let converged = driver.simplify();
60 if converged.is_err() {
61 if let Some(max_iterations) = driver.driver.config.max_iterations {
62 log::trace!(target: "pattern-rewrite-driver", "pattern rewrite did not converge after scanning {max_iterations} times");
63 } else {
64 log::trace!(target: "pattern-rewrite-driver", "pattern rewrite did not converge");
65 }
66 }
67 converged
68}
69
70pub fn apply_patterns_and_fold_greedily(
93 op: OperationRef,
94 patterns: Rc<FrozenRewritePatternSet>,
95 config: GreedyRewriteConfig,
96) -> Result<bool, bool> {
97 let mut any_region_changed = false;
98 let mut failed = false;
99 let op = op.borrow();
100 let mut cursor = op.regions().front();
101 while let Some(region) = cursor.as_pointer() {
102 cursor.move_next();
103 match apply_patterns_and_fold_region_greedily(region, patterns.clone(), config.clone()) {
104 Ok(region_changed) => {
105 any_region_changed |= region_changed;
106 }
107 Err(region_changed) => {
108 any_region_changed |= region_changed;
109 failed = true;
110 }
111 }
112 }
113
114 if failed {
115 Err(any_region_changed)
116 } else {
117 Ok(any_region_changed)
118 }
119}
120
121#[derive(Debug, Copy, Clone, PartialEq, Eq)]
122#[repr(u8)]
123pub enum ApplyPatternsAndFoldEffect {
124 None,
126 Changed,
128 Erased,
130}
131
132pub type ApplyPatternsAndFoldResult =
133 Result<ApplyPatternsAndFoldEffect, ApplyPatternsAndFoldEffect>;
134
135pub fn apply_patterns_and_fold(
161 ops: &[OperationRef],
162 patterns: Rc<FrozenRewritePatternSet>,
163 mut config: GreedyRewriteConfig,
164) -> ApplyPatternsAndFoldResult {
165 if ops.is_empty() {
166 return Ok(ApplyPatternsAndFoldEffect::None);
167 }
168
169 if let Some(scope) = config.scope.as_ref() {
171 let all_ops_in_scope = ops.iter().all(|op| scope.borrow().find_ancestor_op(*op).is_some());
173 assert!(all_ops_in_scope, "ops must be within the specified scope");
174 } else {
175 config.scope = Region::find_common_ancestor(ops);
178 }
179
180 let max_rewrites = config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX);
182 let context = ops[0].borrow().context_rc();
183 let mut driver = MultiOpPatternRewriteDriver::new(context, patterns, config, ops);
184 let converged = driver.simplify(ops);
185 let changed = match converged.as_ref() {
186 Ok(changed) | Err(changed) => *changed,
187 };
188 let erased = driver.inner.surviving_ops.borrow().is_empty();
189 let effect = if erased {
190 ApplyPatternsAndFoldEffect::Erased
191 } else if changed {
192 ApplyPatternsAndFoldEffect::Changed
193 } else {
194 ApplyPatternsAndFoldEffect::None
195 };
196 if converged.is_ok() {
197 Ok(effect)
198 } else {
199 log::trace!(target: "pattern-rewrite-driver", "pattern rewrite did not converge after {max_rewrites} rewrites");
200 Err(effect)
201 }
202}
203
204#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
206pub enum GreedyRewriteStrictness {
207 #[default]
209 Any,
210 ExistingAndNew,
214 Existing,
218}
219
220#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
223pub enum RegionSimplificationLevel {
224 None,
226 #[default]
228 Normal,
229 Aggressive,
231}
232
233#[derive(Clone)]
235pub struct GreedyRewriteConfig {
236 listener: Option<Rc<dyn RewriterListener>>,
237 scope: Option<RegionRef>,
242 max_iterations: Option<core::num::NonZeroU32>,
247 max_rewrites: Option<core::num::NonZeroU32>,
249 region_simplification: RegionSimplificationLevel,
253 restrict: GreedyRewriteStrictness,
255 use_top_down_traversal: bool,
265}
266impl Default for GreedyRewriteConfig {
267 fn default() -> Self {
268 Self {
269 listener: None,
270 scope: None,
271 max_iterations: core::num::NonZeroU32::new(10),
272 max_rewrites: None,
273 region_simplification: Default::default(),
274 restrict: Default::default(),
275 use_top_down_traversal: false,
276 }
277 }
278}
279impl GreedyRewriteConfig {
280 pub fn new_with_listener(listener: impl RewriterListener + 'static) -> Self {
281 Self {
282 listener: Some(Rc::new(listener)),
283 ..Default::default()
284 }
285 }
286
287 pub fn with_scope(&mut self, region: RegionRef) -> &mut Self {
289 self.scope = Some(region);
290 self
291 }
292
293 pub fn with_max_iterations(&mut self, max: u32) -> &mut Self {
300 self.max_iterations = core::num::NonZeroU32::new(max);
301 self
302 }
303
304 pub fn with_max_rewrites(&mut self, max: u32) -> &mut Self {
310 self.max_rewrites = core::num::NonZeroU32::new(max);
311 self
312 }
313
314 pub fn with_region_simplification_level(
318 &mut self,
319 level: RegionSimplificationLevel,
320 ) -> &mut Self {
321 self.region_simplification = level;
322 self
323 }
324
325 pub fn with_restrictions(&mut self, level: GreedyRewriteStrictness) -> &mut Self {
327 self.restrict = level;
328 self
329 }
330
331 pub fn with_top_down_traversal(&mut self, yes: bool) -> &mut Self {
334 self.use_top_down_traversal = yes;
335 self
336 }
337
338 #[inline]
339 pub fn scope(&self) -> Option<RegionRef> {
340 self.scope
341 }
342
343 #[inline]
344 pub fn max_iterations(&self) -> Option<core::num::NonZeroU32> {
345 self.max_iterations
346 }
347
348 #[inline]
349 pub fn max_rewrites(&self) -> Option<core::num::NonZeroU32> {
350 self.max_rewrites
351 }
352
353 #[inline]
354 pub fn region_simplification_level(&self) -> RegionSimplificationLevel {
355 self.region_simplification
356 }
357
358 #[inline]
359 pub fn strictness(&self) -> GreedyRewriteStrictness {
360 self.restrict
361 }
362
363 #[inline]
364 pub fn use_top_down_traversal(&self) -> bool {
365 self.use_top_down_traversal
366 }
367}
368
369pub struct GreedyPatternRewriteDriver {
370 context: Rc<Context>,
371 worklist: RefCell<Worklist>,
372 config: GreedyRewriteConfig,
373 filtered_ops: RefCell<SmallSet<OperationRef, 8>>,
375 matcher: RefCell<PatternApplicator>,
376}
377
378impl GreedyPatternRewriteDriver {
379 pub fn new(
380 context: Rc<Context>,
381 patterns: Rc<FrozenRewritePatternSet>,
382 config: GreedyRewriteConfig,
383 ) -> Self {
384 let mut matcher = PatternApplicator::new(patterns);
386 matcher.apply_default_cost_model();
387
388 Self {
389 context,
390 worklist: Default::default(),
391 config,
392 filtered_ops: Default::default(),
393 matcher: RefCell::new(matcher),
394 }
395 }
396}
397
398impl GreedyPatternRewriteDriver {
400 pub fn add_single_op_to_worklist(&self, op: OperationRef) {
402 if matches!(self.config.restrict, GreedyRewriteStrictness::Any)
403 || self.filtered_ops.borrow().contains(&op)
404 {
405 log::trace!(target: "pattern-rewrite-driver", "adding single op '{}' to worklist", op.name());
406 self.worklist.borrow_mut().push(op);
407 } else {
408 log::trace!(
409 target: "pattern-rewrite-driver", "skipped adding single op '{}' to worklist due to strictness level",
410 op.name()
411 );
412 }
413 }
414
415 pub fn add_to_worklist(&self, op: OperationRef) {
417 let mut ancestors = SmallVec::<[OperationRef; 8]>::default();
419 let mut op = Some(op);
420 while let Some(ancestor_op) = op.take() {
421 let region = ancestor_op.grandparent();
422 if self.config.scope.as_ref() == region.as_ref() {
423 ancestors.push(ancestor_op);
424 for op in ancestors {
425 self.add_single_op_to_worklist(op);
426 }
427 return;
428 } else {
429 log::trace!(target: "pattern-rewrite-driver", "gathering ancestors of '{}' for worklist", ancestor_op.name());
430 ancestors.push(ancestor_op);
431 }
432 if let Some(region) = region {
433 op = region.parent();
434 } else {
435 log::trace!(target: "pattern-rewrite-driver", "reached top level op while searching for ancestors");
436 }
437 }
438 }
439
440 pub fn process_worklist(self: Rc<Self>) -> bool {
444 log::debug!(target: "pattern-rewrite-driver", "starting processing of greedy pattern rewrite driver worklist");
445 let mut rewriter =
446 PatternRewriter::new_with_listener(self.context.clone(), Rc::clone(&self));
447
448 let mut changed = false;
449 let mut num_rewrites = 0u32;
450 while self.config.max_rewrites.is_none_or(|max| num_rewrites < max.get()) {
451 let Some(op) = self.worklist.borrow_mut().pop() else {
452 log::debug!(target: "pattern-rewrite-driver", "processing worklist complete, rewrites have converged");
454 return changed;
455 };
456
457 if self.process_worklist_item(&mut rewriter, op) {
458 changed = true;
459 num_rewrites += 1;
460 }
461 }
462
463 log::debug!(
464 target: "pattern-rewrite-driver", "processing worklist was canceled after {} rewrites without converging (reached max \
465 rewrite limit)",
466 self.config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX)
467 );
468
469 changed
470 }
471
472 fn process_worklist_item(
476 &self,
477 rewriter: &mut PatternRewriter<Rc<Self>>,
478 mut op_ref: OperationRef,
479 ) -> bool {
480 let op = op_ref.borrow();
481
482 log::trace!(target: "pattern-rewrite-driver", "processing operation '{op}'");
483
484 if op.is_trivially_dead() {
486 drop(op);
487 rewriter.erase_op(op_ref);
488 log::trace!(target: "pattern-rewrite-driver", "processing complete: operation is trivially dead");
489 return true;
490 }
491
492 if !op.implements::<dyn ConstantLike>() {
496 drop(op);
498 let op = op_ref.borrow_mut();
499
500 let mut results = SmallVec::<[OpFoldResult; 1]>::default();
501 log::trace!(target: "pattern-rewrite-driver", "attempting to fold operation..");
502 if op.fold(&mut results).is_ok() {
503 if results.is_empty() {
504 self.notify_operation_modified(op_ref);
506 log::trace!(
507 target: "pattern-rewrite-driver",
508 "operation was succesfully folded/modified in-place"
509 );
510 return true;
511 } else {
512 log::trace!(
513 target: "pattern-rewrite-driver",
514 "operation was succesfully folded away, to be replaced with: {}",
515 results.as_slice().print(&crate::OpPrintingFlags::default(), op.context())
516 );
517 }
518
519 assert_eq!(
521 results.len(),
522 op.num_results(),
523 "folder produced incorrect number of results"
524 );
525 let mut rewriter = InsertionGuard::new(&mut **rewriter);
526 rewriter.set_insertion_point(ProgramPoint::before(op_ref));
527
528 log::trace!(target: "pattern-rewrite-driver", "replacing op with fold results..");
529 let mut replacements = SmallVec::<[Option<ValueRef>; 2]>::default();
530 let mut materialization_succeeded = true;
531 for (fold_result, result_ty) in results
532 .into_iter()
533 .zip(op.results().all().iter().map(|r| r.borrow().ty().clone()))
534 {
535 match fold_result {
536 OpFoldResult::Value(value) => {
537 assert_eq!(
538 value.borrow().ty(),
539 &result_ty,
540 "folder produced value of incorrect type"
541 );
542 replacements.push(Some(value));
543 }
544 OpFoldResult::Attribute(attr) => {
545 let span = op.span();
547 log::trace!(
548 target: "pattern-rewrite-driver",
549 "materializing constant for value '{}' and type '{result_ty}'",
550 attr.print(&crate::OpPrintingFlags::default(), op.context())
551 );
552 let constant_op = op.dialect().materialize_constant(
553 &mut *rewriter,
554 attr,
555 &result_ty,
556 span,
557 );
558 match constant_op {
559 None => {
560 log::trace!(
561 target: "pattern-rewrite-driver",
562 "materialization failed: cleaning up any materialized ops \
563 for {} previous results",
564 replacements.len()
565 );
566 let mut replacement_ops =
568 SmallVec::<[OperationRef; 2]>::default();
569 for replacement in replacements.iter().filter_map(|repl| *repl)
570 {
571 let replacement = replacement.borrow();
572 assert!(
573 !replacement.is_used(),
574 "folder reused existing op for one result, but \
575 constant materialization failed for another result"
576 );
577 let replacement_op = replacement.get_defining_op().unwrap();
578 if replacement_ops.contains(&replacement_op) {
579 continue;
580 }
581 replacement_ops.push(replacement_op);
582 }
583 for replacement_op in replacement_ops {
584 rewriter.erase_op(replacement_op);
585 }
586 materialization_succeeded = false;
587 break;
588 }
589 Some(constant_op) => {
590 let const_op = constant_op.borrow();
591 assert!(
592 const_op.implements::<dyn ConstantLike>(),
593 "materialize_constant produced op that does not implement \
594 ConstantLike"
595 );
596 let result: ValueRef = const_op.results().all()[0].upcast();
597 assert_eq!(
598 result.borrow().ty(),
599 &result_ty,
600 "materialize_constant produced incorrect result type"
601 );
602 log::trace!(
603 target: "pattern-rewrite-driver",
604 "successfully materialized constant as {}",
605 result.borrow().id()
606 );
607 replacements.push(Some(result));
608 }
609 }
610 }
611 }
612 }
613
614 if materialization_succeeded {
615 log::trace!(
616 target: "pattern-rewrite-driver",
617 "materialization of fold results was successful, performing replacement.."
618 );
619 drop(op);
620 rewriter.replace_op_with_values(op_ref, &replacements);
621 log::trace!(
622 target: "pattern-rewrite-driver",
623 "fold succeeded: operation was replaced with materialized constants"
624 );
625 return true;
626 } else {
627 log::trace!(
628 target: "pattern-rewrite-driver",
629 "materialization of fold results failed, proceeding without folding"
630 );
631 }
632 }
633 } else {
634 log::trace!(target: "pattern-rewrite-driver", "operation could not be folded");
635 }
636
637 log::trace!(target: "pattern-rewrite-driver", "attempting to match and rewrite one of the input patterns..");
649 let result = if let Some(listener) = self.config.listener.as_deref() {
650 let op_name = op_ref.borrow().name();
651 let can_apply = |pattern: &dyn RewritePattern| {
652 log::trace!(target: "pattern-rewrite-driver", "applying pattern {} to op {}", pattern.name(), &op_name);
653 listener.notify_pattern_begin(pattern, op_ref);
654 true
655 };
656 let on_failure = |pattern: &dyn RewritePattern| {
657 log::trace!(target: "pattern-rewrite-driver", "pattern failed to match");
658 listener.notify_pattern_end(pattern, false);
659 };
660 let on_success = |pattern: &dyn RewritePattern| {
661 log::trace!(target: "pattern-rewrite-driver", "pattern applied successfully");
662 listener.notify_pattern_end(pattern, true);
663 Ok(())
664 };
665 self.matcher.borrow_mut().match_and_rewrite(
666 op_ref,
667 &mut **rewriter,
668 can_apply,
669 on_failure,
670 on_success,
671 )
672 } else {
673 self.matcher.borrow_mut().match_and_rewrite(
674 op_ref,
675 &mut **rewriter,
676 |_| true,
677 |_| {},
678 |_| Ok(()),
679 )
680 };
681
682 match result {
683 Ok(_) => {
684 log::trace!(target: "pattern-rewrite-driver", "processing complete: pattern matched and operation was rewritten");
685 true
686 }
687 Err(PatternApplicationError::NoMatchesFound) => {
688 log::debug!(target: "pattern-rewrite-driver", "processing complete: exhausted all patterns without finding a match");
689 false
690 }
691 Err(PatternApplicationError::Report(report)) => {
692 log::debug!(
693 target: "pattern-rewrite-driver", "processing complete: error occurred during match and rewrite: {report}"
694 );
695 false
696 }
697 }
698 }
699
700 fn add_operands_to_worklist(&self, op: OperationRef) {
704 let current_op = op.borrow();
705 for operand in current_op.operands().all() {
706 let operand = operand.borrow();
711 let Some(def_op) = operand.value().get_defining_op() else {
712 continue;
713 };
714
715 let mut other_user = None;
716 let mut has_more_than_two_uses = false;
717 for user in operand.value().iter_uses() {
718 if user.owner == op || other_user.as_ref().is_some_and(|ou| ou == &user.owner) {
719 continue;
720 }
721 if other_user.is_none() {
722 other_user = Some(user.owner);
723 continue;
724 }
725 has_more_than_two_uses = true;
726 break;
727 }
728 if !has_more_than_two_uses {
729 self.add_to_worklist(def_op);
730 }
731 }
732 }
733}
734
735impl Listener for GreedyPatternRewriteDriver {
737 fn kind(&self) -> crate::ListenerType {
738 crate::ListenerType::Rewriter
739 }
740
741 fn notify_block_inserted(
743 &self,
744 block: crate::BlockRef,
745 prev: Option<RegionRef>,
746 ip: Option<crate::BlockRef>,
747 ) {
748 if let Some(listener) = self.config.listener.as_deref() {
749 listener.notify_block_inserted(block, prev, ip);
750 }
751 }
752
753 fn notify_operation_inserted(&self, op: OperationRef, prev: ProgramPoint) {
757 if let Some(listener) = self.config.listener.as_deref() {
758 listener.notify_operation_inserted(op, prev);
759 }
760 if matches!(self.config.restrict, GreedyRewriteStrictness::ExistingAndNew) {
761 self.filtered_ops.borrow_mut().insert(op);
762 }
763 self.add_to_worklist(op);
764 }
765}
766impl RewriterListener for GreedyPatternRewriteDriver {
767 fn notify_block_erased(&self, block: BlockRef) {
769 if let Some(listener) = self.config.listener.as_deref() {
770 listener.notify_block_erased(block);
771 }
772 }
773
774 fn notify_operation_modified(&self, op: OperationRef) {
777 if let Some(listener) = self.config.listener.as_deref() {
778 listener.notify_operation_modified(op);
779 }
780 self.add_to_worklist(op);
781 }
782
783 fn notify_operation_erased(&self, op: OperationRef) {
787 if let Some(scope) = self.config.scope.as_ref() {
793 assert!(
794 scope.parent().is_some_and(|parent_op| parent_op != op),
795 "scope region must not be erased during greedy pattern rewrite"
796 );
797 }
798
799 if let Some(listener) = self.config.listener.as_deref() {
800 listener.notify_operation_erased(op);
801 }
802
803 self.add_operands_to_worklist(op);
804 self.worklist.borrow_mut().remove(&op);
805
806 if self.config.restrict != GreedyRewriteStrictness::Any {
807 self.filtered_ops.borrow_mut().remove(&op);
808 }
809 }
810
811 fn notify_operation_replaced_with_values(
815 &self,
816 op: OperationRef,
817 replacement: &[Option<ValueRef>],
818 ) {
819 if let Some(listener) = self.config.listener.as_deref() {
820 listener.notify_operation_replaced_with_values(op, replacement);
821 }
822 }
823
824 fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
825 if let Some(listener) = self.config.listener.as_deref() {
826 listener.notify_match_failure(span, reason);
827 }
828 }
829}
830
831pub struct RegionPatternRewriteDriver {
832 driver: Rc<GreedyPatternRewriteDriver>,
833 region: RegionRef,
834}
835impl RegionPatternRewriteDriver {
836 pub fn new(
837 context: Rc<Context>,
838 patterns: Rc<FrozenRewritePatternSet>,
839 config: GreedyRewriteConfig,
840 region: RegionRef,
841 ) -> Self {
842 let mut driver = GreedyPatternRewriteDriver::new(context, patterns, config);
843 if driver.config.restrict != GreedyRewriteStrictness::Any {
845 let filtered_ops = driver.filtered_ops.get_mut();
846 region.raw_postwalk_all::<Forward, _>(|op| {
847 filtered_ops.insert(op);
848 });
849 }
850 Self {
851 driver: Rc::new(driver),
852 region,
853 }
854 }
855
856 pub fn simplify(&mut self) -> Result<bool, bool> {
861 use crate::matchers::Matcher;
862
863 let mut continue_rewrites = false;
864 let mut iteration = 0;
865
866 while self.driver.config.max_iterations.is_none_or(|max| iteration < max.get()) {
867 log::trace!(target: "pattern-rewrite-driver", "starting iteration {iteration} of region pattern rewrite driver");
868 iteration += 1;
869
870 self.driver.worklist.borrow_mut().clear();
872
873 let context = self.driver.context.clone();
876 let mut folder = OperationFolder::new(context, Rc::clone(&self.driver));
877 let mut insert_known_constant = |op: OperationRef| {
878 let operation = op.borrow();
881 if let Some(const_value) = crate::matchers::constant().matches(&operation) {
882 drop(operation);
883 if !folder.insert_known_constant(op, Some(const_value)) {
884 return true;
885 }
886 }
887 false
888 };
889
890 if !self.driver.config.use_top_down_traversal {
891 log::trace!(target: "pattern-rewrite-driver", "adding operations in postorder");
893 self.region.raw_postwalk_all::<Forward, _>(|op| {
894 if !insert_known_constant(op) {
895 self.driver.add_to_worklist(op);
896 }
897 });
898 } else {
899 log::trace!(target: "pattern-rewrite-driver", "adding operations in preorder");
901 self.region
902 .raw_prewalk::<Forward, _, _>(|op| {
903 if !insert_known_constant(op) {
904 self.driver.add_to_worklist(op);
905 WalkResult::<Report>::Continue(())
906 } else {
907 WalkResult::Skip
908 }
909 })
910 .into_result()
911 .expect("unexpected error occurred while walking region");
912
913 self.driver.worklist.borrow_mut().reverse();
915 }
916
917 continue_rewrites = self.driver.clone().process_worklist();
918 log::trace!(
919 target: "pattern-rewrite-driver", "processing of worklist for this iteration has completed, \
920 changed={continue_rewrites}"
921 );
922
923 if self.driver.config.region_simplification != RegionSimplificationLevel::None {
926 let mut rewriter = PatternRewriter::new_with_listener(
927 self.driver.context.clone(),
928 Rc::clone(&self.driver),
929 );
930 continue_rewrites |= Region::simplify_all(
931 &[self.region],
932 &mut *rewriter,
933 self.driver.config.region_simplification,
934 )
935 .is_ok();
936 } else {
937 log::debug!(target: "pattern-rewrite-driver", "region simplification was disabled, skipping simplification rewrites");
938 }
939
940 if !continue_rewrites {
941 log::trace!(target: "pattern-rewrite-driver", "region pattern rewrites have converged");
942 break;
943 }
944 }
945
946 if !continue_rewrites {
949 Ok(iteration > 1)
950 } else {
951 Err(iteration > 1)
952 }
953 }
954}
955
956pub struct MultiOpPatternRewriteDriver {
957 driver: Rc<GreedyPatternRewriteDriver>,
958 inner: Rc<MultiOpPatternRewriteDriverImpl>,
959}
960
961struct MultiOpPatternRewriteDriverImpl {
962 surviving_ops: RefCell<SmallSet<OperationRef, 8>>,
963}
964
965impl MultiOpPatternRewriteDriver {
966 pub fn new(
967 context: Rc<Context>,
968 patterns: Rc<FrozenRewritePatternSet>,
969 mut config: GreedyRewriteConfig,
970 ops: &[OperationRef],
971 ) -> Self {
972 let surviving_ops = SmallSet::from_iter(ops.iter().copied());
973 let inner = Rc::new(MultiOpPatternRewriteDriverImpl {
974 surviving_ops: RefCell::new(surviving_ops),
975 });
976 let listener = Rc::new(ForwardingListener::new(config.listener.take(), Rc::clone(&inner)));
977 config.listener = Some(listener);
978
979 let mut driver = GreedyPatternRewriteDriver::new(context.clone(), patterns, config);
980 if driver.config.restrict != GreedyRewriteStrictness::Any {
981 driver.filtered_ops.get_mut().extend(ops.iter().cloned());
982 }
983
984 Self {
985 driver: Rc::new(driver),
986 inner,
987 }
988 }
989
990 pub fn simplify(&mut self, ops: &[OperationRef]) -> Result<bool, bool> {
991 for op in ops.iter().copied() {
993 self.driver.add_single_op_to_worklist(op);
994 }
995
996 let changed = self.driver.clone().process_worklist();
998 if self.driver.worklist.borrow().is_empty() {
999 Ok(changed)
1000 } else {
1001 Err(changed)
1002 }
1003 }
1004}
1005
1006impl Listener for MultiOpPatternRewriteDriverImpl {
1007 fn kind(&self) -> crate::ListenerType {
1008 crate::ListenerType::Rewriter
1009 }
1010}
1011impl RewriterListener for MultiOpPatternRewriteDriverImpl {
1012 fn notify_operation_erased(&self, op: OperationRef) {
1013 self.surviving_ops.borrow_mut().remove(&op);
1014 }
1015}
1016
1017#[derive(Default)]
1018struct Worklist(Vec<OperationRef>);
1019impl Worklist {
1020 #[inline]
1022 pub fn clear(&mut self) {
1023 self.0.clear()
1024 }
1025
1026 #[inline(always)]
1028 pub fn is_empty(&self) -> bool {
1029 self.0.is_empty()
1030 }
1031
1032 pub fn push(&mut self, op: OperationRef) {
1034 if self.0.contains(&op) {
1035 return;
1036 }
1037 self.0.push(op);
1038 }
1039
1040 #[inline]
1042 pub fn pop(&mut self) -> Option<OperationRef> {
1043 self.0.pop()
1044 }
1045
1046 pub fn remove(&mut self, op: &OperationRef) {
1048 if let Some(index) = self.0.iter().position(|o| o == op) {
1049 self.0.remove(index);
1050 }
1051 }
1052
1053 pub fn reverse(&mut self) {
1055 self.0.reverse();
1056 }
1057}