1use crate::lcnf::{LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfParam, LcnfType, LcnfVarId};
6use std::collections::{HashMap, HashSet};
7
8use super::types::{
9 FreshIds, TRAnalysisCache, TRConstantFoldingHelper, TRDepGraph, TRDominatorTree, TRExtCache,
10 TRExtConstFolder, TRExtDepGraph, TRExtDomTree, TRExtLiveness, TRExtPassConfig, TRExtPassPhase,
11 TRExtPassRegistry, TRExtPassStats, TRExtWorklist, TRLivenessInfo, TRPassConfig, TRPassPhase,
12 TRPassRegistry, TRPassStats, TRWorklist, TRX2Cache, TRX2ConstFolder, TRX2DepGraph, TRX2DomTree,
13 TRX2Liveness, TRX2PassConfig, TRX2PassPhase, TRX2PassRegistry, TRX2PassStats, TRX2Worklist,
14 TailRecConfig, TailRecOpt,
15};
16
17pub(super) fn has_tail_call_to(expr: &LcnfExpr, _fn_name: &str) -> bool {
19 match expr {
20 LcnfExpr::TailCall(LcnfArg::Var(_), _) => false,
21 LcnfExpr::Return(_) | LcnfExpr::Unreachable => false,
22 LcnfExpr::Let { body, .. } => has_tail_call_to(body, _fn_name),
23 LcnfExpr::Case { alts, default, .. } => {
24 alts.iter().any(|a| has_tail_call_to(&a.body, _fn_name))
25 || default
26 .as_ref()
27 .is_some_and(|d| has_tail_call_to(d, _fn_name))
28 }
29 LcnfExpr::TailCall(LcnfArg::Lit(_), _) => false,
30 LcnfExpr::TailCall(LcnfArg::Erased, _) => false,
31 LcnfExpr::TailCall(LcnfArg::Type(_), _) => false,
32 }
33}
34pub(super) fn has_non_tail_recursive_call(
37 expr: &LcnfExpr,
38 fn_name: &str,
39 param_names: &[String],
40) -> bool {
41 match expr {
42 LcnfExpr::Let {
43 name, value, body, ..
44 } => {
45 let self_call_in_value = match value {
46 LcnfLetValue::App(LcnfArg::Var(_), _) => {
47 param_names.contains(name) || name.contains(fn_name) || fn_name == name.as_str()
48 }
49 _ => false,
50 };
51 self_call_in_value || has_non_tail_recursive_call(body, fn_name, param_names)
52 }
53 LcnfExpr::Case { alts, default, .. } => {
54 alts.iter()
55 .any(|a| has_non_tail_recursive_call(&a.body, fn_name, param_names))
56 || default
57 .as_ref()
58 .is_some_and(|d| has_non_tail_recursive_call(d, fn_name, param_names))
59 }
60 _ => false,
61 }
62}
63pub(super) fn rewrite_tail_calls(
66 expr: LcnfExpr,
67 _fn_var: &LcnfVarId,
68 _count: &mut usize,
69) -> LcnfExpr {
70 match expr {
71 LcnfExpr::Let {
72 id,
73 name,
74 ty,
75 value,
76 body,
77 } => {
78 let new_body = rewrite_tail_calls(*body, _fn_var, _count);
79 LcnfExpr::Let {
80 id,
81 name,
82 ty,
83 value,
84 body: Box::new(new_body),
85 }
86 }
87 LcnfExpr::Case {
88 scrutinee,
89 scrutinee_ty,
90 alts,
91 default,
92 } => {
93 let new_alts = alts
94 .into_iter()
95 .map(|a| {
96 let new_body = rewrite_tail_calls(a.body, _fn_var, _count);
97 crate::lcnf::LcnfAlt {
98 body: new_body,
99 ..a
100 }
101 })
102 .collect();
103 let new_default = default.map(|d| Box::new(rewrite_tail_calls(*d, _fn_var, _count)));
104 LcnfExpr::Case {
105 scrutinee,
106 scrutinee_ty,
107 alts: new_alts,
108 default: new_default,
109 }
110 }
111 LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(func, args),
112 other => other,
113 }
114}
115pub(super) fn try_introduce_accumulator(
133 decl: &LcnfFunDecl,
134 fresh: &mut FreshIds,
135) -> Option<LcnfFunDecl> {
136 if decl.params.len() != 1 {
137 return None;
138 }
139 let param = &decl.params[0];
140 if param.ty != LcnfType::Nat {
141 return None;
142 }
143 let (base_alt, _step_alt) = match &decl.body {
144 LcnfExpr::Case { alts, default, .. } if alts.len() == 1 && default.is_some() => {
145 let alt = &alts[0];
146 let def = default
147 .as_ref()
148 .expect("default is Some; guaranteed by pattern match condition default.is_some()");
149 (alt, def.as_ref())
150 }
151 LcnfExpr::Case {
152 alts,
153 default: None,
154 ..
155 } if alts.len() == 2 => (&alts[0], &alts[1].body),
156 _ => return None,
157 };
158 let base_lit = match &base_alt.body {
159 LcnfExpr::Return(LcnfArg::Lit(lit)) => lit.clone(),
160 _ => return None,
161 };
162 let param_names: Vec<String> = decl.params.iter().map(|p| p.name.clone()).collect();
163 if !has_non_tail_recursive_call(&decl.body, &decl.name, ¶m_names) {
164 return None;
165 }
166 let acc_id = fresh.next();
167 let acc_param = LcnfParam {
168 id: acc_id,
169 name: "acc".to_string(),
170 ty: LcnfType::Nat,
171 erased: false,
172 borrowed: false,
173 };
174 let acc_helper_body = LcnfExpr::Case {
175 scrutinee: param.id,
176 scrutinee_ty: LcnfType::Nat,
177 alts: vec![crate::lcnf::LcnfAlt {
178 ctor_name: "Nat.zero".to_string(),
179 ctor_tag: 0,
180 params: vec![],
181 body: LcnfExpr::Return(LcnfArg::Lit(base_lit)),
182 }],
183 default: Some(Box::new(LcnfExpr::TailCall(
184 LcnfArg::Var(acc_id),
185 vec![LcnfArg::Var(param.id), LcnfArg::Var(acc_id)],
186 ))),
187 };
188 Some(LcnfFunDecl {
189 name: format!("{}_acc", decl.name),
190 original_name: decl.original_name.clone(),
191 params: vec![param.clone(), acc_param],
192 ret_type: decl.ret_type.clone(),
193 body: acc_helper_body,
194 is_recursive: true,
195 is_lifted: true,
196 inline_cost: decl.inline_cost + 2,
197 })
198}
199pub(super) fn tail_callees(expr: &LcnfExpr, candidates: &HashSet<String>) -> HashSet<String> {
202 let mut result = HashSet::new();
203 collect_tail_callees(expr, candidates, &mut result);
204 result
205}
206pub(super) fn collect_tail_callees(
207 expr: &LcnfExpr,
208 candidates: &HashSet<String>,
209 result: &mut HashSet<String>,
210) {
211 match expr {
212 LcnfExpr::Let { body, .. } => collect_tail_callees(body, candidates, result),
213 LcnfExpr::Case { alts, default, .. } => {
214 for a in alts {
215 collect_tail_callees(&a.body, candidates, result);
216 }
217 if let Some(d) = default {
218 collect_tail_callees(d, candidates, result);
219 }
220 }
221 LcnfExpr::TailCall(LcnfArg::Var(id), _) => {
222 let key = format!("var_{}", id.0);
223 if candidates.contains(&key) {
224 result.insert(key);
225 }
226 }
227 _ => {}
228 }
229}
230pub fn detect_mutual_tail_recursion(decls: &[LcnfFunDecl]) -> Vec<Vec<String>> {
234 let name_to_idx: HashMap<String, usize> = decls
235 .iter()
236 .enumerate()
237 .map(|(i, d)| (d.name.clone(), i))
238 .collect();
239 let n = decls.len();
240 let mut adj: Vec<HashSet<usize>> = vec![HashSet::new(); n];
241 let candidate_names: HashSet<String> = decls.iter().map(|d| d.name.clone()).collect();
242 for (i, decl) in decls.iter().enumerate() {
243 if decl.is_recursive {
244 adj[i].insert(i);
245 }
246 for other_name in &candidate_names {
247 if other_name == &decl.name {
248 continue;
249 }
250 if let Some(&j) = name_to_idx.get(other_name) {
251 if decl.name.starts_with(&format!("{}_", other_name))
252 || other_name.starts_with(&format!("{}_", decl.name))
253 {
254 adj[i].insert(j);
255 }
256 }
257 }
258 }
259 let mut visited = vec![false; n];
260 let mut sccs: Vec<Vec<String>> = Vec::new();
261 for start in 0..n {
262 if !visited[start] {
263 let mut scc = Vec::new();
264 dfs_scc(start, &adj, &mut visited, &mut scc);
265 let names: Vec<String> = scc.into_iter().map(|i| decls[i].name.clone()).collect();
266 if !names.is_empty() {
267 sccs.push(names);
268 }
269 }
270 }
271 sccs
272}
273pub(super) fn dfs_scc(
274 node: usize,
275 adj: &[HashSet<usize>],
276 visited: &mut Vec<bool>,
277 component: &mut Vec<usize>,
278) {
279 if visited[node] {
280 return;
281 }
282 visited[node] = true;
283 component.push(node);
284 for &next in &adj[node] {
285 dfs_scc(next, adj, visited, component);
286 }
287}
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use crate::lcnf::{
292 LcnfAlt, LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfParam, LcnfType,
293 LcnfVarId,
294 };
295 pub(super) fn nat_param(id: u64, name: &str) -> LcnfParam {
296 LcnfParam {
297 id: LcnfVarId(id),
298 name: name.to_string(),
299 ty: LcnfType::Nat,
300 erased: false,
301 borrowed: false,
302 }
303 }
304 pub(super) fn mk_recursive_decl(
305 name: &str,
306 params: Vec<LcnfParam>,
307 body: LcnfExpr,
308 ) -> LcnfFunDecl {
309 LcnfFunDecl {
310 name: name.to_string(),
311 original_name: None,
312 params,
313 ret_type: LcnfType::Nat,
314 body,
315 is_recursive: true,
316 is_lifted: false,
317 inline_cost: 2,
318 }
319 }
320 pub(super) fn mk_non_recursive_decl(body: LcnfExpr) -> LcnfFunDecl {
321 LcnfFunDecl {
322 name: "non_rec".to_string(),
323 original_name: None,
324 params: vec![],
325 ret_type: LcnfType::Nat,
326 body,
327 is_recursive: false,
328 is_lifted: false,
329 inline_cost: 1,
330 }
331 }
332 #[test]
333 pub(super) fn test_non_recursive_unchanged() {
334 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42)));
335 let mut decl = mk_non_recursive_decl(body.clone());
336 let mut pass = TailRecOpt::new();
337 let (report, extras) = pass.run(&mut decl);
338 assert_eq!(report.functions_transformed, 0);
339 assert_eq!(report.calls_eliminated, 0);
340 assert!(extras.is_empty());
341 assert_eq!(decl.body, body);
342 }
343 #[test]
344 pub(super) fn test_recursive_tailcall_counted() {
345 let n_id = LcnfVarId(1);
346 let body = LcnfExpr::Case {
347 scrutinee: n_id,
348 scrutinee_ty: LcnfType::Nat,
349 alts: vec![LcnfAlt {
350 ctor_name: "Nat.zero".to_string(),
351 ctor_tag: 0,
352 params: vec![],
353 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
354 }],
355 default: Some(Box::new(LcnfExpr::TailCall(
356 LcnfArg::Var(n_id),
357 vec![LcnfArg::Lit(LcnfLit::Nat(0))],
358 ))),
359 };
360 let mut decl = mk_recursive_decl("countdown", vec![nat_param(1, "n")], body);
361 let mut pass = TailRecOpt::new();
362 let (report, _) = pass.run(&mut decl);
363 assert!(
364 report.functions_transformed >= 1,
365 "Recursive function with TailCall should be counted as transformed"
366 );
367 assert!(report.calls_eliminated >= 1);
368 }
369 #[test]
370 pub(super) fn test_accumulator_introduced() {
371 let n_id = LcnfVarId(1);
372 let rec_call_id = LcnfVarId(2);
373 let body = LcnfExpr::Case {
374 scrutinee: n_id,
375 scrutinee_ty: LcnfType::Nat,
376 alts: vec![LcnfAlt {
377 ctor_name: "Nat.zero".to_string(),
378 ctor_tag: 0,
379 params: vec![],
380 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
381 }],
382 default: Some(Box::new(LcnfExpr::Let {
383 id: rec_call_id,
384 name: "sum_acc".to_string(),
385 ty: LcnfType::Nat,
386 value: LcnfLetValue::App(LcnfArg::Var(n_id), vec![LcnfArg::Lit(LcnfLit::Nat(1))]),
387 body: Box::new(LcnfExpr::Return(LcnfArg::Var(rec_call_id))),
388 })),
389 };
390 let mut decl = mk_recursive_decl("sum", vec![nat_param(1, "n")], body);
391 let mut pass = TailRecOpt::with_config(TailRecConfig {
392 transform_linear: true,
393 introduce_accum: true,
394 });
395 let (_report, extras) = pass.run(&mut decl);
396 assert!(
397 !extras.is_empty(),
398 "Accumulator helper should be synthesized for non-tail-recursive single-Nat-param fn"
399 );
400 let helper = &extras[0];
401 assert!(
402 helper.name.ends_with("_acc"),
403 "Helper name should have _acc suffix"
404 );
405 assert_eq!(
406 helper.params.len(),
407 2,
408 "Helper should have original param + accumulator"
409 );
410 assert!(helper.is_recursive);
411 }
412 #[test]
413 pub(super) fn test_no_accum_when_disabled() {
414 let n_id = LcnfVarId(1);
415 let rec_call_id = LcnfVarId(2);
416 let body = LcnfExpr::Case {
417 scrutinee: n_id,
418 scrutinee_ty: LcnfType::Nat,
419 alts: vec![LcnfAlt {
420 ctor_name: "Nat.zero".to_string(),
421 ctor_tag: 0,
422 params: vec![],
423 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
424 }],
425 default: Some(Box::new(LcnfExpr::Let {
426 id: rec_call_id,
427 name: "product_acc".to_string(),
428 ty: LcnfType::Nat,
429 value: LcnfLetValue::App(LcnfArg::Var(n_id), vec![LcnfArg::Lit(LcnfLit::Nat(1))]),
430 body: Box::new(LcnfExpr::Return(LcnfArg::Var(rec_call_id))),
431 })),
432 };
433 let mut decl = mk_recursive_decl("product", vec![nat_param(1, "n")], body);
434 let mut pass = TailRecOpt::with_config(TailRecConfig {
435 transform_linear: true,
436 introduce_accum: false,
437 });
438 let (_report, extras) = pass.run(&mut decl);
439 assert!(
440 extras.is_empty(),
441 "introduce_accum=false must not synthesize helper"
442 );
443 }
444 #[test]
445 pub(super) fn test_mutual_tail_rec_detection() {
446 let decl_a = mk_recursive_decl(
447 "is_even",
448 vec![nat_param(1, "n")],
449 LcnfExpr::TailCall(LcnfArg::Var(LcnfVarId(1)), vec![]),
450 );
451 let decl_b = mk_recursive_decl(
452 "is_even_helper",
453 vec![nat_param(2, "n")],
454 LcnfExpr::TailCall(LcnfArg::Var(LcnfVarId(2)), vec![]),
455 );
456 let decls = vec![decl_a, decl_b];
457 let sccs = detect_mutual_tail_recursion(&decls);
458 let all_names: Vec<String> = sccs.into_iter().flatten().collect();
459 assert!(all_names.contains(&"is_even".to_string()));
460 assert!(all_names.contains(&"is_even_helper".to_string()));
461 }
462 #[test]
463 pub(super) fn test_run_module() {
464 let body_rec = LcnfExpr::Case {
465 scrutinee: LcnfVarId(1),
466 scrutinee_ty: LcnfType::Nat,
467 alts: vec![LcnfAlt {
468 ctor_name: "Nat.zero".to_string(),
469 ctor_tag: 0,
470 params: vec![],
471 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))),
472 }],
473 default: Some(Box::new(LcnfExpr::TailCall(
474 LcnfArg::Var(LcnfVarId(1)),
475 vec![LcnfArg::Lit(LcnfLit::Nat(0))],
476 ))),
477 };
478 let body_non_rec = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
479 let mut decls = vec![
480 mk_recursive_decl("fib", vec![nat_param(1, "n")], body_rec),
481 mk_non_recursive_decl(body_non_rec),
482 ];
483 let mut pass = TailRecOpt::new();
484 let report = pass.run_module(&mut decls);
485 assert!(
486 report.functions_transformed >= 1,
487 "At least one recursive function should be transformed"
488 );
489 }
490 #[test]
491 pub(super) fn test_rewrite_preserves_let_structure() {
492 let fn_var = LcnfVarId(0);
493 let body = LcnfExpr::Let {
494 id: LcnfVarId(10),
495 name: "tmp".to_string(),
496 ty: LcnfType::Nat,
497 value: LcnfLetValue::Lit(LcnfLit::Nat(5)),
498 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(10)))),
499 };
500 let mut count = 0usize;
501 let result = rewrite_tail_calls(body.clone(), &fn_var, &mut count);
502 assert_eq!(result, body, "Non-self-calling Let should be unchanged");
503 assert_eq!(count, 0);
504 }
505 #[test]
506 pub(super) fn test_has_tail_call_to_detects_tailcall() {
507 let expr = LcnfExpr::TailCall(
508 LcnfArg::Var(LcnfVarId(99)),
509 vec![LcnfArg::Lit(LcnfLit::Nat(0))],
510 );
511 let pass = TailRecOpt::new();
512 assert_eq!(pass.count_tailcalls(&expr), 1);
513 }
514}
515#[cfg(test)]
516mod TR_infra_tests {
517 use super::*;
518 #[test]
519 pub(super) fn test_pass_config() {
520 let config = TRPassConfig::new("test_pass", TRPassPhase::Transformation);
521 assert!(config.enabled);
522 assert!(config.phase.is_modifying());
523 assert_eq!(config.phase.name(), "transformation");
524 }
525 #[test]
526 pub(super) fn test_pass_stats() {
527 let mut stats = TRPassStats::new();
528 stats.record_run(10, 100, 3);
529 stats.record_run(20, 200, 5);
530 assert_eq!(stats.total_runs, 2);
531 assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
532 assert!((stats.success_rate() - 1.0).abs() < 0.01);
533 let s = stats.format_summary();
534 assert!(s.contains("Runs: 2/2"));
535 }
536 #[test]
537 pub(super) fn test_pass_registry() {
538 let mut reg = TRPassRegistry::new();
539 reg.register(TRPassConfig::new("pass_a", TRPassPhase::Analysis));
540 reg.register(TRPassConfig::new("pass_b", TRPassPhase::Transformation).disabled());
541 assert_eq!(reg.total_passes(), 2);
542 assert_eq!(reg.enabled_count(), 1);
543 reg.update_stats("pass_a", 5, 50, 2);
544 let stats = reg.get_stats("pass_a").expect("stats should exist");
545 assert_eq!(stats.total_changes, 5);
546 }
547 #[test]
548 pub(super) fn test_analysis_cache() {
549 let mut cache = TRAnalysisCache::new(10);
550 cache.insert("key1".to_string(), vec![1, 2, 3]);
551 assert!(cache.get("key1").is_some());
552 assert!(cache.get("key2").is_none());
553 assert!((cache.hit_rate() - 0.5).abs() < 0.01);
554 cache.invalidate("key1");
555 assert!(!cache.entries["key1"].valid);
556 assert_eq!(cache.size(), 1);
557 }
558 #[test]
559 pub(super) fn test_worklist() {
560 let mut wl = TRWorklist::new();
561 assert!(wl.push(1));
562 assert!(wl.push(2));
563 assert!(!wl.push(1));
564 assert_eq!(wl.len(), 2);
565 assert_eq!(wl.pop(), Some(1));
566 assert!(!wl.contains(1));
567 assert!(wl.contains(2));
568 }
569 #[test]
570 pub(super) fn test_dominator_tree() {
571 let mut dt = TRDominatorTree::new(5);
572 dt.set_idom(1, 0);
573 dt.set_idom(2, 0);
574 dt.set_idom(3, 1);
575 assert!(dt.dominates(0, 3));
576 assert!(dt.dominates(1, 3));
577 assert!(!dt.dominates(2, 3));
578 assert!(dt.dominates(3, 3));
579 }
580 #[test]
581 pub(super) fn test_liveness() {
582 let mut liveness = TRLivenessInfo::new(3);
583 liveness.add_def(0, 1);
584 liveness.add_use(1, 1);
585 assert!(liveness.defs[0].contains(&1));
586 assert!(liveness.uses[1].contains(&1));
587 }
588 #[test]
589 pub(super) fn test_constant_folding() {
590 assert_eq!(TRConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
591 assert_eq!(TRConstantFoldingHelper::fold_div_i64(10, 0), None);
592 assert_eq!(TRConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
593 assert_eq!(
594 TRConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
595 0b1000
596 );
597 assert_eq!(TRConstantFoldingHelper::fold_bitnot_i64(0), -1);
598 }
599 #[test]
600 pub(super) fn test_dep_graph() {
601 let mut g = TRDepGraph::new();
602 g.add_dep(1, 2);
603 g.add_dep(2, 3);
604 g.add_dep(1, 3);
605 assert_eq!(g.dependencies_of(2), vec![1]);
606 let topo = g.topological_sort();
607 assert_eq!(topo.len(), 3);
608 assert!(!g.has_cycle());
609 let pos: std::collections::HashMap<u32, usize> =
610 topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
611 assert!(pos[&1] < pos[&2]);
612 assert!(pos[&1] < pos[&3]);
613 assert!(pos[&2] < pos[&3]);
614 }
615}
616#[cfg(test)]
617mod trext_pass_tests {
618 use super::*;
619 #[test]
620 pub(super) fn test_trext_phase_order() {
621 assert_eq!(TRExtPassPhase::Early.order(), 0);
622 assert_eq!(TRExtPassPhase::Middle.order(), 1);
623 assert_eq!(TRExtPassPhase::Late.order(), 2);
624 assert_eq!(TRExtPassPhase::Finalize.order(), 3);
625 assert!(TRExtPassPhase::Early.is_early());
626 assert!(!TRExtPassPhase::Early.is_late());
627 }
628 #[test]
629 pub(super) fn test_trext_config_builder() {
630 let c = TRExtPassConfig::new("p")
631 .with_phase(TRExtPassPhase::Late)
632 .with_max_iter(50)
633 .with_debug(1);
634 assert_eq!(c.name, "p");
635 assert_eq!(c.max_iterations, 50);
636 assert!(c.is_debug_enabled());
637 assert!(c.enabled);
638 let c2 = c.disabled();
639 assert!(!c2.enabled);
640 }
641 #[test]
642 pub(super) fn test_trext_stats() {
643 let mut s = TRExtPassStats::new();
644 s.visit();
645 s.visit();
646 s.modify();
647 s.iterate();
648 assert_eq!(s.nodes_visited, 2);
649 assert_eq!(s.nodes_modified, 1);
650 assert!(s.changed);
651 assert_eq!(s.iterations, 1);
652 let e = s.efficiency();
653 assert!((e - 0.5).abs() < 1e-9);
654 }
655 #[test]
656 pub(super) fn test_trext_registry() {
657 let mut r = TRExtPassRegistry::new();
658 r.register(TRExtPassConfig::new("a").with_phase(TRExtPassPhase::Early));
659 r.register(TRExtPassConfig::new("b").disabled());
660 assert_eq!(r.len(), 2);
661 assert_eq!(r.enabled_passes().len(), 1);
662 assert_eq!(r.passes_in_phase(&TRExtPassPhase::Early).len(), 1);
663 }
664 #[test]
665 pub(super) fn test_trext_cache() {
666 let mut c = TRExtCache::new(4);
667 assert!(c.get(99).is_none());
668 c.put(99, vec![1, 2, 3]);
669 let v = c.get(99).expect("v should be present in map");
670 assert_eq!(v, &[1u8, 2, 3]);
671 assert!(c.hit_rate() > 0.0);
672 assert_eq!(c.live_count(), 1);
673 }
674 #[test]
675 pub(super) fn test_trext_worklist() {
676 let mut w = TRExtWorklist::new(10);
677 w.push(5);
678 w.push(3);
679 w.push(5);
680 assert_eq!(w.len(), 2);
681 assert!(w.contains(5));
682 let first = w.pop().expect("first should be available to pop");
683 assert!(!w.contains(first));
684 }
685 #[test]
686 pub(super) fn test_trext_dom_tree() {
687 let mut dt = TRExtDomTree::new(5);
688 dt.set_idom(1, 0);
689 dt.set_idom(2, 0);
690 dt.set_idom(3, 1);
691 dt.set_idom(4, 1);
692 assert!(dt.dominates(0, 3));
693 assert!(dt.dominates(1, 4));
694 assert!(!dt.dominates(2, 3));
695 assert_eq!(dt.depth_of(3), 2);
696 }
697 #[test]
698 pub(super) fn test_trext_liveness() {
699 let mut lv = TRExtLiveness::new(3);
700 lv.add_def(0, 1);
701 lv.add_use(1, 1);
702 assert!(lv.var_is_def_in_block(0, 1));
703 assert!(lv.var_is_used_in_block(1, 1));
704 assert!(!lv.var_is_def_in_block(1, 1));
705 }
706 #[test]
707 pub(super) fn test_trext_const_folder() {
708 let mut cf = TRExtConstFolder::new();
709 assert_eq!(cf.add_i64(3, 4), Some(7));
710 assert_eq!(cf.div_i64(10, 0), None);
711 assert_eq!(cf.mul_i64(6, 7), Some(42));
712 assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
713 assert_eq!(cf.fold_count(), 3);
714 assert_eq!(cf.failure_count(), 1);
715 }
716 #[test]
717 pub(super) fn test_trext_dep_graph() {
718 let mut g = TRExtDepGraph::new(4);
719 g.add_edge(0, 1);
720 g.add_edge(1, 2);
721 g.add_edge(2, 3);
722 assert!(!g.has_cycle());
723 assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
724 assert_eq!(g.reachable(0).len(), 4);
725 let sccs = g.scc();
726 assert_eq!(sccs.len(), 4);
727 }
728}
729#[cfg(test)]
730mod trx2_pass_tests {
731 use super::*;
732 #[test]
733 pub(super) fn test_trx2_phase_order() {
734 assert_eq!(TRX2PassPhase::Early.order(), 0);
735 assert_eq!(TRX2PassPhase::Middle.order(), 1);
736 assert_eq!(TRX2PassPhase::Late.order(), 2);
737 assert_eq!(TRX2PassPhase::Finalize.order(), 3);
738 assert!(TRX2PassPhase::Early.is_early());
739 assert!(!TRX2PassPhase::Early.is_late());
740 }
741 #[test]
742 pub(super) fn test_trx2_config_builder() {
743 let c = TRX2PassConfig::new("p")
744 .with_phase(TRX2PassPhase::Late)
745 .with_max_iter(50)
746 .with_debug(1);
747 assert_eq!(c.name, "p");
748 assert_eq!(c.max_iterations, 50);
749 assert!(c.is_debug_enabled());
750 assert!(c.enabled);
751 let c2 = c.disabled();
752 assert!(!c2.enabled);
753 }
754 #[test]
755 pub(super) fn test_trx2_stats() {
756 let mut s = TRX2PassStats::new();
757 s.visit();
758 s.visit();
759 s.modify();
760 s.iterate();
761 assert_eq!(s.nodes_visited, 2);
762 assert_eq!(s.nodes_modified, 1);
763 assert!(s.changed);
764 assert_eq!(s.iterations, 1);
765 let e = s.efficiency();
766 assert!((e - 0.5).abs() < 1e-9);
767 }
768 #[test]
769 pub(super) fn test_trx2_registry() {
770 let mut r = TRX2PassRegistry::new();
771 r.register(TRX2PassConfig::new("a").with_phase(TRX2PassPhase::Early));
772 r.register(TRX2PassConfig::new("b").disabled());
773 assert_eq!(r.len(), 2);
774 assert_eq!(r.enabled_passes().len(), 1);
775 assert_eq!(r.passes_in_phase(&TRX2PassPhase::Early).len(), 1);
776 }
777 #[test]
778 pub(super) fn test_trx2_cache() {
779 let mut c = TRX2Cache::new(4);
780 assert!(c.get(99).is_none());
781 c.put(99, vec![1, 2, 3]);
782 let v = c.get(99).expect("v should be present in map");
783 assert_eq!(v, &[1u8, 2, 3]);
784 assert!(c.hit_rate() > 0.0);
785 assert_eq!(c.live_count(), 1);
786 }
787 #[test]
788 pub(super) fn test_trx2_worklist() {
789 let mut w = TRX2Worklist::new(10);
790 w.push(5);
791 w.push(3);
792 w.push(5);
793 assert_eq!(w.len(), 2);
794 assert!(w.contains(5));
795 let first = w.pop().expect("first should be available to pop");
796 assert!(!w.contains(first));
797 }
798 #[test]
799 pub(super) fn test_trx2_dom_tree() {
800 let mut dt = TRX2DomTree::new(5);
801 dt.set_idom(1, 0);
802 dt.set_idom(2, 0);
803 dt.set_idom(3, 1);
804 dt.set_idom(4, 1);
805 assert!(dt.dominates(0, 3));
806 assert!(dt.dominates(1, 4));
807 assert!(!dt.dominates(2, 3));
808 assert_eq!(dt.depth_of(3), 2);
809 }
810 #[test]
811 pub(super) fn test_trx2_liveness() {
812 let mut lv = TRX2Liveness::new(3);
813 lv.add_def(0, 1);
814 lv.add_use(1, 1);
815 assert!(lv.var_is_def_in_block(0, 1));
816 assert!(lv.var_is_used_in_block(1, 1));
817 assert!(!lv.var_is_def_in_block(1, 1));
818 }
819 #[test]
820 pub(super) fn test_trx2_const_folder() {
821 let mut cf = TRX2ConstFolder::new();
822 assert_eq!(cf.add_i64(3, 4), Some(7));
823 assert_eq!(cf.div_i64(10, 0), None);
824 assert_eq!(cf.mul_i64(6, 7), Some(42));
825 assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
826 assert_eq!(cf.fold_count(), 3);
827 assert_eq!(cf.failure_count(), 1);
828 }
829 #[test]
830 pub(super) fn test_trx2_dep_graph() {
831 let mut g = TRX2DepGraph::new(4);
832 g.add_edge(0, 1);
833 g.add_edge(1, 2);
834 g.add_edge(2, 3);
835 assert!(!g.has_cycle());
836 assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
837 assert_eq!(g.reachable(0).len(), 4);
838 let sccs = g.scc();
839 assert_eq!(sccs.len(), 4);
840 }
841}