midenc_hir_transform/
canonicalization.rs

1use alloc::{boxed::Box, format, rc::Rc};
2
3use midenc_hir::{
4    pass::{OperationPass, Pass, PassExecutionState, PostPassStatus},
5    patterns::{self, FrozenRewritePatternSet, GreedyRewriteConfig, RewritePatternSet},
6    Context, EntityMut, Operation, OperationName, Report, Spanned,
7};
8use midenc_session::diagnostics::Severity;
9
10/// This pass performs various types of canonicalizations over a set of operations by iteratively
11/// applying the canonicalization patterns of all loaded dialects until either a fixpoint is reached
12/// or the maximum number of iterations/rewrites is exhausted. Canonicalization is best-effort and
13/// does not guarantee that the entire IR is in a canonical form after running this pass.
14///
15/// See the docs for [crate::traits::Canonicalizable] for more details.
16pub struct Canonicalizer {
17    config: GreedyRewriteConfig,
18    rewrites: Option<Rc<FrozenRewritePatternSet>>,
19    require_convergence: bool,
20}
21
22impl Default for Canonicalizer {
23    fn default() -> Self {
24        let mut config = GreedyRewriteConfig::default();
25        config.with_top_down_traversal(true);
26        Self {
27            config,
28            rewrites: None,
29            require_convergence: false,
30        }
31    }
32}
33
34impl Canonicalizer {
35    pub fn new(config: GreedyRewriteConfig, require_convergence: bool) -> Self {
36        Self {
37            config,
38            rewrites: None,
39            require_convergence,
40        }
41    }
42
43    /// Creates an instance of this pass, configured with default settings.
44    pub fn create() -> Box<dyn OperationPass> {
45        Box::new(Self::default())
46    }
47
48    /// Creates an instance of this pass with the specified config.
49    pub fn create_with_config(config: &GreedyRewriteConfig) -> Box<dyn OperationPass> {
50        Box::new(Self {
51            config: config.clone(),
52            rewrites: None,
53            require_convergence: false,
54        })
55    }
56}
57
58impl Pass for Canonicalizer {
59    type Target = Operation;
60
61    fn name(&self) -> &'static str {
62        "canonicalizer"
63    }
64
65    fn argument(&self) -> &'static str {
66        "canonicalizer"
67    }
68
69    fn description(&self) -> &'static str {
70        "Performs canonicalization over a set of operations"
71    }
72
73    fn can_schedule_on(&self, _name: &OperationName) -> bool {
74        true
75    }
76
77    fn initialize(&mut self, context: Rc<Context>) -> Result<(), Report> {
78        log::trace!(target: "canonicalization", "initializing canonicalizer pass");
79        let mut rewrites = RewritePatternSet::new(context.clone());
80
81        for dialect in context.registered_dialects().values() {
82            for op in dialect.registered_ops().iter() {
83                op.populate_canonicalization_patterns(&mut rewrites, context.clone());
84            }
85        }
86
87        self.rewrites = Some(Rc::new(FrozenRewritePatternSet::new(rewrites)));
88
89        Ok(())
90    }
91
92    fn run_on_operation(
93        &mut self,
94        op: EntityMut<'_, Self::Target>,
95        state: &mut PassExecutionState,
96    ) -> Result<(), Report> {
97        let Some(rewrites) = self.rewrites.as_ref() else {
98            log::debug!(target: "canonicalization", "skipping canonicalization as there are no rewrite patterns to apply");
99            state.set_post_pass_status(PostPassStatus::Unchanged);
100            return Ok(());
101        };
102        let op = {
103            let ptr = op.as_operation_ref();
104            drop(op);
105            log::debug!(target: "canonicalization", "applying canonicalization to {}", ptr.borrow());
106            log::debug!(target: "canonicalization", "  require_convergence = {}", self.require_convergence);
107            ptr
108        };
109        let converged =
110            patterns::apply_patterns_and_fold_greedily(op, rewrites.clone(), self.config.clone());
111        if self.require_convergence && converged.is_err() {
112            log::debug!(target: "canonicalization", "canonicalization could not converge");
113            let span = op.borrow().span();
114            return Err(state
115                .context()
116                .diagnostics()
117                .diagnostic(Severity::Error)
118                .with_message("canonicalization failed")
119                .with_primary_label(
120                    span,
121                    format!(
122                        "canonicalization did not converge{}",
123                        self.config
124                            .max_iterations()
125                            .map(|max| format!(" after {max} iterations"))
126                            .unwrap_or_default()
127                    ),
128                )
129                .into_report());
130        }
131
132        let op = op.borrow();
133        let changed = match converged {
134            Ok(changed) => {
135                log::debug!(target: "canonicalization", "canonicalization converged for '{}', changed={changed}", op.name());
136                changed
137            }
138            Err(changed) => {
139                log::warn!(
140                    target: "canonicalization",
141                    "canonicalization failed to converge for '{}', changed={changed}",
142                    op.name()
143                );
144                changed
145            }
146        };
147        let ir_changed = changed.into();
148        state.set_post_pass_status(ir_changed);
149
150        Ok(())
151    }
152}