Skip to main content

oxilean_codegen/opt_join/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::*;
6use std::collections::{HashMap, HashSet};
7
8use super::types::{
9    CallSiteInfo, JoinPointConfig, JoinPointOptimizer, JoinPointStats, OJAnalysisCache,
10    OJConstantFoldingHelper, OJDepGraph, OJDominatorTree, OJLivenessInfo, OJPassConfig,
11    OJPassPhase, OJPassRegistry, OJPassStats, OJWorklist, OJoinConfig, OJoinDiagCollector,
12    OJoinDiagMsg, OJoinEmitStats, OJoinEventLog, OJoinFeatures, OJoinIdGen, OJoinIncrKey,
13    OJoinNameScope, OJoinPassTiming, OJoinProfiler, OJoinSourceBuffer, OJoinVersion, TailUse,
14};
15
16/// Analyze tail position usage of variables in an expression
17pub(super) fn analyze_tail_uses(expr: &LcnfExpr, tail: bool) -> HashMap<LcnfVarId, TailUse> {
18    let mut uses: HashMap<LcnfVarId, TailUse> = HashMap::new();
19    match expr {
20        LcnfExpr::Let {
21            value, body, id, ..
22        } => {
23            collect_value_uses(value, &mut uses, false);
24            let body_uses = analyze_tail_uses(body, tail);
25            for (var, use_kind) in body_uses {
26                if var != *id {
27                    let current = uses.entry(var).or_insert(TailUse::Unused);
28                    *current = current.merge(&use_kind);
29                }
30            }
31        }
32        LcnfExpr::Case {
33            scrutinee,
34            alts,
35            default,
36            ..
37        } => {
38            let current = uses.entry(*scrutinee).or_insert(TailUse::Unused);
39            *current = current.merge(&TailUse::NonTail);
40            for alt in alts {
41                let alt_uses = analyze_tail_uses(&alt.body, tail);
42                for (var, use_kind) in alt_uses {
43                    let current = uses.entry(var).or_insert(TailUse::Unused);
44                    *current = current.merge(&use_kind);
45                }
46            }
47            if let Some(def) = default {
48                let def_uses = analyze_tail_uses(def, tail);
49                for (var, use_kind) in def_uses {
50                    let current = uses.entry(var).or_insert(TailUse::Unused);
51                    *current = current.merge(&use_kind);
52                }
53            }
54        }
55        LcnfExpr::Return(arg) => {
56            if let LcnfArg::Var(v) = arg {
57                let use_kind = if tail {
58                    TailUse::TailOnly
59                } else {
60                    TailUse::NonTail
61                };
62                let current = uses.entry(*v).or_insert(TailUse::Unused);
63                *current = current.merge(&use_kind);
64            }
65        }
66        LcnfExpr::TailCall(func, args) => {
67            if let LcnfArg::Var(v) = func {
68                let use_kind = if tail {
69                    TailUse::TailOnly
70                } else {
71                    TailUse::NonTail
72                };
73                let current = uses.entry(*v).or_insert(TailUse::Unused);
74                *current = current.merge(&use_kind);
75            }
76            for arg in args {
77                if let LcnfArg::Var(v) = arg {
78                    let current = uses.entry(*v).or_insert(TailUse::Unused);
79                    *current = current.merge(&TailUse::NonTail);
80                }
81            }
82        }
83        LcnfExpr::Unreachable => {}
84    }
85    uses
86}
87/// Collect variable uses from a let-value
88pub(super) fn collect_value_uses(
89    value: &LcnfLetValue,
90    uses: &mut HashMap<LcnfVarId, TailUse>,
91    _tail: bool,
92) {
93    let vars = extract_value_vars(value);
94    for v in vars {
95        let current = uses.entry(v).or_insert(TailUse::Unused);
96        *current = current.merge(&TailUse::NonTail);
97    }
98}
99/// Extract all variable references from a let-value
100pub(super) fn extract_value_vars(value: &LcnfLetValue) -> Vec<LcnfVarId> {
101    let mut vars = Vec::new();
102    match value {
103        LcnfLetValue::App(func, args) => {
104            if let LcnfArg::Var(v) = func {
105                vars.push(*v);
106            }
107            for a in args {
108                if let LcnfArg::Var(v) = a {
109                    vars.push(*v);
110                }
111            }
112        }
113        LcnfLetValue::Proj(_, _, v) => {
114            vars.push(*v);
115        }
116        LcnfLetValue::Ctor(_, _, args) => {
117            for a in args {
118                if let LcnfArg::Var(v) = a {
119                    vars.push(*v);
120                }
121            }
122        }
123        LcnfLetValue::FVar(v) => {
124            vars.push(*v);
125        }
126        LcnfLetValue::Lit(_)
127        | LcnfLetValue::Erased
128        | LcnfLetValue::Reset(_)
129        | LcnfLetValue::Reuse(_, _, _, _) => {}
130    }
131    vars
132}
133/// Analyze all call sites in a function body
134pub(super) fn analyze_call_sites(
135    expr: &LcnfExpr,
136    caller: &str,
137    in_tail: bool,
138) -> Vec<CallSiteInfo> {
139    let mut sites = Vec::new();
140    match expr {
141        LcnfExpr::Let { value, body, .. } => {
142            if let LcnfLetValue::App(func, args) = value {
143                let callee_var = if let LcnfArg::Var(v) = func {
144                    Some(*v)
145                } else {
146                    None
147                };
148                sites.push(CallSiteInfo {
149                    caller: caller.to_string(),
150                    is_tail: false,
151                    arg_count: args.len(),
152                    callee_var,
153                });
154            }
155            sites.extend(analyze_call_sites(body, caller, in_tail));
156        }
157        LcnfExpr::Case { alts, default, .. } => {
158            for alt in alts {
159                sites.extend(analyze_call_sites(&alt.body, caller, in_tail));
160            }
161            if let Some(def) = default {
162                sites.extend(analyze_call_sites(def, caller, in_tail));
163            }
164        }
165        LcnfExpr::TailCall(func, args) => {
166            let callee_var = if let LcnfArg::Var(v) = func {
167                Some(*v)
168            } else {
169                None
170            };
171            sites.push(CallSiteInfo {
172                caller: caller.to_string(),
173                is_tail: in_tail,
174                arg_count: args.len(),
175                callee_var,
176            });
177        }
178        LcnfExpr::Return(_) | LcnfExpr::Unreachable => {}
179    }
180    sites
181}
182/// Collect all variable IDs that are used (referenced) in an expression
183pub(super) fn collect_used_vars(expr: &LcnfExpr) -> HashSet<LcnfVarId> {
184    let mut used = HashSet::new();
185    collect_used_vars_inner(expr, &mut used);
186    used
187}
188pub(super) fn collect_used_vars_inner(expr: &LcnfExpr, used: &mut HashSet<LcnfVarId>) {
189    match expr {
190        LcnfExpr::Let { value, body, .. } => {
191            collect_value_used_vars(value, used);
192            collect_used_vars_inner(body, used);
193        }
194        LcnfExpr::Case {
195            scrutinee,
196            alts,
197            default,
198            ..
199        } => {
200            used.insert(*scrutinee);
201            for alt in alts {
202                collect_used_vars_inner(&alt.body, used);
203            }
204            if let Some(def) = default {
205                collect_used_vars_inner(def, used);
206            }
207        }
208        LcnfExpr::Return(arg) => {
209            if let LcnfArg::Var(v) = arg {
210                used.insert(*v);
211            }
212        }
213        LcnfExpr::TailCall(func, args) => {
214            if let LcnfArg::Var(v) = func {
215                used.insert(*v);
216            }
217            for a in args {
218                if let LcnfArg::Var(v) = a {
219                    used.insert(*v);
220                }
221            }
222        }
223        LcnfExpr::Unreachable => {}
224    }
225}
226pub(super) fn collect_value_used_vars(value: &LcnfLetValue, used: &mut HashSet<LcnfVarId>) {
227    match value {
228        LcnfLetValue::App(func, args) => {
229            if let LcnfArg::Var(v) = func {
230                used.insert(*v);
231            }
232            for a in args {
233                if let LcnfArg::Var(v) = a {
234                    used.insert(*v);
235                }
236            }
237        }
238        LcnfLetValue::Proj(_, _, v) => {
239            used.insert(*v);
240        }
241        LcnfLetValue::Ctor(_, _, args) => {
242            for a in args {
243                if let LcnfArg::Var(v) = a {
244                    used.insert(*v);
245                }
246            }
247        }
248        LcnfLetValue::FVar(v) => {
249            used.insert(*v);
250        }
251        LcnfLetValue::Lit(_)
252        | LcnfLetValue::Erased
253        | LcnfLetValue::Reset(_)
254        | LcnfLetValue::Reuse(_, _, _, _) => {}
255    }
256}
257/// Check whether a let-value is pure (no side effects)
258pub(super) fn is_pure_value(value: &LcnfLetValue) -> bool {
259    match value {
260        LcnfLetValue::Lit(_)
261        | LcnfLetValue::Erased
262        | LcnfLetValue::FVar(_)
263        | LcnfLetValue::Proj(_, _, _)
264        | LcnfLetValue::Ctor(_, _, _) => true,
265        LcnfLetValue::App(_, _) | LcnfLetValue::Reset(_) | LcnfLetValue::Reuse(_, _, _, _) => false,
266    }
267}
268/// Check whether an expression references a given variable
269pub(super) fn expr_uses_var(expr: &LcnfExpr, var: LcnfVarId) -> bool {
270    match expr {
271        LcnfExpr::Let {
272            id, value, body, ..
273        } => value_uses_var(value, var) || (*id != var && expr_uses_var(body, var)),
274        LcnfExpr::Case {
275            scrutinee,
276            alts,
277            default,
278            ..
279        } => {
280            *scrutinee == var
281                || alts.iter().any(|alt| expr_uses_var(&alt.body, var))
282                || default.as_ref().is_some_and(|d| expr_uses_var(d, var))
283        }
284        LcnfExpr::Return(arg) => matches!(arg, LcnfArg::Var(v) if * v == var),
285        LcnfExpr::TailCall(func, args) => {
286            matches!(func, LcnfArg::Var(v) if * v == var)
287                || args
288                    .iter()
289                    .any(|a| matches!(a, LcnfArg::Var(v) if * v == var))
290        }
291        LcnfExpr::Unreachable => false,
292    }
293}
294/// Check whether a let-value references a given variable
295pub(super) fn value_uses_var(value: &LcnfLetValue, var: LcnfVarId) -> bool {
296    match value {
297        LcnfLetValue::App(func, args) => {
298            matches!(func, LcnfArg::Var(v) if * v == var)
299                || args
300                    .iter()
301                    .any(|a| matches!(a, LcnfArg::Var(v) if * v == var))
302        }
303        LcnfLetValue::Proj(_, _, v) => *v == var,
304        LcnfLetValue::Ctor(_, _, args) => args
305            .iter()
306            .any(|a| matches!(a, LcnfArg::Var(v) if * v == var)),
307        LcnfLetValue::FVar(v) => *v == var,
308        LcnfLetValue::Lit(_)
309        | LcnfLetValue::Erased
310        | LcnfLetValue::Reset(_)
311        | LcnfLetValue::Reuse(_, _, _, _) => false,
312    }
313}
314/// Count the number of LCNF instructions in an expression
315pub(super) fn count_instructions(expr: &LcnfExpr) -> usize {
316    match expr {
317        LcnfExpr::Let { body, .. } => 1 + count_instructions(body),
318        LcnfExpr::Case { alts, default, .. } => {
319            let alts_size: usize = alts.iter().map(|a| count_instructions(&a.body)).sum();
320            let def_size = default.as_ref().map(|d| count_instructions(d)).unwrap_or(0);
321            1 + alts_size + def_size
322        }
323        LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable => 1,
324    }
325}
326/// Compute the call graph from a set of declarations
327pub(super) fn compute_call_graph(decls: &[LcnfFunDecl]) -> HashMap<String, HashSet<String>> {
328    let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
329    let decl_names: HashSet<&str> = decls.iter().map(|d| d.name.as_str()).collect();
330    for decl in decls {
331        let mut callees = HashSet::new();
332        collect_callees(&decl.body, &decl_names, &mut callees);
333        graph.insert(decl.name.clone(), callees);
334    }
335    graph
336}
337/// Collect all callee function names from an expression.
338///
339/// Requires a var-to-name map to resolve function references; starts empty
340/// and propagates name info through FVar copies encountered in let-bindings.
341pub(super) fn collect_callees(
342    expr: &LcnfExpr,
343    known_fns: &HashSet<&str>,
344    callees: &mut HashSet<String>,
345) {
346    let ctx: HashMap<LcnfVarId, String> = HashMap::new();
347    collect_callees_ctx(expr, known_fns, callees, &ctx);
348}
349/// Inner helper that carries a var→name context for name propagation.
350pub(super) fn collect_callees_ctx(
351    expr: &LcnfExpr,
352    known_fns: &HashSet<&str>,
353    callees: &mut HashSet<String>,
354    ctx: &HashMap<LcnfVarId, String>,
355) {
356    match expr {
357        LcnfExpr::Let {
358            id, value, body, ..
359        } => {
360            if let LcnfLetValue::App(LcnfArg::Var(v), _) = value {
361                if let Some(name) = ctx.get(v) {
362                    if known_fns.contains(name.as_str()) {
363                        callees.insert(name.clone());
364                    }
365                }
366            }
367            if let LcnfLetValue::FVar(v) = value {
368                if let Some(name) = ctx.get(v).cloned() {
369                    let mut extended = ctx.clone();
370                    extended.insert(*id, name);
371                    collect_callees_ctx(body, known_fns, callees, &extended);
372                    return;
373                }
374            }
375            collect_callees_ctx(body, known_fns, callees, ctx);
376        }
377        LcnfExpr::Case { alts, default, .. } => {
378            for alt in alts {
379                collect_callees_ctx(&alt.body, known_fns, callees, ctx);
380            }
381            if let Some(def) = default {
382                collect_callees_ctx(def, known_fns, callees, ctx);
383            }
384        }
385        LcnfExpr::TailCall(LcnfArg::Var(v), _) => {
386            if let Some(name) = ctx.get(v) {
387                if known_fns.contains(name.as_str()) {
388                    callees.insert(name.clone());
389                }
390            }
391        }
392        _ => {}
393    }
394}
395/// Find self-recursive tail calls in a function
396pub(super) fn find_self_recursive_tail_calls(
397    expr: &LcnfExpr,
398    fn_name: &str,
399    var_to_name: &HashMap<LcnfVarId, String>,
400) -> Vec<LcnfVarId> {
401    let mut self_calls = Vec::new();
402    match expr {
403        LcnfExpr::Let { body, .. } => {
404            self_calls.extend(find_self_recursive_tail_calls(body, fn_name, var_to_name));
405        }
406        LcnfExpr::Case { alts, default, .. } => {
407            for alt in alts {
408                self_calls.extend(find_self_recursive_tail_calls(
409                    &alt.body,
410                    fn_name,
411                    var_to_name,
412                ));
413            }
414            if let Some(def) = default {
415                self_calls.extend(find_self_recursive_tail_calls(def, fn_name, var_to_name));
416            }
417        }
418        LcnfExpr::TailCall(LcnfArg::Var(v), _) => {
419            if let Some(name) = var_to_name.get(v) {
420                if name == fn_name {
421                    self_calls.push(*v);
422                }
423            }
424        }
425        _ => {}
426    }
427    self_calls
428}
429/// Determine whether a function is a join point candidate
430/// (all call sites are in tail position)
431pub(super) fn is_join_point_candidate(callee_id: LcnfVarId, call_sites: &[CallSiteInfo]) -> bool {
432    let relevant: Vec<&CallSiteInfo> = call_sites
433        .iter()
434        .filter(|cs| cs.callee_var == Some(callee_id))
435        .collect();
436    if relevant.is_empty() {
437        return false;
438    }
439    relevant.iter().all(|cs| cs.is_tail)
440}
441/// Main entry point: optimize join points in a module
442pub fn optimize_join_points(module: &mut LcnfModule, config: &JoinPointConfig) {
443    let mut optimizer = JoinPointOptimizer::new(config.clone());
444    for decl in &mut module.fun_decls {
445        optimizer.optimize_decl(decl);
446    }
447    if config.eliminate_dead_joins {
448        eliminate_dead_functions(module);
449    }
450}
451/// Remove function declarations that are never referenced
452pub(super) fn eliminate_dead_functions(module: &mut LcnfModule) {
453    if module.fun_decls.len() <= 1 {
454        return;
455    }
456    let call_graph = compute_call_graph(&module.fun_decls);
457    let mut reachable: HashSet<String> = HashSet::new();
458    let mut worklist: Vec<String> = module
459        .fun_decls
460        .iter()
461        .filter(|d| !d.is_lifted)
462        .map(|d| d.name.clone())
463        .collect();
464    while let Some(fn_name) = worklist.pop() {
465        if reachable.insert(fn_name.clone()) {
466            if let Some(callees) = call_graph.get(&fn_name) {
467                for callee in callees {
468                    if !reachable.contains(callee) {
469                        worklist.push(callee.clone());
470                    }
471                }
472            }
473        }
474    }
475    module.fun_decls.retain(|d| reachable.contains(&d.name));
476}
477/// Create a join point from a let-binding that is only used in tail position
478pub(super) fn create_join_point(
479    join_id: LcnfVarId,
480    params: Vec<LcnfParam>,
481    body: LcnfExpr,
482    ret_type: LcnfType,
483) -> LcnfFunDecl {
484    let cost = count_instructions(&body);
485    LcnfFunDecl {
486        name: format!("_join_{}", join_id.0),
487        original_name: None,
488        params,
489        ret_type,
490        body,
491        is_recursive: false,
492        is_lifted: true,
493        inline_cost: cost,
494    }
495}
496/// Convert a tail-recursive function to use a loop with join points
497pub(super) fn convert_to_loop(decl: &mut LcnfFunDecl) -> bool {
498    if !decl.is_recursive {
499        return false;
500    }
501    let var_to_name: HashMap<LcnfVarId, String> = HashMap::new();
502    let self_calls = find_self_recursive_tail_calls(&decl.body, &decl.name, &var_to_name);
503    !self_calls.is_empty()
504}
505#[cfg(test)]
506mod tests {
507    use super::*;
508    pub(super) fn make_var(n: u64) -> LcnfVarId {
509        LcnfVarId(n)
510    }
511    pub(super) fn make_param(n: u64, name: &str) -> LcnfParam {
512        LcnfParam {
513            id: LcnfVarId(n),
514            name: name.to_string(),
515            ty: LcnfType::Nat,
516            erased: false,
517            borrowed: false,
518        }
519    }
520    pub(super) fn make_simple_let(id: u64, value: LcnfLetValue, body: LcnfExpr) -> LcnfExpr {
521        LcnfExpr::Let {
522            id: LcnfVarId(id),
523            name: format!("x{}", id),
524            ty: LcnfType::Nat,
525            value,
526            body: Box::new(body),
527        }
528    }
529    pub(super) fn make_simple_decl(name: &str, body: LcnfExpr) -> LcnfFunDecl {
530        LcnfFunDecl {
531            name: name.to_string(),
532            original_name: None,
533            params: vec![make_param(0, "arg0")],
534            ret_type: LcnfType::Nat,
535            body,
536            is_recursive: false,
537            is_lifted: false,
538            inline_cost: 1,
539        }
540    }
541    #[test]
542    pub(super) fn test_config_default() {
543        let config = JoinPointConfig::default();
544        assert_eq!(config.max_join_size, 10);
545        assert!(config.inline_small_joins);
546        assert!(config.detect_tail_calls);
547        assert!(config.enable_contification);
548    }
549    #[test]
550    pub(super) fn test_stats_default() {
551        let stats = JoinPointStats::default();
552        assert_eq!(stats.total_changes(), 0);
553    }
554    #[test]
555    pub(super) fn test_tail_use_merge() {
556        assert_eq!(TailUse::Unused.merge(&TailUse::TailOnly), TailUse::TailOnly);
557        assert_eq!(
558            TailUse::TailOnly.merge(&TailUse::TailOnly),
559            TailUse::TailOnly
560        );
561        assert_eq!(TailUse::TailOnly.merge(&TailUse::NonTail), TailUse::Mixed);
562        assert_eq!(TailUse::NonTail.merge(&TailUse::NonTail), TailUse::NonTail);
563    }
564    #[test]
565    pub(super) fn test_is_pure_value() {
566        assert!(is_pure_value(&LcnfLetValue::Lit(LcnfLit::Nat(42))));
567        assert!(is_pure_value(&LcnfLetValue::Erased));
568        assert!(is_pure_value(&LcnfLetValue::FVar(make_var(0))));
569        assert!(is_pure_value(&LcnfLetValue::Proj(
570            "foo".into(),
571            0,
572            make_var(0)
573        )));
574        assert!(!is_pure_value(&LcnfLetValue::App(
575            LcnfArg::Var(make_var(0)),
576            vec![]
577        )));
578    }
579    #[test]
580    pub(super) fn test_count_instructions() {
581        let ret = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
582        assert_eq!(count_instructions(&ret), 1);
583        let let_expr = make_simple_let(
584            1,
585            LcnfLetValue::Lit(LcnfLit::Nat(42)),
586            LcnfExpr::Return(LcnfArg::Var(make_var(1))),
587        );
588        assert_eq!(count_instructions(&let_expr), 2);
589    }
590    #[test]
591    pub(super) fn test_collect_used_vars() {
592        let expr = make_simple_let(
593            1,
594            LcnfLetValue::FVar(make_var(0)),
595            LcnfExpr::Return(LcnfArg::Var(make_var(1))),
596        );
597        let used = collect_used_vars(&expr);
598        assert!(used.contains(&make_var(0)));
599        assert!(used.contains(&make_var(1)));
600    }
601    #[test]
602    pub(super) fn test_expr_uses_var() {
603        let expr = LcnfExpr::Return(LcnfArg::Var(make_var(5)));
604        assert!(expr_uses_var(&expr, make_var(5)));
605        assert!(!expr_uses_var(&expr, make_var(6)));
606    }
607    #[test]
608    pub(super) fn test_value_uses_var() {
609        let val = LcnfLetValue::App(LcnfArg::Var(make_var(1)), vec![LcnfArg::Var(make_var(2))]);
610        assert!(value_uses_var(&val, make_var(1)));
611        assert!(value_uses_var(&val, make_var(2)));
612        assert!(!value_uses_var(&val, make_var(3)));
613    }
614    #[test]
615    pub(super) fn test_extract_value_vars() {
616        let val = LcnfLetValue::App(
617            LcnfArg::Var(make_var(1)),
618            vec![LcnfArg::Var(make_var(2)), LcnfArg::Lit(LcnfLit::Nat(0))],
619        );
620        let vars = extract_value_vars(&val);
621        assert_eq!(vars.len(), 2);
622        assert!(vars.contains(&make_var(1)));
623        assert!(vars.contains(&make_var(2)));
624    }
625    #[test]
626    pub(super) fn test_detect_tail_calls() {
627        let mut expr = make_simple_let(
628            1,
629            LcnfLetValue::App(LcnfArg::Var(make_var(10)), vec![LcnfArg::Var(make_var(0))]),
630            LcnfExpr::Return(LcnfArg::Var(make_var(1))),
631        );
632        let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
633        optimizer.detect_tail_calls_in_expr(&mut expr, "test");
634        assert!(matches!(expr, LcnfExpr::TailCall(_, _)));
635        assert_eq!(optimizer.stats.tail_calls_detected, 1);
636    }
637    #[test]
638    pub(super) fn test_dead_join_elimination() {
639        let mut expr = make_simple_let(
640            1,
641            LcnfLetValue::Lit(LcnfLit::Nat(42)),
642            make_simple_let(
643                2,
644                LcnfLetValue::Lit(LcnfLit::Nat(100)),
645                LcnfExpr::Return(LcnfArg::Var(make_var(2))),
646            ),
647        );
648        let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
649        optimizer.eliminate_dead_joins(&mut expr);
650        assert!(matches!(expr, LcnfExpr::Let { id, .. } if id == make_var(2)));
651    }
652    #[test]
653    pub(super) fn test_optimize_join_points_full() {
654        let body = make_simple_let(
655            1,
656            LcnfLetValue::Lit(LcnfLit::Nat(42)),
657            LcnfExpr::Return(LcnfArg::Var(make_var(1))),
658        );
659        let decl = make_simple_decl("test_fn", body);
660        let mut module = LcnfModule {
661            fun_decls: vec![decl],
662            extern_decls: vec![],
663            name: "test_mod".to_string(),
664            metadata: LcnfModuleMetadata::default(),
665        };
666        let config = JoinPointConfig::default();
667        optimize_join_points(&mut module, &config);
668        assert_eq!(module.fun_decls.len(), 1);
669    }
670    #[test]
671    pub(super) fn test_value_size() {
672        let optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
673        assert_eq!(optimizer.value_size(&LcnfLetValue::Lit(LcnfLit::Nat(0))), 1);
674        assert_eq!(optimizer.value_size(&LcnfLetValue::Erased), 1);
675        assert_eq!(
676            optimizer.value_size(&LcnfLetValue::App(
677                LcnfArg::Var(make_var(0)),
678                vec![LcnfArg::Var(make_var(1)), LcnfArg::Var(make_var(2))]
679            )),
680            3
681        );
682    }
683    #[test]
684    pub(super) fn test_analyze_tail_uses_return() {
685        let expr = LcnfExpr::Return(LcnfArg::Var(make_var(5)));
686        let uses = analyze_tail_uses(&expr, true);
687        assert_eq!(uses.get(&make_var(5)), Some(&TailUse::TailOnly));
688    }
689    #[test]
690    pub(super) fn test_analyze_tail_uses_non_tail() {
691        let expr = LcnfExpr::Return(LcnfArg::Var(make_var(5)));
692        let uses = analyze_tail_uses(&expr, false);
693        assert_eq!(uses.get(&make_var(5)), Some(&TailUse::NonTail));
694    }
695    #[test]
696    pub(super) fn test_call_site_analysis() {
697        let body = make_simple_let(
698            1,
699            LcnfLetValue::App(LcnfArg::Var(make_var(10)), vec![LcnfArg::Var(make_var(0))]),
700            LcnfExpr::Return(LcnfArg::Var(make_var(1))),
701        );
702        let sites = analyze_call_sites(&body, "test_fn", true);
703        assert_eq!(sites.len(), 1);
704        assert!(!sites[0].is_tail);
705        assert_eq!(sites[0].arg_count, 1);
706    }
707    #[test]
708    pub(super) fn test_compute_call_graph() {
709        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
710        let decl1 = make_simple_decl("foo", body.clone());
711        let decl2 = make_simple_decl("bar", body);
712        let graph = compute_call_graph(&[decl1, decl2]);
713        assert!(graph.contains_key("foo"));
714        assert!(graph.contains_key("bar"));
715    }
716    #[test]
717    pub(super) fn test_create_join_point() {
718        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
719        let jp = create_join_point(make_var(100), vec![make_param(1, "p")], body, LcnfType::Nat);
720        assert_eq!(jp.name, "_join_100");
721        assert!(jp.is_lifted);
722    }
723    #[test]
724    pub(super) fn test_convert_to_loop_non_recursive() {
725        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
726        let mut decl = make_simple_decl("test", body);
727        decl.is_recursive = false;
728        assert!(!convert_to_loop(&mut decl));
729    }
730    #[test]
731    pub(super) fn test_join_point_candidate() {
732        let sites = vec![
733            CallSiteInfo {
734                caller: "f".to_string(),
735                is_tail: true,
736                arg_count: 1,
737                callee_var: Some(make_var(5)),
738            },
739            CallSiteInfo {
740                caller: "g".to_string(),
741                is_tail: true,
742                arg_count: 1,
743                callee_var: Some(make_var(5)),
744            },
745        ];
746        assert!(is_join_point_candidate(make_var(5), &sites));
747        let mixed_sites = vec![
748            CallSiteInfo {
749                caller: "f".to_string(),
750                is_tail: true,
751                arg_count: 1,
752                callee_var: Some(make_var(5)),
753            },
754            CallSiteInfo {
755                caller: "g".to_string(),
756                is_tail: false,
757                arg_count: 1,
758                callee_var: Some(make_var(5)),
759            },
760        ];
761        assert!(!is_join_point_candidate(make_var(5), &mixed_sites));
762    }
763    #[test]
764    pub(super) fn test_optimizer_fresh_id() {
765        let mut opt = JoinPointOptimizer::new(JoinPointConfig::default());
766        let id1 = opt.fresh_id();
767        let id2 = opt.fresh_id();
768        assert_ne!(id1, id2);
769    }
770    #[test]
771    pub(super) fn test_case_tail_call_detection() {
772        let mut expr = LcnfExpr::Case {
773            scrutinee: make_var(0),
774            scrutinee_ty: LcnfType::Nat,
775            alts: vec![
776                LcnfAlt {
777                    ctor_name: "True".to_string(),
778                    ctor_tag: 0,
779                    params: vec![],
780                    body: make_simple_let(
781                        5,
782                        LcnfLetValue::App(
783                            LcnfArg::Var(make_var(10)),
784                            vec![LcnfArg::Var(make_var(1))],
785                        ),
786                        LcnfExpr::Return(LcnfArg::Var(make_var(5))),
787                    ),
788                },
789                LcnfAlt {
790                    ctor_name: "False".to_string(),
791                    ctor_tag: 1,
792                    params: vec![],
793                    body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
794                },
795            ],
796            default: None,
797        };
798        let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
799        optimizer.detect_tail_calls_in_expr(&mut expr, "test");
800        assert_eq!(optimizer.stats.tail_calls_detected, 1);
801        if let LcnfExpr::Case { alts, .. } = &expr {
802            assert!(matches!(alts[0].body, LcnfExpr::TailCall(_, _)));
803        }
804    }
805    #[test]
806    pub(super) fn test_nested_dead_elimination() {
807        let mut expr = make_simple_let(
808            1,
809            LcnfLetValue::Lit(LcnfLit::Nat(1)),
810            make_simple_let(
811                2,
812                LcnfLetValue::Lit(LcnfLit::Nat(2)),
813                make_simple_let(
814                    3,
815                    LcnfLetValue::Lit(LcnfLit::Nat(3)),
816                    LcnfExpr::Return(LcnfArg::Var(make_var(3))),
817                ),
818            ),
819        );
820        let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
821        optimizer.eliminate_dead_joins(&mut expr);
822        assert!(matches!(& expr, LcnfExpr::Let { id, .. } if * id == make_var(3)));
823    }
824    #[test]
825    pub(super) fn test_unreachable_count() {
826        let expr = LcnfExpr::Unreachable;
827        assert_eq!(count_instructions(&expr), 1);
828    }
829    #[test]
830    pub(super) fn test_tail_call_count() {
831        let expr = LcnfExpr::TailCall(LcnfArg::Var(make_var(0)), vec![LcnfArg::Var(make_var(1))]);
832        assert_eq!(count_instructions(&expr), 1);
833    }
834    #[test]
835    pub(super) fn test_find_small_joins() {
836        let expr = make_simple_let(
837            1,
838            LcnfLetValue::Lit(LcnfLit::Nat(42)),
839            make_simple_let(
840                2,
841                LcnfLetValue::FVar(make_var(1)),
842                LcnfExpr::Return(LcnfArg::Var(make_var(2))),
843            ),
844        );
845        let optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
846        let joins = optimizer.find_small_joins(&expr);
847        assert!(joins.contains_key(&make_var(1)));
848        assert!(joins.contains_key(&make_var(2)));
849    }
850    #[test]
851    pub(super) fn test_inline_small_joins() {
852        let mut expr = make_simple_let(
853            1,
854            LcnfLetValue::Lit(LcnfLit::Nat(42)),
855            make_simple_let(
856                2,
857                LcnfLetValue::FVar(make_var(1)),
858                LcnfExpr::Return(LcnfArg::Var(make_var(2))),
859            ),
860        );
861        let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
862        optimizer.inline_small_joins(&mut expr);
863    }
864    #[test]
865    pub(super) fn test_case_instruction_count() {
866        let expr = LcnfExpr::Case {
867            scrutinee: make_var(0),
868            scrutinee_ty: LcnfType::Nat,
869            alts: vec![
870                LcnfAlt {
871                    ctor_name: "A".to_string(),
872                    ctor_tag: 0,
873                    params: vec![],
874                    body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))),
875                },
876                LcnfAlt {
877                    ctor_name: "B".to_string(),
878                    ctor_tag: 1,
879                    params: vec![],
880                    body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(2))),
881                },
882            ],
883            default: None,
884        };
885        assert_eq!(count_instructions(&expr), 3);
886    }
887    #[test]
888    pub(super) fn test_full_pipeline_with_case() {
889        let body = LcnfExpr::Case {
890            scrutinee: make_var(0),
891            scrutinee_ty: LcnfType::Nat,
892            alts: vec![LcnfAlt {
893                ctor_name: "Zero".to_string(),
894                ctor_tag: 0,
895                params: vec![],
896                body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
897            }],
898            default: Some(Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))))),
899        };
900        let decl = make_simple_decl("test_case", body);
901        let mut module = LcnfModule {
902            fun_decls: vec![decl],
903            extern_decls: vec![],
904            name: "test_mod".to_string(),
905            metadata: LcnfModuleMetadata::default(),
906        };
907        let config = JoinPointConfig::default();
908        optimize_join_points(&mut module, &config);
909        assert_eq!(module.fun_decls.len(), 1);
910    }
911    #[test]
912    pub(super) fn test_find_self_recursive_tail_calls() {
913        let mut var_map = HashMap::new();
914        var_map.insert(make_var(10), "my_fn".to_string());
915        let expr = LcnfExpr::TailCall(LcnfArg::Var(make_var(10)), vec![LcnfArg::Var(make_var(0))]);
916        let calls = find_self_recursive_tail_calls(&expr, "my_fn", &var_map);
917        assert_eq!(calls.len(), 1);
918        assert_eq!(calls[0], make_var(10));
919    }
920    #[test]
921    pub(super) fn test_collect_callees_empty() {
922        let expr = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
923        let known: HashSet<&str> = HashSet::new();
924        let mut callees = HashSet::new();
925        collect_callees(&expr, &known, &mut callees);
926        assert!(callees.is_empty());
927    }
928    #[test]
929    pub(super) fn test_multiple_iterations() {
930        let body = make_simple_let(
931            1,
932            LcnfLetValue::Lit(LcnfLit::Nat(1)),
933            make_simple_let(
934                2,
935                LcnfLetValue::Lit(LcnfLit::Nat(2)),
936                make_simple_let(
937                    3,
938                    LcnfLetValue::App(LcnfArg::Var(make_var(10)), vec![LcnfArg::Var(make_var(2))]),
939                    LcnfExpr::Return(LcnfArg::Var(make_var(3))),
940                ),
941            ),
942        );
943        let mut decl = make_simple_decl("multi_iter", body);
944        let mut optimizer = JoinPointOptimizer::new(JoinPointConfig {
945            max_iterations: 10,
946            ..JoinPointConfig::default()
947        });
948        optimizer.optimize_decl(&mut decl);
949        assert!(optimizer.stats.iterations > 0);
950    }
951}
952#[cfg(test)]
953mod tests_ojoin_extra {
954    use super::*;
955    #[test]
956    pub(super) fn test_ojoin_config() {
957        let mut cfg = OJoinConfig::new();
958        cfg.set("mode", "release");
959        cfg.set("verbose", "true");
960        assert_eq!(cfg.get("mode"), Some("release"));
961        assert!(cfg.get_bool("verbose"));
962        assert!(cfg.get_int("mode").is_none());
963        assert_eq!(cfg.len(), 2);
964    }
965    #[test]
966    pub(super) fn test_ojoin_source_buffer() {
967        let mut buf = OJoinSourceBuffer::new();
968        buf.push_line("fn main() {");
969        buf.indent();
970        buf.push_line("println!(\"hello\");");
971        buf.dedent();
972        buf.push_line("}");
973        assert!(buf.as_str().contains("fn main()"));
974        assert!(buf.as_str().contains("    println!"));
975        assert_eq!(buf.line_count(), 3);
976        buf.reset();
977        assert!(buf.is_empty());
978    }
979    #[test]
980    pub(super) fn test_ojoin_name_scope() {
981        let mut scope = OJoinNameScope::new();
982        assert!(scope.declare("x"));
983        assert!(!scope.declare("x"));
984        assert!(scope.is_declared("x"));
985        let scope = scope.push_scope();
986        assert_eq!(scope.depth(), 1);
987        let mut scope = scope.pop_scope();
988        assert_eq!(scope.depth(), 0);
989        scope.declare("y");
990        assert_eq!(scope.len(), 2);
991    }
992    #[test]
993    pub(super) fn test_ojoin_diag_collector() {
994        let mut col = OJoinDiagCollector::new();
995        col.emit(OJoinDiagMsg::warning("pass_a", "slow"));
996        col.emit(OJoinDiagMsg::error("pass_b", "fatal"));
997        assert!(col.has_errors());
998        assert_eq!(col.errors().len(), 1);
999        assert_eq!(col.warnings().len(), 1);
1000        col.clear();
1001        assert!(col.is_empty());
1002    }
1003    #[test]
1004    pub(super) fn test_ojoin_id_gen() {
1005        let mut gen = OJoinIdGen::new();
1006        assert_eq!(gen.next_id(), 0);
1007        assert_eq!(gen.next_id(), 1);
1008        gen.skip(10);
1009        assert_eq!(gen.next_id(), 12);
1010        gen.reset();
1011        assert_eq!(gen.peek_next(), 0);
1012    }
1013    #[test]
1014    pub(super) fn test_ojoin_incr_key() {
1015        let k1 = OJoinIncrKey::new(100, 200);
1016        let k2 = OJoinIncrKey::new(100, 200);
1017        let k3 = OJoinIncrKey::new(999, 200);
1018        assert!(k1.matches(&k2));
1019        assert!(!k1.matches(&k3));
1020    }
1021    #[test]
1022    pub(super) fn test_ojoin_profiler() {
1023        let mut p = OJoinProfiler::new();
1024        p.record(OJoinPassTiming::new("pass_a", 1000, 50, 200, 100));
1025        p.record(OJoinPassTiming::new("pass_b", 500, 30, 100, 200));
1026        assert_eq!(p.total_elapsed_us(), 1500);
1027        assert_eq!(
1028            p.slowest_pass()
1029                .expect("slowest pass should exist")
1030                .pass_name,
1031            "pass_a"
1032        );
1033        assert_eq!(p.profitable_passes().len(), 1);
1034    }
1035    #[test]
1036    pub(super) fn test_ojoin_event_log() {
1037        let mut log = OJoinEventLog::new(3);
1038        log.push("event1");
1039        log.push("event2");
1040        log.push("event3");
1041        assert_eq!(log.len(), 3);
1042        log.push("event4");
1043        assert_eq!(log.len(), 3);
1044        assert_eq!(
1045            log.iter()
1046                .next()
1047                .expect("iterator should have next element"),
1048            "event2"
1049        );
1050    }
1051    #[test]
1052    pub(super) fn test_ojoin_version() {
1053        let v = OJoinVersion::new(1, 2, 3).with_pre("alpha");
1054        assert!(!v.is_stable());
1055        assert_eq!(format!("{}", v), "1.2.3-alpha");
1056        let stable = OJoinVersion::new(2, 0, 0);
1057        assert!(stable.is_stable());
1058        assert!(stable.is_compatible_with(&OJoinVersion::new(2, 0, 0)));
1059        assert!(!stable.is_compatible_with(&OJoinVersion::new(3, 0, 0)));
1060    }
1061    #[test]
1062    pub(super) fn test_ojoin_features() {
1063        let mut f = OJoinFeatures::new();
1064        f.enable("sse2");
1065        f.enable("avx2");
1066        assert!(f.is_enabled("sse2"));
1067        assert!(!f.is_enabled("avx512"));
1068        f.disable("avx2");
1069        assert!(!f.is_enabled("avx2"));
1070        let mut g = OJoinFeatures::new();
1071        g.enable("sse2");
1072        g.enable("neon");
1073        let union = f.union(&g);
1074        assert!(union.is_enabled("sse2") && union.is_enabled("neon"));
1075        let inter = f.intersection(&g);
1076        assert!(inter.is_enabled("sse2"));
1077    }
1078    #[test]
1079    pub(super) fn test_ojoin_emit_stats() {
1080        let mut s = OJoinEmitStats::new();
1081        s.bytes_emitted = 50_000;
1082        s.items_emitted = 500;
1083        s.elapsed_ms = 100;
1084        assert!(s.is_clean());
1085        assert!((s.throughput_bps() - 500_000.0).abs() < 1.0);
1086        let disp = format!("{}", s);
1087        assert!(disp.contains("bytes=50000"));
1088    }
1089}
1090#[cfg(test)]
1091mod OJ_infra_tests {
1092    use super::*;
1093    #[test]
1094    pub(super) fn test_pass_config() {
1095        let config = OJPassConfig::new("test_pass", OJPassPhase::Transformation);
1096        assert!(config.enabled);
1097        assert!(config.phase.is_modifying());
1098        assert_eq!(config.phase.name(), "transformation");
1099    }
1100    #[test]
1101    pub(super) fn test_pass_stats() {
1102        let mut stats = OJPassStats::new();
1103        stats.record_run(10, 100, 3);
1104        stats.record_run(20, 200, 5);
1105        assert_eq!(stats.total_runs, 2);
1106        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
1107        assert!((stats.success_rate() - 1.0).abs() < 0.01);
1108        let s = stats.format_summary();
1109        assert!(s.contains("Runs: 2/2"));
1110    }
1111    #[test]
1112    pub(super) fn test_pass_registry() {
1113        let mut reg = OJPassRegistry::new();
1114        reg.register(OJPassConfig::new("pass_a", OJPassPhase::Analysis));
1115        reg.register(OJPassConfig::new("pass_b", OJPassPhase::Transformation).disabled());
1116        assert_eq!(reg.total_passes(), 2);
1117        assert_eq!(reg.enabled_count(), 1);
1118        reg.update_stats("pass_a", 5, 50, 2);
1119        let stats = reg.get_stats("pass_a").expect("stats should exist");
1120        assert_eq!(stats.total_changes, 5);
1121    }
1122    #[test]
1123    pub(super) fn test_analysis_cache() {
1124        let mut cache = OJAnalysisCache::new(10);
1125        cache.insert("key1".to_string(), vec![1, 2, 3]);
1126        assert!(cache.get("key1").is_some());
1127        assert!(cache.get("key2").is_none());
1128        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
1129        cache.invalidate("key1");
1130        assert!(!cache.entries["key1"].valid);
1131        assert_eq!(cache.size(), 1);
1132    }
1133    #[test]
1134    pub(super) fn test_worklist() {
1135        let mut wl = OJWorklist::new();
1136        assert!(wl.push(1));
1137        assert!(wl.push(2));
1138        assert!(!wl.push(1));
1139        assert_eq!(wl.len(), 2);
1140        assert_eq!(wl.pop(), Some(1));
1141        assert!(!wl.contains(1));
1142        assert!(wl.contains(2));
1143    }
1144    #[test]
1145    pub(super) fn test_dominator_tree() {
1146        let mut dt = OJDominatorTree::new(5);
1147        dt.set_idom(1, 0);
1148        dt.set_idom(2, 0);
1149        dt.set_idom(3, 1);
1150        assert!(dt.dominates(0, 3));
1151        assert!(dt.dominates(1, 3));
1152        assert!(!dt.dominates(2, 3));
1153        assert!(dt.dominates(3, 3));
1154    }
1155    #[test]
1156    pub(super) fn test_liveness() {
1157        let mut liveness = OJLivenessInfo::new(3);
1158        liveness.add_def(0, 1);
1159        liveness.add_use(1, 1);
1160        assert!(liveness.defs[0].contains(&1));
1161        assert!(liveness.uses[1].contains(&1));
1162    }
1163    #[test]
1164    pub(super) fn test_constant_folding() {
1165        assert_eq!(OJConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
1166        assert_eq!(OJConstantFoldingHelper::fold_div_i64(10, 0), None);
1167        assert_eq!(OJConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
1168        assert_eq!(
1169            OJConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
1170            0b1000
1171        );
1172        assert_eq!(OJConstantFoldingHelper::fold_bitnot_i64(0), -1);
1173    }
1174    #[test]
1175    pub(super) fn test_dep_graph() {
1176        let mut g = OJDepGraph::new();
1177        g.add_dep(1, 2);
1178        g.add_dep(2, 3);
1179        g.add_dep(1, 3);
1180        assert_eq!(g.dependencies_of(2), vec![1]);
1181        let topo = g.topological_sort();
1182        assert_eq!(topo.len(), 3);
1183        assert!(!g.has_cycle());
1184        let pos: std::collections::HashMap<u32, usize> =
1185            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
1186        assert!(pos[&1] < pos[&2]);
1187        assert!(pos[&1] < pos[&3]);
1188        assert!(pos[&2] < pos[&3]);
1189    }
1190}