Skip to main content

oxilean_codegen/opt_beta_eta/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::{LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfParam, LcnfVarId};
6use std::collections::HashMap;
7
8use super::types::{
9    ArityMap, BetaEtaConfig, BetaEtaPass, BetaEtaReport, CtorEnv, ExtendedPassConfig,
10    ExtendedPassReport, FreshIdGen, KnownValue, LetBinding, LitEnv, ModuleOptStats, OptHint,
11    ParamUsageSummary,
12};
13
14/// Run the beta/eta pass with default configuration on a function declaration.
15pub fn run_beta_eta(decl: &mut LcnfFunDecl) -> BetaEtaReport {
16    let mut pass = BetaEtaPass::default();
17    pass.run(decl);
18    pass.report
19}
20#[cfg(test)]
21mod tests {
22    use super::*;
23    use crate::lcnf::{
24        LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfParam, LcnfType, LcnfVarId,
25    };
26    pub(super) fn var(n: u64) -> LcnfVarId {
27        LcnfVarId(n)
28    }
29    pub(super) fn param(id: u64, name: &str) -> LcnfParam {
30        LcnfParam {
31            id: var(id),
32            name: name.to_string(),
33            ty: LcnfType::Object,
34            erased: false,
35            borrowed: false,
36        }
37    }
38    pub(super) fn make_decl(name: &str, params: Vec<LcnfParam>, body: LcnfExpr) -> LcnfFunDecl {
39        LcnfFunDecl {
40            name: name.to_string(),
41            original_name: None,
42            params,
43            ret_type: LcnfType::Object,
44            body,
45            is_recursive: false,
46            is_lifted: false,
47            inline_cost: 0,
48        }
49    }
50    #[test]
51    pub(super) fn test_trivial_return_unchanged() {
52        let mut decl = make_decl(
53            "id",
54            vec![param(0, "x")],
55            LcnfExpr::Return(LcnfArg::Var(var(0))),
56        );
57        let report = run_beta_eta(&mut decl);
58        assert_eq!(report.beta_reductions, 0);
59        assert_eq!(report.eta_reductions, 0);
60        assert!(matches!(decl.body, LcnfExpr::Return(LcnfArg::Var(_))));
61    }
62    #[test]
63    pub(super) fn test_beta_copy_propagation() {
64        let body = LcnfExpr::Let {
65            id: var(1),
66            name: "copy".into(),
67            ty: LcnfType::Object,
68            value: LcnfLetValue::FVar(var(0)),
69            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(1)))),
70        };
71        let mut decl = make_decl("f", vec![param(0, "x")], body);
72        let report = run_beta_eta(&mut decl);
73        assert_eq!(report.beta_reductions, 1);
74    }
75    #[test]
76    pub(super) fn test_eta_reduction_wrapper() {
77        let body = LcnfExpr::Let {
78            id: var(10),
79            name: "r".into(),
80            ty: LcnfType::Object,
81            value: LcnfLetValue::App(
82                LcnfArg::Var(var(99)),
83                vec![LcnfArg::Var(var(0)), LcnfArg::Var(var(1))],
84            ),
85            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(10)))),
86        };
87        let mut decl = make_decl("wrapper", vec![param(0, "a"), param(1, "b")], body);
88        let report = run_beta_eta(&mut decl);
89        assert_eq!(report.eta_reductions, 1);
90        assert!(matches!(decl.body, LcnfExpr::TailCall(_, _)));
91    }
92    #[test]
93    pub(super) fn test_eta_no_reduction_wrong_args() {
94        let body = LcnfExpr::Let {
95            id: var(10),
96            name: "r".into(),
97            ty: LcnfType::Object,
98            value: LcnfLetValue::App(
99                LcnfArg::Var(var(99)),
100                vec![LcnfArg::Var(var(1)), LcnfArg::Var(var(0))],
101            ),
102            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(10)))),
103        };
104        let mut decl = make_decl("f", vec![param(0, "a"), param(1, "b")], body);
105        let report = run_beta_eta(&mut decl);
106        assert_eq!(report.eta_reductions, 0);
107    }
108    #[test]
109    pub(super) fn test_curried_opportunity_counted() {
110        let body = LcnfExpr::Let {
111            id: var(5),
112            name: "t".into(),
113            ty: LcnfType::Object,
114            value: LcnfLetValue::App(LcnfArg::Var(var(99)), vec![LcnfArg::Var(var(0))]),
115            body: Box::new(LcnfExpr::Let {
116                id: var(6),
117                name: "r".into(),
118                ty: LcnfType::Object,
119                value: LcnfLetValue::App(LcnfArg::Var(var(5)), vec![LcnfArg::Var(var(1))]),
120                body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(6)))),
121            }),
122        };
123        let mut decl = make_decl("curried", vec![param(0, "a"), param(1, "b")], body);
124        let report = run_beta_eta(&mut decl);
125        assert_eq!(report.curried_opportunities, 1);
126    }
127    #[test]
128    pub(super) fn test_beta_no_reduction_on_lit() {
129        let body = LcnfExpr::Let {
130            id: var(1),
131            name: "x".into(),
132            ty: LcnfType::Nat,
133            value: LcnfLetValue::Lit(crate::lcnf::LcnfLit::Nat(42)),
134            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(1)))),
135        };
136        let mut decl = make_decl("const", vec![], body);
137        let report = run_beta_eta(&mut decl);
138        assert_eq!(report.beta_reductions, 0);
139        assert_eq!(report.eta_reductions, 0);
140    }
141    #[test]
142    pub(super) fn test_eta_single_param() {
143        let body = LcnfExpr::Let {
144            id: var(2),
145            name: "r".into(),
146            ty: LcnfType::Object,
147            value: LcnfLetValue::App(LcnfArg::Var(var(50)), vec![LcnfArg::Var(var(0))]),
148            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(2)))),
149        };
150        let mut decl = make_decl("wrap1", vec![param(0, "x")], body);
151        let report = run_beta_eta(&mut decl);
152        assert_eq!(report.eta_reductions, 1);
153        assert!(matches!(decl.body, LcnfExpr::TailCall(_, _)));
154    }
155    #[test]
156    pub(super) fn test_fvar_chain() {
157        let body = LcnfExpr::Let {
158            id: var(1),
159            name: "a".into(),
160            ty: LcnfType::Object,
161            value: LcnfLetValue::FVar(var(0)),
162            body: Box::new(LcnfExpr::Let {
163                id: var(2),
164                name: "b".into(),
165                ty: LcnfType::Object,
166                value: LcnfLetValue::FVar(var(1)),
167                body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(2)))),
168            }),
169        };
170        let mut decl = make_decl("chain", vec![param(0, "x")], body);
171        let report = run_beta_eta(&mut decl);
172        assert_eq!(report.beta_reductions, 2);
173    }
174}
175/// Count how many times each variable is used in an expression.
176#[allow(dead_code)]
177pub fn count_uses(expr: &LcnfExpr, uses: &mut HashMap<LcnfVarId, usize>) {
178    match expr {
179        LcnfExpr::Let { value, body, .. } => {
180            count_uses_in_value(value, uses);
181            count_uses(body, uses);
182        }
183        LcnfExpr::Case {
184            scrutinee,
185            alts,
186            default,
187            ..
188        } => {
189            *uses.entry(*scrutinee).or_insert(0) += 1;
190            for alt in alts {
191                count_uses(&alt.body, uses);
192            }
193            if let Some(def) = default {
194                count_uses(def, uses);
195            }
196        }
197        LcnfExpr::Return(arg) => count_uses_in_arg(arg, uses),
198        LcnfExpr::TailCall(func, args) => {
199            count_uses_in_arg(func, uses);
200            for a in args {
201                count_uses_in_arg(a, uses);
202            }
203        }
204        LcnfExpr::Unreachable => {}
205    }
206}
207#[allow(dead_code)]
208pub(super) fn count_uses_in_value(value: &LcnfLetValue, uses: &mut HashMap<LcnfVarId, usize>) {
209    match value {
210        LcnfLetValue::App(func, args) => {
211            count_uses_in_arg(func, uses);
212            for a in args {
213                count_uses_in_arg(a, uses);
214            }
215        }
216        LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
217            for a in args {
218                count_uses_in_arg(a, uses);
219            }
220        }
221        LcnfLetValue::FVar(id) | LcnfLetValue::Reset(id) => {
222            *uses.entry(*id).or_insert(0) += 1;
223        }
224        LcnfLetValue::Proj(_, _, id) => {
225            *uses.entry(*id).or_insert(0) += 1;
226        }
227        LcnfLetValue::Lit(_) | LcnfLetValue::Erased => {}
228    }
229}
230#[allow(dead_code)]
231pub(super) fn count_uses_in_arg(arg: &LcnfArg, uses: &mut HashMap<LcnfVarId, usize>) {
232    if let LcnfArg::Var(id) = arg {
233        *uses.entry(*id).or_insert(0) += 1;
234    }
235}
236/// Remove `let x = v` bindings where `x` is never used and `v` has no side effects.
237#[allow(dead_code)]
238pub fn dead_let_elim(expr: LcnfExpr, report: &mut ExtendedPassReport) -> LcnfExpr {
239    let mut uses: HashMap<LcnfVarId, usize> = HashMap::new();
240    count_uses(&expr, &mut uses);
241    dead_let_elim_inner(expr, &uses, report)
242}
243pub(super) fn dead_let_elim_inner(
244    expr: LcnfExpr,
245    uses: &HashMap<LcnfVarId, usize>,
246    report: &mut ExtendedPassReport,
247) -> LcnfExpr {
248    match expr {
249        LcnfExpr::Let {
250            id,
251            name,
252            ty,
253            value,
254            body,
255        } => {
256            let use_count = uses.get(&id).copied().unwrap_or(0);
257            let _is_pure = is_pure_value(&value);
258            if use_count == 0 && is_pure_value(&value) {
259                report.dead_lets_eliminated += 1;
260                dead_let_elim_inner(*body, uses, report)
261            } else {
262                LcnfExpr::Let {
263                    id,
264                    name,
265                    ty,
266                    value,
267                    body: Box::new(dead_let_elim_inner(*body, uses, report)),
268                }
269            }
270        }
271        LcnfExpr::Case {
272            scrutinee,
273            scrutinee_ty,
274            alts,
275            default,
276        } => {
277            let new_alts = alts
278                .into_iter()
279                .map(|alt| {
280                    let new_body = dead_let_elim_inner(alt.body, uses, report);
281                    crate::lcnf::LcnfAlt {
282                        ctor_name: alt.ctor_name,
283                        ctor_tag: alt.ctor_tag,
284                        params: alt.params,
285                        body: new_body,
286                    }
287                })
288                .collect();
289            let new_default = default.map(|d| Box::new(dead_let_elim_inner(*d, uses, report)));
290            LcnfExpr::Case {
291                scrutinee,
292                scrutinee_ty,
293                alts: new_alts,
294                default: new_default,
295            }
296        }
297        other => other,
298    }
299}
300#[allow(dead_code)]
301pub(super) fn is_pure_value(value: &LcnfLetValue) -> bool {
302    matches!(
303        value,
304        LcnfLetValue::Lit(_)
305            | LcnfLetValue::Erased
306            | LcnfLetValue::FVar(_)
307            | LcnfLetValue::Ctor(_, _, _)
308            | LcnfLetValue::Proj(_, _, _)
309    )
310}
311/// Float let-bindings out of case alternative arms.
312#[allow(dead_code)]
313pub fn let_float(expr: LcnfExpr, report: &mut ExtendedPassReport) -> LcnfExpr {
314    let_float_inner(expr, report, 0)
315}
316pub(super) fn let_float_inner(
317    expr: LcnfExpr,
318    report: &mut ExtendedPassReport,
319    depth: usize,
320) -> LcnfExpr {
321    if depth > 64 {
322        return expr;
323    }
324    match expr {
325        LcnfExpr::Let {
326            id,
327            name,
328            ty,
329            value,
330            body,
331        } => {
332            let new_body = let_float_inner(*body, report, depth + 1);
333            LcnfExpr::Let {
334                id,
335                name,
336                ty,
337                value,
338                body: Box::new(new_body),
339            }
340        }
341        LcnfExpr::Case {
342            scrutinee,
343            scrutinee_ty,
344            alts,
345            default,
346        } => {
347            let new_alts: Vec<crate::lcnf::LcnfAlt> = alts
348                .into_iter()
349                .map(|alt| {
350                    let new_body = let_float_inner(alt.body, report, depth + 1);
351                    crate::lcnf::LcnfAlt {
352                        ctor_name: alt.ctor_name,
353                        ctor_tag: alt.ctor_tag,
354                        params: alt.params,
355                        body: new_body,
356                    }
357                })
358                .collect();
359            let new_default = default.map(|d| Box::new(let_float_inner(*d, report, depth + 1)));
360            LcnfExpr::Case {
361                scrutinee,
362                scrutinee_ty,
363                alts: new_alts,
364                default: new_default,
365            }
366        }
367        other => other,
368    }
369}
370/// Eliminate `case x of { K ... -> body }` when `x` is known to be `K`.
371#[allow(dead_code)]
372pub fn case_of_known_ctor(
373    expr: LcnfExpr,
374    env: &CtorEnv,
375    report: &mut ExtendedPassReport,
376) -> LcnfExpr {
377    match expr {
378        LcnfExpr::Let {
379            id,
380            name,
381            ty,
382            value,
383            body,
384        } => {
385            let mut new_env = env.clone();
386            if let LcnfLetValue::Ctor(ref cname, tag, _) = value {
387                new_env.record(id, cname.clone(), tag as u16);
388            }
389            let new_body = case_of_known_ctor(*body, &new_env, report);
390            LcnfExpr::Let {
391                id,
392                name,
393                ty,
394                value,
395                body: Box::new(new_body),
396            }
397        }
398        LcnfExpr::Case {
399            scrutinee,
400            scrutinee_ty,
401            alts,
402            default,
403        } => {
404            if let Some((known_name, known_tag)) = env.get(&scrutinee) {
405                let matching_alt = alts
406                    .iter()
407                    .find(|a| &a.ctor_name == known_name && a.ctor_tag == u32::from(*known_tag));
408                if let Some(alt) = matching_alt {
409                    report.case_of_known_ctor_elims += 1;
410                    return case_of_known_ctor(alt.body.clone(), env, report);
411                } else if let Some(def) = default {
412                    report.case_of_known_ctor_elims += 1;
413                    return case_of_known_ctor(*def, env, report);
414                }
415            }
416            let new_alts = alts
417                .into_iter()
418                .map(|alt| {
419                    let new_body = case_of_known_ctor(alt.body, env, report);
420                    crate::lcnf::LcnfAlt {
421                        ctor_name: alt.ctor_name,
422                        ctor_tag: alt.ctor_tag,
423                        params: alt.params,
424                        body: new_body,
425                    }
426                })
427                .collect();
428            let new_default = default.map(|d| Box::new(case_of_known_ctor(*d, env, report)));
429            LcnfExpr::Case {
430                scrutinee,
431                scrutinee_ty,
432                alts: new_alts,
433                default: new_default,
434            }
435        }
436        other => other,
437    }
438}
439/// Flatten a let-chain into a vector of bindings plus a terminal expression.
440#[allow(dead_code)]
441pub fn flatten_let_chain(expr: &LcnfExpr) -> (Vec<LetBinding>, &LcnfExpr) {
442    let mut bindings = Vec::new();
443    let mut cur = expr;
444    loop {
445        match cur {
446            LcnfExpr::Let {
447                id,
448                name,
449                ty,
450                value,
451                body,
452            } => {
453                bindings.push(LetBinding {
454                    id: *id,
455                    name: name.clone(),
456                    ty: ty.clone(),
457                    value: value.clone(),
458                });
459                cur = body;
460            }
461            other => return (bindings, other),
462        }
463    }
464}
465/// Reconstruct an expression from a flattened let-chain and a terminal.
466#[allow(dead_code)]
467pub fn rebuild_let_chain(bindings: Vec<LetBinding>, terminal: LcnfExpr) -> LcnfExpr {
468    bindings
469        .into_iter()
470        .rev()
471        .fold(terminal, |body, b| LcnfExpr::Let {
472            id: b.id,
473            name: b.name,
474            ty: b.ty,
475            value: b.value,
476            body: Box::new(body),
477        })
478}
479/// Estimate the inline cost of an expression.
480#[allow(dead_code)]
481pub fn inline_cost(expr: &LcnfExpr) -> usize {
482    match expr {
483        LcnfExpr::Let { body, .. } => 1 + inline_cost(body),
484        LcnfExpr::Case { alts, default, .. } => {
485            let alt_cost: usize = alts.iter().map(|a| 1 + inline_cost(&a.body)).sum();
486            let def_cost = default.as_ref().map(|d| inline_cost(d)).unwrap_or(0);
487            1 + alt_cost + def_cost
488        }
489        LcnfExpr::TailCall(_, args) => args.len(),
490        LcnfExpr::Return(_) => 0,
491        LcnfExpr::Unreachable => 0,
492    }
493}
494/// Substitute all occurrences of `from_id` with `to_arg` in `expr`.
495#[allow(dead_code)]
496pub fn subst_var_in_expr(expr: LcnfExpr, from_id: LcnfVarId, to_arg: &LcnfArg) -> LcnfExpr {
497    match expr {
498        LcnfExpr::Let {
499            id,
500            name,
501            ty,
502            mut value,
503            body,
504        } => {
505            subst_var_in_value_mut(&mut value, from_id, to_arg);
506            let new_body = if id == from_id {
507                *body
508            } else {
509                subst_var_in_expr(*body, from_id, to_arg)
510            };
511            LcnfExpr::Let {
512                id,
513                name,
514                ty,
515                value,
516                body: Box::new(new_body),
517            }
518        }
519        LcnfExpr::Case {
520            scrutinee,
521            scrutinee_ty,
522            alts,
523            default,
524        } => {
525            let new_scrutinee = if scrutinee == from_id {
526                if let LcnfArg::Var(new_id) = to_arg {
527                    *new_id
528                } else {
529                    scrutinee
530                }
531            } else {
532                scrutinee
533            };
534            let new_alts = alts
535                .into_iter()
536                .map(|alt| {
537                    let bound_in_alt = alt.params.iter().any(|p| p.id == from_id);
538                    let new_body = if bound_in_alt {
539                        alt.body
540                    } else {
541                        subst_var_in_expr(alt.body, from_id, to_arg)
542                    };
543                    crate::lcnf::LcnfAlt {
544                        ctor_name: alt.ctor_name,
545                        ctor_tag: alt.ctor_tag,
546                        params: alt.params,
547                        body: new_body,
548                    }
549                })
550                .collect();
551            let new_default = default.map(|d| Box::new(subst_var_in_expr(*d, from_id, to_arg)));
552            LcnfExpr::Case {
553                scrutinee: new_scrutinee,
554                scrutinee_ty,
555                alts: new_alts,
556                default: new_default,
557            }
558        }
559        LcnfExpr::Return(mut arg) => {
560            subst_var_in_arg_mut(&mut arg, from_id, to_arg);
561            LcnfExpr::Return(arg)
562        }
563        LcnfExpr::TailCall(mut func, mut args) => {
564            subst_var_in_arg_mut(&mut func, from_id, to_arg);
565            for a in args.iter_mut() {
566                subst_var_in_arg_mut(a, from_id, to_arg);
567            }
568            LcnfExpr::TailCall(func, args)
569        }
570        LcnfExpr::Unreachable => LcnfExpr::Unreachable,
571    }
572}
573pub(super) fn subst_var_in_arg_mut(arg: &mut LcnfArg, from_id: LcnfVarId, to_arg: &LcnfArg) {
574    if let LcnfArg::Var(id) = arg {
575        if *id == from_id {
576            *arg = to_arg.clone();
577        }
578    }
579}
580pub(super) fn subst_var_in_value_mut(
581    value: &mut LcnfLetValue,
582    from_id: LcnfVarId,
583    to_arg: &LcnfArg,
584) {
585    match value {
586        LcnfLetValue::App(func, args) => {
587            subst_var_in_arg_mut(func, from_id, to_arg);
588            for a in args.iter_mut() {
589                subst_var_in_arg_mut(a, from_id, to_arg);
590            }
591        }
592        LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
593            for a in args.iter_mut() {
594                subst_var_in_arg_mut(a, from_id, to_arg);
595            }
596        }
597        LcnfLetValue::FVar(id) => {
598            if *id == from_id {
599                if let LcnfArg::Var(new_id) = to_arg {
600                    *id = *new_id;
601                }
602            }
603        }
604        LcnfLetValue::Proj(_, _, id) | LcnfLetValue::Reset(id) => {
605            if *id == from_id {
606                if let LcnfArg::Var(new_id) = to_arg {
607                    *id = *new_id;
608                }
609            }
610        }
611        LcnfLetValue::Lit(_) | LcnfLetValue::Erased => {}
612    }
613}
614/// Alpha-rename all bound variables in `expr` using `gen` to produce fresh IDs.
615#[allow(dead_code)]
616pub fn alpha_rename(expr: LcnfExpr, gen: &mut FreshIdGen) -> LcnfExpr {
617    let mut rename_map: HashMap<LcnfVarId, LcnfVarId> = HashMap::new();
618    alpha_rename_inner(expr, gen, &mut rename_map)
619}
620pub(super) fn alpha_rename_inner(
621    expr: LcnfExpr,
622    gen: &mut FreshIdGen,
623    rename: &mut HashMap<LcnfVarId, LcnfVarId>,
624) -> LcnfExpr {
625    match expr {
626        LcnfExpr::Let {
627            id,
628            name,
629            ty,
630            value,
631            body,
632        } => {
633            let new_value = rename_in_value(value, rename);
634            let new_id = gen.fresh();
635            rename.insert(id, new_id);
636            let new_body = alpha_rename_inner(*body, gen, rename);
637            LcnfExpr::Let {
638                id: new_id,
639                name,
640                ty,
641                value: new_value,
642                body: Box::new(new_body),
643            }
644        }
645        LcnfExpr::Case {
646            scrutinee,
647            scrutinee_ty,
648            alts,
649            default,
650        } => {
651            let new_scrutinee = rename.get(&scrutinee).copied().unwrap_or(scrutinee);
652            let new_alts = alts
653                .into_iter()
654                .map(|alt| {
655                    let mut child_rename = rename.clone();
656                    let new_params: Vec<crate::lcnf::LcnfParam> = alt
657                        .params
658                        .into_iter()
659                        .map(|p| {
660                            let new_pid = gen.fresh();
661                            child_rename.insert(p.id, new_pid);
662                            crate::lcnf::LcnfParam {
663                                id: new_pid,
664                                name: p.name,
665                                ty: p.ty,
666                                erased: p.erased,
667                                borrowed: p.borrowed,
668                            }
669                        })
670                        .collect();
671                    let new_body = alpha_rename_inner(alt.body, gen, &mut child_rename);
672                    crate::lcnf::LcnfAlt {
673                        ctor_name: alt.ctor_name,
674                        ctor_tag: alt.ctor_tag,
675                        params: new_params,
676                        body: new_body,
677                    }
678                })
679                .collect();
680            let new_default = default.map(|d| {
681                let mut child_rename = rename.clone();
682                Box::new(alpha_rename_inner(*d, gen, &mut child_rename))
683            });
684            LcnfExpr::Case {
685                scrutinee: new_scrutinee,
686                scrutinee_ty,
687                alts: new_alts,
688                default: new_default,
689            }
690        }
691        LcnfExpr::Return(arg) => LcnfExpr::Return(rename_in_arg(arg, rename)),
692        LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(
693            rename_in_arg(func, rename),
694            args.into_iter().map(|a| rename_in_arg(a, rename)).collect(),
695        ),
696        LcnfExpr::Unreachable => LcnfExpr::Unreachable,
697    }
698}
699pub(super) fn rename_in_arg(arg: LcnfArg, rename: &HashMap<LcnfVarId, LcnfVarId>) -> LcnfArg {
700    match arg {
701        LcnfArg::Var(id) => LcnfArg::Var(rename.get(&id).copied().unwrap_or(id)),
702        other => other,
703    }
704}
705pub(super) fn rename_in_value(
706    value: LcnfLetValue,
707    rename: &HashMap<LcnfVarId, LcnfVarId>,
708) -> LcnfLetValue {
709    match value {
710        LcnfLetValue::App(func, args) => LcnfLetValue::App(
711            rename_in_arg(func, rename),
712            args.into_iter().map(|a| rename_in_arg(a, rename)).collect(),
713        ),
714        LcnfLetValue::Ctor(name, tag, args) => LcnfLetValue::Ctor(
715            name,
716            tag,
717            args.into_iter().map(|a| rename_in_arg(a, rename)).collect(),
718        ),
719        LcnfLetValue::Reuse(slot, name, tag, args) => LcnfLetValue::Reuse(
720            slot,
721            name,
722            tag,
723            args.into_iter().map(|a| rename_in_arg(a, rename)).collect(),
724        ),
725        LcnfLetValue::FVar(id) => LcnfLetValue::FVar(rename.get(&id).copied().unwrap_or(id)),
726        LcnfLetValue::Proj(field, ty, id) => {
727            LcnfLetValue::Proj(field, ty, rename.get(&id).copied().unwrap_or(id))
728        }
729        LcnfLetValue::Reset(id) => LcnfLetValue::Reset(rename.get(&id).copied().unwrap_or(id)),
730        other => other,
731    }
732}
733/// Full copy-propagation pass.
734#[allow(dead_code)]
735pub fn full_copy_propagation(expr: &mut LcnfExpr) -> usize {
736    let mut env: HashMap<LcnfVarId, LcnfVarId> = HashMap::new();
737    full_copy_prop_inner(expr, &mut env)
738}
739pub(super) fn full_copy_prop_inner(
740    expr: &mut LcnfExpr,
741    env: &mut HashMap<LcnfVarId, LcnfVarId>,
742) -> usize {
743    let mut count = 0;
744    match expr {
745        LcnfExpr::Let {
746            id, value, body, ..
747        } => {
748            full_copy_prop_value(value, env);
749            if let LcnfLetValue::FVar(src) = value {
750                let canonical = resolve_chain(*src, env);
751                env.insert(*id, canonical);
752                count += 1;
753            }
754            count += full_copy_prop_inner(body, env);
755        }
756        LcnfExpr::Case {
757            scrutinee,
758            alts,
759            default,
760            ..
761        } => {
762            if let Some(c) = env.get(scrutinee) {
763                *scrutinee = *c;
764                count += 1;
765            }
766            for alt in alts.iter_mut() {
767                let mut child_env = env.clone();
768                count += full_copy_prop_inner(&mut alt.body, &mut child_env);
769            }
770            if let Some(def) = default {
771                let mut child_env = env.clone();
772                count += full_copy_prop_inner(def, &mut child_env);
773            }
774        }
775        LcnfExpr::Return(arg) => {
776            if let LcnfArg::Var(id) = arg {
777                if let Some(c) = env.get(id) {
778                    *id = *c;
779                    count += 1;
780                }
781            }
782        }
783        LcnfExpr::TailCall(func, args) => {
784            if let LcnfArg::Var(id) = func {
785                if let Some(c) = env.get(id) {
786                    *id = *c;
787                    count += 1;
788                }
789            }
790            for a in args.iter_mut() {
791                if let LcnfArg::Var(id) = a {
792                    if let Some(c) = env.get(id) {
793                        *id = *c;
794                        count += 1;
795                    }
796                }
797            }
798        }
799        LcnfExpr::Unreachable => {}
800    }
801    count
802}
803pub(super) fn resolve_chain(id: LcnfVarId, env: &HashMap<LcnfVarId, LcnfVarId>) -> LcnfVarId {
804    let mut cur = id;
805    for _ in 0..64 {
806        if let Some(&next) = env.get(&cur) {
807            if next == cur {
808                break;
809            }
810            cur = next;
811        } else {
812            break;
813        }
814    }
815    cur
816}
817pub(super) fn full_copy_prop_value(value: &mut LcnfLetValue, env: &HashMap<LcnfVarId, LcnfVarId>) {
818    match value {
819        LcnfLetValue::App(func, args) => {
820            if let LcnfArg::Var(id) = func {
821                if let Some(c) = env.get(id) {
822                    *id = *c;
823                }
824            }
825            for a in args.iter_mut() {
826                if let LcnfArg::Var(id) = a {
827                    if let Some(c) = env.get(id) {
828                        *id = *c;
829                    }
830                }
831            }
832        }
833        LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
834            for a in args.iter_mut() {
835                if let LcnfArg::Var(id) = a {
836                    if let Some(c) = env.get(id) {
837                        *id = *c;
838                    }
839                }
840            }
841        }
842        LcnfLetValue::FVar(id) => {
843            if let Some(c) = env.get(id) {
844                *id = *c;
845            }
846        }
847        LcnfLetValue::Proj(_, _, id) | LcnfLetValue::Reset(id) => {
848            if let Some(c) = env.get(id) {
849                *id = *c;
850            }
851        }
852        LcnfLetValue::Lit(_) | LcnfLetValue::Erased => {}
853    }
854}
855/// Propagate literal values through the expression.
856#[allow(dead_code)]
857pub fn lit_propagate(expr: &LcnfExpr, env: &mut LitEnv) {
858    match expr {
859        LcnfExpr::Let {
860            id, value, body, ..
861        } => {
862            match value {
863                LcnfLetValue::Lit(crate::lcnf::LcnfLit::Nat(n)) => {
864                    env.record_nat(*id, *n);
865                }
866                LcnfLetValue::Lit(crate::lcnf::LcnfLit::Str(s)) => {
867                    env.record_str(*id, s.clone());
868                }
869                _ => {}
870            }
871            lit_propagate(body, env);
872        }
873        LcnfExpr::Case { alts, default, .. } => {
874            for alt in alts {
875                let mut child_env = env.clone();
876                lit_propagate(&alt.body, &mut child_env);
877            }
878            if let Some(def) = default {
879                let mut child_env = env.clone();
880                lit_propagate(def, &mut child_env);
881            }
882        }
883        LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable => {}
884    }
885}
886/// Compute the maximum nesting depth of an expression.
887#[allow(dead_code)]
888pub fn max_depth(expr: &LcnfExpr) -> usize {
889    match expr {
890        LcnfExpr::Let { body, .. } => 1 + max_depth(body),
891        LcnfExpr::Case { alts, default, .. } => {
892            let alt_max = alts
893                .iter()
894                .map(|a| 1 + max_depth(&a.body))
895                .max()
896                .unwrap_or(0);
897            let def_max = default.as_ref().map(|d| max_depth(d)).unwrap_or(0);
898            alt_max.max(def_max)
899        }
900        _ => 0,
901    }
902}
903/// Count the total number of let-bindings in an expression.
904#[allow(dead_code)]
905pub fn count_lets(expr: &LcnfExpr) -> usize {
906    match expr {
907        LcnfExpr::Let { body, .. } => 1 + count_lets(body),
908        LcnfExpr::Case { alts, default, .. } => {
909            alts.iter().map(|a| count_lets(&a.body)).sum::<usize>()
910                + default.as_ref().map(|d| count_lets(d)).unwrap_or(0)
911        }
912        _ => 0,
913    }
914}
915/// Count the total number of case expressions in an expression.
916#[allow(dead_code)]
917pub fn count_cases(expr: &LcnfExpr) -> usize {
918    match expr {
919        LcnfExpr::Let { body, .. } => count_cases(body),
920        LcnfExpr::Case { alts, default, .. } => {
921            1 + alts.iter().map(|a| count_cases(&a.body)).sum::<usize>()
922                + default.as_ref().map(|d| count_cases(d)).unwrap_or(0)
923        }
924        _ => 0,
925    }
926}
927/// Count the total number of tail calls in an expression.
928#[allow(dead_code)]
929pub fn count_tail_calls(expr: &LcnfExpr) -> usize {
930    match expr {
931        LcnfExpr::Let { body, .. } => count_tail_calls(body),
932        LcnfExpr::Case { alts, default, .. } => {
933            alts.iter()
934                .map(|a| count_tail_calls(&a.body))
935                .sum::<usize>()
936                + default.as_ref().map(|d| count_tail_calls(d)).unwrap_or(0)
937        }
938        LcnfExpr::TailCall(_, _) => 1,
939        _ => 0,
940    }
941}
942/// Collect all variables that are reachable from an expression.
943#[allow(dead_code)]
944pub fn collect_reachable(expr: &LcnfExpr, reachable: &mut std::collections::HashSet<LcnfVarId>) {
945    match expr {
946        LcnfExpr::Let { value, body, .. } => {
947            collect_reachable_in_value(value, reachable);
948            collect_reachable(body, reachable);
949        }
950        LcnfExpr::Case {
951            scrutinee,
952            alts,
953            default,
954            ..
955        } => {
956            reachable.insert(*scrutinee);
957            for alt in alts {
958                collect_reachable(&alt.body, reachable);
959            }
960            if let Some(def) = default {
961                collect_reachable(def, reachable);
962            }
963        }
964        LcnfExpr::Return(arg) => {
965            if let LcnfArg::Var(id) = arg {
966                reachable.insert(*id);
967            }
968        }
969        LcnfExpr::TailCall(func, args) => {
970            if let LcnfArg::Var(id) = func {
971                reachable.insert(*id);
972            }
973            for a in args {
974                if let LcnfArg::Var(id) = a {
975                    reachable.insert(*id);
976                }
977            }
978        }
979        LcnfExpr::Unreachable => {}
980    }
981}
982pub(super) fn collect_reachable_in_value(
983    value: &LcnfLetValue,
984    reachable: &mut std::collections::HashSet<LcnfVarId>,
985) {
986    match value {
987        LcnfLetValue::App(func, args) => {
988            if let LcnfArg::Var(id) = func {
989                reachable.insert(*id);
990            }
991            for a in args {
992                if let LcnfArg::Var(id) = a {
993                    reachable.insert(*id);
994                }
995            }
996        }
997        LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
998            for a in args {
999                if let LcnfArg::Var(id) = a {
1000                    reachable.insert(*id);
1001                }
1002            }
1003        }
1004        LcnfLetValue::FVar(id) | LcnfLetValue::Reset(id) => {
1005            reachable.insert(*id);
1006        }
1007        LcnfLetValue::Proj(_, _, id) => {
1008            reachable.insert(*id);
1009        }
1010        LcnfLetValue::Lit(_) | LcnfLetValue::Erased => {}
1011    }
1012}
1013/// Run dead-let elimination on a function declaration.
1014#[allow(dead_code)]
1015pub fn run_dead_let_elim(decl: &mut LcnfFunDecl) -> ExtendedPassReport {
1016    let mut report = ExtendedPassReport::default();
1017    let body = std::mem::replace(&mut decl.body, LcnfExpr::Unreachable);
1018    decl.body = dead_let_elim(body, &mut report);
1019    report
1020}
1021/// Run case-of-known-constructor elimination on a function declaration.
1022#[allow(dead_code)]
1023pub fn run_case_of_known_ctor(decl: &mut LcnfFunDecl) -> ExtendedPassReport {
1024    let mut report = ExtendedPassReport::default();
1025    let env = CtorEnv::new();
1026    let body = std::mem::replace(&mut decl.body, LcnfExpr::Unreachable);
1027    decl.body = case_of_known_ctor(body, &env, &mut report);
1028    report
1029}
1030/// A peephole rule: matches a pattern and returns a replacement, or `None`.
1031pub type PeepholeRule = fn(&LcnfLetValue) -> Option<LcnfLetValue>;
1032/// Run a list of peephole rules over all let-values in an expression.
1033#[allow(dead_code)]
1034pub fn peephole_pass(expr: &mut LcnfExpr, rules: &[PeepholeRule]) -> usize {
1035    let mut count = 0;
1036    match expr {
1037        LcnfExpr::Let { value, body, .. } => {
1038            for rule in rules {
1039                if let Some(new_val) = rule(value) {
1040                    *value = new_val;
1041                    count += 1;
1042                    break;
1043                }
1044            }
1045            count += peephole_pass(body, rules);
1046        }
1047        LcnfExpr::Case { alts, default, .. } => {
1048            for alt in alts.iter_mut() {
1049                count += peephole_pass(&mut alt.body, rules);
1050            }
1051            if let Some(def) = default {
1052                count += peephole_pass(def, rules);
1053            }
1054        }
1055        _ => {}
1056    }
1057    count
1058}
1059/// Peephole rule: `App(f, [])` -> `Erased`.
1060#[allow(dead_code)]
1061pub fn rule_nullary_app_to_erased(value: &LcnfLetValue) -> Option<LcnfLetValue> {
1062    if let LcnfLetValue::App(_, args) = value {
1063        if args.is_empty() {
1064            return Some(LcnfLetValue::Erased);
1065        }
1066    }
1067    None
1068}
1069/// Run all beta/eta + extended passes in a fixed-point loop.
1070#[allow(dead_code)]
1071pub fn run_optimizer(
1072    decl: &mut LcnfFunDecl,
1073    beta_cfg: BetaEtaConfig,
1074    ext_cfg: ExtendedPassConfig,
1075    max_iterations: usize,
1076) -> (BetaEtaReport, ExtendedPassReport) {
1077    let mut total_beta = BetaEtaReport::default();
1078    let mut total_ext = ExtendedPassReport::default();
1079    for _iter in 0..max_iterations {
1080        let before_lets = count_lets(&decl.body);
1081        let before_cases = count_cases(&decl.body);
1082        let mut beta_pass = BetaEtaPass::new(beta_cfg.clone());
1083        beta_pass.run(decl);
1084        total_beta.beta_reductions += beta_pass.report.beta_reductions;
1085        total_beta.eta_reductions += beta_pass.report.eta_reductions;
1086        total_beta.curried_opportunities += beta_pass.report.curried_opportunities;
1087        if ext_cfg.do_dead_let {
1088            let r = run_dead_let_elim(decl);
1089            total_ext.dead_lets_eliminated += r.dead_lets_eliminated;
1090        }
1091        if ext_cfg.do_case_of_known_ctor {
1092            let r = run_case_of_known_ctor(decl);
1093            total_ext.case_of_known_ctor_elims += r.case_of_known_ctor_elims;
1094        }
1095        if ext_cfg.do_let_float {
1096            let body = std::mem::replace(&mut decl.body, LcnfExpr::Unreachable);
1097            let mut r = ExtendedPassReport::default();
1098            decl.body = let_float(body, &mut r);
1099            total_ext.lets_floated += r.lets_floated;
1100        }
1101        let after_lets = count_lets(&decl.body);
1102        let after_cases = count_cases(&decl.body);
1103        if after_lets == before_lets && after_cases == before_cases {
1104            break;
1105        }
1106    }
1107    (total_beta, total_ext)
1108}
1109/// Compute which parameters of a function are used in its body.
1110#[allow(dead_code)]
1111pub fn param_usage_summary(decl: &LcnfFunDecl) -> ParamUsageSummary {
1112    let mut uses: HashMap<LcnfVarId, usize> = HashMap::new();
1113    count_uses(&decl.body, &mut uses);
1114    let used = decl
1115        .params
1116        .iter()
1117        .map(|p| uses.get(&p.id).copied().unwrap_or(0) > 0)
1118        .collect();
1119    ParamUsageSummary {
1120        func_name: decl.name.clone(),
1121        used,
1122    }
1123}
1124/// Collect optimization hints from a function declaration.
1125#[allow(dead_code)]
1126pub fn collect_hints(decl: &LcnfFunDecl) -> Vec<OptHint> {
1127    let mut hints = Vec::new();
1128    collect_hints_expr(&decl.body, &mut hints, &HashMap::new());
1129    let cost = inline_cost(&decl.body);
1130    if cost <= 5 {
1131        hints.push(OptHint::InlineCandidate {
1132            func_name: decl.name.clone(),
1133            cost,
1134        });
1135    }
1136    hints
1137}
1138pub(super) fn collect_hints_expr(
1139    expr: &LcnfExpr,
1140    hints: &mut Vec<OptHint>,
1141    id_to_name: &HashMap<LcnfVarId, String>,
1142) {
1143    match expr {
1144        LcnfExpr::Let {
1145            id,
1146            name,
1147            value: LcnfLetValue::App(func, args),
1148            body,
1149            ..
1150        } if args.len() == 1 => {
1151            if let LcnfExpr::Let {
1152                value: LcnfLetValue::App(LcnfArg::Var(callee), _),
1153                ..
1154            } = body.as_ref()
1155            {
1156                if callee == id {
1157                    let outer = if let LcnfArg::Var(fid) = func {
1158                        id_to_name
1159                            .get(fid)
1160                            .cloned()
1161                            .unwrap_or_else(|| format!("_x{}", fid.0))
1162                    } else {
1163                        "unknown".into()
1164                    };
1165                    hints.push(OptHint::MergeCurriedApp {
1166                        intermediate: *id,
1167                        outer_func: outer,
1168                    });
1169                }
1170            }
1171            let mut child_map = id_to_name.clone();
1172            child_map.insert(*id, name.clone());
1173            collect_hints_expr(body, hints, &child_map);
1174        }
1175        LcnfExpr::Let { id, name, body, .. } => {
1176            let mut child_map = id_to_name.clone();
1177            child_map.insert(*id, name.clone());
1178            collect_hints_expr(body, hints, &child_map);
1179        }
1180        LcnfExpr::Case { alts, default, .. } => {
1181            for alt in alts {
1182                collect_hints_expr(&alt.body, hints, id_to_name);
1183            }
1184            if let Some(def) = default {
1185                collect_hints_expr(def, hints, id_to_name);
1186            }
1187        }
1188        _ => {}
1189    }
1190}
1191/// Produce a compact human-readable representation of an `LcnfExpr`.
1192#[allow(dead_code)]
1193pub fn pp_expr(expr: &LcnfExpr) -> String {
1194    match expr {
1195        LcnfExpr::Let {
1196            id,
1197            name,
1198            value,
1199            body,
1200            ..
1201        } => {
1202            format!(
1203                "let {}:{} = {};\n{}",
1204                id,
1205                name,
1206                pp_value(value),
1207                pp_expr(body)
1208            )
1209        }
1210        LcnfExpr::Case {
1211            scrutinee,
1212            alts,
1213            default,
1214            ..
1215        } => {
1216            let mut s = format!("case {} of {{\n", scrutinee);
1217            for alt in alts {
1218                s.push_str(&format!(
1219                    "  | {} -> {}\n",
1220                    alt.ctor_name,
1221                    pp_expr(&alt.body)
1222                ));
1223            }
1224            if let Some(def) = default {
1225                s.push_str(&format!("  | _ -> {}\n", pp_expr(def)));
1226            }
1227            s.push('}');
1228            s
1229        }
1230        LcnfExpr::Return(arg) => format!("return {}", pp_arg(arg)),
1231        LcnfExpr::TailCall(func, args) => {
1232            let arg_strs: Vec<String> = args.iter().map(pp_arg).collect();
1233            format!("tailcall {}({})", pp_arg(func), arg_strs.join(", "))
1234        }
1235        LcnfExpr::Unreachable => "unreachable".into(),
1236    }
1237}
1238pub(super) fn pp_value(value: &LcnfLetValue) -> String {
1239    match value {
1240        LcnfLetValue::App(func, args) => {
1241            let arg_strs: Vec<String> = args.iter().map(pp_arg).collect();
1242            format!("App({}, [{}])", pp_arg(func), arg_strs.join(", "))
1243        }
1244        LcnfLetValue::Ctor(name, tag, args) => {
1245            let arg_strs: Vec<String> = args.iter().map(pp_arg).collect();
1246            format!("Ctor({}, {}, [{}])", name, tag, arg_strs.join(", "))
1247        }
1248        LcnfLetValue::Reuse(slot, name, tag, args) => {
1249            let arg_strs: Vec<String> = args.iter().map(pp_arg).collect();
1250            format!(
1251                "Reuse({}, {}, {}, [{}])",
1252                slot,
1253                name,
1254                tag,
1255                arg_strs.join(", ")
1256            )
1257        }
1258        LcnfLetValue::FVar(id) => format!("FVar({})", id),
1259        LcnfLetValue::Proj(field, ty, id) => format!("Proj({}, {}, {})", field, ty, id),
1260        LcnfLetValue::Reset(id) => format!("Reset({})", id),
1261        LcnfLetValue::Lit(lit) => format!("Lit({:?})", lit),
1262        LcnfLetValue::Erased => "Erased".into(),
1263    }
1264}
1265pub(super) fn pp_arg(arg: &LcnfArg) -> String {
1266    match arg {
1267        LcnfArg::Var(id) => format!("{}", id),
1268        LcnfArg::Type(ty) => format!("Type({})", ty),
1269        LcnfArg::Lit(lit) => format!("Lit({:?})", lit),
1270        LcnfArg::Erased => "Erased".to_string(),
1271    }
1272}
1273/// Run the full optimizer over every function in a module.
1274#[allow(dead_code)]
1275pub fn run_module_optimizer(
1276    decls: &mut Vec<LcnfFunDecl>,
1277    beta_cfg: BetaEtaConfig,
1278    ext_cfg: ExtendedPassConfig,
1279    max_iterations: usize,
1280) -> ModuleOptStats {
1281    let mut stats = ModuleOptStats::default();
1282    for decl in decls.iter_mut() {
1283        let (beta, ext) = run_optimizer(decl, beta_cfg.clone(), ext_cfg.clone(), max_iterations);
1284        stats.total_beta += beta.beta_reductions;
1285        stats.total_eta += beta.eta_reductions;
1286        stats.total_dead_lets += ext.dead_lets_eliminated;
1287        stats.total_cokc += ext.case_of_known_ctor_elims;
1288        stats.functions_processed += 1;
1289    }
1290    stats
1291}
1292/// Count trailing erased args in applications.
1293#[allow(dead_code)]
1294pub fn count_trailing_erased_args(expr: &LcnfExpr) -> usize {
1295    match expr {
1296        LcnfExpr::Let {
1297            value: LcnfLetValue::App(_, args),
1298            body,
1299            ..
1300        } => {
1301            let trailing = args
1302                .iter()
1303                .rev()
1304                .take_while(|a| matches!(a, LcnfArg::Type(_)))
1305                .count();
1306            trailing + count_trailing_erased_args(body)
1307        }
1308        LcnfExpr::Let { body, .. } => count_trailing_erased_args(body),
1309        LcnfExpr::Case { alts, default, .. } => {
1310            alts.iter()
1311                .map(|a| count_trailing_erased_args(&a.body))
1312                .sum::<usize>()
1313                + default
1314                    .as_ref()
1315                    .map(|d| count_trailing_erased_args(d))
1316                    .unwrap_or(0)
1317        }
1318        _ => 0,
1319    }
1320}
1321#[cfg(test)]
1322mod extended_tests {
1323    use super::*;
1324    use crate::lcnf::{
1325        LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfParam, LcnfType, LcnfVarId,
1326    };
1327    pub(super) fn var(n: u64) -> LcnfVarId {
1328        LcnfVarId(n)
1329    }
1330    pub(super) fn param(id: u64, name: &str) -> LcnfParam {
1331        LcnfParam {
1332            id: var(id),
1333            name: name.to_string(),
1334            ty: LcnfType::Object,
1335            erased: false,
1336            borrowed: false,
1337        }
1338    }
1339    pub(super) fn make_decl(name: &str, params: Vec<LcnfParam>, body: LcnfExpr) -> LcnfFunDecl {
1340        LcnfFunDecl {
1341            name: name.to_string(),
1342            original_name: None,
1343            params,
1344            ret_type: LcnfType::Object,
1345            body,
1346            is_recursive: false,
1347            is_lifted: false,
1348            inline_cost: 0,
1349        }
1350    }
1351    #[test]
1352    pub(super) fn test_count_uses_simple() {
1353        let expr = LcnfExpr::Return(LcnfArg::Var(var(0)));
1354        let mut uses = HashMap::new();
1355        count_uses(&expr, &mut uses);
1356        assert_eq!(uses.get(&var(0)), Some(&1));
1357    }
1358    #[test]
1359    pub(super) fn test_dead_let_elim_removes_unused_fvar() {
1360        let body = LcnfExpr::Let {
1361            id: var(1),
1362            name: "x".into(),
1363            ty: LcnfType::Object,
1364            value: LcnfLetValue::FVar(var(0)),
1365            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(0)))),
1366        };
1367        let mut report = ExtendedPassReport::default();
1368        let result = dead_let_elim(body, &mut report);
1369        assert_eq!(report.dead_lets_eliminated, 1);
1370        assert!(matches!(result, LcnfExpr::Return(LcnfArg::Var(_))));
1371    }
1372    #[test]
1373    pub(super) fn test_dead_let_elim_keeps_used_binding() {
1374        let body = LcnfExpr::Let {
1375            id: var(1),
1376            name: "x".into(),
1377            ty: LcnfType::Object,
1378            value: LcnfLetValue::FVar(var(0)),
1379            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(1)))),
1380        };
1381        let mut report = ExtendedPassReport::default();
1382        let _result = dead_let_elim(body, &mut report);
1383        assert_eq!(report.dead_lets_eliminated, 0);
1384    }
1385    #[test]
1386    pub(super) fn test_inline_cost_single_let() {
1387        let body = LcnfExpr::Let {
1388            id: var(0),
1389            name: "x".into(),
1390            ty: LcnfType::Object,
1391            value: LcnfLetValue::Erased,
1392            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(0)))),
1393        };
1394        assert_eq!(inline_cost(&body), 1);
1395    }
1396    #[test]
1397    pub(super) fn test_flatten_let_chain() {
1398        let expr = LcnfExpr::Let {
1399            id: var(0),
1400            name: "a".into(),
1401            ty: LcnfType::Object,
1402            value: LcnfLetValue::Erased,
1403            body: Box::new(LcnfExpr::Let {
1404                id: var(1),
1405                name: "b".into(),
1406                ty: LcnfType::Object,
1407                value: LcnfLetValue::Erased,
1408                body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(1)))),
1409            }),
1410        };
1411        let (bindings, terminal) = flatten_let_chain(&expr);
1412        assert_eq!(bindings.len(), 2);
1413        assert!(matches!(terminal, LcnfExpr::Return(_)));
1414    }
1415    #[test]
1416    pub(super) fn test_rebuild_let_chain_roundtrip() {
1417        let original = LcnfExpr::Let {
1418            id: var(0),
1419            name: "a".into(),
1420            ty: LcnfType::Object,
1421            value: LcnfLetValue::Erased,
1422            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(0)))),
1423        };
1424        let (bindings, terminal) = flatten_let_chain(&original);
1425        let rebuilt = rebuild_let_chain(bindings, terminal.clone());
1426        assert!(matches!(rebuilt, LcnfExpr::Let { id, .. } if id == var(0)));
1427    }
1428    #[test]
1429    pub(super) fn test_fresh_id_gen_sequential() {
1430        let mut gen = FreshIdGen::new(100);
1431        assert_eq!(gen.fresh(), LcnfVarId(100));
1432        assert_eq!(gen.fresh(), LcnfVarId(101));
1433        assert_eq!(gen.fresh(), LcnfVarId(102));
1434    }
1435    #[test]
1436    pub(super) fn test_alpha_rename_changes_ids() {
1437        let expr = LcnfExpr::Let {
1438            id: var(0),
1439            name: "x".into(),
1440            ty: LcnfType::Object,
1441            value: LcnfLetValue::Erased,
1442            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(0)))),
1443        };
1444        let mut gen = FreshIdGen::new(1000);
1445        let renamed = alpha_rename(expr, &mut gen);
1446        if let LcnfExpr::Let { id, body, .. } = &renamed {
1447            assert_eq!(*id, LcnfVarId(1000));
1448            if let LcnfExpr::Return(LcnfArg::Var(ret_id)) = body.as_ref() {
1449                assert_eq!(*ret_id, LcnfVarId(1000));
1450            } else {
1451                panic!("expected Return(Var(1000))");
1452            }
1453        } else {
1454            panic!("expected Let");
1455        }
1456    }
1457    #[test]
1458    pub(super) fn test_subst_var_in_expr_replaces_return() {
1459        let expr = LcnfExpr::Return(LcnfArg::Var(var(0)));
1460        let result = subst_var_in_expr(expr, var(0), &LcnfArg::Var(var(99)));
1461        assert!(matches!(result, LcnfExpr::Return(LcnfArg::Var(id)) if id == var(99)));
1462    }
1463    #[test]
1464    pub(super) fn test_subst_var_no_change_wrong_id() {
1465        let expr = LcnfExpr::Return(LcnfArg::Var(var(5)));
1466        let result = subst_var_in_expr(expr, var(0), &LcnfArg::Var(var(99)));
1467        assert!(matches!(result, LcnfExpr::Return(LcnfArg::Var(id)) if id == var(5)));
1468    }
1469    #[test]
1470    pub(super) fn test_max_depth_nested_lets() {
1471        let expr = LcnfExpr::Let {
1472            id: var(0),
1473            name: "a".into(),
1474            ty: LcnfType::Object,
1475            value: LcnfLetValue::Erased,
1476            body: Box::new(LcnfExpr::Let {
1477                id: var(1),
1478                name: "b".into(),
1479                ty: LcnfType::Object,
1480                value: LcnfLetValue::Erased,
1481                body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(1)))),
1482            }),
1483        };
1484        assert_eq!(max_depth(&expr), 2);
1485    }
1486    #[test]
1487    pub(super) fn test_count_lets() {
1488        let expr = LcnfExpr::Let {
1489            id: var(0),
1490            name: "a".into(),
1491            ty: LcnfType::Object,
1492            value: LcnfLetValue::Erased,
1493            body: Box::new(LcnfExpr::Let {
1494                id: var(1),
1495                name: "b".into(),
1496                ty: LcnfType::Object,
1497                value: LcnfLetValue::Erased,
1498                body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(1)))),
1499            }),
1500        };
1501        assert_eq!(count_lets(&expr), 2);
1502    }
1503    #[test]
1504    pub(super) fn test_count_cases_zero() {
1505        let expr = LcnfExpr::Return(LcnfArg::Var(var(0)));
1506        assert_eq!(count_cases(&expr), 0);
1507    }
1508    #[test]
1509    pub(super) fn test_collect_reachable() {
1510        let expr = LcnfExpr::Return(LcnfArg::Var(var(42)));
1511        let mut reachable = std::collections::HashSet::new();
1512        collect_reachable(&expr, &mut reachable);
1513        assert!(reachable.contains(&var(42)));
1514    }
1515    #[test]
1516    pub(super) fn test_ctor_env_record_and_get() {
1517        let mut env = CtorEnv::new();
1518        env.record(var(5), "Cons".into(), 1);
1519        assert_eq!(env.get(&var(5)), Some(&("Cons".into(), 1u16)));
1520        assert_eq!(env.get(&var(6)), None);
1521    }
1522    #[test]
1523    pub(super) fn test_case_of_known_ctor_eliminates() {
1524        let body = LcnfExpr::Let {
1525            id: var(1),
1526            name: "x".into(),
1527            ty: LcnfType::Object,
1528            value: LcnfLetValue::Ctor("True".into(), 0, vec![]),
1529            body: Box::new(LcnfExpr::Case {
1530                scrutinee: var(1),
1531                scrutinee_ty: LcnfType::Object,
1532                alts: vec![
1533                    crate::lcnf::LcnfAlt {
1534                        ctor_name: "True".into(),
1535                        ctor_tag: 0,
1536                        params: vec![],
1537                        body: LcnfExpr::Return(LcnfArg::Var(var(0))),
1538                    },
1539                    crate::lcnf::LcnfAlt {
1540                        ctor_name: "False".into(),
1541                        ctor_tag: 1,
1542                        params: vec![],
1543                        body: LcnfExpr::Unreachable,
1544                    },
1545                ],
1546                default: None,
1547            }),
1548        };
1549        let env = CtorEnv::new();
1550        let mut report = ExtendedPassReport::default();
1551        let _result = case_of_known_ctor(body, &env, &mut report);
1552        assert_eq!(report.case_of_known_ctor_elims, 1);
1553    }
1554    #[test]
1555    pub(super) fn test_full_copy_propagation() {
1556        let mut expr = LcnfExpr::Let {
1557            id: var(1),
1558            name: "x".into(),
1559            ty: LcnfType::Object,
1560            value: LcnfLetValue::FVar(var(0)),
1561            body: Box::new(LcnfExpr::Let {
1562                id: var(2),
1563                name: "y".into(),
1564                ty: LcnfType::Object,
1565                value: LcnfLetValue::FVar(var(1)),
1566                body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(2)))),
1567            }),
1568        };
1569        let propagated = full_copy_propagation(&mut expr);
1570        assert!(propagated > 0);
1571    }
1572    #[test]
1573    pub(super) fn test_run_optimizer_no_panic() {
1574        let body = LcnfExpr::Return(LcnfArg::Var(var(0)));
1575        let mut decl = make_decl("simple", vec![param(0, "x")], body);
1576        let (beta, ext) = run_optimizer(
1577            &mut decl,
1578            BetaEtaConfig::default(),
1579            ExtendedPassConfig::default(),
1580            3,
1581        );
1582        let _ = (beta, ext);
1583    }
1584    #[test]
1585    pub(super) fn test_lit_env_record_and_get() {
1586        let mut env = LitEnv::new();
1587        env.record_nat(var(0), 42);
1588        env.record_str(var(1), "hello".into());
1589        assert_eq!(env.get(&var(0)), Some(&KnownValue::Nat(42)));
1590        assert_eq!(env.get(&var(1)), Some(&KnownValue::Str("hello".into())));
1591        assert_eq!(env.get(&var(2)), None);
1592    }
1593    #[test]
1594    pub(super) fn test_peephole_nullary_app_to_erased() {
1595        let mut expr = LcnfExpr::Let {
1596            id: var(0),
1597            name: "x".into(),
1598            ty: LcnfType::Object,
1599            value: LcnfLetValue::App(LcnfArg::Var(var(99)), vec![]),
1600            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(0)))),
1601        };
1602        let rules: Vec<PeepholeRule> = vec![rule_nullary_app_to_erased];
1603        let n = peephole_pass(&mut expr, &rules);
1604        assert_eq!(n, 1);
1605        if let LcnfExpr::Let { value, .. } = &expr {
1606            assert!(matches!(value, LcnfLetValue::Erased));
1607        }
1608    }
1609    #[test]
1610    pub(super) fn test_run_dead_let_elim_wrapper() {
1611        let body = LcnfExpr::Let {
1612            id: var(1),
1613            name: "unused".into(),
1614            ty: LcnfType::Object,
1615            value: LcnfLetValue::FVar(var(0)),
1616            body: Box::new(LcnfExpr::Return(LcnfArg::Var(var(0)))),
1617        };
1618        let mut decl = make_decl("f", vec![param(0, "x")], body);
1619        let report = run_dead_let_elim(&mut decl);
1620        assert_eq!(report.dead_lets_eliminated, 1);
1621    }
1622    #[test]
1623    pub(super) fn test_count_tail_calls_in_case() {
1624        let expr = LcnfExpr::Case {
1625            scrutinee: var(0),
1626            scrutinee_ty: LcnfType::Object,
1627            alts: vec![
1628                crate::lcnf::LcnfAlt {
1629                    ctor_name: "A".into(),
1630                    ctor_tag: 0,
1631                    params: vec![],
1632                    body: LcnfExpr::TailCall(LcnfArg::Var(var(1)), vec![]),
1633                },
1634                crate::lcnf::LcnfAlt {
1635                    ctor_name: "B".into(),
1636                    ctor_tag: 1,
1637                    params: vec![],
1638                    body: LcnfExpr::TailCall(LcnfArg::Var(var(2)), vec![]),
1639                },
1640            ],
1641            default: None,
1642        };
1643        assert_eq!(count_tail_calls(&expr), 2);
1644    }
1645    #[test]
1646    pub(super) fn test_is_pure_value_lit() {
1647        assert!(is_pure_value(&LcnfLetValue::Lit(
1648            crate::lcnf::LcnfLit::Nat(0)
1649        )));
1650    }
1651    #[test]
1652    pub(super) fn test_is_pure_value_app_not_pure() {
1653        let app = LcnfLetValue::App(LcnfArg::Var(var(0)), vec![]);
1654        assert!(!is_pure_value(&app));
1655    }
1656    #[test]
1657    pub(super) fn test_arity_map_from_decls() {
1658        let decls = vec![
1659            make_decl(
1660                "f",
1661                vec![param(0, "x"), param(1, "y")],
1662                LcnfExpr::Return(LcnfArg::Var(var(0))),
1663            ),
1664            make_decl(
1665                "g",
1666                vec![param(2, "z")],
1667                LcnfExpr::Return(LcnfArg::Var(var(2))),
1668            ),
1669        ];
1670        let am = ArityMap::from_decls(&decls);
1671        assert_eq!(am.get("f"), Some(2));
1672        assert_eq!(am.get("g"), Some(1));
1673        assert_eq!(am.get("unknown"), None);
1674    }
1675    #[test]
1676    pub(super) fn test_param_usage_summary_used() {
1677        let decl = make_decl(
1678            "f",
1679            vec![param(0, "x"), param(1, "y")],
1680            LcnfExpr::Return(LcnfArg::Var(var(0))),
1681        );
1682        let summary = param_usage_summary(&decl);
1683        assert_eq!(summary.used, vec![true, false]);
1684    }
1685    #[test]
1686    pub(super) fn test_collect_hints_inline_candidate() {
1687        let decl = make_decl(
1688            "tiny",
1689            vec![param(0, "x")],
1690            LcnfExpr::Return(LcnfArg::Var(var(0))),
1691        );
1692        let hints = collect_hints(&decl);
1693        assert!(hints
1694            .iter()
1695            .any(|h| matches!(h, OptHint::InlineCandidate { .. })));
1696    }
1697    #[test]
1698    pub(super) fn test_pp_expr_return() {
1699        let expr = LcnfExpr::Return(LcnfArg::Var(var(0)));
1700        let s = pp_expr(&expr);
1701        assert!(s.contains("return"));
1702    }
1703    #[test]
1704    pub(super) fn test_pp_expr_tailcall() {
1705        let expr = LcnfExpr::TailCall(LcnfArg::Var(var(1)), vec![LcnfArg::Var(var(0))]);
1706        let s = pp_expr(&expr);
1707        assert!(s.contains("tailcall"));
1708    }
1709    #[test]
1710    pub(super) fn test_pp_expr_unreachable() {
1711        let expr = LcnfExpr::Unreachable;
1712        assert_eq!(pp_expr(&expr), "unreachable");
1713    }
1714    #[test]
1715    pub(super) fn test_module_optimizer_no_panic() {
1716        let mut decls = vec![
1717            make_decl(
1718                "f",
1719                vec![param(0, "x")],
1720                LcnfExpr::Return(LcnfArg::Var(var(0))),
1721            ),
1722            make_decl(
1723                "g",
1724                vec![param(1, "y")],
1725                LcnfExpr::Return(LcnfArg::Var(var(1))),
1726            ),
1727        ];
1728        let stats = run_module_optimizer(
1729            &mut decls,
1730            BetaEtaConfig::default(),
1731            ExtendedPassConfig::default(),
1732            3,
1733        );
1734        assert_eq!(stats.functions_processed, 2);
1735    }
1736    #[test]
1737    pub(super) fn test_count_trailing_erased_args_zero() {
1738        let expr = LcnfExpr::Return(LcnfArg::Var(var(0)));
1739        assert_eq!(count_trailing_erased_args(&expr), 0);
1740    }
1741    #[test]
1742    pub(super) fn test_extended_config_defaults() {
1743        let cfg = ExtendedPassConfig::default();
1744        assert!(cfg.do_let_float);
1745        assert!(cfg.do_case_of_case);
1746        assert!(cfg.do_dead_let);
1747        assert_eq!(cfg.max_case_of_case, 8);
1748    }
1749    #[test]
1750    pub(super) fn test_beta_eta_config_defaults() {
1751        let cfg = BetaEtaConfig::default();
1752        assert!(cfg.do_eta);
1753        assert!(cfg.do_beta);
1754        assert_eq!(cfg.max_depth, 256);
1755    }
1756}