Skip to main content

oxilean_codegen/opt_dce/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::*;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8use super::types::{
9    ConstValue, DCEAnalysisCache, DCEConstantFoldingHelper, DCEDepGraph, DCEDominatorTree,
10    DCELivenessInfo, DCEPassConfig, DCEPassPhase, DCEPassRegistry, DCEPassStats, DCEWorklist,
11    DceConfig, DceStats, UsageInfo,
12};
13
14/// Collect usage information for every variable referenced in `expr`.
15///
16/// This is an occurrence-analysis pass similar to GHC's.  It walks the
17/// expression tree once, counting references for each `LcnfVarId` and
18/// noting whether any use escapes or is inside a loop.
19pub fn collect_usage_info(expr: &LcnfExpr) -> HashMap<LcnfVarId, UsageInfo> {
20    let mut info: HashMap<LcnfVarId, UsageInfo> = HashMap::new();
21    collect_usage_expr(expr, &mut info, false);
22    info
23}
24/// Internal recursive walker for usage analysis.
25pub(super) fn collect_usage_expr(
26    expr: &LcnfExpr,
27    info: &mut HashMap<LcnfVarId, UsageInfo>,
28    in_loop: bool,
29) {
30    match expr {
31        LcnfExpr::Let { value, body, .. } => {
32            collect_usage_value(value, info, in_loop);
33            collect_usage_expr(body, info, in_loop);
34        }
35        LcnfExpr::Case {
36            scrutinee,
37            alts,
38            default,
39            ..
40        } => {
41            record_use(info, *scrutinee, in_loop, false);
42            for alt in alts {
43                collect_usage_expr(&alt.body, info, in_loop);
44            }
45            if let Some(def) = default {
46                collect_usage_expr(def, info, in_loop);
47            }
48        }
49        LcnfExpr::Return(arg) => {
50            record_arg_use(info, arg, in_loop, false);
51        }
52        LcnfExpr::TailCall(func, args) => {
53            record_arg_use(info, func, in_loop, false);
54            for a in args {
55                record_arg_use(info, a, in_loop, false);
56            }
57        }
58        LcnfExpr::Unreachable => {}
59    }
60}
61/// Record uses from a let-bound value.
62pub(super) fn collect_usage_value(
63    value: &LcnfLetValue,
64    info: &mut HashMap<LcnfVarId, UsageInfo>,
65    in_loop: bool,
66) {
67    match value {
68        LcnfLetValue::App(func, args) => {
69            record_arg_use(info, func, in_loop, false);
70            for a in args {
71                record_arg_use(info, a, in_loop, false);
72            }
73        }
74        LcnfLetValue::Proj(_, _, v) => {
75            record_use(info, *v, in_loop, false);
76        }
77        LcnfLetValue::Ctor(_, _, args) => {
78            for a in args {
79                record_arg_use(info, a, in_loop, true);
80            }
81        }
82        LcnfLetValue::FVar(v) => {
83            record_use(info, *v, in_loop, false);
84        }
85        LcnfLetValue::Lit(_)
86        | LcnfLetValue::Erased
87        | LcnfLetValue::Reset(_)
88        | LcnfLetValue::Reuse(_, _, _, _) => {}
89    }
90}
91/// Increment the use count for a variable, with optional flags.
92pub(super) fn record_use(
93    info: &mut HashMap<LcnfVarId, UsageInfo>,
94    var: LcnfVarId,
95    in_loop: bool,
96    escaping: bool,
97) {
98    let entry = info.entry(var).or_default();
99    entry.add_use();
100    if in_loop {
101        entry.mark_in_loop();
102    }
103    if escaping {
104        entry.mark_escaping();
105    }
106}
107/// Record a use for an argument (only `LcnfArg::Var` contributes).
108pub(super) fn record_arg_use(
109    info: &mut HashMap<LcnfVarId, UsageInfo>,
110    arg: &LcnfArg,
111    in_loop: bool,
112    escaping: bool,
113) {
114    if let LcnfArg::Var(v) = arg {
115        record_use(info, *v, in_loop, escaping);
116    }
117}
118/// Count total variable references in an expression (quick version
119/// that does not distinguish escaping / loop context).
120pub(super) fn count_refs(expr: &LcnfExpr) -> HashMap<LcnfVarId, usize> {
121    let mut counts: HashMap<LcnfVarId, usize> = HashMap::new();
122    count_refs_expr(expr, &mut counts);
123    counts
124}
125pub(super) fn count_refs_expr(expr: &LcnfExpr, counts: &mut HashMap<LcnfVarId, usize>) {
126    match expr {
127        LcnfExpr::Let { value, body, .. } => {
128            count_refs_value(value, counts);
129            count_refs_expr(body, counts);
130        }
131        LcnfExpr::Case {
132            scrutinee,
133            alts,
134            default,
135            ..
136        } => {
137            *counts.entry(*scrutinee).or_insert(0) += 1;
138            for alt in alts {
139                count_refs_expr(&alt.body, counts);
140            }
141            if let Some(d) = default {
142                count_refs_expr(d, counts);
143            }
144        }
145        LcnfExpr::Return(arg) => {
146            count_refs_arg(arg, counts);
147        }
148        LcnfExpr::TailCall(func, args) => {
149            count_refs_arg(func, counts);
150            for a in args {
151                count_refs_arg(a, counts);
152            }
153        }
154        LcnfExpr::Unreachable => {}
155    }
156}
157pub(super) fn count_refs_value(value: &LcnfLetValue, counts: &mut HashMap<LcnfVarId, usize>) {
158    match value {
159        LcnfLetValue::App(func, args) => {
160            count_refs_arg(func, counts);
161            for a in args {
162                count_refs_arg(a, counts);
163            }
164        }
165        LcnfLetValue::Proj(_, _, v) => {
166            *counts.entry(*v).or_insert(0) += 1;
167        }
168        LcnfLetValue::Ctor(_, _, args) => {
169            for a in args {
170                count_refs_arg(a, counts);
171            }
172        }
173        LcnfLetValue::FVar(v) => {
174            *counts.entry(*v).or_insert(0) += 1;
175        }
176        LcnfLetValue::Lit(_)
177        | LcnfLetValue::Erased
178        | LcnfLetValue::Reset(_)
179        | LcnfLetValue::Reuse(_, _, _, _) => {}
180    }
181}
182pub(super) fn count_refs_arg(arg: &LcnfArg, counts: &mut HashMap<LcnfVarId, usize>) {
183    if let LcnfArg::Var(v) = arg {
184        *counts.entry(*v).or_insert(0) += 1;
185    }
186}
187/// Remove let-bindings whose bound variable is never used in the
188/// continuation body.  Only pure (side-effect-free) bindings are removed;
189/// applications are conservatively kept because they may diverge or have
190/// side-effects.
191///
192/// This is a single bottom-up pass; run it inside a fixed-point loop if
193/// earlier passes may create new dead code.
194pub fn eliminate_dead_lets(expr: &LcnfExpr) -> LcnfExpr {
195    match expr {
196        LcnfExpr::Let {
197            id,
198            name,
199            ty,
200            value,
201            body,
202        } => {
203            let new_body = eliminate_dead_lets(body);
204            let refs = count_refs(&new_body);
205            let used = refs.get(id).copied().unwrap_or(0) > 0;
206            if !used && is_pure_let_value(value) {
207                new_body
208            } else {
209                LcnfExpr::Let {
210                    id: *id,
211                    name: name.clone(),
212                    ty: ty.clone(),
213                    value: value.clone(),
214                    body: Box::new(new_body),
215                }
216            }
217        }
218        LcnfExpr::Case {
219            scrutinee,
220            scrutinee_ty,
221            alts,
222            default,
223        } => {
224            let new_alts: Vec<LcnfAlt> = alts
225                .iter()
226                .map(|alt| LcnfAlt {
227                    ctor_name: alt.ctor_name.clone(),
228                    ctor_tag: alt.ctor_tag,
229                    params: alt.params.clone(),
230                    body: eliminate_dead_lets(&alt.body),
231                })
232                .collect();
233            let new_default = default.as_ref().map(|d| Box::new(eliminate_dead_lets(d)));
234            LcnfExpr::Case {
235                scrutinee: *scrutinee,
236                scrutinee_ty: scrutinee_ty.clone(),
237                alts: new_alts,
238                default: new_default,
239            }
240        }
241        other => other.clone(),
242    }
243}
244/// Returns `true` if the let-value is guaranteed pure (no side-effects,
245/// no divergence).  We conservatively treat function application as
246/// impure since the callee may diverge or perform IO.
247pub(super) fn is_pure_let_value(value: &LcnfLetValue) -> bool {
248    matches!(
249        value,
250        LcnfLetValue::Lit(_)
251            | LcnfLetValue::Erased
252            | LcnfLetValue::FVar(_)
253            | LcnfLetValue::Proj(_, _, _)
254            | LcnfLetValue::Ctor(_, _, _)
255    )
256}
257/// Propagate literal constants: when `let x = <lit>`, replace every
258/// occurrence of `Var(x)` with `Lit(<lit>)` in the continuation and
259/// remove the binding when it becomes dead.
260///
261/// Only literal values (`LcnfLetValue::Lit`) are propagated; constructor
262/// constants are handled by `fold_known_case` instead.
263pub fn propagate_constants(expr: &LcnfExpr) -> LcnfExpr {
264    propagate_constants_env(expr, &HashMap::new())
265}
266/// Propagate constants with an environment mapping variables to their
267/// known literal values.
268pub(super) fn propagate_constants_env(
269    expr: &LcnfExpr,
270    env: &HashMap<LcnfVarId, LcnfLit>,
271) -> LcnfExpr {
272    match expr {
273        LcnfExpr::Let {
274            id,
275            name,
276            ty,
277            value,
278            body,
279        } => {
280            if let LcnfLetValue::Lit(lit) = value {
281                let mut new_env = env.clone();
282                new_env.insert(*id, lit.clone());
283                let new_body = propagate_constants_env(body, &new_env);
284                let refs = count_refs(&new_body);
285                if refs.get(id).copied().unwrap_or(0) == 0 {
286                    new_body
287                } else {
288                    LcnfExpr::Let {
289                        id: *id,
290                        name: name.clone(),
291                        ty: ty.clone(),
292                        value: value.clone(),
293                        body: Box::new(new_body),
294                    }
295                }
296            } else {
297                let new_value = subst_value_constants(value, env);
298                let new_body = propagate_constants_env(body, env);
299                LcnfExpr::Let {
300                    id: *id,
301                    name: name.clone(),
302                    ty: ty.clone(),
303                    value: new_value,
304                    body: Box::new(new_body),
305                }
306            }
307        }
308        LcnfExpr::Case {
309            scrutinee,
310            scrutinee_ty,
311            alts,
312            default,
313        } => {
314            let new_alts: Vec<LcnfAlt> = alts
315                .iter()
316                .map(|alt| LcnfAlt {
317                    ctor_name: alt.ctor_name.clone(),
318                    ctor_tag: alt.ctor_tag,
319                    params: alt.params.clone(),
320                    body: propagate_constants_env(&alt.body, env),
321                })
322                .collect();
323            let new_default = default
324                .as_ref()
325                .map(|d| Box::new(propagate_constants_env(d, env)));
326            LcnfExpr::Case {
327                scrutinee: *scrutinee,
328                scrutinee_ty: scrutinee_ty.clone(),
329                alts: new_alts,
330                default: new_default,
331            }
332        }
333        LcnfExpr::Return(arg) => LcnfExpr::Return(subst_arg_constant(arg, env)),
334        LcnfExpr::TailCall(func, args) => {
335            let new_func = subst_arg_constant(func, env);
336            let new_args: Vec<LcnfArg> = args.iter().map(|a| subst_arg_constant(a, env)).collect();
337            LcnfExpr::TailCall(new_func, new_args)
338        }
339        LcnfExpr::Unreachable => LcnfExpr::Unreachable,
340    }
341}
342/// Substitute known constant literals inside a `LcnfLetValue`.
343pub(super) fn subst_value_constants(
344    value: &LcnfLetValue,
345    env: &HashMap<LcnfVarId, LcnfLit>,
346) -> LcnfLetValue {
347    match value {
348        LcnfLetValue::App(func, args) => {
349            let new_func = subst_arg_constant(func, env);
350            let new_args: Vec<LcnfArg> = args.iter().map(|a| subst_arg_constant(a, env)).collect();
351            LcnfLetValue::App(new_func, new_args)
352        }
353        LcnfLetValue::Ctor(name, tag, args) => {
354            let new_args: Vec<LcnfArg> = args.iter().map(|a| subst_arg_constant(a, env)).collect();
355            LcnfLetValue::Ctor(name.clone(), *tag, new_args)
356        }
357        LcnfLetValue::FVar(v) => {
358            if let Some(lit) = env.get(v) {
359                LcnfLetValue::Lit(lit.clone())
360            } else {
361                value.clone()
362            }
363        }
364        LcnfLetValue::Proj(name, idx, v) => LcnfLetValue::Proj(name.clone(), *idx, *v),
365        other => other.clone(),
366    }
367}
368/// Replace a variable argument with its known literal if available.
369pub(super) fn subst_arg_constant(arg: &LcnfArg, env: &HashMap<LcnfVarId, LcnfLit>) -> LcnfArg {
370    match arg {
371        LcnfArg::Var(v) => {
372            if let Some(lit) = env.get(v) {
373                LcnfArg::Lit(lit.clone())
374            } else {
375                arg.clone()
376            }
377        }
378        other => other.clone(),
379    }
380}
381/// Propagate copies: when `let x = y` (i.e. `LcnfLetValue::FVar(y)`),
382/// replace every use of `x` with `y` in the continuation and drop the
383/// binding.
384///
385/// This is particularly effective after lambda lifting and join point
386/// optimization which often introduce trivial copy bindings.
387pub fn propagate_copies(expr: &LcnfExpr) -> LcnfExpr {
388    propagate_copies_env(expr, &HashMap::new())
389}
390/// Copy propagation with an accumulated substitution environment.
391pub(super) fn propagate_copies_env(
392    expr: &LcnfExpr,
393    env: &HashMap<LcnfVarId, LcnfVarId>,
394) -> LcnfExpr {
395    match expr {
396        LcnfExpr::Let {
397            id,
398            name,
399            ty,
400            value,
401            body,
402        } => {
403            if let LcnfLetValue::FVar(src) = value {
404                let resolved = resolve_copy(env, *src);
405                let mut new_env = env.clone();
406                new_env.insert(*id, resolved);
407                let new_body = propagate_copies_env(body, &new_env);
408                let refs = count_refs(&new_body);
409                if refs.get(id).copied().unwrap_or(0) == 0 {
410                    new_body
411                } else {
412                    LcnfExpr::Let {
413                        id: *id,
414                        name: name.clone(),
415                        ty: ty.clone(),
416                        value: LcnfLetValue::FVar(resolved),
417                        body: Box::new(new_body),
418                    }
419                }
420            } else {
421                let new_value = subst_value_copies(value, env);
422                let new_body = propagate_copies_env(body, env);
423                LcnfExpr::Let {
424                    id: *id,
425                    name: name.clone(),
426                    ty: ty.clone(),
427                    value: new_value,
428                    body: Box::new(new_body),
429                }
430            }
431        }
432        LcnfExpr::Case {
433            scrutinee,
434            scrutinee_ty,
435            alts,
436            default,
437        } => {
438            let resolved_scrutinee = resolve_copy(env, *scrutinee);
439            let new_alts: Vec<LcnfAlt> = alts
440                .iter()
441                .map(|alt| LcnfAlt {
442                    ctor_name: alt.ctor_name.clone(),
443                    ctor_tag: alt.ctor_tag,
444                    params: alt.params.clone(),
445                    body: propagate_copies_env(&alt.body, env),
446                })
447                .collect();
448            let new_default = default
449                .as_ref()
450                .map(|d| Box::new(propagate_copies_env(d, env)));
451            LcnfExpr::Case {
452                scrutinee: resolved_scrutinee,
453                scrutinee_ty: scrutinee_ty.clone(),
454                alts: new_alts,
455                default: new_default,
456            }
457        }
458        LcnfExpr::Return(arg) => LcnfExpr::Return(subst_arg_copy(arg, env)),
459        LcnfExpr::TailCall(func, args) => {
460            let new_func = subst_arg_copy(func, env);
461            let new_args: Vec<LcnfArg> = args.iter().map(|a| subst_arg_copy(a, env)).collect();
462            LcnfExpr::TailCall(new_func, new_args)
463        }
464        LcnfExpr::Unreachable => LcnfExpr::Unreachable,
465    }
466}
467/// Follow a chain of copy substitutions to find the ultimate source.
468/// Detects cycles via a visited set.
469pub(super) fn resolve_copy(env: &HashMap<LcnfVarId, LcnfVarId>, mut var: LcnfVarId) -> LcnfVarId {
470    let mut visited = HashSet::new();
471    while let Some(&target) = env.get(&var) {
472        if !visited.insert(var) {
473            break;
474        }
475        var = target;
476    }
477    var
478}
479/// Substitute copy-renamed variables inside a let-value.
480pub(super) fn subst_value_copies(
481    value: &LcnfLetValue,
482    env: &HashMap<LcnfVarId, LcnfVarId>,
483) -> LcnfLetValue {
484    match value {
485        LcnfLetValue::App(func, args) => {
486            let new_func = subst_arg_copy(func, env);
487            let new_args: Vec<LcnfArg> = args.iter().map(|a| subst_arg_copy(a, env)).collect();
488            LcnfLetValue::App(new_func, new_args)
489        }
490        LcnfLetValue::Ctor(name, tag, args) => {
491            let new_args: Vec<LcnfArg> = args.iter().map(|a| subst_arg_copy(a, env)).collect();
492            LcnfLetValue::Ctor(name.clone(), *tag, new_args)
493        }
494        LcnfLetValue::Proj(name, idx, v) => {
495            LcnfLetValue::Proj(name.clone(), *idx, resolve_copy(env, *v))
496        }
497        LcnfLetValue::FVar(v) => LcnfLetValue::FVar(resolve_copy(env, *v)),
498        other => other.clone(),
499    }
500}
501/// Replace a variable argument with its copy-target if available.
502pub(super) fn subst_arg_copy(arg: &LcnfArg, env: &HashMap<LcnfVarId, LcnfVarId>) -> LcnfArg {
503    match arg {
504        LcnfArg::Var(v) => LcnfArg::Var(resolve_copy(env, *v)),
505        other => other.clone(),
506    }
507}
508/// Eliminate case alternatives that are statically unreachable.
509///
510/// Currently detects three patterns:
511/// 1. Any alternative whose body is `Unreachable` is removed.
512/// 2. If the default is `Unreachable`, it is removed.
513/// 3. If after trimming there are no alternatives left but a default
514///    exists, the case is replaced with the default body.  If there are
515///    no alternatives and no default, the result is `Unreachable`.
516pub fn eliminate_unreachable_alts(expr: &LcnfExpr) -> LcnfExpr {
517    match expr {
518        LcnfExpr::Let {
519            id,
520            name,
521            ty,
522            value,
523            body,
524        } => LcnfExpr::Let {
525            id: *id,
526            name: name.clone(),
527            ty: ty.clone(),
528            value: value.clone(),
529            body: Box::new(eliminate_unreachable_alts(body)),
530        },
531        LcnfExpr::Case {
532            scrutinee,
533            scrutinee_ty,
534            alts,
535            default,
536        } => {
537            let mut new_alts: Vec<LcnfAlt> = alts
538                .iter()
539                .map(|alt| LcnfAlt {
540                    ctor_name: alt.ctor_name.clone(),
541                    ctor_tag: alt.ctor_tag,
542                    params: alt.params.clone(),
543                    body: eliminate_unreachable_alts(&alt.body),
544                })
545                .collect();
546            let mut new_default = default
547                .as_ref()
548                .map(|d| Box::new(eliminate_unreachable_alts(d)));
549            new_alts.retain(|alt| !matches!(alt.body, LcnfExpr::Unreachable));
550            if let Some(ref d) = new_default {
551                if matches!(d.as_ref(), LcnfExpr::Unreachable) {
552                    new_default = None;
553                }
554            }
555            if new_alts.is_empty() {
556                if let Some(def) = new_default {
557                    return *def;
558                }
559                return LcnfExpr::Unreachable;
560            }
561            LcnfExpr::Case {
562                scrutinee: *scrutinee,
563                scrutinee_ty: scrutinee_ty.clone(),
564                alts: new_alts,
565                default: new_default,
566            }
567        }
568        other => other.clone(),
569    }
570}
571/// Fold a case expression when the scrutinee is a known constructor.
572///
573/// When we can determine (via a local constant environment) that the
574/// scrutinee variable was bound to a specific constructor, we select the
575/// matching alternative and substitute the constructor's fields for the
576/// alt's parameters.
577pub fn fold_known_case(expr: &LcnfExpr) -> LcnfExpr {
578    fold_known_case_env(expr, &HashMap::new())
579}
580/// Known-case folding with an environment mapping variables to their
581/// known `ConstValue`.
582pub(super) fn fold_known_case_env(
583    expr: &LcnfExpr,
584    env: &HashMap<LcnfVarId, ConstValue>,
585) -> LcnfExpr {
586    match expr {
587        LcnfExpr::Let {
588            id,
589            name,
590            ty,
591            value,
592            body,
593        } => {
594            let mut new_env = env.clone();
595            match value {
596                LcnfLetValue::Ctor(ctor_name, tag, args) => {
597                    new_env.insert(*id, ConstValue::Ctor(ctor_name.clone(), *tag, args.clone()));
598                }
599                LcnfLetValue::Lit(lit) => {
600                    new_env.insert(*id, ConstValue::Lit(lit.clone()));
601                }
602                _ => {}
603            }
604            let new_body = fold_known_case_env(body, &new_env);
605            LcnfExpr::Let {
606                id: *id,
607                name: name.clone(),
608                ty: ty.clone(),
609                value: value.clone(),
610                body: Box::new(new_body),
611            }
612        }
613        LcnfExpr::Case {
614            scrutinee,
615            scrutinee_ty,
616            alts,
617            default,
618        } => {
619            if let Some(ConstValue::Ctor(_, known_tag, ctor_args)) = env.get(scrutinee) {
620                if let Some(matching_alt) = alts.iter().find(|a| a.ctor_tag == *known_tag) {
621                    let mut result = matching_alt.body.clone();
622                    result = substitute_alt_params(&result, &matching_alt.params, ctor_args);
623                    return fold_known_case_env(&result, env);
624                }
625                if let Some(def) = default {
626                    return fold_known_case_env(def, env);
627                }
628                return LcnfExpr::Unreachable;
629            }
630            let new_alts: Vec<LcnfAlt> = alts
631                .iter()
632                .map(|alt| {
633                    let mut branch_env = env.clone();
634                    branch_env.insert(
635                        *scrutinee,
636                        ConstValue::Ctor(
637                            alt.ctor_name.clone(),
638                            alt.ctor_tag,
639                            alt.params.iter().map(|p| LcnfArg::Var(p.id)).collect(),
640                        ),
641                    );
642                    LcnfAlt {
643                        ctor_name: alt.ctor_name.clone(),
644                        ctor_tag: alt.ctor_tag,
645                        params: alt.params.clone(),
646                        body: fold_known_case_env(&alt.body, &branch_env),
647                    }
648                })
649                .collect();
650            let new_default = default
651                .as_ref()
652                .map(|d| Box::new(fold_known_case_env(d, env)));
653            LcnfExpr::Case {
654                scrutinee: *scrutinee,
655                scrutinee_ty: scrutinee_ty.clone(),
656                alts: new_alts,
657                default: new_default,
658            }
659        }
660        LcnfExpr::Return(arg) => LcnfExpr::Return(arg.clone()),
661        LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(func.clone(), args.clone()),
662        LcnfExpr::Unreachable => LcnfExpr::Unreachable,
663    }
664}
665/// Substitute constructor field values for the alt-bound parameters in
666/// an expression.  For each parameter `p_i` in `params`, if `ctor_args`
667/// has a matching argument at position `i`, we replace `Var(p_i.id)`
668/// with that argument throughout `expr`.
669pub(super) fn substitute_alt_params(
670    expr: &LcnfExpr,
671    params: &[LcnfParam],
672    ctor_args: &[LcnfArg],
673) -> LcnfExpr {
674    let mut subst: HashMap<LcnfVarId, LcnfArg> = HashMap::new();
675    for (param, arg) in params.iter().zip(ctor_args.iter()) {
676        subst.insert(param.id, arg.clone());
677    }
678    apply_arg_subst(expr, &subst)
679}
680/// Apply an argument-level substitution throughout an expression.
681pub(super) fn apply_arg_subst(expr: &LcnfExpr, subst: &HashMap<LcnfVarId, LcnfArg>) -> LcnfExpr {
682    match expr {
683        LcnfExpr::Let {
684            id,
685            name,
686            ty,
687            value,
688            body,
689        } => {
690            let new_value = apply_value_subst(value, subst);
691            let mut inner_subst = subst.clone();
692            inner_subst.remove(id);
693            let new_body = apply_arg_subst(body, &inner_subst);
694            LcnfExpr::Let {
695                id: *id,
696                name: name.clone(),
697                ty: ty.clone(),
698                value: new_value,
699                body: Box::new(new_body),
700            }
701        }
702        LcnfExpr::Case {
703            scrutinee,
704            scrutinee_ty,
705            alts,
706            default,
707        } => {
708            let new_scrutinee = resolve_var_subst(subst, *scrutinee);
709            let new_alts: Vec<LcnfAlt> = alts
710                .iter()
711                .map(|alt| {
712                    let mut alt_subst = subst.clone();
713                    for p in &alt.params {
714                        alt_subst.remove(&p.id);
715                    }
716                    LcnfAlt {
717                        ctor_name: alt.ctor_name.clone(),
718                        ctor_tag: alt.ctor_tag,
719                        params: alt.params.clone(),
720                        body: apply_arg_subst(&alt.body, &alt_subst),
721                    }
722                })
723                .collect();
724            let new_default = default
725                .as_ref()
726                .map(|d| Box::new(apply_arg_subst(d, subst)));
727            LcnfExpr::Case {
728                scrutinee: new_scrutinee,
729                scrutinee_ty: scrutinee_ty.clone(),
730                alts: new_alts,
731                default: new_default,
732            }
733        }
734        LcnfExpr::Return(arg) => LcnfExpr::Return(do_subst_arg(arg, subst)),
735        LcnfExpr::TailCall(func, args) => {
736            let new_func = do_subst_arg(func, subst);
737            let new_args: Vec<LcnfArg> = args.iter().map(|a| do_subst_arg(a, subst)).collect();
738            LcnfExpr::TailCall(new_func, new_args)
739        }
740        LcnfExpr::Unreachable => LcnfExpr::Unreachable,
741    }
742}
743/// Apply substitution to a let-value.
744pub(super) fn apply_value_subst(
745    value: &LcnfLetValue,
746    subst: &HashMap<LcnfVarId, LcnfArg>,
747) -> LcnfLetValue {
748    match value {
749        LcnfLetValue::App(func, args) => {
750            let new_func = do_subst_arg(func, subst);
751            let new_args: Vec<LcnfArg> = args.iter().map(|a| do_subst_arg(a, subst)).collect();
752            LcnfLetValue::App(new_func, new_args)
753        }
754        LcnfLetValue::Ctor(name, tag, args) => {
755            let new_args: Vec<LcnfArg> = args.iter().map(|a| do_subst_arg(a, subst)).collect();
756            LcnfLetValue::Ctor(name.clone(), *tag, new_args)
757        }
758        LcnfLetValue::Proj(name, idx, v) => {
759            let resolved = resolve_var_subst(subst, *v);
760            LcnfLetValue::Proj(name.clone(), *idx, resolved)
761        }
762        LcnfLetValue::FVar(v) => {
763            if let Some(replacement) = subst.get(v) {
764                match replacement {
765                    LcnfArg::Var(new_v) => LcnfLetValue::FVar(*new_v),
766                    LcnfArg::Lit(lit) => LcnfLetValue::Lit(lit.clone()),
767                    _ => value.clone(),
768                }
769            } else {
770                value.clone()
771            }
772        }
773        other => other.clone(),
774    }
775}
776/// Substitute an argument via the substitution map.
777pub(super) fn do_subst_arg(arg: &LcnfArg, subst: &HashMap<LcnfVarId, LcnfArg>) -> LcnfArg {
778    match arg {
779        LcnfArg::Var(v) => {
780            if let Some(replacement) = subst.get(v) {
781                replacement.clone()
782            } else {
783                arg.clone()
784            }
785        }
786        other => other.clone(),
787    }
788}
789/// Resolve a variable through a substitution map; returns the original
790/// if not present, or the target variable if the substitution maps to a Var.
791pub(super) fn resolve_var_subst(subst: &HashMap<LcnfVarId, LcnfArg>, var: LcnfVarId) -> LcnfVarId {
792    if let Some(LcnfArg::Var(target)) = subst.get(&var) {
793        *target
794    } else {
795        var
796    }
797}
798/// Remove function declarations from `module` that are not reachable from
799/// the given `roots`.  A function is reachable if it is named in `roots`
800/// or transitively called by a reachable function.
801///
802/// The call-graph is approximated conservatively: any mention of a
803/// function name in the body of another function counts as a reference.
804pub fn eliminate_dead_functions(module: &LcnfModule, roots: &[String]) -> LcnfModule {
805    let name_to_idx: HashMap<&str, usize> = module
806        .fun_decls
807        .iter()
808        .enumerate()
809        .map(|(i, d)| (d.name.as_str(), i))
810        .collect();
811    let mut adj: Vec<HashSet<usize>> = vec![HashSet::new(); module.fun_decls.len()];
812    for (i, decl) in module.fun_decls.iter().enumerate() {
813        let mentioned = collect_mentioned_names(&decl.body);
814        for name in &mentioned {
815            if let Some(&j) = name_to_idx.get(name.as_str()) {
816                adj[i].insert(j);
817            }
818        }
819    }
820    let mut reachable: HashSet<usize> = HashSet::new();
821    let mut queue: VecDeque<usize> = VecDeque::new();
822    for root in roots {
823        if let Some(&idx) = name_to_idx.get(root.as_str()) {
824            if reachable.insert(idx) {
825                queue.push_back(idx);
826            }
827        }
828    }
829    for (i, decl) in module.fun_decls.iter().enumerate() {
830        if !decl.is_lifted && reachable.insert(i) {
831            queue.push_back(i);
832        }
833    }
834    while let Some(idx) = queue.pop_front() {
835        for &callee in &adj[idx] {
836            if reachable.insert(callee) {
837                queue.push_back(callee);
838            }
839        }
840    }
841    let kept_decls: Vec<LcnfFunDecl> = module
842        .fun_decls
843        .iter()
844        .enumerate()
845        .filter(|(i, _)| reachable.contains(i))
846        .map(|(_, d)| d.clone())
847        .collect();
848    let eliminated_count = module.fun_decls.len() - kept_decls.len();
849    LcnfModule {
850        fun_decls: kept_decls,
851        extern_decls: module.extern_decls.clone(),
852        name: module.name.clone(),
853        metadata: LcnfModuleMetadata {
854            decl_count: module.metadata.decl_count.saturating_sub(eliminated_count),
855            ..module.metadata.clone()
856        },
857    }
858}
859/// Collect all constructor / function names mentioned in an expression.
860/// This is a conservative over-approximation used for reachability.
861pub(super) fn collect_mentioned_names(expr: &LcnfExpr) -> HashSet<String> {
862    let mut names = HashSet::new();
863    collect_names_inner(expr, &mut names);
864    names
865}
866pub(super) fn collect_names_inner(expr: &LcnfExpr, names: &mut HashSet<String>) {
867    match expr {
868        LcnfExpr::Let { value, body, .. } => {
869            collect_names_value(value, names);
870            collect_names_inner(body, names);
871        }
872        LcnfExpr::Case { alts, default, .. } => {
873            for alt in alts {
874                names.insert(alt.ctor_name.clone());
875                collect_names_inner(&alt.body, names);
876            }
877            if let Some(d) = default {
878                collect_names_inner(d, names);
879            }
880        }
881        LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable => {}
882    }
883}
884pub(super) fn collect_names_value(value: &LcnfLetValue, names: &mut HashSet<String>) {
885    match value {
886        LcnfLetValue::Ctor(name, _, _) => {
887            names.insert(name.clone());
888        }
889        LcnfLetValue::Proj(name, _, _) => {
890            names.insert(name.clone());
891        }
892        LcnfLetValue::App(_, _)
893        | LcnfLetValue::FVar(_)
894        | LcnfLetValue::Lit(_)
895        | LcnfLetValue::Erased
896        | LcnfLetValue::Reset(_)
897        | LcnfLetValue::Reuse(_, _, _, _) => {}
898    }
899}
900/// Run the complete DCE + constant propagation pipeline on a module.
901///
902/// The optimizer runs the enabled passes in a fixed-point loop:
903///   1. Constant propagation
904///   2. Copy propagation
905///   3. Known case folding
906///   4. Dead let elimination
907///   5. Unreachable alt elimination
908///
909/// After the fixed point, dead function elimination is applied once
910/// (interprocedural).
911///
912/// Returns the optimized module and accumulated statistics.
913pub fn optimize_dce(module: &LcnfModule, config: &DceConfig) -> (LcnfModule, DceStats) {
914    let mut stats = DceStats::default();
915    let mut result = module.clone();
916    for decl in &mut result.fun_decls {
917        let fn_stats = optimize_function_body(&mut decl.body, config);
918        stats.merge(&fn_stats);
919    }
920    let roots: Vec<String> = result
921        .fun_decls
922        .iter()
923        .filter(|d| !d.is_lifted)
924        .map(|d| d.name.clone())
925        .collect();
926    let before_count = result.fun_decls.len();
927    result = eliminate_dead_functions(&result, &roots);
928    stats.functions_eliminated += before_count.saturating_sub(result.fun_decls.len());
929    (result, stats)
930}
931/// Optimize a single function body using intraprocedural passes.
932pub(super) fn optimize_function_body(body: &mut LcnfExpr, config: &DceConfig) -> DceStats {
933    let mut total_stats = DceStats::default();
934    for _iteration in 0..config.max_iterations {
935        total_stats.iterations += 1;
936        let before = count_let_bindings(body);
937        if config.propagate_constants {
938            *body = propagate_constants(body);
939        }
940        if config.propagate_copies {
941            *body = propagate_copies(body);
942        }
943        if config.fold_known_calls {
944            *body = fold_known_case(body);
945        }
946        if config.eliminate_unused_lets {
947            *body = eliminate_dead_lets(body);
948        }
949        if config.eliminate_unreachable_alts {
950            *body = eliminate_unreachable_alts(body);
951        }
952        let after = count_let_bindings(body);
953        let eliminated = before.saturating_sub(after);
954        total_stats.lets_eliminated += eliminated;
955        if eliminated == 0 {
956            break;
957        }
958    }
959    total_stats
960}
961/// Count the number of let-bindings in an expression (used for convergence
962/// checking in the fixed-point loop).
963pub(super) fn count_let_bindings(expr: &LcnfExpr) -> usize {
964    match expr {
965        LcnfExpr::Let { body, .. } => 1 + count_let_bindings(body),
966        LcnfExpr::Case { alts, default, .. } => {
967            let alt_count: usize = alts.iter().map(|a| count_let_bindings(&a.body)).sum();
968            let def_count = default.as_ref().map(|d| count_let_bindings(d)).unwrap_or(0);
969            alt_count + def_count
970        }
971        LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable => 0,
972    }
973}
974#[cfg(test)]
975mod tests {
976    use super::*;
977    pub(super) fn vid(n: u64) -> LcnfVarId {
978        LcnfVarId(n)
979    }
980    pub(super) fn mk_param(n: u64, name: &str) -> LcnfParam {
981        LcnfParam {
982            id: vid(n),
983            name: name.to_string(),
984            ty: LcnfType::Nat,
985            erased: false,
986            borrowed: false,
987        }
988    }
989    pub(super) fn mk_let(id: u64, value: LcnfLetValue, body: LcnfExpr) -> LcnfExpr {
990        LcnfExpr::Let {
991            id: vid(id),
992            name: format!("x{}", id),
993            ty: LcnfType::Nat,
994            value,
995            body: Box::new(body),
996        }
997    }
998    pub(super) fn mk_decl(name: &str, body: LcnfExpr) -> LcnfFunDecl {
999        LcnfFunDecl {
1000            name: name.to_string(),
1001            original_name: None,
1002            params: vec![mk_param(0, "a")],
1003            ret_type: LcnfType::Nat,
1004            body,
1005            is_recursive: false,
1006            is_lifted: false,
1007            inline_cost: 1,
1008        }
1009    }
1010    pub(super) fn mk_module(decls: Vec<LcnfFunDecl>) -> LcnfModule {
1011        LcnfModule {
1012            fun_decls: decls,
1013            extern_decls: vec![],
1014            name: "test_mod".to_string(),
1015            metadata: LcnfModuleMetadata::default(),
1016        }
1017    }
1018    #[test]
1019    pub(super) fn test_config_default() {
1020        let cfg = DceConfig::default();
1021        assert!(cfg.eliminate_unused_lets);
1022        assert!(cfg.eliminate_unreachable_alts);
1023        assert!(cfg.propagate_constants);
1024        assert!(cfg.propagate_copies);
1025        assert!(cfg.fold_known_calls);
1026        assert_eq!(cfg.max_iterations, 10);
1027    }
1028    #[test]
1029    pub(super) fn test_config_display() {
1030        let cfg = DceConfig::default();
1031        let s = cfg.to_string();
1032        assert!(s.contains("unused_lets=true"));
1033        assert!(s.contains("max_iter=10"));
1034    }
1035    #[test]
1036    pub(super) fn test_stats_default() {
1037        let stats = DceStats::default();
1038        assert_eq!(stats.total_changes(), 0);
1039    }
1040    #[test]
1041    pub(super) fn test_stats_merge() {
1042        let mut a = DceStats {
1043            lets_eliminated: 3,
1044            ..Default::default()
1045        };
1046        let b = DceStats {
1047            lets_eliminated: 2,
1048            constants_propagated: 1,
1049            ..Default::default()
1050        };
1051        a.merge(&b);
1052        assert_eq!(a.lets_eliminated, 5);
1053        assert_eq!(a.constants_propagated, 1);
1054    }
1055    #[test]
1056    pub(super) fn test_stats_display() {
1057        let stats = DceStats {
1058            lets_eliminated: 7,
1059            ..Default::default()
1060        };
1061        let s = stats.to_string();
1062        assert!(s.contains("lets_elim=7"));
1063    }
1064    #[test]
1065    pub(super) fn test_const_value_lit() {
1066        let cv = ConstValue::Lit(LcnfLit::Nat(42));
1067        assert!(cv.is_known());
1068        assert_eq!(cv.as_lit(), Some(&LcnfLit::Nat(42)));
1069        assert!(cv.as_ctor().is_none());
1070    }
1071    #[test]
1072    pub(super) fn test_const_value_ctor() {
1073        let cv = ConstValue::Ctor("Nil".to_string(), 0, vec![]);
1074        assert!(cv.is_known());
1075        assert!(cv.as_lit().is_none());
1076        let (name, tag, args) = cv.as_ctor().expect("expected Some/Ok value");
1077        assert_eq!(name, "Nil");
1078        assert_eq!(tag, 0);
1079        assert!(args.is_empty());
1080    }
1081    #[test]
1082    pub(super) fn test_const_value_unknown() {
1083        let cv = ConstValue::Unknown;
1084        assert!(!cv.is_known());
1085    }
1086    #[test]
1087    pub(super) fn test_const_value_display() {
1088        assert!(ConstValue::Unknown.to_string().contains("unknown"));
1089        let lit = ConstValue::Lit(LcnfLit::Nat(99));
1090        assert!(lit.to_string().contains("99"));
1091    }
1092    #[test]
1093    pub(super) fn test_usage_info_basic() {
1094        let mut u = UsageInfo::new();
1095        assert!(u.is_dead());
1096        assert!(!u.is_once());
1097        u.add_use();
1098        assert!(!u.is_dead());
1099        assert!(u.is_once());
1100        u.add_use();
1101        assert!(!u.is_once());
1102        assert_eq!(u.use_count, 2);
1103    }
1104    #[test]
1105    pub(super) fn test_usage_info_flags() {
1106        let mut u = UsageInfo::new();
1107        assert!(!u.is_escaping);
1108        assert!(!u.is_in_loop);
1109        u.mark_escaping();
1110        assert!(u.is_escaping);
1111        u.mark_in_loop();
1112        assert!(u.is_in_loop);
1113    }
1114    #[test]
1115    pub(super) fn test_usage_info_display() {
1116        let u = UsageInfo {
1117            use_count: 3,
1118            is_escaping: true,
1119            is_in_loop: false,
1120        };
1121        let s = u.to_string();
1122        assert!(s.contains("uses=3"));
1123        assert!(s.contains("escaping=true"));
1124    }
1125    #[test]
1126    pub(super) fn test_collect_usage_simple_return() {
1127        let expr = LcnfExpr::Return(LcnfArg::Var(vid(5)));
1128        let info = collect_usage_info(&expr);
1129        assert_eq!(
1130            info.get(&vid(5))
1131                .expect("value should be present in map")
1132                .use_count,
1133            1
1134        );
1135    }
1136    #[test]
1137    pub(super) fn test_collect_usage_let_chain() {
1138        let expr = mk_let(
1139            1,
1140            LcnfLetValue::Lit(LcnfLit::Nat(42)),
1141            mk_let(
1142                2,
1143                LcnfLetValue::FVar(vid(1)),
1144                LcnfExpr::Return(LcnfArg::Var(vid(2))),
1145            ),
1146        );
1147        let info = collect_usage_info(&expr);
1148        assert_eq!(
1149            info.get(&vid(1))
1150                .expect("value should be present in map")
1151                .use_count,
1152            1
1153        );
1154        assert_eq!(
1155            info.get(&vid(2))
1156                .expect("value should be present in map")
1157                .use_count,
1158            1
1159        );
1160    }
1161    #[test]
1162    pub(super) fn test_collect_usage_ctor_escaping() {
1163        let expr = mk_let(
1164            1,
1165            LcnfLetValue::Ctor("Cons".into(), 1, vec![LcnfArg::Var(vid(0))]),
1166            LcnfExpr::Return(LcnfArg::Var(vid(1))),
1167        );
1168        let info = collect_usage_info(&expr);
1169        assert!(
1170            info.get(&vid(0))
1171                .expect("value should be present")
1172                .is_escaping
1173        );
1174    }
1175    #[test]
1176    pub(super) fn test_collect_usage_tail_call() {
1177        let expr = LcnfExpr::TailCall(
1178            LcnfArg::Var(vid(10)),
1179            vec![LcnfArg::Var(vid(0)), LcnfArg::Var(vid(1))],
1180        );
1181        let info = collect_usage_info(&expr);
1182        assert_eq!(
1183            info.get(&vid(10))
1184                .expect("value should be present in map")
1185                .use_count,
1186            1
1187        );
1188        assert_eq!(
1189            info.get(&vid(0))
1190                .expect("value should be present in map")
1191                .use_count,
1192            1
1193        );
1194        assert_eq!(
1195            info.get(&vid(1))
1196                .expect("value should be present in map")
1197                .use_count,
1198            1
1199        );
1200    }
1201    #[test]
1202    pub(super) fn test_eliminate_dead_lets_simple() {
1203        let expr = mk_let(
1204            1,
1205            LcnfLetValue::Lit(LcnfLit::Nat(42)),
1206            LcnfExpr::Return(LcnfArg::Var(vid(0))),
1207        );
1208        let result = eliminate_dead_lets(&expr);
1209        assert!(matches!(result, LcnfExpr::Return(LcnfArg::Var(v)) if v == vid(0)));
1210    }
1211    #[test]
1212    pub(super) fn test_eliminate_dead_lets_keeps_used() {
1213        let expr = mk_let(
1214            1,
1215            LcnfLetValue::Lit(LcnfLit::Nat(42)),
1216            LcnfExpr::Return(LcnfArg::Var(vid(1))),
1217        );
1218        let result = eliminate_dead_lets(&expr);
1219        assert!(matches!(result, LcnfExpr::Let { id, .. } if id == vid(1)));
1220    }
1221    #[test]
1222    pub(super) fn test_eliminate_dead_lets_chain() {
1223        let expr = mk_let(
1224            1,
1225            LcnfLetValue::Lit(LcnfLit::Nat(1)),
1226            mk_let(
1227                2,
1228                LcnfLetValue::Lit(LcnfLit::Nat(2)),
1229                mk_let(
1230                    3,
1231                    LcnfLetValue::Lit(LcnfLit::Nat(3)),
1232                    LcnfExpr::Return(LcnfArg::Var(vid(3))),
1233                ),
1234            ),
1235        );
1236        let result = eliminate_dead_lets(&expr);
1237        assert!(matches!(& result, LcnfExpr::Let { id, .. } if * id == vid(3)));
1238    }
1239    #[test]
1240    pub(super) fn test_eliminate_dead_lets_keeps_app() {
1241        let expr = mk_let(
1242            1,
1243            LcnfLetValue::App(LcnfArg::Var(vid(10)), vec![LcnfArg::Var(vid(0))]),
1244            LcnfExpr::Return(LcnfArg::Var(vid(0))),
1245        );
1246        let result = eliminate_dead_lets(&expr);
1247        assert!(matches!(result, LcnfExpr::Let { .. }));
1248    }
1249    #[test]
1250    pub(super) fn test_eliminate_dead_lets_in_case() {
1251        let expr = LcnfExpr::Case {
1252            scrutinee: vid(0),
1253            scrutinee_ty: LcnfType::Nat,
1254            alts: vec![
1255                LcnfAlt {
1256                    ctor_name: "A".into(),
1257                    ctor_tag: 0,
1258                    params: vec![],
1259                    body: mk_let(
1260                        5,
1261                        LcnfLetValue::Lit(LcnfLit::Nat(99)),
1262                        LcnfExpr::Return(LcnfArg::Var(vid(0))),
1263                    ),
1264                },
1265                LcnfAlt {
1266                    ctor_name: "B".into(),
1267                    ctor_tag: 1,
1268                    params: vec![],
1269                    body: LcnfExpr::Return(LcnfArg::Var(vid(0))),
1270                },
1271            ],
1272            default: None,
1273        };
1274        let result = eliminate_dead_lets(&expr);
1275        if let LcnfExpr::Case { alts, .. } = &result {
1276            assert!(matches!(&alts[0].body, LcnfExpr::Return(_)));
1277        } else {
1278            panic!("expected Case");
1279        }
1280    }
1281    #[test]
1282    pub(super) fn test_propagate_constants_simple() {
1283        let expr = mk_let(
1284            1,
1285            LcnfLetValue::Lit(LcnfLit::Nat(42)),
1286            LcnfExpr::Return(LcnfArg::Var(vid(1))),
1287        );
1288        let result = propagate_constants(&expr);
1289        assert!(matches!(
1290            result,
1291            LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42)))
1292        ));
1293    }
1294    #[test]
1295    pub(super) fn test_propagate_constants_in_app() {
1296        let expr = mk_let(
1297            1,
1298            LcnfLetValue::Lit(LcnfLit::Nat(10)),
1299            mk_let(
1300                2,
1301                LcnfLetValue::App(LcnfArg::Var(vid(99)), vec![LcnfArg::Var(vid(1))]),
1302                LcnfExpr::Return(LcnfArg::Var(vid(2))),
1303            ),
1304        );
1305        let result = propagate_constants(&expr);
1306        if let LcnfExpr::Let { value, .. } = &result {
1307            if let LcnfLetValue::App(_, args) = value {
1308                assert!(matches!(args[0], LcnfArg::Lit(LcnfLit::Nat(10))));
1309            } else {
1310                panic!("expected App");
1311            }
1312        } else {
1313            panic!("expected Let");
1314        }
1315    }
1316    #[test]
1317    pub(super) fn test_propagate_constants_in_tail_call() {
1318        let expr = mk_let(
1319            1,
1320            LcnfLetValue::Lit(LcnfLit::Nat(5)),
1321            LcnfExpr::TailCall(LcnfArg::Var(vid(99)), vec![LcnfArg::Var(vid(1))]),
1322        );
1323        let result = propagate_constants(&expr);
1324        if let LcnfExpr::TailCall(_, args) = &result {
1325            assert!(matches!(args[0], LcnfArg::Lit(LcnfLit::Nat(5))));
1326        } else {
1327            panic!("expected TailCall");
1328        }
1329    }
1330    #[test]
1331    pub(super) fn test_propagate_copies_simple() {
1332        let expr = mk_let(
1333            2,
1334            LcnfLetValue::FVar(vid(1)),
1335            LcnfExpr::Return(LcnfArg::Var(vid(2))),
1336        );
1337        let result = propagate_copies(&expr);
1338        assert!(matches!(result, LcnfExpr::Return(LcnfArg::Var(v)) if v == vid(1)));
1339    }
1340    #[test]
1341    pub(super) fn test_propagate_copies_transitive() {
1342        let expr = mk_let(
1343            2,
1344            LcnfLetValue::FVar(vid(1)),
1345            mk_let(
1346                3,
1347                LcnfLetValue::FVar(vid(2)),
1348                LcnfExpr::Return(LcnfArg::Var(vid(3))),
1349            ),
1350        );
1351        let result = propagate_copies(&expr);
1352        assert!(matches!(result, LcnfExpr::Return(LcnfArg::Var(v)) if v == vid(1)));
1353    }
1354    #[test]
1355    pub(super) fn test_propagate_copies_in_case_scrutinee() {
1356        let expr = mk_let(
1357            2,
1358            LcnfLetValue::FVar(vid(1)),
1359            LcnfExpr::Case {
1360                scrutinee: vid(2),
1361                scrutinee_ty: LcnfType::Nat,
1362                alts: vec![LcnfAlt {
1363                    ctor_name: "Zero".into(),
1364                    ctor_tag: 0,
1365                    params: vec![],
1366                    body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
1367                }],
1368                default: None,
1369            },
1370        );
1371        let result = propagate_copies(&expr);
1372        if let LcnfExpr::Case { scrutinee, .. } = &result {
1373            assert_eq!(*scrutinee, vid(1));
1374        } else {
1375            panic!("expected Case");
1376        }
1377    }
1378    #[test]
1379    pub(super) fn test_propagate_copies_in_value() {
1380        let expr = mk_let(
1381            2,
1382            LcnfLetValue::FVar(vid(1)),
1383            mk_let(
1384                3,
1385                LcnfLetValue::Ctor(
1386                    "Pair".into(),
1387                    0,
1388                    vec![LcnfArg::Var(vid(2)), LcnfArg::Var(vid(2))],
1389                ),
1390                LcnfExpr::Return(LcnfArg::Var(vid(3))),
1391            ),
1392        );
1393        let result = propagate_copies(&expr);
1394        if let LcnfExpr::Let { value, .. } = &result {
1395            if let LcnfLetValue::Ctor(_, _, args) = value {
1396                assert!(matches!(args[0], LcnfArg::Var(v) if v == vid(1)));
1397                assert!(matches!(args[1], LcnfArg::Var(v) if v == vid(1)));
1398            } else {
1399                panic!("expected Ctor");
1400            }
1401        } else {
1402            panic!("expected Let");
1403        }
1404    }
1405    #[test]
1406    pub(super) fn test_eliminate_unreachable_alts_removes_unreachable_default() {
1407        let expr = LcnfExpr::Case {
1408            scrutinee: vid(0),
1409            scrutinee_ty: LcnfType::Nat,
1410            alts: vec![LcnfAlt {
1411                ctor_name: "Zero".into(),
1412                ctor_tag: 0,
1413                params: vec![],
1414                body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
1415            }],
1416            default: Some(Box::new(LcnfExpr::Unreachable)),
1417        };
1418        let result = eliminate_unreachable_alts(&expr);
1419        if let LcnfExpr::Case { default, .. } = &result {
1420            assert!(default.is_none());
1421        } else {
1422            panic!("expected Case");
1423        }
1424    }
1425    #[test]
1426    pub(super) fn test_eliminate_unreachable_alts_removes_unreachable_alt() {
1427        let expr = LcnfExpr::Case {
1428            scrutinee: vid(0),
1429            scrutinee_ty: LcnfType::Nat,
1430            alts: vec![
1431                LcnfAlt {
1432                    ctor_name: "Zero".into(),
1433                    ctor_tag: 0,
1434                    params: vec![],
1435                    body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
1436                },
1437                LcnfAlt {
1438                    ctor_name: "Dead".into(),
1439                    ctor_tag: 99,
1440                    params: vec![],
1441                    body: LcnfExpr::Unreachable,
1442                },
1443            ],
1444            default: None,
1445        };
1446        let result = eliminate_unreachable_alts(&expr);
1447        if let LcnfExpr::Case { alts, .. } = &result {
1448            assert_eq!(alts.len(), 1);
1449            assert_eq!(alts[0].ctor_name, "Zero");
1450        } else {
1451            panic!("expected Case");
1452        }
1453    }
1454    #[test]
1455    pub(super) fn test_eliminate_unreachable_alts_inline_default_when_no_alts() {
1456        let expr = LcnfExpr::Case {
1457            scrutinee: vid(0),
1458            scrutinee_ty: LcnfType::Nat,
1459            alts: vec![LcnfAlt {
1460                ctor_name: "Dead".into(),
1461                ctor_tag: 0,
1462                params: vec![],
1463                body: LcnfExpr::Unreachable,
1464            }],
1465            default: Some(Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(77))))),
1466        };
1467        let result = eliminate_unreachable_alts(&expr);
1468        assert!(matches!(
1469            result,
1470            LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(77)))
1471        ));
1472    }
1473    #[test]
1474    pub(super) fn test_eliminate_unreachable_alts_all_dead() {
1475        let expr = LcnfExpr::Case {
1476            scrutinee: vid(0),
1477            scrutinee_ty: LcnfType::Nat,
1478            alts: vec![LcnfAlt {
1479                ctor_name: "X".into(),
1480                ctor_tag: 0,
1481                params: vec![],
1482                body: LcnfExpr::Unreachable,
1483            }],
1484            default: None,
1485        };
1486        let result = eliminate_unreachable_alts(&expr);
1487        assert!(matches!(result, LcnfExpr::Unreachable));
1488    }
1489    #[test]
1490    pub(super) fn test_fold_known_case_simple() {
1491        let expr = mk_let(
1492            1,
1493            LcnfLetValue::Ctor("Nil".into(), 0, vec![]),
1494            LcnfExpr::Case {
1495                scrutinee: vid(1),
1496                scrutinee_ty: LcnfType::Ctor("List".into(), vec![LcnfType::Nat]),
1497                alts: vec![
1498                    LcnfAlt {
1499                        ctor_name: "Nil".into(),
1500                        ctor_tag: 0,
1501                        params: vec![],
1502                        body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
1503                    },
1504                    LcnfAlt {
1505                        ctor_name: "Cons".into(),
1506                        ctor_tag: 1,
1507                        params: vec![mk_param(10, "h"), mk_param(11, "t")],
1508                        body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))),
1509                    },
1510                ],
1511                default: None,
1512            },
1513        );
1514        let result = fold_known_case(&expr);
1515        if let LcnfExpr::Let { body, .. } = &result {
1516            assert!(matches!(
1517                body.as_ref(),
1518                LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))
1519            ));
1520        } else {
1521            panic!("expected Let wrapping the folded result");
1522        }
1523    }
1524    #[test]
1525    pub(super) fn test_fold_known_case_with_params() {
1526        let expr = mk_let(
1527            1,
1528            LcnfLetValue::Ctor(
1529                "Cons".into(),
1530                1,
1531                vec![LcnfArg::Var(vid(10)), LcnfArg::Var(vid(11))],
1532            ),
1533            LcnfExpr::Case {
1534                scrutinee: vid(1),
1535                scrutinee_ty: LcnfType::Ctor("List".into(), vec![LcnfType::Nat]),
1536                alts: vec![
1537                    LcnfAlt {
1538                        ctor_name: "Nil".into(),
1539                        ctor_tag: 0,
1540                        params: vec![],
1541                        body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
1542                    },
1543                    LcnfAlt {
1544                        ctor_name: "Cons".into(),
1545                        ctor_tag: 1,
1546                        params: vec![mk_param(20, "h"), mk_param(21, "t")],
1547                        body: LcnfExpr::Return(LcnfArg::Var(vid(20))),
1548                    },
1549                ],
1550                default: None,
1551            },
1552        );
1553        let result = fold_known_case(&expr);
1554        if let LcnfExpr::Let { body, .. } = &result {
1555            assert!(
1556                matches!(body.as_ref(), LcnfExpr::Return(LcnfArg::Var(v)) if * v ==
1557                vid(10))
1558            );
1559        } else {
1560            panic!("expected Let wrapping folded result");
1561        }
1562    }
1563    #[test]
1564    pub(super) fn test_fold_known_case_falls_to_default() {
1565        let expr = mk_let(
1566            1,
1567            LcnfLetValue::Ctor("Nil".into(), 0, vec![]),
1568            LcnfExpr::Case {
1569                scrutinee: vid(1),
1570                scrutinee_ty: LcnfType::Nat,
1571                alts: vec![LcnfAlt {
1572                    ctor_name: "Cons".into(),
1573                    ctor_tag: 1,
1574                    params: vec![],
1575                    body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))),
1576                }],
1577                default: Some(Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(99))))),
1578            },
1579        );
1580        let result = fold_known_case(&expr);
1581        if let LcnfExpr::Let { body, .. } = &result {
1582            assert!(matches!(
1583                body.as_ref(),
1584                LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(99)))
1585            ));
1586        } else {
1587            panic!("expected Let wrapping default");
1588        }
1589    }
1590    #[test]
1591    pub(super) fn test_eliminate_dead_functions_keeps_roots() {
1592        let decls = vec![
1593            mk_decl("main", LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
1594            mk_decl("helper", LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1)))),
1595        ];
1596        let module = mk_module(decls);
1597        let result = eliminate_dead_functions(&module, &["main".to_string()]);
1598        assert_eq!(result.fun_decls.len(), 2);
1599    }
1600    #[test]
1601    pub(super) fn test_eliminate_dead_functions_removes_lifted_unreachable() {
1602        let mut lifted = mk_decl(
1603            "lifted_helper",
1604            LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))),
1605        );
1606        lifted.is_lifted = true;
1607        let decls = vec![
1608            mk_decl("main", LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
1609            lifted,
1610        ];
1611        let module = mk_module(decls);
1612        let result = eliminate_dead_functions(&module, &["main".to_string()]);
1613        assert_eq!(result.fun_decls.len(), 1);
1614        assert_eq!(result.fun_decls[0].name, "main");
1615    }
1616    #[test]
1617    pub(super) fn test_optimize_dce_full_pipeline() {
1618        let body = mk_let(
1619            1,
1620            LcnfLetValue::Lit(LcnfLit::Nat(42)),
1621            mk_let(
1622                2,
1623                LcnfLetValue::FVar(vid(1)),
1624                mk_let(
1625                    3,
1626                    LcnfLetValue::Lit(LcnfLit::Nat(99)),
1627                    LcnfExpr::Return(LcnfArg::Var(vid(2))),
1628                ),
1629            ),
1630        );
1631        let module = mk_module(vec![mk_decl("test", body)]);
1632        let config = DceConfig::default();
1633        let (result, stats) = optimize_dce(&module, &config);
1634        assert_eq!(result.fun_decls.len(), 1);
1635        let final_body = &result.fun_decls[0].body;
1636        assert!(
1637            matches!(final_body, LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42)))),
1638            "expected return 42, got: {:?}",
1639            final_body,
1640        );
1641        assert!(stats.lets_eliminated > 0);
1642    }
1643    #[test]
1644    pub(super) fn test_optimize_dce_no_passes() {
1645        let body = mk_let(
1646            1,
1647            LcnfLetValue::Lit(LcnfLit::Nat(42)),
1648            LcnfExpr::Return(LcnfArg::Var(vid(0))),
1649        );
1650        let module = mk_module(vec![mk_decl("test", body)]);
1651        let config = DceConfig {
1652            eliminate_unused_lets: false,
1653            eliminate_unreachable_alts: false,
1654            propagate_constants: false,
1655            propagate_copies: false,
1656            fold_known_calls: false,
1657            max_iterations: 10,
1658        };
1659        let (result, _stats) = optimize_dce(&module, &config);
1660        assert!(matches!(result.fun_decls[0].body, LcnfExpr::Let { .. }));
1661    }
1662    #[test]
1663    pub(super) fn test_count_let_bindings() {
1664        let expr = mk_let(
1665            1,
1666            LcnfLetValue::Lit(LcnfLit::Nat(1)),
1667            mk_let(
1668                2,
1669                LcnfLetValue::Lit(LcnfLit::Nat(2)),
1670                LcnfExpr::Return(LcnfArg::Var(vid(2))),
1671            ),
1672        );
1673        assert_eq!(count_let_bindings(&expr), 2);
1674    }
1675    #[test]
1676    pub(super) fn test_count_let_bindings_terminal() {
1677        let expr = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
1678        assert_eq!(count_let_bindings(&expr), 0);
1679    }
1680    #[test]
1681    pub(super) fn test_count_let_bindings_case() {
1682        let expr = LcnfExpr::Case {
1683            scrutinee: vid(0),
1684            scrutinee_ty: LcnfType::Nat,
1685            alts: vec![LcnfAlt {
1686                ctor_name: "A".into(),
1687                ctor_tag: 0,
1688                params: vec![],
1689                body: mk_let(
1690                    5,
1691                    LcnfLetValue::Lit(LcnfLit::Nat(5)),
1692                    LcnfExpr::Return(LcnfArg::Var(vid(5))),
1693                ),
1694            }],
1695            default: Some(Box::new(mk_let(
1696                6,
1697                LcnfLetValue::Lit(LcnfLit::Nat(6)),
1698                LcnfExpr::Return(LcnfArg::Var(vid(6))),
1699            ))),
1700        };
1701        assert_eq!(count_let_bindings(&expr), 2);
1702    }
1703    #[test]
1704    pub(super) fn test_resolve_copy_chain() {
1705        let mut env = HashMap::new();
1706        env.insert(vid(3), vid(2));
1707        env.insert(vid(2), vid(1));
1708        assert_eq!(resolve_copy(&env, vid(3)), vid(1));
1709    }
1710    #[test]
1711    pub(super) fn test_resolve_copy_cycle() {
1712        let mut env = HashMap::new();
1713        env.insert(vid(1), vid(2));
1714        env.insert(vid(2), vid(1));
1715        let _ = resolve_copy(&env, vid(1));
1716    }
1717    #[test]
1718    pub(super) fn test_resolve_copy_identity() {
1719        let env: HashMap<LcnfVarId, LcnfVarId> = HashMap::new();
1720        assert_eq!(resolve_copy(&env, vid(7)), vid(7));
1721    }
1722    #[test]
1723    pub(super) fn test_is_pure_let_value() {
1724        assert!(is_pure_let_value(&LcnfLetValue::Lit(LcnfLit::Nat(0))));
1725        assert!(is_pure_let_value(&LcnfLetValue::Erased));
1726        assert!(is_pure_let_value(&LcnfLetValue::FVar(vid(0))));
1727        assert!(is_pure_let_value(&LcnfLetValue::Proj(
1728            "S".into(),
1729            0,
1730            vid(0)
1731        )));
1732        assert!(is_pure_let_value(&LcnfLetValue::Ctor(
1733            "X".into(),
1734            0,
1735            vec![]
1736        )));
1737        assert!(!is_pure_let_value(&LcnfLetValue::App(
1738            LcnfArg::Var(vid(0)),
1739            vec![]
1740        )));
1741    }
1742    #[test]
1743    pub(super) fn test_count_refs_multiple_uses() {
1744        let expr = mk_let(
1745            1,
1746            LcnfLetValue::App(
1747                LcnfArg::Var(vid(99)),
1748                vec![LcnfArg::Var(vid(0)), LcnfArg::Var(vid(0))],
1749            ),
1750            LcnfExpr::Return(LcnfArg::Var(vid(1))),
1751        );
1752        let refs = count_refs(&expr);
1753        assert_eq!(refs.get(&vid(0)).copied().unwrap_or(0), 2);
1754        assert_eq!(refs.get(&vid(1)).copied().unwrap_or(0), 1);
1755        assert_eq!(refs.get(&vid(99)).copied().unwrap_or(0), 1);
1756    }
1757    #[test]
1758    pub(super) fn test_const_prop_then_dead_let() {
1759        let expr = mk_let(
1760            1,
1761            LcnfLetValue::Lit(LcnfLit::Nat(10)),
1762            mk_let(
1763                2,
1764                LcnfLetValue::App(LcnfArg::Var(vid(99)), vec![LcnfArg::Var(vid(1))]),
1765                LcnfExpr::Return(LcnfArg::Var(vid(2))),
1766            ),
1767        );
1768        let after_const = propagate_constants(&expr);
1769        let after_dce = eliminate_dead_lets(&after_const);
1770        if let LcnfExpr::Let { id, .. } = &after_dce {
1771            assert_eq!(*id, vid(2), "only x2 should remain");
1772        } else {
1773            panic!("expected a Let for x2");
1774        }
1775    }
1776    #[test]
1777    pub(super) fn test_copy_prop_then_dead_let() {
1778        let expr = mk_let(
1779            2,
1780            LcnfLetValue::FVar(vid(0)),
1781            mk_let(
1782                3,
1783                LcnfLetValue::App(LcnfArg::Var(vid(99)), vec![LcnfArg::Var(vid(2))]),
1784                LcnfExpr::Return(LcnfArg::Var(vid(3))),
1785            ),
1786        );
1787        let after_copy = propagate_copies(&expr);
1788        let after_dce = eliminate_dead_lets(&after_copy);
1789        if let LcnfExpr::Let { id, value, .. } = &after_dce {
1790            assert_eq!(*id, vid(3));
1791            if let LcnfLetValue::App(_, args) = value {
1792                assert!(matches!(args[0], LcnfArg::Var(v) if v == vid(0)));
1793            } else {
1794                panic!("expected App");
1795            }
1796        } else {
1797            panic!("expected Let for x3");
1798        }
1799    }
1800    #[test]
1801    pub(super) fn test_tail_call_not_affected_by_dce() {
1802        let expr = LcnfExpr::TailCall(
1803            LcnfArg::Var(vid(10)),
1804            vec![LcnfArg::Var(vid(0)), LcnfArg::Var(vid(1))],
1805        );
1806        let result = eliminate_dead_lets(&expr);
1807        assert!(matches!(result, LcnfExpr::TailCall(_, _)));
1808        let result2 = propagate_constants(&expr);
1809        assert!(matches!(result2, LcnfExpr::TailCall(_, _)));
1810    }
1811    #[test]
1812    pub(super) fn test_unreachable_preserved() {
1813        let expr = LcnfExpr::Unreachable;
1814        assert!(matches!(eliminate_dead_lets(&expr), LcnfExpr::Unreachable));
1815        assert!(matches!(propagate_constants(&expr), LcnfExpr::Unreachable));
1816        assert!(matches!(propagate_copies(&expr), LcnfExpr::Unreachable));
1817        assert!(matches!(fold_known_case(&expr), LcnfExpr::Unreachable));
1818        assert!(matches!(
1819            eliminate_unreachable_alts(&expr),
1820            LcnfExpr::Unreachable
1821        ));
1822    }
1823}
1824#[cfg(test)]
1825mod DCE_infra_tests {
1826    use super::*;
1827    #[test]
1828    pub(super) fn test_pass_config() {
1829        let config = DCEPassConfig::new("test_pass", DCEPassPhase::Transformation);
1830        assert!(config.enabled);
1831        assert!(config.phase.is_modifying());
1832        assert_eq!(config.phase.name(), "transformation");
1833    }
1834    #[test]
1835    pub(super) fn test_pass_stats() {
1836        let mut stats = DCEPassStats::new();
1837        stats.record_run(10, 100, 3);
1838        stats.record_run(20, 200, 5);
1839        assert_eq!(stats.total_runs, 2);
1840        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
1841        assert!((stats.success_rate() - 1.0).abs() < 0.01);
1842        let s = stats.format_summary();
1843        assert!(s.contains("Runs: 2/2"));
1844    }
1845    #[test]
1846    pub(super) fn test_pass_registry() {
1847        let mut reg = DCEPassRegistry::new();
1848        reg.register(DCEPassConfig::new("pass_a", DCEPassPhase::Analysis));
1849        reg.register(DCEPassConfig::new("pass_b", DCEPassPhase::Transformation).disabled());
1850        assert_eq!(reg.total_passes(), 2);
1851        assert_eq!(reg.enabled_count(), 1);
1852        reg.update_stats("pass_a", 5, 50, 2);
1853        let stats = reg.get_stats("pass_a").expect("stats should exist");
1854        assert_eq!(stats.total_changes, 5);
1855    }
1856    #[test]
1857    pub(super) fn test_analysis_cache() {
1858        let mut cache = DCEAnalysisCache::new(10);
1859        cache.insert("key1".to_string(), vec![1, 2, 3]);
1860        assert!(cache.get("key1").is_some());
1861        assert!(cache.get("key2").is_none());
1862        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
1863        cache.invalidate("key1");
1864        assert!(!cache.entries["key1"].valid);
1865        assert_eq!(cache.size(), 1);
1866    }
1867    #[test]
1868    pub(super) fn test_worklist() {
1869        let mut wl = DCEWorklist::new();
1870        assert!(wl.push(1));
1871        assert!(wl.push(2));
1872        assert!(!wl.push(1));
1873        assert_eq!(wl.len(), 2);
1874        assert_eq!(wl.pop(), Some(1));
1875        assert!(!wl.contains(1));
1876        assert!(wl.contains(2));
1877    }
1878    #[test]
1879    pub(super) fn test_dominator_tree() {
1880        let mut dt = DCEDominatorTree::new(5);
1881        dt.set_idom(1, 0);
1882        dt.set_idom(2, 0);
1883        dt.set_idom(3, 1);
1884        assert!(dt.dominates(0, 3));
1885        assert!(dt.dominates(1, 3));
1886        assert!(!dt.dominates(2, 3));
1887        assert!(dt.dominates(3, 3));
1888    }
1889    #[test]
1890    pub(super) fn test_liveness() {
1891        let mut liveness = DCELivenessInfo::new(3);
1892        liveness.add_def(0, 1);
1893        liveness.add_use(1, 1);
1894        assert!(liveness.defs[0].contains(&1));
1895        assert!(liveness.uses[1].contains(&1));
1896    }
1897    #[test]
1898    pub(super) fn test_constant_folding() {
1899        assert_eq!(DCEConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
1900        assert_eq!(DCEConstantFoldingHelper::fold_div_i64(10, 0), None);
1901        assert_eq!(DCEConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
1902        assert_eq!(
1903            DCEConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
1904            0b1000
1905        );
1906        assert_eq!(DCEConstantFoldingHelper::fold_bitnot_i64(0), -1);
1907    }
1908    #[test]
1909    pub(super) fn test_dep_graph() {
1910        let mut g = DCEDepGraph::new();
1911        g.add_dep(1, 2);
1912        g.add_dep(2, 3);
1913        g.add_dep(1, 3);
1914        assert_eq!(g.dependencies_of(2), vec![1]);
1915        let topo = g.topological_sort();
1916        assert_eq!(topo.len(), 3);
1917        assert!(!g.has_cycle());
1918        let pos: std::collections::HashMap<u32, usize> =
1919            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
1920        assert!(pos[&1] < pos[&2]);
1921        assert!(pos[&1] < pos[&3]);
1922        assert!(pos[&2] < pos[&3]);
1923    }
1924}