midenc_hir_transform/
canonicalization.rs1use alloc::{boxed::Box, format, rc::Rc};
2
3use midenc_hir::{
4 pass::{OperationPass, Pass, PassExecutionState, PostPassStatus},
5 patterns::{self, FrozenRewritePatternSet, GreedyRewriteConfig, RewritePatternSet},
6 Context, EntityMut, Operation, OperationName, Report, Spanned,
7};
8use midenc_session::diagnostics::Severity;
9
10pub struct Canonicalizer {
17 config: GreedyRewriteConfig,
18 rewrites: Option<Rc<FrozenRewritePatternSet>>,
19 require_convergence: bool,
20}
21
22impl Default for Canonicalizer {
23 fn default() -> Self {
24 let mut config = GreedyRewriteConfig::default();
25 config.with_top_down_traversal(true);
26 Self {
27 config,
28 rewrites: None,
29 require_convergence: false,
30 }
31 }
32}
33
34impl Canonicalizer {
35 pub fn new(config: GreedyRewriteConfig, require_convergence: bool) -> Self {
36 Self {
37 config,
38 rewrites: None,
39 require_convergence,
40 }
41 }
42
43 pub fn create() -> Box<dyn OperationPass> {
45 Box::new(Self::default())
46 }
47
48 pub fn create_with_config(config: &GreedyRewriteConfig) -> Box<dyn OperationPass> {
50 Box::new(Self {
51 config: config.clone(),
52 rewrites: None,
53 require_convergence: false,
54 })
55 }
56}
57
58impl Pass for Canonicalizer {
59 type Target = Operation;
60
61 fn name(&self) -> &'static str {
62 "canonicalizer"
63 }
64
65 fn argument(&self) -> &'static str {
66 "canonicalizer"
67 }
68
69 fn description(&self) -> &'static str {
70 "Performs canonicalization over a set of operations"
71 }
72
73 fn can_schedule_on(&self, _name: &OperationName) -> bool {
74 true
75 }
76
77 fn initialize(&mut self, context: Rc<Context>) -> Result<(), Report> {
78 log::trace!(target: "canonicalization", "initializing canonicalizer pass");
79 let mut rewrites = RewritePatternSet::new(context.clone());
80
81 for dialect in context.registered_dialects().values() {
82 for op in dialect.registered_ops().iter() {
83 op.populate_canonicalization_patterns(&mut rewrites, context.clone());
84 }
85 }
86
87 self.rewrites = Some(Rc::new(FrozenRewritePatternSet::new(rewrites)));
88
89 Ok(())
90 }
91
92 fn run_on_operation(
93 &mut self,
94 op: EntityMut<'_, Self::Target>,
95 state: &mut PassExecutionState,
96 ) -> Result<(), Report> {
97 let Some(rewrites) = self.rewrites.as_ref() else {
98 log::debug!(target: "canonicalization", "skipping canonicalization as there are no rewrite patterns to apply");
99 state.set_post_pass_status(PostPassStatus::Unchanged);
100 return Ok(());
101 };
102 let op = {
103 let ptr = op.as_operation_ref();
104 drop(op);
105 log::debug!(target: "canonicalization", "applying canonicalization to {}", ptr.borrow());
106 log::debug!(target: "canonicalization", " require_convergence = {}", self.require_convergence);
107 ptr
108 };
109 let converged =
110 patterns::apply_patterns_and_fold_greedily(op, rewrites.clone(), self.config.clone());
111 if self.require_convergence && converged.is_err() {
112 log::debug!(target: "canonicalization", "canonicalization could not converge");
113 let span = op.borrow().span();
114 return Err(state
115 .context()
116 .diagnostics()
117 .diagnostic(Severity::Error)
118 .with_message("canonicalization failed")
119 .with_primary_label(
120 span,
121 format!(
122 "canonicalization did not converge{}",
123 self.config
124 .max_iterations()
125 .map(|max| format!(" after {max} iterations"))
126 .unwrap_or_default()
127 ),
128 )
129 .into_report());
130 }
131
132 let op = op.borrow();
133 let changed = match converged {
134 Ok(changed) => {
135 log::debug!(target: "canonicalization", "canonicalization converged for '{}', changed={changed}", op.name());
136 changed
137 }
138 Err(changed) => {
139 log::warn!(
140 target: "canonicalization",
141 "canonicalization failed to converge for '{}', changed={changed}",
142 op.name()
143 );
144 changed
145 }
146 };
147 let ir_changed = changed.into();
148 state.set_post_pass_status(ir_changed);
149
150 Ok(())
151 }
152}