sway_ir/
pass_manager.rs

1use crate::{
2    create_arg_demotion_pass, create_arg_pointee_mutability_tagger_pass, create_ccp_pass,
3    create_const_demotion_pass, create_const_folding_pass, create_cse_pass, create_dce_pass,
4    create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass,
5    create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass,
6    create_fn_inline_pass, create_globals_dce_pass, create_mem2reg_pass, create_memcpyopt_pass,
7    create_misc_demotion_pass, create_module_printer_pass, create_module_verifier_pass,
8    create_postorder_pass, create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass,
9    Context, Function, IrError, Module, ARG_DEMOTION_NAME, ARG_POINTEE_MUTABILITY_TAGGER_NAME,
10    CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, CSE_NAME, DCE_NAME,
11    FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, GLOBALS_DCE_NAME,
12    MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME,
13    SROA_NAME,
14};
15use downcast_rs::{impl_downcast, Downcast};
16use rustc_hash::FxHashMap;
17use std::{
18    any::{type_name, TypeId},
19    collections::{hash_map, HashSet},
20};
21
22/// Result of an analysis. Specific result must be downcasted to.
23pub trait AnalysisResultT: Downcast {}
24impl_downcast!(AnalysisResultT);
25pub type AnalysisResult = Box<dyn AnalysisResultT>;
26
27/// Program scope over which a pass executes.
28pub trait PassScope {
29    fn get_arena_idx(&self) -> slotmap::DefaultKey;
30}
31impl PassScope for Module {
32    fn get_arena_idx(&self) -> slotmap::DefaultKey {
33        self.0
34    }
35}
36impl PassScope for Function {
37    fn get_arena_idx(&self) -> slotmap::DefaultKey {
38        self.0
39    }
40}
41
42/// Is a pass an Analysis or a Transformation over the IR?
43#[derive(Clone)]
44pub enum PassMutability<S: PassScope> {
45    /// An analysis pass, producing an analysis result.
46    Analysis(fn(&Context, analyses: &AnalysisResults, S) -> Result<AnalysisResult, IrError>),
47    /// A pass over the IR that can possibly modify it.
48    Transform(fn(&mut Context, analyses: &AnalysisResults, S) -> Result<bool, IrError>),
49}
50
51/// A concrete version of [PassScope].
52#[derive(Clone)]
53pub enum ScopedPass {
54    ModulePass(PassMutability<Module>),
55    FunctionPass(PassMutability<Function>),
56}
57
58/// An analysis or transformation pass.
59pub struct Pass {
60    /// Pass identifier.
61    pub name: &'static str,
62    /// A short description.
63    pub descr: &'static str,
64    /// Other passes that this pass depends on.
65    pub deps: Vec<&'static str>,
66    /// The executor.
67    pub runner: ScopedPass,
68}
69
70impl Pass {
71    pub fn is_analysis(&self) -> bool {
72        match &self.runner {
73            ScopedPass::ModulePass(pm) => matches!(pm, PassMutability::Analysis(_)),
74            ScopedPass::FunctionPass(pm) => matches!(pm, PassMutability::Analysis(_)),
75        }
76    }
77
78    pub fn is_transform(&self) -> bool {
79        !self.is_analysis()
80    }
81
82    pub fn is_module_pass(&self) -> bool {
83        matches!(self.runner, ScopedPass::ModulePass(_))
84    }
85
86    pub fn is_function_pass(&self) -> bool {
87        matches!(self.runner, ScopedPass::FunctionPass(_))
88    }
89}
90
91#[derive(Default)]
92pub struct AnalysisResults {
93    // Hash from (AnalysisResultT, (PassScope, Scope Identity)) to an actual result.
94    results: FxHashMap<(TypeId, (TypeId, slotmap::DefaultKey)), AnalysisResult>,
95    name_typeid_map: FxHashMap<&'static str, TypeId>,
96}
97
98impl AnalysisResults {
99    /// Get the results of an analysis.
100    /// Example analyses.get_analysis_result::<DomTreeAnalysis>(foo).
101    pub fn get_analysis_result<T: AnalysisResultT, S: PassScope + 'static>(&self, scope: S) -> &T {
102        self.results
103            .get(&(
104                TypeId::of::<T>(),
105                (TypeId::of::<S>(), scope.get_arena_idx()),
106            ))
107            .unwrap_or_else(|| {
108                panic!(
109                    "Internal error. Analysis result {} unavailable for {} with idx {:?}",
110                    type_name::<T>(),
111                    type_name::<S>(),
112                    scope.get_arena_idx()
113                )
114            })
115            .downcast_ref()
116            .expect("AnalysisResult: Incorrect type")
117    }
118
119    /// Is an analysis result available at the given scope?
120    fn is_analysis_result_available<S: PassScope + 'static>(
121        &self,
122        name: &'static str,
123        scope: S,
124    ) -> bool {
125        self.name_typeid_map
126            .get(name)
127            .and_then(|result_typeid| {
128                self.results
129                    .get(&(*result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())))
130            })
131            .is_some()
132    }
133
134    /// Add a new result.
135    fn add_result<S: PassScope + 'static>(
136        &mut self,
137        name: &'static str,
138        scope: S,
139        result: AnalysisResult,
140    ) {
141        let result_typeid = (*result).type_id();
142        self.results.insert(
143            (result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())),
144            result,
145        );
146        self.name_typeid_map.insert(name, result_typeid);
147    }
148
149    /// Invalidate all results at a given scope.
150    fn invalidate_all_results_at_scope<S: PassScope + 'static>(&mut self, scope: S) {
151        self.results
152            .retain(|(_result_typeid, (scope_typeid, scope_idx)), _v| {
153                (*scope_typeid, *scope_idx) != (TypeId::of::<S>(), scope.get_arena_idx())
154            });
155    }
156}
157
158/// Options for printing [Pass]es in case of running them with printing requested.
159///
160/// Note that states of IR can always be printed by injecting the module printer pass
161/// and just running the passes. That approach however offers less control over the
162/// printing. E.g., requiring the printing to happen only if the previous passes
163/// modified the IR cannot be done by simply injecting a module printer.
164#[derive(Debug)]
165pub struct PrintPassesOpts {
166    pub initial: bool,
167    pub r#final: bool,
168    pub modified_only: bool,
169    pub passes: HashSet<String>,
170}
171
172#[derive(Default)]
173pub struct PassManager {
174    passes: FxHashMap<&'static str, Pass>,
175    analyses: AnalysisResults,
176}
177
178impl PassManager {
179    pub const OPTIMIZATION_PASSES: [&'static str; 14] = [
180        FN_INLINE_NAME,
181        SIMPLIFY_CFG_NAME,
182        SROA_NAME,
183        DCE_NAME,
184        GLOBALS_DCE_NAME,
185        FN_DEDUP_RELEASE_PROFILE_NAME,
186        FN_DEDUP_DEBUG_PROFILE_NAME,
187        MEM2REG_NAME,
188        MEMCPYOPT_NAME,
189        CONST_FOLDING_NAME,
190        ARG_DEMOTION_NAME,
191        CONST_DEMOTION_NAME,
192        RET_DEMOTION_NAME,
193        MISC_DEMOTION_NAME,
194    ];
195
196    /// Register a pass. Should be called only once for each pass.
197    pub fn register(&mut self, pass: Pass) -> &'static str {
198        for dep in &pass.deps {
199            if let Some(dep_t) = self.lookup_registered_pass(dep) {
200                if dep_t.is_transform() {
201                    panic!(
202                        "Pass {} cannot depend on a transformation pass {}",
203                        pass.name, dep
204                    );
205                }
206                if pass.is_function_pass() && dep_t.is_module_pass() {
207                    panic!(
208                        "Function pass {} cannot depend on module pass {}",
209                        pass.name, dep
210                    );
211                }
212            } else {
213                panic!(
214                    "Pass {} depends on a (yet) unregistered pass {}",
215                    pass.name, dep
216                );
217            }
218        }
219        let pass_name = pass.name;
220        match self.passes.entry(pass.name) {
221            hash_map::Entry::Occupied(_) => {
222                panic!("Trying to register an already registered pass");
223            }
224            hash_map::Entry::Vacant(entry) => {
225                entry.insert(pass);
226            }
227        }
228        pass_name
229    }
230
231    fn actually_run(&mut self, ir: &mut Context, pass: &'static str) -> Result<bool, IrError> {
232        let mut modified = false;
233
234        fn run_module_pass(
235            pm: &mut PassManager,
236            ir: &mut Context,
237            pass: &'static str,
238            module: Module,
239        ) -> Result<bool, IrError> {
240            let mut modified = false;
241            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
242            for dep in pass_t.deps.clone() {
243                let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
244                // If pass registration allows transformations as dependents, we could remove this I guess.
245                assert!(dep_t.is_analysis());
246                match dep_t.runner {
247                    ScopedPass::ModulePass(_) => {
248                        if !pm.analyses.is_analysis_result_available(dep, module) {
249                            run_module_pass(pm, ir, dep, module)?;
250                        }
251                    }
252                    ScopedPass::FunctionPass(_) => {
253                        for f in module.function_iter(ir) {
254                            if !pm.analyses.is_analysis_result_available(dep, f) {
255                                run_function_pass(pm, ir, dep, f)?;
256                            }
257                        }
258                    }
259                }
260            }
261
262            // Get the pass again to satisfy the borrow checker.
263            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
264            let ScopedPass::ModulePass(mp) = pass_t.runner.clone() else {
265                panic!("Expected a module pass");
266            };
267            match mp {
268                PassMutability::Analysis(analysis) => {
269                    let result = analysis(ir, &pm.analyses, module)?;
270                    pm.analyses.add_result(pass, module, result);
271                }
272                PassMutability::Transform(transform) => {
273                    if transform(ir, &pm.analyses, module)? {
274                        pm.analyses.invalidate_all_results_at_scope(module);
275                        for f in module.function_iter(ir) {
276                            pm.analyses.invalidate_all_results_at_scope(f);
277                        }
278                        modified = true;
279                    }
280                }
281            }
282
283            Ok(modified)
284        }
285
286        fn run_function_pass(
287            pm: &mut PassManager,
288            ir: &mut Context,
289            pass: &'static str,
290            function: Function,
291        ) -> Result<bool, IrError> {
292            let mut modified = false;
293            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
294            for dep in pass_t.deps.clone() {
295                let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
296                // If pass registration allows transformations as dependents, we could remove this I guess.
297                assert!(dep_t.is_analysis());
298                match dep_t.runner {
299                    ScopedPass::ModulePass(_) => {
300                        panic!("Function pass {pass} cannot depend on module pass {dep}")
301                    }
302                    ScopedPass::FunctionPass(_) => {
303                        if !pm.analyses.is_analysis_result_available(dep, function) {
304                            run_function_pass(pm, ir, dep, function)?;
305                        };
306                    }
307                }
308            }
309
310            // Get the pass again to satisfy the borrow checker.
311            let pass_t = pm.passes.get(pass).expect("Unregistered pass");
312            let ScopedPass::FunctionPass(fp) = pass_t.runner.clone() else {
313                panic!("Expected a function pass");
314            };
315            match fp {
316                PassMutability::Analysis(analysis) => {
317                    let result = analysis(ir, &pm.analyses, function)?;
318                    pm.analyses.add_result(pass, function, result);
319                }
320                PassMutability::Transform(transform) => {
321                    if transform(ir, &pm.analyses, function)? {
322                        pm.analyses.invalidate_all_results_at_scope(function);
323                        modified = true;
324                    }
325                }
326            }
327
328            Ok(modified)
329        }
330
331        for m in ir.module_iter() {
332            let pass_t = self.passes.get(pass).expect("Unregistered pass");
333            let pass_runner = pass_t.runner.clone();
334            match pass_runner {
335                ScopedPass::ModulePass(_) => {
336                    modified |= run_module_pass(self, ir, pass, m)?;
337                }
338                ScopedPass::FunctionPass(_) => {
339                    for f in m.function_iter(ir) {
340                        modified |= run_function_pass(self, ir, pass, f)?;
341                    }
342                }
343            }
344        }
345        Ok(modified)
346    }
347
348    /// Run the `passes` and return true if the `passes` modify the initial `ir`.
349    pub fn run(&mut self, ir: &mut Context, passes: &PassGroup) -> Result<bool, IrError> {
350        let mut modified = false;
351        for pass in passes.flatten_pass_group() {
352            modified |= self.actually_run(ir, pass)?;
353        }
354        Ok(modified)
355    }
356
357    /// Run the `passes` and return true if the `passes` modify the initial `ir`.
358    /// The IR states are printed according to the printing options provided in `print_opts`.
359    pub fn run_with_print(
360        &mut self,
361        ir: &mut Context,
362        passes: &PassGroup,
363        print_opts: &PrintPassesOpts,
364    ) -> Result<bool, IrError> {
365        // Empty IRs are result of compiling dependencies. We don't want to print those.
366        fn ir_is_empty(ir: &Context) -> bool {
367            ir.functions.is_empty()
368                && ir.blocks.is_empty()
369                && ir.values.is_empty()
370                && ir.local_vars.is_empty()
371        }
372
373        fn print_ir_after_pass(ir: &Context, pass: &Pass) {
374            if !ir_is_empty(ir) {
375                println!("// IR: [{}] {}", pass.name, pass.descr);
376                println!("{ir}");
377            }
378        }
379
380        fn print_initial_or_final_ir(ir: &Context, initial_or_final: &'static str) {
381            if !ir_is_empty(ir) {
382                println!("// IR: {initial_or_final}");
383                println!("{ir}");
384            }
385        }
386
387        if print_opts.initial {
388            print_initial_or_final_ir(ir, "Initial");
389        }
390
391        let mut modified = false;
392        for pass in passes.flatten_pass_group() {
393            let modified_in_pass = self.actually_run(ir, pass)?;
394
395            if print_opts.passes.contains(pass) && (!print_opts.modified_only || modified_in_pass) {
396                print_ir_after_pass(ir, self.lookup_registered_pass(pass).unwrap());
397            }
398
399            modified |= modified_in_pass;
400        }
401
402        if print_opts.r#final {
403            print_initial_or_final_ir(ir, "Final");
404        }
405
406        Ok(modified)
407    }
408
409    /// Get reference to a registered pass.
410    pub fn lookup_registered_pass(&self, name: &str) -> Option<&Pass> {
411        self.passes.get(name)
412    }
413
414    pub fn help_text(&self) -> String {
415        let summary = self
416            .passes
417            .iter()
418            .map(|(name, pass)| format!("  {name:16} - {}", pass.descr))
419            .collect::<Vec<_>>()
420            .join("\n");
421
422        format!("Valid pass names are:\n\n{summary}",)
423    }
424}
425
426/// A group of passes.
427/// Can contain sub-groups.
428#[derive(Default)]
429pub struct PassGroup(Vec<PassOrGroup>);
430
431/// An individual pass, or a group (with possible subgroup) of passes.
432pub enum PassOrGroup {
433    Pass(&'static str),
434    Group(PassGroup),
435}
436
437impl PassGroup {
438    // Flatten a group of passes into an ordered list.
439    fn flatten_pass_group(&self) -> Vec<&'static str> {
440        let mut output = Vec::<&str>::new();
441        fn inner(output: &mut Vec<&str>, input: &PassGroup) {
442            for pass_or_group in &input.0 {
443                match pass_or_group {
444                    PassOrGroup::Pass(pass) => output.push(pass),
445                    PassOrGroup::Group(pg) => inner(output, pg),
446                }
447            }
448        }
449        inner(&mut output, self);
450        output
451    }
452
453    /// Append a pass to this group.
454    pub fn append_pass(&mut self, pass: &'static str) {
455        self.0.push(PassOrGroup::Pass(pass));
456    }
457
458    /// Append a pass group.
459    pub fn append_group(&mut self, group: PassGroup) {
460        self.0.push(PassOrGroup::Group(group));
461    }
462}
463
464/// A convenience utility to register known passes.
465pub fn register_known_passes(pm: &mut PassManager) {
466    // Analysis passes.
467    pm.register(create_postorder_pass());
468    pm.register(create_dominators_pass());
469    pm.register(create_dom_fronts_pass());
470    pm.register(create_escaped_symbols_pass());
471    pm.register(create_module_printer_pass());
472    pm.register(create_module_verifier_pass());
473    // Optimization passes.
474    pm.register(create_arg_pointee_mutability_tagger_pass());
475    pm.register(create_fn_dedup_release_profile_pass());
476    pm.register(create_fn_dedup_debug_profile_pass());
477    pm.register(create_mem2reg_pass());
478    pm.register(create_sroa_pass());
479    pm.register(create_fn_inline_pass());
480    pm.register(create_const_folding_pass());
481    pm.register(create_ccp_pass());
482    pm.register(create_simplify_cfg_pass());
483    pm.register(create_globals_dce_pass());
484    pm.register(create_dce_pass());
485    pm.register(create_cse_pass());
486    pm.register(create_arg_demotion_pass());
487    pm.register(create_const_demotion_pass());
488    pm.register(create_ret_demotion_pass());
489    pm.register(create_misc_demotion_pass());
490    pm.register(create_memcpyopt_pass());
491}
492
493pub fn create_o1_pass_group() -> PassGroup {
494    // Create a create_ccp_passo specify which passes we want to run now.
495    let mut o1 = PassGroup::default();
496    // Configure to run our passes.
497    o1.append_pass(MEM2REG_NAME);
498    o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
499    o1.append_pass(FN_INLINE_NAME);
500    o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
501    o1.append_pass(SIMPLIFY_CFG_NAME);
502    o1.append_pass(GLOBALS_DCE_NAME);
503    o1.append_pass(DCE_NAME);
504    o1.append_pass(FN_INLINE_NAME);
505    o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
506    o1.append_pass(CCP_NAME);
507    o1.append_pass(CONST_FOLDING_NAME);
508    o1.append_pass(SIMPLIFY_CFG_NAME);
509    o1.append_pass(CSE_NAME);
510    o1.append_pass(CONST_FOLDING_NAME);
511    o1.append_pass(SIMPLIFY_CFG_NAME);
512    o1.append_pass(GLOBALS_DCE_NAME);
513    o1.append_pass(DCE_NAME);
514    o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
515
516    o1
517}
518
519/// Utility to insert a pass after every pass in the given group `pg`.
520/// It preserves the `pg` group's structure. This means if `pg` has subgroups
521/// and those have subgroups, the resulting [PassGroup] will have the
522/// same subgroups, but with the `pass` inserted after every pass in every
523/// subgroup, as well as all passes outside of any groups.
524pub fn insert_after_each(pg: PassGroup, pass: &'static str) -> PassGroup {
525    fn insert_after_each_rec(pg: PassGroup, pass: &'static str) -> Vec<PassOrGroup> {
526        pg.0.into_iter()
527            .flat_map(|p_o_g| match p_o_g {
528                PassOrGroup::Group(group) => vec![PassOrGroup::Group(PassGroup(
529                    insert_after_each_rec(group, pass),
530                ))],
531                PassOrGroup::Pass(_) => vec![p_o_g, PassOrGroup::Pass(pass)],
532            })
533            .collect()
534    }
535
536    PassGroup(insert_after_each_rec(pg, pass))
537}