1use crate::lcnf::{LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfVarId};
6use std::collections::{HashMap, HashSet};
7
8use super::types::{
9 BetaReductionPass, ConstantFoldingPass, CopyPropagationPass, DeadCodeEliminationPass,
10 ExprSizeEstimator, IdentityEliminationPass, InlineCostEstimator, OPAnalysisCache,
11 OPConstantFoldingHelper, OPDepGraph, OPDominatorTree, OPLivenessInfo, OPPassConfig,
12 OPPassPhase, OPPassRegistry, OPPassStats, OPWorklist, PassDependency, PassManager, PassStats,
13 PgoHints, StrengthReductionPass, UnreachableCodeEliminationPass,
14};
15use std::fmt;
16
17pub trait OptPass: fmt::Debug {
19 fn name(&self) -> &str;
21 fn run_pass(&mut self, decls: &mut [LcnfFunDecl]) -> usize;
23 fn is_enabled(&self) -> bool {
25 true
26 }
27 fn dependencies(&self) -> Vec<&str> {
29 Vec::new()
30 }
31}
32pub fn substitute_var_in_expr(expr: &mut LcnfExpr, from: LcnfVarId, to: LcnfVarId) {
34 let subst_arg = |a: &mut LcnfArg| {
35 if let LcnfArg::Var(v) = a {
36 if *v == from {
37 *v = to;
38 }
39 }
40 };
41 let subst_value = |val: &mut LcnfLetValue| match val {
42 LcnfLetValue::App(f, args) => {
43 subst_arg(f);
44 for a in args {
45 subst_arg(a);
46 }
47 }
48 LcnfLetValue::FVar(v) => {
49 if *v == from {
50 *v = to;
51 }
52 }
53 LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
54 for a in args {
55 subst_arg(a);
56 }
57 }
58 LcnfLetValue::Proj(_, _, v) => {
59 if *v == from {
60 *v = to;
61 }
62 }
63 LcnfLetValue::Reset(v) => {
64 if *v == from {
65 *v = to;
66 }
67 }
68 LcnfLetValue::Lit(_) | LcnfLetValue::Erased => {}
69 };
70 match expr {
71 LcnfExpr::Let { value, body, .. } => {
72 subst_value(value);
73 substitute_var_in_expr(body, from, to);
74 }
75 LcnfExpr::Case {
76 scrutinee,
77 alts,
78 default,
79 ..
80 } => {
81 if *scrutinee == from {
82 *scrutinee = to;
83 }
84 for alt in alts.iter_mut() {
85 substitute_var_in_expr(&mut alt.body, from, to);
86 }
87 if let Some(def) = default {
88 substitute_var_in_expr(def, from, to);
89 }
90 }
91 LcnfExpr::Return(a) => subst_arg(a),
92 LcnfExpr::TailCall(f, args) => {
93 subst_arg(f);
94 for a in args {
95 subst_arg(a);
96 }
97 }
98 LcnfExpr::Unreachable => {}
99 }
100}
101pub fn run_all_passes(_decls: &mut Vec<LcnfFunDecl>, pgo: Option<&PgoHints>) {
103 let mut _dce = DeadCodeEliminationPass::new();
104 let mut _cp = CopyPropagationPass::new();
105 let mut _cf = ConstantFoldingPass::new();
106 let mut _beta = BetaReductionPass::new();
107 let mut _identity = IdentityEliminationPass::new();
108 let mut _unreachable = UnreachableCodeEliminationPass::new();
109 let _ = pgo;
110}
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::lcnf::{LcnfLit, LcnfType};
115 pub(super) fn vid(n: u64) -> LcnfVarId {
116 LcnfVarId(n)
117 }
118 pub(super) fn mk_fun_decl(name: &str, body: LcnfExpr) -> LcnfFunDecl {
119 LcnfFunDecl {
120 name: name.to_string(),
121 original_name: None,
122 params: vec![],
123 ret_type: LcnfType::Nat,
124 body,
125 is_recursive: false,
126 is_lifted: false,
127 inline_cost: 0,
128 }
129 }
130 pub(super) fn mk_let(id: u64, value: LcnfLetValue, body: LcnfExpr) -> LcnfExpr {
131 LcnfExpr::Let {
132 id: vid(id),
133 name: format!("x{}", id),
134 ty: LcnfType::Nat,
135 value,
136 body: Box::new(body),
137 }
138 }
139 #[test]
140 pub(super) fn test_constant_folding_pass() {
141 let mut pass = ConstantFoldingPass::new();
142 assert_eq!(pass.folds_performed, 0);
143 assert_eq!(pass.try_fold_nat_op("add", 3, 4), Some(7));
144 assert_eq!(pass.try_fold_nat_op("sub", 5, 3), Some(2));
145 assert_eq!(pass.try_fold_nat_op("mul", 2, 6), Some(12));
146 assert_eq!(pass.try_fold_nat_op("div", 10, 2), Some(5));
147 assert_eq!(pass.try_fold_nat_op("div", 10, 0), None);
148 assert_eq!(pass.try_fold_nat_op("mod", 10, 3), Some(1));
149 assert_eq!(pass.try_fold_nat_op("mod", 10, 0), None);
150 assert_eq!(pass.try_fold_nat_op("min", 3, 7), Some(3));
151 assert_eq!(pass.try_fold_nat_op("max", 3, 7), Some(7));
152 assert_eq!(pass.try_fold_nat_op("pow", 2, 10), Some(1024));
153 assert_eq!(pass.try_fold_nat_op("and", 0xFF, 0x0F), Some(0x0F));
154 assert_eq!(pass.try_fold_nat_op("or", 0xF0, 0x0F), Some(0xFF));
155 assert_eq!(pass.try_fold_nat_op("xor", 0xFF, 0xFF), Some(0));
156 assert_eq!(pass.try_fold_nat_op("shl", 1, 3), Some(8));
157 assert_eq!(pass.try_fold_nat_op("shr", 16, 2), Some(4));
158 assert_eq!(pass.try_fold_nat_op("unknown", 1, 2), None);
159 }
160 #[test]
161 pub(super) fn test_constant_folding_bool_ops() {
162 let pass = ConstantFoldingPass::new();
163 assert_eq!(pass.try_fold_bool_op("and", true, false), Some(false));
164 assert_eq!(pass.try_fold_bool_op("or", true, false), Some(true));
165 assert_eq!(pass.try_fold_bool_op("xor", true, true), Some(false));
166 assert_eq!(pass.try_fold_bool_op("eq", true, true), Some(true));
167 assert_eq!(pass.try_fold_bool_op("ne", true, false), Some(true));
168 assert_eq!(pass.try_fold_bool_op("bad", true, false), None);
169 }
170 #[test]
171 pub(super) fn test_constant_folding_cmp_ops() {
172 let pass = ConstantFoldingPass::new();
173 assert_eq!(pass.try_fold_cmp("eq", 5, 5), Some(true));
174 assert_eq!(pass.try_fold_cmp("ne", 5, 5), Some(false));
175 assert_eq!(pass.try_fold_cmp("lt", 3, 5), Some(true));
176 assert_eq!(pass.try_fold_cmp("le", 5, 5), Some(true));
177 assert_eq!(pass.try_fold_cmp("gt", 5, 3), Some(true));
178 assert_eq!(pass.try_fold_cmp("ge", 3, 5), Some(false));
179 assert_eq!(pass.try_fold_cmp("bad", 1, 2), None);
180 }
181 #[test]
182 pub(super) fn test_constant_folding_run() {
183 let mut pass = ConstantFoldingPass::new();
184 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42)));
185 let mut decls = vec![mk_fun_decl("f", body)];
186 pass.run(&mut decls);
187 assert_eq!(pass.folds_performed, 0);
188 }
189 #[test]
190 pub(super) fn test_constant_folding_debug() {
191 let pass = ConstantFoldingPass::new();
192 let s = format!("{:?}", pass);
193 assert!(s.contains("ConstantFoldingPass"));
194 }
195 #[test]
196 pub(super) fn test_dead_code_elimination_pass() {
197 let mut pass = DeadCodeEliminationPass::new();
198 assert_eq!(pass.removed, 0);
199 let body = mk_let(
200 0,
201 LcnfLetValue::Lit(LcnfLit::Nat(42)),
202 mk_let(
203 1,
204 LcnfLetValue::Lit(LcnfLit::Nat(99)),
205 LcnfExpr::Return(LcnfArg::Var(vid(1))),
206 ),
207 );
208 let mut decls = vec![mk_fun_decl("f", body)];
209 pass.run(&mut decls);
210 assert!(pass.removed > 0, "expected dead let to be removed");
211 }
212 #[test]
213 pub(super) fn test_dead_code_elimination_debug() {
214 let pass = DeadCodeEliminationPass::new();
215 let s = format!("{:?}", pass);
216 assert!(s.contains("DeadCodeEliminationPass"));
217 }
218 #[test]
219 pub(super) fn test_copy_propagation_pass() {
220 let mut pass = CopyPropagationPass::new();
221 assert_eq!(pass.substitutions, 0);
222 let body = mk_let(
223 1,
224 LcnfLetValue::FVar(vid(0)),
225 LcnfExpr::Return(LcnfArg::Var(vid(1))),
226 );
227 let mut decls = vec![mk_fun_decl("f", body)];
228 pass.run(&mut decls);
229 assert!(pass.substitutions > 0, "expected copy to be propagated");
230 }
231 #[test]
232 pub(super) fn test_copy_propagation_debug() {
233 let pass = CopyPropagationPass::new();
234 let s = format!("{:?}", pass);
235 assert!(s.contains("CopyPropagationPass"));
236 }
237 #[test]
238 pub(super) fn test_beta_reduction_pass() {
239 let mut pass = BetaReductionPass::new();
240 assert_eq!(pass.reductions, 0);
241 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
242 let mut decls = vec![mk_fun_decl("f", body)];
243 pass.run(&mut decls);
244 assert_eq!(pass.reductions, 0);
245 let body2 = LcnfExpr::TailCall(LcnfArg::Lit(LcnfLit::Nat(0)), vec![]);
246 let mut decls2 = vec![mk_fun_decl("g", body2)];
247 pass.run(&mut decls2);
248 assert_eq!(pass.reductions, 1);
249 }
250 #[test]
251 pub(super) fn test_beta_reduction_debug() {
252 let pass = BetaReductionPass::new();
253 let s = format!("{:?}", pass);
254 assert!(s.contains("BetaReductionPass"));
255 }
256 #[test]
257 pub(super) fn test_identity_elimination() {
258 let mut pass = IdentityEliminationPass::new();
259 let body = mk_let(
260 0,
261 LcnfLetValue::FVar(vid(0)),
262 LcnfExpr::Return(LcnfArg::Var(vid(0))),
263 );
264 let mut decls = vec![mk_fun_decl("f", body)];
265 pass.run(&mut decls);
266 assert_eq!(pass.eliminated, 1);
267 assert!(matches!(decls[0].body, LcnfExpr::Return(_)));
268 }
269 #[test]
270 pub(super) fn test_identity_elimination_no_self_ref() {
271 let mut pass = IdentityEliminationPass::new();
272 let body = mk_let(
273 1,
274 LcnfLetValue::FVar(vid(0)),
275 LcnfExpr::Return(LcnfArg::Var(vid(1))),
276 );
277 let mut decls = vec![mk_fun_decl("f", body)];
278 pass.run(&mut decls);
279 assert_eq!(pass.eliminated, 0);
280 }
281 #[test]
282 pub(super) fn test_strength_reduction_power_of_two() {
283 assert!(StrengthReductionPass::is_power_of_two(1));
284 assert!(StrengthReductionPass::is_power_of_two(2));
285 assert!(StrengthReductionPass::is_power_of_two(4));
286 assert!(StrengthReductionPass::is_power_of_two(1024));
287 assert!(!StrengthReductionPass::is_power_of_two(0));
288 assert!(!StrengthReductionPass::is_power_of_two(3));
289 assert!(!StrengthReductionPass::is_power_of_two(6));
290 }
291 #[test]
292 pub(super) fn test_strength_reduction_log2() {
293 assert_eq!(StrengthReductionPass::log2_exact(1), Some(0));
294 assert_eq!(StrengthReductionPass::log2_exact(2), Some(1));
295 assert_eq!(StrengthReductionPass::log2_exact(8), Some(3));
296 assert_eq!(StrengthReductionPass::log2_exact(1024), Some(10));
297 assert_eq!(StrengthReductionPass::log2_exact(0), None);
298 assert_eq!(StrengthReductionPass::log2_exact(3), None);
299 }
300 #[test]
301 pub(super) fn test_strength_reduction_is_mask() {
302 assert!(StrengthReductionPass::is_mask(1));
303 assert!(StrengthReductionPass::is_mask(3));
304 assert!(StrengthReductionPass::is_mask(7));
305 assert!(StrengthReductionPass::is_mask(0xFF));
306 assert!(!StrengthReductionPass::is_mask(0));
307 assert!(!StrengthReductionPass::is_mask(5));
308 }
309 #[test]
310 pub(super) fn test_strength_reduction_bit_ops() {
311 assert_eq!(StrengthReductionPass::ctz(8), 3);
312 assert_eq!(StrengthReductionPass::ctz(0), 64);
313 assert_eq!(StrengthReductionPass::clz(1), 63);
314 assert_eq!(StrengthReductionPass::popcount(0xFF), 8);
315 assert_eq!(StrengthReductionPass::popcount(0), 0);
316 }
317 #[test]
318 pub(super) fn test_unreachable_code_elimination() {
319 let mut pass = UnreachableCodeEliminationPass::new();
320 let body = mk_let(
321 0,
322 LcnfLetValue::Lit(LcnfLit::Nat(42)),
323 LcnfExpr::Unreachable,
324 );
325 let mut decls = vec![mk_fun_decl("f", body)];
326 pass.run(&mut decls);
327 assert_eq!(pass.eliminated, 1);
328 assert!(matches!(decls[0].body, LcnfExpr::Unreachable));
329 }
330 #[test]
331 pub(super) fn test_unreachable_nested() {
332 let mut pass = UnreachableCodeEliminationPass::new();
333 let body = mk_let(
334 0,
335 LcnfLetValue::Lit(LcnfLit::Nat(1)),
336 mk_let(1, LcnfLetValue::Lit(LcnfLit::Nat(2)), LcnfExpr::Unreachable),
337 );
338 let mut decls = vec![mk_fun_decl("f", body)];
339 pass.run(&mut decls);
340 assert!(pass.eliminated >= 2);
341 }
342 #[test]
343 pub(super) fn test_expr_size_count_lets() {
344 let body = mk_let(
345 0,
346 LcnfLetValue::Lit(LcnfLit::Nat(1)),
347 mk_let(
348 1,
349 LcnfLetValue::Lit(LcnfLit::Nat(2)),
350 LcnfExpr::Return(LcnfArg::Var(vid(1))),
351 ),
352 );
353 assert_eq!(ExprSizeEstimator::count_lets(&body), 2);
354 }
355 #[test]
356 pub(super) fn test_expr_size_count_cases() {
357 let body = LcnfExpr::Case {
358 scrutinee: vid(0),
359 scrutinee_ty: LcnfType::Nat,
360 alts: vec![],
361 default: Some(Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))))),
362 };
363 assert_eq!(ExprSizeEstimator::count_cases(&body), 1);
364 }
365 #[test]
366 pub(super) fn test_expr_size_complexity() {
367 let body = mk_let(
368 0,
369 LcnfLetValue::Lit(LcnfLit::Nat(1)),
370 LcnfExpr::Return(LcnfArg::Var(vid(0))),
371 );
372 assert_eq!(ExprSizeEstimator::complexity(&body), 1);
373 }
374 #[test]
375 pub(super) fn test_expr_size_max_depth() {
376 let body = mk_let(
377 0,
378 LcnfLetValue::Lit(LcnfLit::Nat(1)),
379 mk_let(
380 1,
381 LcnfLetValue::Lit(LcnfLit::Nat(2)),
382 LcnfExpr::Return(LcnfArg::Var(vid(1))),
383 ),
384 );
385 assert_eq!(ExprSizeEstimator::max_depth(&body), 2);
386 }
387 #[test]
388 pub(super) fn test_expr_size_is_trivial() {
389 assert!(ExprSizeEstimator::is_trivial(&LcnfExpr::Return(
390 LcnfArg::Lit(LcnfLit::Nat(0))
391 )));
392 assert!(ExprSizeEstimator::is_trivial(&LcnfExpr::Unreachable));
393 assert!(!ExprSizeEstimator::is_trivial(&mk_let(
394 0,
395 LcnfLetValue::Lit(LcnfLit::Nat(0)),
396 LcnfExpr::Unreachable
397 )));
398 }
399 #[test]
400 pub(super) fn test_expr_size_should_inline() {
401 let small = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
402 assert!(ExprSizeEstimator::should_inline(&small, 5));
403 let big = mk_let(
404 0,
405 LcnfLetValue::Lit(LcnfLit::Nat(1)),
406 mk_let(
407 1,
408 LcnfLetValue::Lit(LcnfLit::Nat(2)),
409 mk_let(
410 2,
411 LcnfLetValue::Lit(LcnfLit::Nat(3)),
412 mk_let(
413 3,
414 LcnfLetValue::Lit(LcnfLit::Nat(4)),
415 mk_let(
416 4,
417 LcnfLetValue::Lit(LcnfLit::Nat(5)),
418 mk_let(
419 5,
420 LcnfLetValue::Lit(LcnfLit::Nat(6)),
421 LcnfExpr::Return(LcnfArg::Var(vid(5))),
422 ),
423 ),
424 ),
425 ),
426 ),
427 );
428 assert!(!ExprSizeEstimator::should_inline(&big, 3));
429 }
430 #[test]
431 pub(super) fn test_expr_size_var_refs() {
432 let body = mk_let(
433 1,
434 LcnfLetValue::FVar(vid(0)),
435 LcnfExpr::Return(LcnfArg::Var(vid(1))),
436 );
437 assert_eq!(ExprSizeEstimator::count_var_refs(&body), 2);
438 }
439 #[test]
440 pub(super) fn test_pgo_hints() {
441 let mut hints = PgoHints::new();
442 assert!(!hints.is_hot("foo"));
443 assert!(!hints.should_inline("foo"));
444 hints.mark_hot("foo");
445 hints.mark_hot("bar");
446 hints.mark_hot("foo");
447 assert!(hints.is_hot("foo"));
448 assert!(hints.is_hot("bar"));
449 assert_eq!(hints.hot_functions.len(), 2);
450 hints.mark_inline("baz");
451 assert!(hints.should_inline("baz"));
452 assert!(!hints.should_inline("qux"));
453 }
454 #[test]
455 pub(super) fn test_pgo_hints_cold() {
456 let mut hints = PgoHints::new();
457 hints.mark_cold("cold_fn");
458 assert!(hints.is_cold("cold_fn"));
459 assert!(!hints.is_cold("other"));
460 }
461 #[test]
462 pub(super) fn test_pgo_hints_total() {
463 let mut hints = PgoHints::new();
464 hints.mark_hot("a");
465 hints.mark_cold("b");
466 hints.mark_inline("c");
467 hints.record_call("d", 10);
468 assert_eq!(hints.total_hints(), 4);
469 }
470 #[test]
471 pub(super) fn test_pgo_hints_classify() {
472 let mut hints = PgoHints::new();
473 hints.mark_hot("h");
474 hints.mark_cold("c");
475 assert_eq!(hints.classify("h"), "hot");
476 assert_eq!(hints.classify("c"), "cold");
477 assert_eq!(hints.classify("other"), "normal");
478 }
479 #[test]
480 pub(super) fn test_pgo_hints_merge() {
481 let mut h1 = PgoHints::new();
482 h1.mark_hot("a");
483 h1.record_call("f", 5);
484 let mut h2 = PgoHints::new();
485 h2.mark_hot("b");
486 h2.mark_cold("c");
487 h2.record_call("f", 3);
488 h1.merge(&h2);
489 assert!(h1.is_hot("a"));
490 assert!(h1.is_hot("b"));
491 assert!(h1.is_cold("c"));
492 assert_eq!(h1.call_count("f"), 8);
493 }
494 #[test]
495 pub(super) fn test_pgo_hints_call_count() {
496 let mut hints = PgoHints::new();
497 hints.record_call("f", 10);
498 hints.record_call("f", 5);
499 assert_eq!(hints.call_count("f"), 15);
500 assert_eq!(hints.call_count("g"), 0);
501 }
502 #[test]
503 pub(super) fn test_inline_cost_estimator_trivial() {
504 let est = InlineCostEstimator::default();
505 let decl = mk_fun_decl("f", LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))));
506 assert!(est.should_inline(&decl, None));
507 }
508 #[test]
509 pub(super) fn test_inline_cost_estimator_with_pgo() {
510 let est = InlineCostEstimator::default();
511 let body = mk_let(
512 0,
513 LcnfLetValue::Lit(LcnfLit::Nat(1)),
514 mk_let(
515 1,
516 LcnfLetValue::Lit(LcnfLit::Nat(2)),
517 mk_let(
518 2,
519 LcnfLetValue::Lit(LcnfLit::Nat(3)),
520 mk_let(
521 3,
522 LcnfLetValue::Lit(LcnfLit::Nat(4)),
523 LcnfExpr::Return(LcnfArg::Var(vid(3))),
524 ),
525 ),
526 ),
527 );
528 let decl = mk_fun_decl("medium", body);
529 let mut pgo = PgoHints::new();
530 pgo.mark_inline("medium");
531 assert!(est.should_inline(&decl, Some(&pgo)));
532 }
533 #[test]
534 pub(super) fn test_inline_cost_recursive_penalty() {
535 let est = InlineCostEstimator::default();
536 let mut decl = mk_fun_decl("rec", LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))));
537 decl.is_recursive = true;
538 let cost = est.cost(&decl);
539 assert_eq!(cost, 10);
540 }
541 #[test]
542 pub(super) fn test_pass_manager_new() {
543 let pm = PassManager::new();
544 assert_eq!(pm.num_passes(), 0);
545 assert_eq!(pm.max_iterations, 10);
546 }
547 #[test]
548 pub(super) fn test_pass_manager_add_pass() {
549 let mut pm = PassManager::new();
550 pm.add_pass("dce");
551 pm.add_pass("cp");
552 pm.add_pass("dce");
553 assert_eq!(pm.num_passes(), 2);
554 }
555 #[test]
556 pub(super) fn test_pass_manager_record_run() {
557 let mut pm = PassManager::new();
558 pm.add_pass("dce");
559 pm.record_run("dce", 5, 100);
560 let stats = pm.get_stats("dce").expect("stats should exist");
561 assert_eq!(stats.run_count, 1);
562 assert_eq!(stats.total_changes, 5);
563 }
564 #[test]
565 pub(super) fn test_pass_manager_topological_order() {
566 let mut pm = PassManager::new();
567 pm.add_pass("beta");
568 pm.add_pass("dce");
569 pm.add_pass("cp");
570 pm.add_dependency("dce", "cp");
571 pm.add_dependency("cp", "beta");
572 let order = pm.topological_order().expect("no cycle");
573 let beta_pos = order
574 .iter()
575 .position(|n| n == "beta")
576 .expect("beta_pos position should exist");
577 let cp_pos = order
578 .iter()
579 .position(|n| n == "cp")
580 .expect("cp_pos position should exist");
581 let dce_pos = order
582 .iter()
583 .position(|n| n == "dce")
584 .expect("dce_pos position should exist");
585 assert!(beta_pos < cp_pos);
586 assert!(cp_pos < dce_pos);
587 }
588 #[test]
589 pub(super) fn test_pass_manager_cycle_detection() {
590 let mut pm = PassManager::new();
591 pm.add_pass("a");
592 pm.add_pass("b");
593 pm.add_dependency("a", "b");
594 pm.add_dependency("b", "a");
595 assert!(pm.has_cycle());
596 assert!(pm.topological_order().is_none());
597 }
598 #[test]
599 pub(super) fn test_pass_manager_no_cycle() {
600 let mut pm = PassManager::new();
601 pm.add_pass("a");
602 pm.add_pass("b");
603 pm.add_dependency("b", "a");
604 assert!(!pm.has_cycle());
605 }
606 #[test]
607 pub(super) fn test_pass_manager_total_changes() {
608 let mut pm = PassManager::new();
609 pm.add_pass("a");
610 pm.add_pass("b");
611 pm.record_run("a", 3, 0);
612 pm.record_run("b", 7, 0);
613 assert_eq!(pm.total_changes(), 10);
614 assert_eq!(pm.total_runs(), 2);
615 }
616 #[test]
617 pub(super) fn test_pass_stats_display() {
618 let mut stats = PassStats::new("test_pass");
619 stats.record_run(5, 100);
620 stats.record_run(3, 50);
621 let s = format!("{}", stats);
622 assert!(s.contains("test_pass"));
623 assert!(s.contains("runs=2"));
624 assert!(s.contains("changes=8"));
625 }
626 #[test]
627 pub(super) fn test_pass_stats_avg() {
628 let mut stats = PassStats::new("avg_test");
629 stats.record_run(10, 0);
630 stats.record_run(20, 0);
631 assert!((stats.avg_changes() - 15.0).abs() < 0.001);
632 }
633 #[test]
634 pub(super) fn test_pass_stats_empty_avg() {
635 let stats = PassStats::new("empty");
636 assert_eq!(stats.avg_changes(), 0.0);
637 }
638 #[test]
639 pub(super) fn test_pass_dependency_display() {
640 let dep = PassDependency::new("b", "a");
641 assert_eq!(format!("{}", dep), "a -> b");
642 }
643 #[test]
644 pub(super) fn test_substitute_var_in_return() {
645 let mut expr = LcnfExpr::Return(LcnfArg::Var(vid(1)));
646 substitute_var_in_expr(&mut expr, vid(1), vid(2));
647 assert_eq!(expr, LcnfExpr::Return(LcnfArg::Var(vid(2))));
648 }
649 #[test]
650 pub(super) fn test_substitute_var_in_tailcall() {
651 let mut expr = LcnfExpr::TailCall(
652 LcnfArg::Var(vid(1)),
653 vec![LcnfArg::Var(vid(1)), LcnfArg::Lit(LcnfLit::Nat(0))],
654 );
655 substitute_var_in_expr(&mut expr, vid(1), vid(2));
656 if let LcnfExpr::TailCall(f, args) = &expr {
657 assert_eq!(*f, LcnfArg::Var(vid(2)));
658 assert_eq!(args[0], LcnfArg::Var(vid(2)));
659 }
660 }
661 #[test]
662 pub(super) fn test_substitute_var_in_case() {
663 let mut expr = LcnfExpr::Case {
664 scrutinee: vid(1),
665 scrutinee_ty: LcnfType::Nat,
666 alts: vec![],
667 default: Some(Box::new(LcnfExpr::Return(LcnfArg::Var(vid(1))))),
668 };
669 substitute_var_in_expr(&mut expr, vid(1), vid(2));
670 if let LcnfExpr::Case {
671 scrutinee, default, ..
672 } = &expr
673 {
674 assert_eq!(*scrutinee, vid(2));
675 assert_eq!(
676 **default.as_ref().expect("expected Some/Ok value"),
677 LcnfExpr::Return(LcnfArg::Var(vid(2)))
678 );
679 }
680 }
681 #[test]
682 pub(super) fn test_run_all_passes() {
683 let mut hints = PgoHints::new();
684 hints.mark_hot("main");
685 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
686 let mut decls = vec![mk_fun_decl("main", body)];
687 run_all_passes(&mut decls, Some(&hints));
688 run_all_passes(&mut decls, None);
689 }
690 #[test]
691 pub(super) fn test_opt_pass_trait_constant_folding() {
692 let mut pass = ConstantFoldingPass::new();
693 assert_eq!(pass.name(), "constant_folding");
694 assert!(pass.is_enabled());
695 assert!(pass.dependencies().is_empty());
696 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
697 let mut decls = vec![mk_fun_decl("f", body)];
698 let changes = pass.run_pass(&mut decls);
699 assert_eq!(changes, 0);
700 }
701 #[test]
702 pub(super) fn test_opt_pass_trait_dce() {
703 let mut pass = DeadCodeEliminationPass::new();
704 assert_eq!(pass.name(), "dead_code_elimination");
705 }
706 #[test]
707 pub(super) fn test_opt_pass_trait_cp() {
708 let mut pass = CopyPropagationPass::new();
709 assert_eq!(pass.name(), "copy_propagation");
710 }
711 #[test]
712 pub(super) fn test_opt_pass_trait_beta() {
713 let mut pass = BetaReductionPass::new();
714 assert_eq!(pass.name(), "beta_reduction");
715 }
716 #[test]
717 pub(super) fn test_opt_pass_trait_identity() {
718 let mut pass = IdentityEliminationPass::new();
719 assert_eq!(pass.name(), "identity_elimination");
720 }
721 #[test]
722 pub(super) fn test_opt_pass_trait_unreachable() {
723 let mut pass = UnreachableCodeEliminationPass::new();
724 assert_eq!(pass.name(), "unreachable_code_elimination");
725 }
726 #[test]
727 pub(super) fn test_pass_debug_impls() {
728 let cf = ConstantFoldingPass::new();
729 let dce = DeadCodeEliminationPass::new();
730 let cp = CopyPropagationPass::new();
731 let beta = BetaReductionPass::new();
732 let id = IdentityEliminationPass::new();
733 let sr = StrengthReductionPass::new();
734 let uce = UnreachableCodeEliminationPass::new();
735 assert!(format!("{:?}", cf).contains("ConstantFolding"));
736 assert!(format!("{:?}", dce).contains("DeadCode"));
737 assert!(format!("{:?}", cp).contains("CopyPropagation"));
738 assert!(format!("{:?}", beta).contains("BetaReduction"));
739 assert!(format!("{:?}", id).contains("Identity"));
740 assert!(format!("{:?}", sr).contains("StrengthReduction"));
741 assert!(format!("{:?}", uce).contains("Unreachable"));
742 }
743}
744#[cfg(test)]
745mod OP_infra_tests {
746 use super::*;
747 #[test]
748 pub(super) fn test_pass_config() {
749 let config = OPPassConfig::new("test_pass", OPPassPhase::Transformation);
750 assert!(config.enabled);
751 assert!(config.phase.is_modifying());
752 assert_eq!(config.phase.name(), "transformation");
753 }
754 #[test]
755 pub(super) fn test_pass_stats() {
756 let mut stats = OPPassStats::new();
757 stats.record_run(10, 100, 3);
758 stats.record_run(20, 200, 5);
759 assert_eq!(stats.total_runs, 2);
760 assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
761 assert!((stats.success_rate() - 1.0).abs() < 0.01);
762 let s = stats.format_summary();
763 assert!(s.contains("Runs: 2/2"));
764 }
765 #[test]
766 pub(super) fn test_pass_registry() {
767 let mut reg = OPPassRegistry::new();
768 reg.register(OPPassConfig::new("pass_a", OPPassPhase::Analysis));
769 reg.register(OPPassConfig::new("pass_b", OPPassPhase::Transformation).disabled());
770 assert_eq!(reg.total_passes(), 2);
771 assert_eq!(reg.enabled_count(), 1);
772 reg.update_stats("pass_a", 5, 50, 2);
773 let stats = reg.get_stats("pass_a").expect("stats should exist");
774 assert_eq!(stats.total_changes, 5);
775 }
776 #[test]
777 pub(super) fn test_analysis_cache() {
778 let mut cache = OPAnalysisCache::new(10);
779 cache.insert("key1".to_string(), vec![1, 2, 3]);
780 assert!(cache.get("key1").is_some());
781 assert!(cache.get("key2").is_none());
782 assert!((cache.hit_rate() - 0.5).abs() < 0.01);
783 cache.invalidate("key1");
784 assert!(!cache.entries["key1"].valid);
785 assert_eq!(cache.size(), 1);
786 }
787 #[test]
788 pub(super) fn test_worklist() {
789 let mut wl = OPWorklist::new();
790 assert!(wl.push(1));
791 assert!(wl.push(2));
792 assert!(!wl.push(1));
793 assert_eq!(wl.len(), 2);
794 assert_eq!(wl.pop(), Some(1));
795 assert!(!wl.contains(1));
796 assert!(wl.contains(2));
797 }
798 #[test]
799 pub(super) fn test_dominator_tree() {
800 let mut dt = OPDominatorTree::new(5);
801 dt.set_idom(1, 0);
802 dt.set_idom(2, 0);
803 dt.set_idom(3, 1);
804 assert!(dt.dominates(0, 3));
805 assert!(dt.dominates(1, 3));
806 assert!(!dt.dominates(2, 3));
807 assert!(dt.dominates(3, 3));
808 }
809 #[test]
810 pub(super) fn test_liveness() {
811 let mut liveness = OPLivenessInfo::new(3);
812 liveness.add_def(0, 1);
813 liveness.add_use(1, 1);
814 assert!(liveness.defs[0].contains(&1));
815 assert!(liveness.uses[1].contains(&1));
816 }
817 #[test]
818 pub(super) fn test_constant_folding() {
819 assert_eq!(OPConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
820 assert_eq!(OPConstantFoldingHelper::fold_div_i64(10, 0), None);
821 assert_eq!(OPConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
822 assert_eq!(
823 OPConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
824 0b1000
825 );
826 assert_eq!(OPConstantFoldingHelper::fold_bitnot_i64(0), -1);
827 }
828 #[test]
829 pub(super) fn test_dep_graph() {
830 let mut g = OPDepGraph::new();
831 g.add_dep(1, 2);
832 g.add_dep(2, 3);
833 g.add_dep(1, 3);
834 assert_eq!(g.dependencies_of(2), vec![1]);
835 let topo = g.topological_sort();
836 assert_eq!(topo.len(), 3);
837 assert!(!g.has_cycle());
838 let pos: std::collections::HashMap<u32, usize> =
839 topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
840 assert!(pos[&1] < pos[&2]);
841 assert!(pos[&1] < pos[&3]);
842 assert!(pos[&2] < pos[&3]);
843 }
844}