sway_ir/
pass_manager.rs

1use crate::{
2    create_arg_demotion_pass, create_arg_pointee_mutability_tagger_pass, create_ccp_pass,
3    create_const_demotion_pass, create_const_folding_pass, create_cse_pass, create_dce_pass,
4    create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass,
5    create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass,
6    create_fn_inline_pass, create_globals_dce_pass, create_mem2reg_pass, create_memcpyopt_pass,
7    create_memcpyprop_reverse_pass, create_misc_demotion_pass, create_module_printer_pass,
8    create_module_verifier_pass, create_postorder_pass, create_ret_demotion_pass,
9    create_simplify_cfg_pass, create_sroa_pass, Context, Function, IrError, Module,
10    ARG_DEMOTION_NAME, ARG_POINTEE_MUTABILITY_TAGGER_NAME, CCP_NAME, CONST_DEMOTION_NAME,
11    CONST_FOLDING_NAME, CSE_NAME, DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME,
12    FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, GLOBALS_DCE_NAME, MEM2REG_NAME, MEMCPYOPT_NAME,
13    MEMCPYPROP_REVERSE_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME, SROA_NAME,
14};
15use downcast_rs::{impl_downcast, Downcast};
16use rustc_hash::FxHashMap;
17use std::{
18    any::{type_name, TypeId},
19    collections::{hash_map, HashSet},
20};
21
22/// Result of an analysis. Specific result must be downcasted to.
23pub trait AnalysisResultT: Downcast {}
24impl_downcast!(AnalysisResultT);
25pub type AnalysisResult = Box<dyn AnalysisResultT>;
26
27/// Program scope over which a pass executes.
28pub trait PassScope {
29    fn get_arena_idx(&self) -> slotmap::DefaultKey;
30}
31impl PassScope for Module {
32    fn get_arena_idx(&self) -> slotmap::DefaultKey {
33        self.0
34    }
35}
36impl PassScope for Function {
37    fn get_arena_idx(&self) -> slotmap::DefaultKey {
38        self.0
39    }
40}
41
42/// Is a pass an Analysis or a Transformation over the IR?
43#[derive(Clone)]
44pub enum PassMutability<S: PassScope> {
45    /// An analysis pass, producing an analysis result.
46    Analysis(fn(&Context, analyses: &AnalysisResults, S) -> Result<AnalysisResult, IrError>),
47    /// A pass over the IR that can possibly modify it.
48    Transform(fn(&mut Context, analyses: &AnalysisResults, S) -> Result<bool, IrError>),
49}
50
51/// A concrete version of [PassScope].
52#[derive(Clone)]
53pub enum ScopedPass {
54    ModulePass(PassMutability<Module>),
55    FunctionPass(PassMutability<Function>),
56}
57
58/// An analysis or transformation pass.
59pub struct Pass {
60    /// Pass identifier.
61    pub name: &'static str,
62    /// A short description.
63    pub descr: &'static str,
64    /// Other passes that this pass depends on.
65    pub deps: Vec<&'static str>,
66    /// The executor.
67    pub runner: ScopedPass,
68}
69
70impl Pass {
71    pub fn is_analysis(&self) -> bool {
72        match &self.runner {
73            ScopedPass::ModulePass(pm) => matches!(pm, PassMutability::Analysis(_)),
74            ScopedPass::FunctionPass(pm) => matches!(pm, PassMutability::Analysis(_)),
75        }
76    }
77
78    pub fn is_transform(&self) -> bool {
79        !self.is_analysis()
80    }
81
82    pub fn is_module_pass(&self) -> bool {
83        matches!(self.runner, ScopedPass::ModulePass(_))
84    }
85
86    pub fn is_function_pass(&self) -> bool {
87        matches!(self.runner, ScopedPass::FunctionPass(_))
88    }
89}
90
91#[derive(Default)]
92pub struct AnalysisResults {
93    // Hash from (AnalysisResultT, (PassScope, Scope Identity)) to an actual result.
94    results: FxHashMap<(TypeId, (TypeId, slotmap::DefaultKey)), AnalysisResult>,
95    name_typeid_map: FxHashMap<&'static str, TypeId>,
96}
97
98impl AnalysisResults {
99    /// Get the results of an analysis.
100    /// Example analyses.get_analysis_result::<DomTreeAnalysis>(foo).
101    pub fn get_analysis_result<T: AnalysisResultT, S: PassScope + 'static>(&self, scope: S) -> &T {
102        self.results
103            .get(&(
104                TypeId::of::<T>(),
105                (TypeId::of::<S>(), scope.get_arena_idx()),
106            ))
107            .unwrap_or_else(|| {
108                panic!(
109                    "Internal error. Analysis result {} unavailable for {} with idx {:?}",
110                    type_name::<T>(),
111                    type_name::<S>(),
112                    scope.get_arena_idx()
113                )
114            })
115            .downcast_ref()
116            .expect("AnalysisResult: Incorrect type")
117    }
118
119    /// Is an analysis result available at the given scope?
120    fn is_analysis_result_available<S: PassScope + 'static>(
121        &self,
122        name: &'static str,
123        scope: S,
124    ) -> bool {
125        self.name_typeid_map
126            .get(name)
127            .and_then(|result_typeid| {
128                self.results
129                    .get(&(*result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())))
130            })
131            .is_some()
132    }
133
134    /// Add a new result.
135    fn add_result<S: PassScope + 'static>(
136        &mut self,
137        name: &'static str,
138        scope: S,
139        result: AnalysisResult,
140    ) {
141        let result_typeid = (*result).type_id();
142        self.results.insert(
143            (result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())),
144            result,
145        );
146        self.name_typeid_map.insert(name, result_typeid);
147    }
148
149    /// Invalidate all results at a given scope.
150    fn invalidate_all_results_at_scope<S: PassScope + 'static>(&mut self, scope: S) {
151        self.results
152            .retain(|(_result_typeid, (scope_typeid, scope_idx)), _v| {
153                (*scope_typeid, *scope_idx) != (TypeId::of::<S>(), scope.get_arena_idx())
154            });
155    }
156}
157
158/// Options for printing [Pass]es in case of running them with printing requested.
159///
160/// Note that states of IR can always be printed by injecting the module printer pass
161/// and just running the passes. That approach however offers less control over the
162/// printing. E.g., requiring the printing to happen only if the previous passes
163/// modified the IR cannot be done by simply injecting a module printer.
164#[derive(Debug)]
165pub struct PrintPassesOpts {
166    pub initial: bool,
167    pub r#final: bool,
168    pub modified_only: bool,
169    pub passes: HashSet<String>,
170}
171
172/// Options for verifying [Pass]es in case of running them with verifying requested.
173///
174/// Note that states of IR can always be verified by injecting the module verifier pass
175/// and just running the passes. That approach however offers less control over the
176/// verification. E.g., requiring the verification to happen only if the previous passes
177/// modified the IR cannot be done by simply injecting a module verifier.
178#[derive(Debug)]
179pub struct VerifyPassesOpts {
180    pub initial: bool,
181    pub r#final: bool,
182    pub modified_only: bool,
183    pub passes: HashSet<String>,
184}
185
186#[derive(Default)]
187pub struct PassManager {
188    passes: FxHashMap<&'static str, Pass>,
189    analyses: AnalysisResults,
190}
191
192impl PassManager {
193    pub const OPTIMIZATION_PASSES: [&'static str; 15] = [
194        FN_INLINE_NAME,
195        SIMPLIFY_CFG_NAME,
196        SROA_NAME,
197        DCE_NAME,
198        GLOBALS_DCE_NAME,
199        FN_DEDUP_RELEASE_PROFILE_NAME,
200        FN_DEDUP_DEBUG_PROFILE_NAME,
201        MEM2REG_NAME,
202        MEMCPYOPT_NAME,
203        MEMCPYPROP_REVERSE_NAME,
204        CONST_FOLDING_NAME,
205        ARG_DEMOTION_NAME,
206        CONST_DEMOTION_NAME,
207        RET_DEMOTION_NAME,
208        MISC_DEMOTION_NAME,
209    ];
210
211    /// Register a pass. Should be called only once for each pass.
212    pub fn register(&mut self, pass: Pass) -> &'static str {
213        for dep in &pass.deps {
214            if let Some(dep_t) = self.lookup_registered_pass(dep) {
215                if dep_t.is_transform() {
216                    panic!(
217                        "Pass {} cannot depend on a transformation pass {}",
218                        pass.name, dep
219                    );
220                }
221                if pass.is_function_pass() && dep_t.is_module_pass() {
222                    panic!(
223                        "Function pass {} cannot depend on module pass {}",
224                        pass.name, dep
225                    );
226                }
227            } else {
228                panic!(
229                    "Pass {} depends on a (yet) unregistered pass {}",
230                    pass.name, dep
231                );
232            }
233        }
234        let pass_name = pass.name;
235        match self.passes.entry(pass.name) {
236            hash_map::Entry::Occupied(_) => {
237                panic!("Trying to register an already registered pass");
238            }
239            hash_map::Entry::Vacant(entry) => {
240                entry.insert(pass);
241            }
242        }
243        pass_name
244    }
245
246    fn actually_run(&mut self, ir: &mut Context, pass: &'static str) -> Result<bool, IrError> {
247        let mut modified = false;
248
249        fn run_module_pass(
250            pm: &mut PassManager,
251            ir: &mut Context,
252            pass: &'static str,
253            module: Module,
254        ) -> Result<bool, IrError> {
255            let mut modified = false;
256            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
257            for dep in pass_t.deps.clone() {
258                let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
259                // If pass registration allows transformations as dependents, we could remove this I guess.
260                assert!(dep_t.is_analysis());
261                match dep_t.runner {
262                    ScopedPass::ModulePass(_) => {
263                        if !pm.analyses.is_analysis_result_available(dep, module) {
264                            run_module_pass(pm, ir, dep, module)?;
265                        }
266                    }
267                    ScopedPass::FunctionPass(_) => {
268                        for f in module.function_iter(ir) {
269                            if !pm.analyses.is_analysis_result_available(dep, f) {
270                                run_function_pass(pm, ir, dep, f)?;
271                            }
272                        }
273                    }
274                }
275            }
276
277            // Get the pass again to satisfy the borrow checker.
278            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
279            let ScopedPass::ModulePass(mp) = pass_t.runner.clone() else {
280                panic!("Expected a module pass");
281            };
282            match mp {
283                PassMutability::Analysis(analysis) => {
284                    let result = analysis(ir, &pm.analyses, module)?;
285                    pm.analyses.add_result(pass, module, result);
286                }
287                PassMutability::Transform(transform) => {
288                    if transform(ir, &pm.analyses, module)? {
289                        pm.analyses.invalidate_all_results_at_scope(module);
290                        for f in module.function_iter(ir) {
291                            pm.analyses.invalidate_all_results_at_scope(f);
292                        }
293                        modified = true;
294                    }
295                }
296            }
297
298            Ok(modified)
299        }
300
301        fn run_function_pass(
302            pm: &mut PassManager,
303            ir: &mut Context,
304            pass: &'static str,
305            function: Function,
306        ) -> Result<bool, IrError> {
307            let mut modified = false;
308            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
309            for dep in pass_t.deps.clone() {
310                let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
311                // If pass registration allows transformations as dependents, we could remove this I guess.
312                assert!(dep_t.is_analysis());
313                match dep_t.runner {
314                    ScopedPass::ModulePass(_) => {
315                        panic!("Function pass {pass} cannot depend on module pass {dep}")
316                    }
317                    ScopedPass::FunctionPass(_) => {
318                        if !pm.analyses.is_analysis_result_available(dep, function) {
319                            run_function_pass(pm, ir, dep, function)?;
320                        };
321                    }
322                }
323            }
324
325            // Get the pass again to satisfy the borrow checker.
326            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
327            let ScopedPass::FunctionPass(fp) = pass_t.runner.clone() else {
328                panic!("Expected a function pass");
329            };
330            match fp {
331                PassMutability::Analysis(analysis) => {
332                    let result = analysis(ir, &pm.analyses, function)?;
333                    pm.analyses.add_result(pass, function, result);
334                }
335                PassMutability::Transform(transform) => {
336                    if transform(ir, &pm.analyses, function)? {
337                        pm.analyses.invalidate_all_results_at_scope(function);
338                        modified = true;
339                    }
340                }
341            }
342
343            Ok(modified)
344        }
345
346        for m in ir.module_iter() {
347            let pass_t = self.passes.get(pass).expect("Unregistered pass");
348            let pass_runner = pass_t.runner.clone();
349            match pass_runner {
350                ScopedPass::ModulePass(_) => {
351                    modified |= run_module_pass(self, ir, pass, m)?;
352                }
353                ScopedPass::FunctionPass(_) => {
354                    for f in m.function_iter(ir) {
355                        modified |= run_function_pass(self, ir, pass, f)?;
356                    }
357                }
358            }
359        }
360        Ok(modified)
361    }
362
363    /// Run the `passes` and return true if the `passes` modify the initial `ir`.
364    pub fn run(&mut self, ir: &mut Context, passes: &PassGroup) -> Result<bool, IrError> {
365        let mut modified = false;
366        for pass in passes.flatten_pass_group() {
367            modified |= self.actually_run(ir, pass)?;
368        }
369        Ok(modified)
370    }
371
372    /// Run the `passes` and return true if the `passes` modify the initial `ir`.
373    /// The IR states are printed and verified according to the options provided.
374    pub fn run_with_print_verify(
375        &mut self,
376        ir: &mut Context,
377        passes: &PassGroup,
378        print_opts: &PrintPassesOpts,
379        verify_opts: &VerifyPassesOpts,
380    ) -> Result<bool, IrError> {
381        // Empty IRs are result of compiling dependencies. We don't want to print those.
382        fn ir_is_empty(ir: &Context) -> bool {
383            ir.functions.is_empty()
384                && ir.blocks.is_empty()
385                && ir.values.is_empty()
386                && ir.local_vars.is_empty()
387        }
388
389        fn print_ir_after_pass(ir: &Context, pass: &Pass) {
390            if !ir_is_empty(ir) {
391                println!("// IR: [{}] {}", pass.name, pass.descr);
392                println!("{ir}");
393            }
394        }
395
396        fn print_initial_or_final_ir(ir: &Context, initial_or_final: &'static str) {
397            if !ir_is_empty(ir) {
398                println!("// IR: {initial_or_final}");
399                println!("{ir}");
400            }
401        }
402
403        if print_opts.initial {
404            print_initial_or_final_ir(ir, "Initial");
405        }
406
407        if verify_opts.initial {
408            ir.verify()?;
409        }
410
411        let mut modified = false;
412        for pass in passes.flatten_pass_group() {
413            let modified_in_pass = self.actually_run(ir, pass)?;
414
415            if print_opts.passes.contains(pass) && (!print_opts.modified_only || modified_in_pass) {
416                print_ir_after_pass(ir, self.lookup_registered_pass(pass).unwrap());
417            }
418
419            modified |= modified_in_pass;
420            if verify_opts.passes.contains(pass) && (!verify_opts.modified_only || modified_in_pass)
421            {
422                ir.verify()?;
423            }
424        }
425
426        if print_opts.r#final {
427            print_initial_or_final_ir(ir, "Final");
428        }
429
430        if verify_opts.r#final {
431            ir.verify()?;
432        }
433
434        Ok(modified)
435    }
436
437    /// Get reference to a registered pass.
438    pub fn lookup_registered_pass(&self, name: &str) -> Option<&Pass> {
439        self.passes.get(name)
440    }
441
442    pub fn help_text(&self) -> String {
443        let summary = self
444            .passes
445            .iter()
446            .map(|(name, pass)| format!("  {name:16} - {}", pass.descr))
447            .collect::<Vec<_>>()
448            .join("\n");
449
450        format!("Valid pass names are:\n\n{summary}",)
451    }
452}
453
454/// A group of passes.
455/// Can contain sub-groups.
456#[derive(Default)]
457pub struct PassGroup(Vec<PassOrGroup>);
458
459/// An individual pass, or a group (with possible subgroup) of passes.
460pub enum PassOrGroup {
461    Pass(&'static str),
462    Group(PassGroup),
463}
464
465impl PassGroup {
466    // Flatten a group of passes into an ordered list.
467    fn flatten_pass_group(&self) -> Vec<&'static str> {
468        let mut output = Vec::<&str>::new();
469        fn inner(output: &mut Vec<&str>, input: &PassGroup) {
470            for pass_or_group in &input.0 {
471                match pass_or_group {
472                    PassOrGroup::Pass(pass) => output.push(pass),
473                    PassOrGroup::Group(pg) => inner(output, pg),
474                }
475            }
476        }
477        inner(&mut output, self);
478        output
479    }
480
481    /// Append a pass to this group.
482    pub fn append_pass(&mut self, pass: &'static str) {
483        self.0.push(PassOrGroup::Pass(pass));
484    }
485
486    /// Append a pass group.
487    pub fn append_group(&mut self, group: PassGroup) {
488        self.0.push(PassOrGroup::Group(group));
489    }
490}
491
492/// A convenience utility to register known passes.
493pub fn register_known_passes(pm: &mut PassManager) {
494    // Analysis passes.
495    pm.register(create_postorder_pass());
496    pm.register(create_dominators_pass());
497    pm.register(create_dom_fronts_pass());
498    pm.register(create_escaped_symbols_pass());
499    pm.register(create_module_printer_pass());
500    pm.register(create_module_verifier_pass());
501    // Optimization passes.
502    pm.register(create_arg_pointee_mutability_tagger_pass());
503    pm.register(create_fn_dedup_release_profile_pass());
504    pm.register(create_fn_dedup_debug_profile_pass());
505    pm.register(create_mem2reg_pass());
506    pm.register(create_sroa_pass());
507    pm.register(create_fn_inline_pass());
508    pm.register(create_const_folding_pass());
509    pm.register(create_ccp_pass());
510    pm.register(create_simplify_cfg_pass());
511    pm.register(create_globals_dce_pass());
512    pm.register(create_dce_pass());
513    pm.register(create_cse_pass());
514    pm.register(create_arg_demotion_pass());
515    pm.register(create_const_demotion_pass());
516    pm.register(create_ret_demotion_pass());
517    pm.register(create_misc_demotion_pass());
518    pm.register(create_memcpyopt_pass());
519    pm.register(create_memcpyprop_reverse_pass());
520}
521
522pub fn create_o1_pass_group() -> PassGroup {
523    // Create a create_ccp_passo specify which passes we want to run now.
524    let mut o1 = PassGroup::default();
525    // Configure to run our passes.
526    o1.append_pass(MEM2REG_NAME);
527    o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
528    o1.append_pass(FN_INLINE_NAME);
529    o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
530    o1.append_pass(SIMPLIFY_CFG_NAME);
531    o1.append_pass(GLOBALS_DCE_NAME);
532    o1.append_pass(DCE_NAME);
533    o1.append_pass(FN_INLINE_NAME);
534    o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
535    o1.append_pass(CCP_NAME);
536    o1.append_pass(CONST_FOLDING_NAME);
537    o1.append_pass(SIMPLIFY_CFG_NAME);
538    o1.append_pass(CSE_NAME);
539    o1.append_pass(CONST_FOLDING_NAME);
540    o1.append_pass(SIMPLIFY_CFG_NAME);
541    o1.append_pass(GLOBALS_DCE_NAME);
542    o1.append_pass(DCE_NAME);
543    o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
544
545    o1
546}
547
548/// Utility to insert a pass after every pass in the given group `pg`.
549/// It preserves the `pg` group's structure. This means if `pg` has subgroups
550/// and those have subgroups, the resulting [PassGroup] will have the
551/// same subgroups, but with the `pass` inserted after every pass in every
552/// subgroup, as well as all passes outside of any groups.
553pub fn insert_after_each(pg: PassGroup, pass: &'static str) -> PassGroup {
554    fn insert_after_each_rec(pg: PassGroup, pass: &'static str) -> Vec<PassOrGroup> {
555        pg.0.into_iter()
556            .flat_map(|p_o_g| match p_o_g {
557                PassOrGroup::Group(group) => vec![PassOrGroup::Group(PassGroup(
558                    insert_after_each_rec(group, pass),
559                ))],
560                PassOrGroup::Pass(_) => vec![p_o_g, PassOrGroup::Pass(pass)],
561            })
562            .collect()
563    }
564
565    PassGroup(insert_after_each_rec(pg, pass))
566}