1use crate::{
2 create_arg_demotion_pass, create_arg_pointee_mutability_tagger_pass, create_ccp_pass,
3 create_const_demotion_pass, create_const_folding_pass, create_cse_pass, create_dce_pass,
4 create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass,
5 create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass,
6 create_fn_inline_pass, create_globals_dce_pass, create_init_aggr_lowering_pass,
7 create_mem2reg_pass, create_memcpyopt_pass, create_memcpyprop_reverse_pass,
8 create_misc_demotion_pass, create_module_printer_pass, create_module_verifier_pass,
9 create_postorder_pass, create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass,
10 Context, Function, IrError, Module, ARG_DEMOTION_NAME, ARG_POINTEE_MUTABILITY_TAGGER_NAME,
11 CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, CSE_NAME, DCE_NAME,
12 FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, GLOBALS_DCE_NAME,
13 INIT_AGGR_LOWERING_NAME, MEM2REG_NAME, MEMCPYOPT_NAME, MEMCPYPROP_REVERSE_NAME,
14 MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME, SROA_NAME,
15};
16use downcast_rs::{impl_downcast, Downcast};
17use rustc_hash::FxHashMap;
18use std::{
19 any::{type_name, TypeId},
20 collections::{hash_map, HashSet},
21};
22
23pub trait AnalysisResultT: Downcast {}
25impl_downcast!(AnalysisResultT);
26pub type AnalysisResult = Box<dyn AnalysisResultT>;
27
28pub trait PassScope {
30 fn get_arena_idx(&self) -> slotmap::DefaultKey;
31}
32impl PassScope for Module {
33 fn get_arena_idx(&self) -> slotmap::DefaultKey {
34 self.0
35 }
36}
37impl PassScope for Function {
38 fn get_arena_idx(&self) -> slotmap::DefaultKey {
39 self.0
40 }
41}
42
43#[derive(Clone)]
45pub enum PassMutability<S: PassScope> {
46 Analysis(fn(&Context, analyses: &AnalysisResults, S) -> Result<AnalysisResult, IrError>),
48 Transform(fn(&mut Context, analyses: &AnalysisResults, S) -> Result<bool, IrError>),
50}
51
52#[derive(Clone)]
54pub enum ScopedPass {
55 ModulePass(PassMutability<Module>),
56 FunctionPass(PassMutability<Function>),
57}
58
59pub struct Pass {
61 pub name: &'static str,
63 pub descr: &'static str,
65 pub deps: Vec<&'static str>,
67 pub runner: ScopedPass,
69}
70
71impl Pass {
72 pub fn is_analysis(&self) -> bool {
73 match &self.runner {
74 ScopedPass::ModulePass(pm) => matches!(pm, PassMutability::Analysis(_)),
75 ScopedPass::FunctionPass(pm) => matches!(pm, PassMutability::Analysis(_)),
76 }
77 }
78
79 pub fn is_transform(&self) -> bool {
80 !self.is_analysis()
81 }
82
83 pub fn is_module_pass(&self) -> bool {
84 matches!(self.runner, ScopedPass::ModulePass(_))
85 }
86
87 pub fn is_function_pass(&self) -> bool {
88 matches!(self.runner, ScopedPass::FunctionPass(_))
89 }
90}
91
92#[derive(Default)]
93pub struct AnalysisResults {
94 results: FxHashMap<(TypeId, (TypeId, slotmap::DefaultKey)), AnalysisResult>,
96 name_typeid_map: FxHashMap<&'static str, TypeId>,
97}
98
99impl AnalysisResults {
100 pub fn get_analysis_result<T: AnalysisResultT, S: PassScope + 'static>(&self, scope: S) -> &T {
103 self.results
104 .get(&(
105 TypeId::of::<T>(),
106 (TypeId::of::<S>(), scope.get_arena_idx()),
107 ))
108 .unwrap_or_else(|| {
109 panic!(
110 "Internal error. Analysis result {} unavailable for {} with idx {:?}",
111 type_name::<T>(),
112 type_name::<S>(),
113 scope.get_arena_idx()
114 )
115 })
116 .downcast_ref()
117 .expect("AnalysisResult: Incorrect type")
118 }
119
120 fn is_analysis_result_available<S: PassScope + 'static>(
122 &self,
123 name: &'static str,
124 scope: S,
125 ) -> bool {
126 self.name_typeid_map
127 .get(name)
128 .and_then(|result_typeid| {
129 self.results
130 .get(&(*result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())))
131 })
132 .is_some()
133 }
134
135 fn add_result<S: PassScope + 'static>(
137 &mut self,
138 name: &'static str,
139 scope: S,
140 result: AnalysisResult,
141 ) {
142 let result_typeid = (*result).type_id();
143 self.results.insert(
144 (result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())),
145 result,
146 );
147 self.name_typeid_map.insert(name, result_typeid);
148 }
149
150 fn invalidate_all_results_at_scope<S: PassScope + 'static>(&mut self, scope: S) {
152 self.results
153 .retain(|(_result_typeid, (scope_typeid, scope_idx)), _v| {
154 (*scope_typeid, *scope_idx) != (TypeId::of::<S>(), scope.get_arena_idx())
155 });
156 }
157}
158
159#[derive(Debug)]
166pub struct PrintPassesOpts {
167 pub initial: bool,
168 pub r#final: bool,
169 pub modified_only: bool,
170 pub passes: HashSet<String>,
171}
172
173#[derive(Debug)]
180pub struct VerifyPassesOpts {
181 pub initial: bool,
182 pub r#final: bool,
183 pub modified_only: bool,
184 pub passes: HashSet<String>,
185}
186
187#[derive(Default)]
188pub struct PassManager {
189 passes: FxHashMap<&'static str, Pass>,
190 analyses: AnalysisResults,
191}
192
193impl PassManager {
194 pub const OPTIMIZATION_PASSES: [&'static str; 16] = [
195 FN_INLINE_NAME,
196 SIMPLIFY_CFG_NAME,
197 SROA_NAME,
198 DCE_NAME,
199 GLOBALS_DCE_NAME,
200 FN_DEDUP_RELEASE_PROFILE_NAME,
201 FN_DEDUP_DEBUG_PROFILE_NAME,
202 MEM2REG_NAME,
203 MEMCPYOPT_NAME,
204 MEMCPYPROP_REVERSE_NAME,
205 CONST_FOLDING_NAME,
206 ARG_DEMOTION_NAME,
207 CONST_DEMOTION_NAME,
208 RET_DEMOTION_NAME,
209 MISC_DEMOTION_NAME,
210 INIT_AGGR_LOWERING_NAME,
211 ];
212
213 pub fn register(&mut self, pass: Pass) -> &'static str {
215 for dep in &pass.deps {
216 if let Some(dep_t) = self.lookup_registered_pass(dep) {
217 if dep_t.is_transform() {
218 panic!(
219 "Pass {} cannot depend on a transformation pass {}",
220 pass.name, dep
221 );
222 }
223 if pass.is_function_pass() && dep_t.is_module_pass() {
224 panic!(
225 "Function pass {} cannot depend on module pass {}",
226 pass.name, dep
227 );
228 }
229 } else {
230 panic!(
231 "Pass {} depends on a (yet) unregistered pass {}",
232 pass.name, dep
233 );
234 }
235 }
236 let pass_name = pass.name;
237 match self.passes.entry(pass.name) {
238 hash_map::Entry::Occupied(_) => {
239 panic!("Trying to register an already registered pass");
240 }
241 hash_map::Entry::Vacant(entry) => {
242 entry.insert(pass);
243 }
244 }
245 pass_name
246 }
247
248 fn actually_run(&mut self, ir: &mut Context, pass: &'static str) -> Result<bool, IrError> {
249 let mut modified = false;
250
251 fn run_module_pass(
252 pm: &mut PassManager,
253 ir: &mut Context,
254 pass: &'static str,
255 module: Module,
256 ) -> Result<bool, IrError> {
257 let mut modified = false;
258 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
259 for dep in pass_t.deps.clone() {
260 let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
261 assert!(dep_t.is_analysis());
263 match dep_t.runner {
264 ScopedPass::ModulePass(_) => {
265 if !pm.analyses.is_analysis_result_available(dep, module) {
266 run_module_pass(pm, ir, dep, module)?;
267 }
268 }
269 ScopedPass::FunctionPass(_) => {
270 for f in module.function_iter(ir) {
271 if !pm.analyses.is_analysis_result_available(dep, f) {
272 run_function_pass(pm, ir, dep, f)?;
273 }
274 }
275 }
276 }
277 }
278
279 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
281 let ScopedPass::ModulePass(mp) = pass_t.runner.clone() else {
282 panic!("Expected a module pass");
283 };
284 match mp {
285 PassMutability::Analysis(analysis) => {
286 let result = analysis(ir, &pm.analyses, module)?;
287 pm.analyses.add_result(pass, module, result);
288 }
289 PassMutability::Transform(transform) => {
290 if transform(ir, &pm.analyses, module)? {
291 pm.analyses.invalidate_all_results_at_scope(module);
292 for f in module.function_iter(ir) {
293 pm.analyses.invalidate_all_results_at_scope(f);
294 }
295 modified = true;
296 }
297 }
298 }
299
300 Ok(modified)
301 }
302
303 fn run_function_pass(
304 pm: &mut PassManager,
305 ir: &mut Context,
306 pass: &'static str,
307 function: Function,
308 ) -> Result<bool, IrError> {
309 let mut modified = false;
310 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
311 for dep in pass_t.deps.clone() {
312 let dep_t = pm.passes.get(dep).expect("Unregistered dependent pass");
313 assert!(dep_t.is_analysis());
315 match dep_t.runner {
316 ScopedPass::ModulePass(_) => {
317 panic!("Function pass {pass} cannot depend on module pass {dep}")
318 }
319 ScopedPass::FunctionPass(_) => {
320 if !pm.analyses.is_analysis_result_available(dep, function) {
321 run_function_pass(pm, ir, dep, function)?;
322 };
323 }
324 }
325 }
326
327 let pass_t = pm.passes.get(pass).expect("Unregistered pass");
329 let ScopedPass::FunctionPass(fp) = pass_t.runner.clone() else {
330 panic!("Expected a function pass");
331 };
332 match fp {
333 PassMutability::Analysis(analysis) => {
334 let result = analysis(ir, &pm.analyses, function)?;
335 pm.analyses.add_result(pass, function, result);
336 }
337 PassMutability::Transform(transform) => {
338 if transform(ir, &pm.analyses, function)? {
339 pm.analyses.invalidate_all_results_at_scope(function);
340 modified = true;
341 }
342 }
343 }
344
345 Ok(modified)
346 }
347
348 for m in ir.module_iter() {
349 let pass_t = self.passes.get(pass).expect("Unregistered pass");
350 let pass_runner = pass_t.runner.clone();
351 match pass_runner {
352 ScopedPass::ModulePass(_) => {
353 modified |= run_module_pass(self, ir, pass, m)?;
354 }
355 ScopedPass::FunctionPass(_) => {
356 for f in m.function_iter(ir) {
357 modified |= run_function_pass(self, ir, pass, f)?;
358 }
359 }
360 }
361 }
362 Ok(modified)
363 }
364
365 pub fn run(&mut self, ir: &mut Context, passes: &PassGroup) -> Result<bool, IrError> {
367 let mut modified = false;
368 for pass in passes.flatten_pass_group() {
369 modified |= self.actually_run(ir, pass)?;
370 }
371 Ok(modified)
372 }
373
374 pub fn run_with_print_verify(
377 &mut self,
378 ir: &mut Context,
379 passes: &PassGroup,
380 print_opts: &PrintPassesOpts,
381 verify_opts: &VerifyPassesOpts,
382 ) -> Result<bool, IrError> {
383 fn ir_is_empty(ir: &Context) -> bool {
385 ir.functions.is_empty()
386 && ir.blocks.is_empty()
387 && ir.values.is_empty()
388 && ir.local_vars.is_empty()
389 }
390
391 fn print_ir_after_pass(ir: &Context, pass: &Pass) {
392 if !ir_is_empty(ir) {
393 println!("// IR: [{}] {}", pass.name, pass.descr);
394 println!("{ir}");
395 }
396 }
397
398 fn print_initial_or_final_ir(ir: &Context, initial_or_final: &'static str) {
399 if !ir_is_empty(ir) {
400 println!("// IR: {initial_or_final}");
401 println!("{ir}");
402 }
403 }
404
405 if print_opts.initial {
406 print_initial_or_final_ir(ir, "Initial");
407 }
408
409 if verify_opts.initial {
410 ir.verify()?;
411 }
412
413 let mut modified = false;
414 for pass in passes.flatten_pass_group() {
415 let modified_in_pass = self.actually_run(ir, pass)?;
416
417 if print_opts.passes.contains(pass) && (!print_opts.modified_only || modified_in_pass) {
418 print_ir_after_pass(ir, self.lookup_registered_pass(pass).unwrap());
419 }
420
421 modified |= modified_in_pass;
422 if verify_opts.passes.contains(pass) && (!verify_opts.modified_only || modified_in_pass)
423 {
424 ir.verify()?;
425 }
426 }
427
428 if print_opts.r#final {
429 print_initial_or_final_ir(ir, "Final");
430 }
431
432 if verify_opts.r#final {
433 ir.verify()?;
434 }
435
436 Ok(modified)
437 }
438
439 pub fn lookup_registered_pass(&self, name: &str) -> Option<&Pass> {
441 self.passes.get(name)
442 }
443
444 pub fn help_text(&self) -> String {
445 let summary = self
446 .passes
447 .iter()
448 .map(|(name, pass)| format!(" {name:16} - {}", pass.descr))
449 .collect::<Vec<_>>()
450 .join("\n");
451
452 format!("Valid pass names are:\n\n{summary}",)
453 }
454}
455
456#[derive(Default)]
459pub struct PassGroup(Vec<PassOrGroup>);
460
461pub enum PassOrGroup {
463 Pass(&'static str),
464 Group(PassGroup),
465}
466
467impl PassGroup {
468 fn flatten_pass_group(&self) -> Vec<&'static str> {
470 let mut output = Vec::<&str>::new();
471 fn inner(output: &mut Vec<&str>, input: &PassGroup) {
472 for pass_or_group in &input.0 {
473 match pass_or_group {
474 PassOrGroup::Pass(pass) => output.push(pass),
475 PassOrGroup::Group(pg) => inner(output, pg),
476 }
477 }
478 }
479 inner(&mut output, self);
480 output
481 }
482
483 pub fn append_pass(&mut self, pass: &'static str) {
485 self.0.push(PassOrGroup::Pass(pass));
486 }
487
488 pub fn append_group(&mut self, group: PassGroup) {
490 self.0.push(PassOrGroup::Group(group));
491 }
492}
493
494pub fn register_known_passes(pm: &mut PassManager) {
496 pm.register(create_postorder_pass());
498 pm.register(create_dominators_pass());
499 pm.register(create_dom_fronts_pass());
500 pm.register(create_escaped_symbols_pass());
501 pm.register(create_module_printer_pass());
502 pm.register(create_module_verifier_pass());
503
504 pm.register(create_init_aggr_lowering_pass());
506
507 pm.register(create_arg_pointee_mutability_tagger_pass());
509 pm.register(create_fn_dedup_release_profile_pass());
510 pm.register(create_fn_dedup_debug_profile_pass());
511 pm.register(create_mem2reg_pass());
512 pm.register(create_sroa_pass());
513 pm.register(create_fn_inline_pass());
514 pm.register(create_const_folding_pass());
515 pm.register(create_ccp_pass());
516 pm.register(create_simplify_cfg_pass());
517 pm.register(create_globals_dce_pass());
518 pm.register(create_dce_pass());
519 pm.register(create_cse_pass());
520 pm.register(create_arg_demotion_pass());
521 pm.register(create_const_demotion_pass());
522 pm.register(create_ret_demotion_pass());
523 pm.register(create_misc_demotion_pass());
524 pm.register(create_memcpyopt_pass());
525 pm.register(create_memcpyprop_reverse_pass());
526}
527
528pub fn create_o1_pass_group() -> PassGroup {
529 let mut o1 = PassGroup::default();
531 o1.append_pass(MEM2REG_NAME);
533 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
534 o1.append_pass(FN_INLINE_NAME);
535 o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
536 o1.append_pass(SIMPLIFY_CFG_NAME);
537 o1.append_pass(GLOBALS_DCE_NAME);
538 o1.append_pass(DCE_NAME);
539 o1.append_pass(FN_INLINE_NAME);
540 o1.append_pass(ARG_POINTEE_MUTABILITY_TAGGER_NAME);
541 o1.append_pass(CCP_NAME);
542 o1.append_pass(CONST_FOLDING_NAME);
543 o1.append_pass(SIMPLIFY_CFG_NAME);
544 o1.append_pass(CSE_NAME);
545 o1.append_pass(CONST_FOLDING_NAME);
546 o1.append_pass(SIMPLIFY_CFG_NAME);
547 o1.append_pass(GLOBALS_DCE_NAME);
548 o1.append_pass(DCE_NAME);
549 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
550
551 o1
552}
553
554pub fn insert_after_each(pg: PassGroup, pass: &'static str) -> PassGroup {
560 fn insert_after_each_rec(pg: PassGroup, pass: &'static str) -> Vec<PassOrGroup> {
561 pg.0.into_iter()
562 .flat_map(|p_o_g| match p_o_g {
563 PassOrGroup::Group(group) => vec![PassOrGroup::Group(PassGroup(
564 insert_after_each_rec(group, pass),
565 ))],
566 PassOrGroup::Pass(_) => vec![p_o_g, PassOrGroup::Pass(pass)],
567 })
568 .collect()
569 }
570
571 PassGroup(insert_after_each_rec(pg, pass))
572}