1use crate::{
2 create_arg_demotion_pass, create_arg_pointee_mutability_tagger_pass, create_ccp_pass,
3 create_const_demotion_pass, create_const_folding_pass, create_cse_pass, create_dce_pass,
4 create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass,
5 create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass,
6 create_fn_inline_pass, create_globals_dce_pass, create_mem2reg_pass, create_memcpyopt_pass,
7 create_misc_demotion_pass, create_module_printer_pass, create_module_verifier_pass,
8 create_postorder_pass, create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass,
9 Context, Function, IrError, Module, ARG_DEMOTION_NAME, ARG_POINTEE_MUTABILITY_TAGGER_NAME,
10 CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, CSE_NAME, DCE_NAME,
11 FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, GLOBALS_DCE_NAME,
12 MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME,
13 SROA_NAME,
14};
15use downcast_rs::{impl_downcast, Downcast};
16use rustc_hash::FxHashMap;
17use std::{
18 any::{type_name, TypeId},
19 collections::{hash_map, HashSet},
20};
21
22pub trait AnalysisResultT: Downcast {}
24impl_downcast!(AnalysisResultT);
25pub type AnalysisResult = Box<dyn AnalysisResultT>;
26
27pub trait PassScope {
29 fn get_arena_idx(&self) -> slotmap::DefaultKey;
30}
31impl PassScope for Module {
32 fn get_arena_idx(&self) -> slotmap::DefaultKey {
33 self.0
34 }
35}
36impl PassScope for Function {
37 fn get_arena_idx(&self) -> slotmap::DefaultKey {
38 self.0
39 }
40}
41
42#[derive(Clone)]
44pub enum PassMutability<S: PassScope> {
45 Analysis(fn(&Context, analyses: &AnalysisResults, S) -> Result<AnalysisResult, IrError>),
47 Transform(fn(&mut Context, analyses: &AnalysisResults, S) -> Result<bool, IrError>),
49}
50
51#[derive(Clone)]
53pub enum ScopedPass {
54 ModulePass(PassMutability<Module>),
55 FunctionPass(PassMutability<Function>),
56}
57
58pub struct Pass {
60 pub name: &'static str,
62 pub descr: &'static str,
64 pub deps: Vec<&'static str>,
66 pub runner: ScopedPass,
68}
69
70impl Pass {
71 pub fn is_analysis(&self) -> bool {
72 match &self.runner {
73 ScopedPass::ModulePass(pm) => matches!(pm, PassMutability::Analysis(_)),
74 ScopedPass::FunctionPass(pm) => matches!(pm, PassMutability::Analysis(_)),
75 }
76 }
77
78 pub fn is_transform(&self) -> bool {
79 !self.is_analysis()
80 }
81
82 pub fn is_module_pass(&self) -> bool {
83 matches!(self.runner, ScopedPass::ModulePass(_))
84 }
85
86 pub fn is_function_pass(&self) -> bool {
87 matches!(self.runner, ScopedPass::FunctionPass(_))
88 }
89}
90
91#[derive(Default)]
92pub struct AnalysisResults {
93 results: FxHashMap<(TypeId, (TypeId, slotmap::DefaultKey)), AnalysisResult>,
95 name_typeid_map: FxHashMap<&'static str, TypeId>,
96}
97
98impl AnalysisResults {
99 pub fn get_analysis_result<T: AnalysisResultT, S: PassScope + 'static>(&self, scope: S) -> &T {
102 self.results
103 .get(&(
104 TypeId::of::<T>(),
105 (TypeId::of::<S>(), scope.get_arena_idx()),
106 ))
107 .unwrap_or_else(|| {
108 panic!(
109 "Internal error. Analysis result {} unavailable for {} with idx {:?}",
110 type_name::<T>(),
111 type_name::<S>(),
112 scope.get_arena_idx()
113 )
114 })
115 .downcast_ref()
116 .expect("AnalysisResult: Incorrect type")
117 }
118
119 fn is_analysis_result_available<S: PassScope + 'static>(
121 &self,
122 name: &'static str,
123 scope: S,
124 ) -> bool {
125 self.name_typeid_map
126 .get(name)
127 .and_then(|result_typeid| {
128 self.results
129 .get(&(*result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())))
130 })
131 .is_some()
132 }
133
134 fn add_result<S: PassScope + 'static>(
136 &mut self,
137 name: &'static str,
138 scope: S,
139 result: AnalysisResult,
140 ) {
141 let result_typeid = (*result).type_id();
142 self.results.insert(
143 (result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())),
144 result,
145 );
146 self.name_typeid_map.insert(name, result_typeid);
147 }
148
149 fn invalidate_all_results_at_scope<S: PassScope + 'static>(&mut self, scope: S) {
151 self.results
152 .retain(|(_result_typeid, (scope_typeid, scope_idx)), _v| {
153 (*scope_typeid, *scope_idx) != (TypeId::of::<S>(), scope.get_arena_idx())
154 });
155 }
156}
157
158#[derive(Debug)]
165pub struct PrintPassesOpts {
166 pub initial: bool,
167 pub r#final: bool,
168 pub modified_only: bool,
169 pub passes: HashSet<String>,
170}
171
172#[derive(Default)]
173pub struct PassManager {
174 passes: FxHashMap<&'static str, Pass>,
175 analyses: AnalysisResults,
176}
177
178impl PassManager {
179 pub const OPTIMIZATION_PASSES: [&'static str; 14] = [
180 FN_INLINE_NAME,
181 SIMPLIFY_CFG_NAME,
182 SROA_NAME,
183 DCE_NAME,
184 GLOBALS_DCE_NAME,
185 FN_DEDUP_RELEASE_PROFILE_NAME,
186 FN_DEDUP_DEBUG_PROFILE_NAME,
187 MEM2REG_NAME,
188 MEMCPYOPT_NAME,
189 CONST_FOLDING_NAME,
190 ARG_DEMOTION_NAME,
191 CONST_DEMOTION_NAME,
192 RET_DEMOTION_NAME,
193 MISC_DEMOTION_NAME,
194 ];
195
196 pub fn register(&mut self, pass: Pass) -> &'static str {
198 for dep in &pass.deps {
199 if let Some(dep_t) = self.lookup_registered_pass(dep) {
200 if dep_t.is_transform() {
201 panic!(
202 "Pass {} cannot depend on a transformation pass {}",
203 pass.name, dep
204 );
205 }
206 if pass.is_function_pass() && dep_t.is_module_pass() {
207 panic!(
208 "Function pass {} cannot depend on module pass {}",
209 pass.name, dep
210 );
211 }
212 } else {
213 panic!(
214 "Pass {} depends on a (yet) unregistered pass {}",
215 pass.name, dep
216 );
217 }
218 }
219 let pass_name = pass.name;
220 match self.passes.entry(pass.name) {
221 hash_map::Entry::Occupied(_) => {
222 panic!("Trying to register an already registered pass");
223 }
224 hash_map::Entry::Vacant(entry) => {
225 entry.insert(pass);
226 }
227 }
228 pass_name
229 }
230
231 fn actually_run(&mut self, ir: &mut Context, pass: &'static str) -> Result<bool, IrError> {
232 let mut modified = false;
233
234 fn run_module_pass(
235 pm: &mut PassManager,
236 ir: &mut Context,
237 pass: &'static str,
238 module: Module,
239 ) -> Result<bool, IrError> {
240 let mut modified = false;
241 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
242 for dep in pass_t.deps.clone() {
243 let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
244 assert!(dep_t.is_analysis());
246 match dep_t.runner {
247 ScopedPass::ModulePass(_) => {
248 if !pm.analyses.is_analysis_result_available(dep, module) {
249 run_module_pass(pm, ir, dep, module)?;
250 }
251 }
252 ScopedPass::FunctionPass(_) => {
253 for f in module.function_iter(ir) {
254 if !pm.analyses.is_analysis_result_available(dep, f) {
255 run_function_pass(pm, ir, dep, f)?;
256 }
257 }
258 }
259 }
260 }
261
262 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
264 let ScopedPass::ModulePass(mp) = pass_t.runner.clone() else {
265 panic!("Expected a module pass");
266 };
267 match mp {
268 PassMutability::Analysis(analysis) => {
269 let result = analysis(ir, &pm.analyses, module)?;
270 pm.analyses.add_result(pass, module, result);
271 }
272 PassMutability::Transform(transform) => {
273 if transform(ir, &pm.analyses, module)? {
274 pm.analyses.invalidate_all_results_at_scope(module);
275 for f in module.function_iter(ir) {
276 pm.analyses.invalidate_all_results_at_scope(f);
277 }
278 modified = true;
279 }
280 }
281 }
282
283 Ok(modified)
284 }
285
286 fn run_function_pass(
287 pm: &mut PassManager,
288 ir: &mut Context,
289 pass: &'static str,
290 function: Function,
291 ) -> Result<bool, IrError> {
292 let mut modified = false;
293 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
294 for dep in pass_t.deps.clone() {
295 let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
296 assert!(dep_t.is_analysis());
298 match dep_t.runner {
299 ScopedPass::ModulePass(_) => {
300 panic!("Function pass {pass} cannot depend on module pass {dep}")
301 }
302 ScopedPass::FunctionPass(_) => {
303 if !pm.analyses.is_analysis_result_available(dep, function) {
304 run_function_pass(pm, ir, dep, function)?;
305 };
306 }
307 }
308 }
309
310 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
312 let ScopedPass::FunctionPass(fp) = pass_t.runner.clone() else {
313 panic!("Expected a function pass");
314 };
315 match fp {
316 PassMutability::Analysis(analysis) => {
317 let result = analysis(ir, &pm.analyses, function)?;
318 pm.analyses.add_result(pass, function, result);
319 }
320 PassMutability::Transform(transform) => {
321 if transform(ir, &pm.analyses, function)? {
322 pm.analyses.invalidate_all_results_at_scope(function);
323 modified = true;
324 }
325 }
326 }
327
328 Ok(modified)
329 }
330
331 for m in ir.module_iter() {
332 let pass_t = self.passes.get(pass).expect("Unregistered pass");
333 let pass_runner = pass_t.runner.clone();
334 match pass_runner {
335 ScopedPass::ModulePass(_) => {
336 modified |= run_module_pass(self, ir, pass, m)?;
337 }
338 ScopedPass::FunctionPass(_) => {
339 for f in m.function_iter(ir) {
340 modified |= run_function_pass(self, ir, pass, f)?;
341 }
342 }
343 }
344 }
345 Ok(modified)
346 }
347
348 pub fn run(&mut self, ir: &mut Context, passes: &PassGroup) -> Result<bool, IrError> {
350 let mut modified = false;
351 for pass in passes.flatten_pass_group() {
352 modified |= self.actually_run(ir, pass)?;
353 }
354 Ok(modified)
355 }
356
357 pub fn run_with_print(
360 &mut self,
361 ir: &mut Context,
362 passes: &PassGroup,
363 print_opts: &PrintPassesOpts,
364 ) -> Result<bool, IrError> {
365 fn ir_is_empty(ir: &Context) -> bool {
367 ir.functions.is_empty()
368 && ir.blocks.is_empty()
369 && ir.values.is_empty()
370 && ir.local_vars.is_empty()
371 }
372
373 fn print_ir_after_pass(ir: &Context, pass: &Pass) {
374 if !ir_is_empty(ir) {
375 println!("// IR: [{}] {}", pass.name, pass.descr);
376 println!("{ir}");
377 }
378 }
379
380 fn print_initial_or_final_ir(ir: &Context, initial_or_final: &'static str) {
381 if !ir_is_empty(ir) {
382 println!("// IR: {initial_or_final}");
383 println!("{ir}");
384 }
385 }
386
387 if print_opts.initial {
388 print_initial_or_final_ir(ir, "Initial");
389 }
390
391 let mut modified = false;
392 for pass in passes.flatten_pass_group() {
393 let modified_in_pass = self.actually_run(ir, pass)?;
394
395 if print_opts.passes.contains(pass) && (!print_opts.modified_only || modified_in_pass) {
396 print_ir_after_pass(ir, self.lookup_registered_pass(pass).unwrap());
397 }
398
399 modified |= modified_in_pass;
400 }
401
402 if print_opts.r#final {
403 print_initial_or_final_ir(ir, "Final");
404 }
405
406 Ok(modified)
407 }
408
409 pub fn lookup_registered_pass(&self, name: &str) -> Option<&Pass> {
411 self.passes.get(name)
412 }
413
414 pub fn help_text(&self) -> String {
415 let summary = self
416 .passes
417 .iter()
418 .map(|(name, pass)| format!(" {name:16} - {}", pass.descr))
419 .collect::<Vec<_>>()
420 .join("\n");
421
422 format!("Valid pass names are:\n\n{summary}",)
423 }
424}
425
426#[derive(Default)]
429pub struct PassGroup(Vec<PassOrGroup>);
430
431pub enum PassOrGroup {
433 Pass(&'static str),
434 Group(PassGroup),
435}
436
437impl PassGroup {
438 fn flatten_pass_group(&self) -> Vec<&'static str> {
440 let mut output = Vec::<&str>::new();
441 fn inner(output: &mut Vec<&str>, input: &PassGroup) {
442 for pass_or_group in &input.0 {
443 match pass_or_group {
444 PassOrGroup::Pass(pass) => output.push(pass),
445 PassOrGroup::Group(pg) => inner(output, pg),
446 }
447 }
448 }
449 inner(&mut output, self);
450 output
451 }
452
453 pub fn append_pass(&mut self, pass: &'static str) {
455 self.0.push(PassOrGroup::Pass(pass));
456 }
457
458 pub fn append_group(&mut self, group: PassGroup) {
460 self.0.push(PassOrGroup::Group(group));
461 }
462}
463
464pub fn register_known_passes(pm: &mut PassManager) {
466 pm.register(create_postorder_pass());
468 pm.register(create_dominators_pass());
469 pm.register(create_dom_fronts_pass());
470 pm.register(create_escaped_symbols_pass());
471 pm.register(create_module_printer_pass());
472 pm.register(create_module_verifier_pass());
473 pm.register(create_arg_pointee_mutability_tagger_pass());
475 pm.register(create_fn_dedup_release_profile_pass());
476 pm.register(create_fn_dedup_debug_profile_pass());
477 pm.register(create_mem2reg_pass());
478 pm.register(create_sroa_pass());
479 pm.register(create_fn_inline_pass());
480 pm.register(create_const_folding_pass());
481 pm.register(create_ccp_pass());
482 pm.register(create_simplify_cfg_pass());
483 pm.register(create_globals_dce_pass());
484 pm.register(create_dce_pass());
485 pm.register(create_cse_pass());
486 pm.register(create_arg_demotion_pass());
487 pm.register(create_const_demotion_pass());
488 pm.register(create_ret_demotion_pass());
489 pm.register(create_misc_demotion_pass());
490 pm.register(create_memcpyopt_pass());
491}
492
493pub fn create_o1_pass_group() -> PassGroup {
494 let mut o1 = PassGroup::default();
496 o1.append_pass(MEM2REG_NAME);
498 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
499 o1.append_pass(FN_INLINE_NAME);
500 o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
501 o1.append_pass(SIMPLIFY_CFG_NAME);
502 o1.append_pass(GLOBALS_DCE_NAME);
503 o1.append_pass(DCE_NAME);
504 o1.append_pass(FN_INLINE_NAME);
505 o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
506 o1.append_pass(CCP_NAME);
507 o1.append_pass(CONST_FOLDING_NAME);
508 o1.append_pass(SIMPLIFY_CFG_NAME);
509 o1.append_pass(CSE_NAME);
510 o1.append_pass(CONST_FOLDING_NAME);
511 o1.append_pass(SIMPLIFY_CFG_NAME);
512 o1.append_pass(GLOBALS_DCE_NAME);
513 o1.append_pass(DCE_NAME);
514 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
515
516 o1
517}
518
519pub fn insert_after_each(pg: PassGroup, pass: &'static str) -> PassGroup {
525 fn insert_after_each_rec(pg: PassGroup, pass: &'static str) -> Vec<PassOrGroup> {
526 pg.0.into_iter()
527 .flat_map(|p_o_g| match p_o_g {
528 PassOrGroup::Group(group) => vec![PassOrGroup::Group(PassGroup(
529 insert_after_each_rec(group, pass),
530 ))],
531 PassOrGroup::Pass(_) => vec![p_o_g, PassOrGroup::Pass(pass)],
532 })
533 .collect()
534 }
535
536 PassGroup(insert_after_each_rec(pg, pass))
537}