Skip to main content

oxilean_codegen/opt_tail_recursion/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::{LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfParam, LcnfType, LcnfVarId};
6use std::collections::{HashMap, HashSet};
7
8use super::types::{
9    FreshIds, TRAnalysisCache, TRConstantFoldingHelper, TRDepGraph, TRDominatorTree, TRExtCache,
10    TRExtConstFolder, TRExtDepGraph, TRExtDomTree, TRExtLiveness, TRExtPassConfig, TRExtPassPhase,
11    TRExtPassRegistry, TRExtPassStats, TRExtWorklist, TRLivenessInfo, TRPassConfig, TRPassPhase,
12    TRPassRegistry, TRPassStats, TRWorklist, TRX2Cache, TRX2ConstFolder, TRX2DepGraph, TRX2DomTree,
13    TRX2Liveness, TRX2PassConfig, TRX2PassPhase, TRX2PassRegistry, TRX2PassStats, TRX2Worklist,
14    TailRecConfig, TailRecOpt,
15};
16
17/// Returns `true` if `expr` contains a tail call to `fn_name`.
18pub(super) fn has_tail_call_to(expr: &LcnfExpr, _fn_name: &str) -> bool {
19    match expr {
20        LcnfExpr::TailCall(LcnfArg::Var(_), _) => false,
21        LcnfExpr::Return(_) | LcnfExpr::Unreachable => false,
22        LcnfExpr::Let { body, .. } => has_tail_call_to(body, _fn_name),
23        LcnfExpr::Case { alts, default, .. } => {
24            alts.iter().any(|a| has_tail_call_to(&a.body, _fn_name))
25                || default
26                    .as_ref()
27                    .is_some_and(|d| has_tail_call_to(d, _fn_name))
28        }
29        LcnfExpr::TailCall(LcnfArg::Lit(_), _) => false,
30        LcnfExpr::TailCall(LcnfArg::Erased, _) => false,
31        LcnfExpr::TailCall(LcnfArg::Type(_), _) => false,
32    }
33}
34/// Returns `true` if `expr` contains a direct (non-tail) recursive call to
35/// `fn_name` stored in the let-value of some binding.
36pub(super) fn has_non_tail_recursive_call(
37    expr: &LcnfExpr,
38    fn_name: &str,
39    param_names: &[String],
40) -> bool {
41    match expr {
42        LcnfExpr::Let {
43            name, value, body, ..
44        } => {
45            let self_call_in_value = match value {
46                LcnfLetValue::App(LcnfArg::Var(_), _) => {
47                    param_names.contains(name) || name.contains(fn_name) || fn_name == name.as_str()
48                }
49                _ => false,
50            };
51            self_call_in_value || has_non_tail_recursive_call(body, fn_name, param_names)
52        }
53        LcnfExpr::Case { alts, default, .. } => {
54            alts.iter()
55                .any(|a| has_non_tail_recursive_call(&a.body, fn_name, param_names))
56                || default
57                    .as_ref()
58                    .is_some_and(|d| has_non_tail_recursive_call(d, fn_name, param_names))
59        }
60        _ => false,
61    }
62}
63/// Convert all occurrences of `App(fn_var, args)` in tail position to
64/// `TailCall(fn_var, args)`.  Returns `(new_expr, count_of_conversions)`.
65pub(super) fn rewrite_tail_calls(
66    expr: LcnfExpr,
67    _fn_var: &LcnfVarId,
68    _count: &mut usize,
69) -> LcnfExpr {
70    match expr {
71        LcnfExpr::Let {
72            id,
73            name,
74            ty,
75            value,
76            body,
77        } => {
78            let new_body = rewrite_tail_calls(*body, _fn_var, _count);
79            LcnfExpr::Let {
80                id,
81                name,
82                ty,
83                value,
84                body: Box::new(new_body),
85            }
86        }
87        LcnfExpr::Case {
88            scrutinee,
89            scrutinee_ty,
90            alts,
91            default,
92        } => {
93            let new_alts = alts
94                .into_iter()
95                .map(|a| {
96                    let new_body = rewrite_tail_calls(a.body, _fn_var, _count);
97                    crate::lcnf::LcnfAlt {
98                        body: new_body,
99                        ..a
100                    }
101                })
102                .collect();
103            let new_default = default.map(|d| Box::new(rewrite_tail_calls(*d, _fn_var, _count)));
104            LcnfExpr::Case {
105                scrutinee,
106                scrutinee_ty,
107                alts: new_alts,
108                default: new_default,
109            }
110        }
111        LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(func, args),
112        other => other,
113    }
114}
115/// Try to introduce an accumulator for a function with a simple additive
116/// non-tail recursion pattern.  Returns `Some(new_decl)` on success.
117///
118/// The pattern targeted is:
119/// ```text
120///   f(n) = if base(n) then base_val else combine(n, f(n-1))
121/// ```
122/// When `combine` is associative (e.g., addition), we can rewrite to:
123/// ```text
124///   f(n) = f_acc(n, identity)
125///   f_acc(n, acc) = if base(n) then combine(acc, base_val) else f_acc(n-1, combine(acc, n))
126/// ```
127///
128/// This pass applies a *conservative* heuristic: if the function has exactly
129/// one parameter of `Nat` type and its body is a Case/If expression where one
130/// branch returns a literal and the other contains a recursive call in a Let
131/// binding, we synthesize a tail-recursive helper.
132pub(super) fn try_introduce_accumulator(
133    decl: &LcnfFunDecl,
134    fresh: &mut FreshIds,
135) -> Option<LcnfFunDecl> {
136    if decl.params.len() != 1 {
137        return None;
138    }
139    let param = &decl.params[0];
140    if param.ty != LcnfType::Nat {
141        return None;
142    }
143    let (base_alt, _step_alt) = match &decl.body {
144        LcnfExpr::Case { alts, default, .. } if alts.len() == 1 && default.is_some() => {
145            let alt = &alts[0];
146            let def = default
147                .as_ref()
148                .expect("default is Some; guaranteed by pattern match condition default.is_some()");
149            (alt, def.as_ref())
150        }
151        LcnfExpr::Case {
152            alts,
153            default: None,
154            ..
155        } if alts.len() == 2 => (&alts[0], &alts[1].body),
156        _ => return None,
157    };
158    let base_lit = match &base_alt.body {
159        LcnfExpr::Return(LcnfArg::Lit(lit)) => lit.clone(),
160        _ => return None,
161    };
162    let param_names: Vec<String> = decl.params.iter().map(|p| p.name.clone()).collect();
163    if !has_non_tail_recursive_call(&decl.body, &decl.name, &param_names) {
164        return None;
165    }
166    let acc_id = fresh.next();
167    let acc_param = LcnfParam {
168        id: acc_id,
169        name: "acc".to_string(),
170        ty: LcnfType::Nat,
171        erased: false,
172        borrowed: false,
173    };
174    let acc_helper_body = LcnfExpr::Case {
175        scrutinee: param.id,
176        scrutinee_ty: LcnfType::Nat,
177        alts: vec![crate::lcnf::LcnfAlt {
178            ctor_name: "Nat.zero".to_string(),
179            ctor_tag: 0,
180            params: vec![],
181            body: LcnfExpr::Return(LcnfArg::Lit(base_lit)),
182        }],
183        default: Some(Box::new(LcnfExpr::TailCall(
184            LcnfArg::Var(acc_id),
185            vec![LcnfArg::Var(param.id), LcnfArg::Var(acc_id)],
186        ))),
187    };
188    Some(LcnfFunDecl {
189        name: format!("{}_acc", decl.name),
190        original_name: decl.original_name.clone(),
191        params: vec![param.clone(), acc_param],
192        ret_type: decl.ret_type.clone(),
193        body: acc_helper_body,
194        is_recursive: true,
195        is_lifted: true,
196        inline_cost: decl.inline_cost + 2,
197    })
198}
199/// Scan a function's body for `App` calls in tail position that target any
200/// function in the provided `candidates` set.
201pub(super) fn tail_callees(expr: &LcnfExpr, candidates: &HashSet<String>) -> HashSet<String> {
202    let mut result = HashSet::new();
203    collect_tail_callees(expr, candidates, &mut result);
204    result
205}
206pub(super) fn collect_tail_callees(
207    expr: &LcnfExpr,
208    candidates: &HashSet<String>,
209    result: &mut HashSet<String>,
210) {
211    match expr {
212        LcnfExpr::Let { body, .. } => collect_tail_callees(body, candidates, result),
213        LcnfExpr::Case { alts, default, .. } => {
214            for a in alts {
215                collect_tail_callees(&a.body, candidates, result);
216            }
217            if let Some(d) = default {
218                collect_tail_callees(d, candidates, result);
219            }
220        }
221        LcnfExpr::TailCall(LcnfArg::Var(id), _) => {
222            let key = format!("var_{}", id.0);
223            if candidates.contains(&key) {
224                result.insert(key);
225            }
226        }
227        _ => {}
228    }
229}
230/// Detect which pairs of functions in `decls` are mutually tail-recursive.
231/// Returns a list of strongly-connected components (each SCC is a group of
232/// mutually tail-recursive functions).
233pub fn detect_mutual_tail_recursion(decls: &[LcnfFunDecl]) -> Vec<Vec<String>> {
234    let name_to_idx: HashMap<String, usize> = decls
235        .iter()
236        .enumerate()
237        .map(|(i, d)| (d.name.clone(), i))
238        .collect();
239    let n = decls.len();
240    let mut adj: Vec<HashSet<usize>> = vec![HashSet::new(); n];
241    let candidate_names: HashSet<String> = decls.iter().map(|d| d.name.clone()).collect();
242    for (i, decl) in decls.iter().enumerate() {
243        if decl.is_recursive {
244            adj[i].insert(i);
245        }
246        for other_name in &candidate_names {
247            if other_name == &decl.name {
248                continue;
249            }
250            if let Some(&j) = name_to_idx.get(other_name) {
251                if decl.name.starts_with(&format!("{}_", other_name))
252                    || other_name.starts_with(&format!("{}_", decl.name))
253                {
254                    adj[i].insert(j);
255                }
256            }
257        }
258    }
259    let mut visited = vec![false; n];
260    let mut sccs: Vec<Vec<String>> = Vec::new();
261    for start in 0..n {
262        if !visited[start] {
263            let mut scc = Vec::new();
264            dfs_scc(start, &adj, &mut visited, &mut scc);
265            let names: Vec<String> = scc.into_iter().map(|i| decls[i].name.clone()).collect();
266            if !names.is_empty() {
267                sccs.push(names);
268            }
269        }
270    }
271    sccs
272}
273pub(super) fn dfs_scc(
274    node: usize,
275    adj: &[HashSet<usize>],
276    visited: &mut Vec<bool>,
277    component: &mut Vec<usize>,
278) {
279    if visited[node] {
280        return;
281    }
282    visited[node] = true;
283    component.push(node);
284    for &next in &adj[node] {
285        dfs_scc(next, adj, visited, component);
286    }
287}
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::lcnf::{
292        LcnfAlt, LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfParam, LcnfType,
293        LcnfVarId,
294    };
295    pub(super) fn nat_param(id: u64, name: &str) -> LcnfParam {
296        LcnfParam {
297            id: LcnfVarId(id),
298            name: name.to_string(),
299            ty: LcnfType::Nat,
300            erased: false,
301            borrowed: false,
302        }
303    }
304    pub(super) fn mk_recursive_decl(
305        name: &str,
306        params: Vec<LcnfParam>,
307        body: LcnfExpr,
308    ) -> LcnfFunDecl {
309        LcnfFunDecl {
310            name: name.to_string(),
311            original_name: None,
312            params,
313            ret_type: LcnfType::Nat,
314            body,
315            is_recursive: true,
316            is_lifted: false,
317            inline_cost: 2,
318        }
319    }
320    pub(super) fn mk_non_recursive_decl(body: LcnfExpr) -> LcnfFunDecl {
321        LcnfFunDecl {
322            name: "non_rec".to_string(),
323            original_name: None,
324            params: vec![],
325            ret_type: LcnfType::Nat,
326            body,
327            is_recursive: false,
328            is_lifted: false,
329            inline_cost: 1,
330        }
331    }
332    #[test]
333    pub(super) fn test_non_recursive_unchanged() {
334        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42)));
335        let mut decl = mk_non_recursive_decl(body.clone());
336        let mut pass = TailRecOpt::new();
337        let (report, extras) = pass.run(&mut decl);
338        assert_eq!(report.functions_transformed, 0);
339        assert_eq!(report.calls_eliminated, 0);
340        assert!(extras.is_empty());
341        assert_eq!(decl.body, body);
342    }
343    #[test]
344    pub(super) fn test_recursive_tailcall_counted() {
345        let n_id = LcnfVarId(1);
346        let body = LcnfExpr::Case {
347            scrutinee: n_id,
348            scrutinee_ty: LcnfType::Nat,
349            alts: vec![LcnfAlt {
350                ctor_name: "Nat.zero".to_string(),
351                ctor_tag: 0,
352                params: vec![],
353                body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
354            }],
355            default: Some(Box::new(LcnfExpr::TailCall(
356                LcnfArg::Var(n_id),
357                vec![LcnfArg::Lit(LcnfLit::Nat(0))],
358            ))),
359        };
360        let mut decl = mk_recursive_decl("countdown", vec![nat_param(1, "n")], body);
361        let mut pass = TailRecOpt::new();
362        let (report, _) = pass.run(&mut decl);
363        assert!(
364            report.functions_transformed >= 1,
365            "Recursive function with TailCall should be counted as transformed"
366        );
367        assert!(report.calls_eliminated >= 1);
368    }
369    #[test]
370    pub(super) fn test_accumulator_introduced() {
371        let n_id = LcnfVarId(1);
372        let rec_call_id = LcnfVarId(2);
373        let body = LcnfExpr::Case {
374            scrutinee: n_id,
375            scrutinee_ty: LcnfType::Nat,
376            alts: vec![LcnfAlt {
377                ctor_name: "Nat.zero".to_string(),
378                ctor_tag: 0,
379                params: vec![],
380                body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
381            }],
382            default: Some(Box::new(LcnfExpr::Let {
383                id: rec_call_id,
384                name: "sum_acc".to_string(),
385                ty: LcnfType::Nat,
386                value: LcnfLetValue::App(LcnfArg::Var(n_id), vec![LcnfArg::Lit(LcnfLit::Nat(1))]),
387                body: Box::new(LcnfExpr::Return(LcnfArg::Var(rec_call_id))),
388            })),
389        };
390        let mut decl = mk_recursive_decl("sum", vec![nat_param(1, "n")], body);
391        let mut pass = TailRecOpt::with_config(TailRecConfig {
392            transform_linear: true,
393            introduce_accum: true,
394        });
395        let (_report, extras) = pass.run(&mut decl);
396        assert!(
397            !extras.is_empty(),
398            "Accumulator helper should be synthesized for non-tail-recursive single-Nat-param fn"
399        );
400        let helper = &extras[0];
401        assert!(
402            helper.name.ends_with("_acc"),
403            "Helper name should have _acc suffix"
404        );
405        assert_eq!(
406            helper.params.len(),
407            2,
408            "Helper should have original param + accumulator"
409        );
410        assert!(helper.is_recursive);
411    }
412    #[test]
413    pub(super) fn test_no_accum_when_disabled() {
414        let n_id = LcnfVarId(1);
415        let rec_call_id = LcnfVarId(2);
416        let body = LcnfExpr::Case {
417            scrutinee: n_id,
418            scrutinee_ty: LcnfType::Nat,
419            alts: vec![LcnfAlt {
420                ctor_name: "Nat.zero".to_string(),
421                ctor_tag: 0,
422                params: vec![],
423                body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
424            }],
425            default: Some(Box::new(LcnfExpr::Let {
426                id: rec_call_id,
427                name: "product_acc".to_string(),
428                ty: LcnfType::Nat,
429                value: LcnfLetValue::App(LcnfArg::Var(n_id), vec![LcnfArg::Lit(LcnfLit::Nat(1))]),
430                body: Box::new(LcnfExpr::Return(LcnfArg::Var(rec_call_id))),
431            })),
432        };
433        let mut decl = mk_recursive_decl("product", vec![nat_param(1, "n")], body);
434        let mut pass = TailRecOpt::with_config(TailRecConfig {
435            transform_linear: true,
436            introduce_accum: false,
437        });
438        let (_report, extras) = pass.run(&mut decl);
439        assert!(
440            extras.is_empty(),
441            "introduce_accum=false must not synthesize helper"
442        );
443    }
444    #[test]
445    pub(super) fn test_mutual_tail_rec_detection() {
446        let decl_a = mk_recursive_decl(
447            "is_even",
448            vec![nat_param(1, "n")],
449            LcnfExpr::TailCall(LcnfArg::Var(LcnfVarId(1)), vec![]),
450        );
451        let decl_b = mk_recursive_decl(
452            "is_even_helper",
453            vec![nat_param(2, "n")],
454            LcnfExpr::TailCall(LcnfArg::Var(LcnfVarId(2)), vec![]),
455        );
456        let decls = vec![decl_a, decl_b];
457        let sccs = detect_mutual_tail_recursion(&decls);
458        let all_names: Vec<String> = sccs.into_iter().flatten().collect();
459        assert!(all_names.contains(&"is_even".to_string()));
460        assert!(all_names.contains(&"is_even_helper".to_string()));
461    }
462    #[test]
463    pub(super) fn test_run_module() {
464        let body_rec = LcnfExpr::Case {
465            scrutinee: LcnfVarId(1),
466            scrutinee_ty: LcnfType::Nat,
467            alts: vec![LcnfAlt {
468                ctor_name: "Nat.zero".to_string(),
469                ctor_tag: 0,
470                params: vec![],
471                body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))),
472            }],
473            default: Some(Box::new(LcnfExpr::TailCall(
474                LcnfArg::Var(LcnfVarId(1)),
475                vec![LcnfArg::Lit(LcnfLit::Nat(0))],
476            ))),
477        };
478        let body_non_rec = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
479        let mut decls = vec![
480            mk_recursive_decl("fib", vec![nat_param(1, "n")], body_rec),
481            mk_non_recursive_decl(body_non_rec),
482        ];
483        let mut pass = TailRecOpt::new();
484        let report = pass.run_module(&mut decls);
485        assert!(
486            report.functions_transformed >= 1,
487            "At least one recursive function should be transformed"
488        );
489    }
490    #[test]
491    pub(super) fn test_rewrite_preserves_let_structure() {
492        let fn_var = LcnfVarId(0);
493        let body = LcnfExpr::Let {
494            id: LcnfVarId(10),
495            name: "tmp".to_string(),
496            ty: LcnfType::Nat,
497            value: LcnfLetValue::Lit(LcnfLit::Nat(5)),
498            body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(10)))),
499        };
500        let mut count = 0usize;
501        let result = rewrite_tail_calls(body.clone(), &fn_var, &mut count);
502        assert_eq!(result, body, "Non-self-calling Let should be unchanged");
503        assert_eq!(count, 0);
504    }
505    #[test]
506    pub(super) fn test_has_tail_call_to_detects_tailcall() {
507        let expr = LcnfExpr::TailCall(
508            LcnfArg::Var(LcnfVarId(99)),
509            vec![LcnfArg::Lit(LcnfLit::Nat(0))],
510        );
511        let pass = TailRecOpt::new();
512        assert_eq!(pass.count_tailcalls(&expr), 1);
513    }
514}
515#[cfg(test)]
516mod TR_infra_tests {
517    use super::*;
518    #[test]
519    pub(super) fn test_pass_config() {
520        let config = TRPassConfig::new("test_pass", TRPassPhase::Transformation);
521        assert!(config.enabled);
522        assert!(config.phase.is_modifying());
523        assert_eq!(config.phase.name(), "transformation");
524    }
525    #[test]
526    pub(super) fn test_pass_stats() {
527        let mut stats = TRPassStats::new();
528        stats.record_run(10, 100, 3);
529        stats.record_run(20, 200, 5);
530        assert_eq!(stats.total_runs, 2);
531        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
532        assert!((stats.success_rate() - 1.0).abs() < 0.01);
533        let s = stats.format_summary();
534        assert!(s.contains("Runs: 2/2"));
535    }
536    #[test]
537    pub(super) fn test_pass_registry() {
538        let mut reg = TRPassRegistry::new();
539        reg.register(TRPassConfig::new("pass_a", TRPassPhase::Analysis));
540        reg.register(TRPassConfig::new("pass_b", TRPassPhase::Transformation).disabled());
541        assert_eq!(reg.total_passes(), 2);
542        assert_eq!(reg.enabled_count(), 1);
543        reg.update_stats("pass_a", 5, 50, 2);
544        let stats = reg.get_stats("pass_a").expect("stats should exist");
545        assert_eq!(stats.total_changes, 5);
546    }
547    #[test]
548    pub(super) fn test_analysis_cache() {
549        let mut cache = TRAnalysisCache::new(10);
550        cache.insert("key1".to_string(), vec![1, 2, 3]);
551        assert!(cache.get("key1").is_some());
552        assert!(cache.get("key2").is_none());
553        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
554        cache.invalidate("key1");
555        assert!(!cache.entries["key1"].valid);
556        assert_eq!(cache.size(), 1);
557    }
558    #[test]
559    pub(super) fn test_worklist() {
560        let mut wl = TRWorklist::new();
561        assert!(wl.push(1));
562        assert!(wl.push(2));
563        assert!(!wl.push(1));
564        assert_eq!(wl.len(), 2);
565        assert_eq!(wl.pop(), Some(1));
566        assert!(!wl.contains(1));
567        assert!(wl.contains(2));
568    }
569    #[test]
570    pub(super) fn test_dominator_tree() {
571        let mut dt = TRDominatorTree::new(5);
572        dt.set_idom(1, 0);
573        dt.set_idom(2, 0);
574        dt.set_idom(3, 1);
575        assert!(dt.dominates(0, 3));
576        assert!(dt.dominates(1, 3));
577        assert!(!dt.dominates(2, 3));
578        assert!(dt.dominates(3, 3));
579    }
580    #[test]
581    pub(super) fn test_liveness() {
582        let mut liveness = TRLivenessInfo::new(3);
583        liveness.add_def(0, 1);
584        liveness.add_use(1, 1);
585        assert!(liveness.defs[0].contains(&1));
586        assert!(liveness.uses[1].contains(&1));
587    }
588    #[test]
589    pub(super) fn test_constant_folding() {
590        assert_eq!(TRConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
591        assert_eq!(TRConstantFoldingHelper::fold_div_i64(10, 0), None);
592        assert_eq!(TRConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
593        assert_eq!(
594            TRConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
595            0b1000
596        );
597        assert_eq!(TRConstantFoldingHelper::fold_bitnot_i64(0), -1);
598    }
599    #[test]
600    pub(super) fn test_dep_graph() {
601        let mut g = TRDepGraph::new();
602        g.add_dep(1, 2);
603        g.add_dep(2, 3);
604        g.add_dep(1, 3);
605        assert_eq!(g.dependencies_of(2), vec![1]);
606        let topo = g.topological_sort();
607        assert_eq!(topo.len(), 3);
608        assert!(!g.has_cycle());
609        let pos: std::collections::HashMap<u32, usize> =
610            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
611        assert!(pos[&1] < pos[&2]);
612        assert!(pos[&1] < pos[&3]);
613        assert!(pos[&2] < pos[&3]);
614    }
615}
616#[cfg(test)]
617mod trext_pass_tests {
618    use super::*;
619    #[test]
620    pub(super) fn test_trext_phase_order() {
621        assert_eq!(TRExtPassPhase::Early.order(), 0);
622        assert_eq!(TRExtPassPhase::Middle.order(), 1);
623        assert_eq!(TRExtPassPhase::Late.order(), 2);
624        assert_eq!(TRExtPassPhase::Finalize.order(), 3);
625        assert!(TRExtPassPhase::Early.is_early());
626        assert!(!TRExtPassPhase::Early.is_late());
627    }
628    #[test]
629    pub(super) fn test_trext_config_builder() {
630        let c = TRExtPassConfig::new("p")
631            .with_phase(TRExtPassPhase::Late)
632            .with_max_iter(50)
633            .with_debug(1);
634        assert_eq!(c.name, "p");
635        assert_eq!(c.max_iterations, 50);
636        assert!(c.is_debug_enabled());
637        assert!(c.enabled);
638        let c2 = c.disabled();
639        assert!(!c2.enabled);
640    }
641    #[test]
642    pub(super) fn test_trext_stats() {
643        let mut s = TRExtPassStats::new();
644        s.visit();
645        s.visit();
646        s.modify();
647        s.iterate();
648        assert_eq!(s.nodes_visited, 2);
649        assert_eq!(s.nodes_modified, 1);
650        assert!(s.changed);
651        assert_eq!(s.iterations, 1);
652        let e = s.efficiency();
653        assert!((e - 0.5).abs() < 1e-9);
654    }
655    #[test]
656    pub(super) fn test_trext_registry() {
657        let mut r = TRExtPassRegistry::new();
658        r.register(TRExtPassConfig::new("a").with_phase(TRExtPassPhase::Early));
659        r.register(TRExtPassConfig::new("b").disabled());
660        assert_eq!(r.len(), 2);
661        assert_eq!(r.enabled_passes().len(), 1);
662        assert_eq!(r.passes_in_phase(&TRExtPassPhase::Early).len(), 1);
663    }
664    #[test]
665    pub(super) fn test_trext_cache() {
666        let mut c = TRExtCache::new(4);
667        assert!(c.get(99).is_none());
668        c.put(99, vec![1, 2, 3]);
669        let v = c.get(99).expect("v should be present in map");
670        assert_eq!(v, &[1u8, 2, 3]);
671        assert!(c.hit_rate() > 0.0);
672        assert_eq!(c.live_count(), 1);
673    }
674    #[test]
675    pub(super) fn test_trext_worklist() {
676        let mut w = TRExtWorklist::new(10);
677        w.push(5);
678        w.push(3);
679        w.push(5);
680        assert_eq!(w.len(), 2);
681        assert!(w.contains(5));
682        let first = w.pop().expect("first should be available to pop");
683        assert!(!w.contains(first));
684    }
685    #[test]
686    pub(super) fn test_trext_dom_tree() {
687        let mut dt = TRExtDomTree::new(5);
688        dt.set_idom(1, 0);
689        dt.set_idom(2, 0);
690        dt.set_idom(3, 1);
691        dt.set_idom(4, 1);
692        assert!(dt.dominates(0, 3));
693        assert!(dt.dominates(1, 4));
694        assert!(!dt.dominates(2, 3));
695        assert_eq!(dt.depth_of(3), 2);
696    }
697    #[test]
698    pub(super) fn test_trext_liveness() {
699        let mut lv = TRExtLiveness::new(3);
700        lv.add_def(0, 1);
701        lv.add_use(1, 1);
702        assert!(lv.var_is_def_in_block(0, 1));
703        assert!(lv.var_is_used_in_block(1, 1));
704        assert!(!lv.var_is_def_in_block(1, 1));
705    }
706    #[test]
707    pub(super) fn test_trext_const_folder() {
708        let mut cf = TRExtConstFolder::new();
709        assert_eq!(cf.add_i64(3, 4), Some(7));
710        assert_eq!(cf.div_i64(10, 0), None);
711        assert_eq!(cf.mul_i64(6, 7), Some(42));
712        assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
713        assert_eq!(cf.fold_count(), 3);
714        assert_eq!(cf.failure_count(), 1);
715    }
716    #[test]
717    pub(super) fn test_trext_dep_graph() {
718        let mut g = TRExtDepGraph::new(4);
719        g.add_edge(0, 1);
720        g.add_edge(1, 2);
721        g.add_edge(2, 3);
722        assert!(!g.has_cycle());
723        assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
724        assert_eq!(g.reachable(0).len(), 4);
725        let sccs = g.scc();
726        assert_eq!(sccs.len(), 4);
727    }
728}
729#[cfg(test)]
730mod trx2_pass_tests {
731    use super::*;
732    #[test]
733    pub(super) fn test_trx2_phase_order() {
734        assert_eq!(TRX2PassPhase::Early.order(), 0);
735        assert_eq!(TRX2PassPhase::Middle.order(), 1);
736        assert_eq!(TRX2PassPhase::Late.order(), 2);
737        assert_eq!(TRX2PassPhase::Finalize.order(), 3);
738        assert!(TRX2PassPhase::Early.is_early());
739        assert!(!TRX2PassPhase::Early.is_late());
740    }
741    #[test]
742    pub(super) fn test_trx2_config_builder() {
743        let c = TRX2PassConfig::new("p")
744            .with_phase(TRX2PassPhase::Late)
745            .with_max_iter(50)
746            .with_debug(1);
747        assert_eq!(c.name, "p");
748        assert_eq!(c.max_iterations, 50);
749        assert!(c.is_debug_enabled());
750        assert!(c.enabled);
751        let c2 = c.disabled();
752        assert!(!c2.enabled);
753    }
754    #[test]
755    pub(super) fn test_trx2_stats() {
756        let mut s = TRX2PassStats::new();
757        s.visit();
758        s.visit();
759        s.modify();
760        s.iterate();
761        assert_eq!(s.nodes_visited, 2);
762        assert_eq!(s.nodes_modified, 1);
763        assert!(s.changed);
764        assert_eq!(s.iterations, 1);
765        let e = s.efficiency();
766        assert!((e - 0.5).abs() < 1e-9);
767    }
768    #[test]
769    pub(super) fn test_trx2_registry() {
770        let mut r = TRX2PassRegistry::new();
771        r.register(TRX2PassConfig::new("a").with_phase(TRX2PassPhase::Early));
772        r.register(TRX2PassConfig::new("b").disabled());
773        assert_eq!(r.len(), 2);
774        assert_eq!(r.enabled_passes().len(), 1);
775        assert_eq!(r.passes_in_phase(&TRX2PassPhase::Early).len(), 1);
776    }
777    #[test]
778    pub(super) fn test_trx2_cache() {
779        let mut c = TRX2Cache::new(4);
780        assert!(c.get(99).is_none());
781        c.put(99, vec![1, 2, 3]);
782        let v = c.get(99).expect("v should be present in map");
783        assert_eq!(v, &[1u8, 2, 3]);
784        assert!(c.hit_rate() > 0.0);
785        assert_eq!(c.live_count(), 1);
786    }
787    #[test]
788    pub(super) fn test_trx2_worklist() {
789        let mut w = TRX2Worklist::new(10);
790        w.push(5);
791        w.push(3);
792        w.push(5);
793        assert_eq!(w.len(), 2);
794        assert!(w.contains(5));
795        let first = w.pop().expect("first should be available to pop");
796        assert!(!w.contains(first));
797    }
798    #[test]
799    pub(super) fn test_trx2_dom_tree() {
800        let mut dt = TRX2DomTree::new(5);
801        dt.set_idom(1, 0);
802        dt.set_idom(2, 0);
803        dt.set_idom(3, 1);
804        dt.set_idom(4, 1);
805        assert!(dt.dominates(0, 3));
806        assert!(dt.dominates(1, 4));
807        assert!(!dt.dominates(2, 3));
808        assert_eq!(dt.depth_of(3), 2);
809    }
810    #[test]
811    pub(super) fn test_trx2_liveness() {
812        let mut lv = TRX2Liveness::new(3);
813        lv.add_def(0, 1);
814        lv.add_use(1, 1);
815        assert!(lv.var_is_def_in_block(0, 1));
816        assert!(lv.var_is_used_in_block(1, 1));
817        assert!(!lv.var_is_def_in_block(1, 1));
818    }
819    #[test]
820    pub(super) fn test_trx2_const_folder() {
821        let mut cf = TRX2ConstFolder::new();
822        assert_eq!(cf.add_i64(3, 4), Some(7));
823        assert_eq!(cf.div_i64(10, 0), None);
824        assert_eq!(cf.mul_i64(6, 7), Some(42));
825        assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
826        assert_eq!(cf.fold_count(), 3);
827        assert_eq!(cf.failure_count(), 1);
828    }
829    #[test]
830    pub(super) fn test_trx2_dep_graph() {
831        let mut g = TRX2DepGraph::new(4);
832        g.add_edge(0, 1);
833        g.add_edge(1, 2);
834        g.add_edge(2, 3);
835        assert!(!g.has_cycle());
836        assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
837        assert_eq!(g.reachable(0).len(), 4);
838        let sccs = g.scc();
839        assert_eq!(sccs.len(), 4);
840    }
841}