Skip to main content

oxilean_codegen/lcnf/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use std::collections::{HashMap, HashSet};
6
7use super::types::{
8    CostModel, DefinitionSite, FreeVarCollector, LcnfAlt, LcnfArg, LcnfExpr, LcnfFunDecl,
9    LcnfLetValue, LcnfModule, LcnfParam, LcnfType, LcnfVarId, PrettyConfig, Substitution,
10    UsageCounter, ValidationError,
11};
12
13/// Mapping from original names to mangled LCNF names.
14pub type NameMap = HashMap<String, String>;
15/// Immutable visitor trait for LCNF expressions.
16pub trait LcnfVisitor {
17    fn visit_expr(&mut self, expr: &LcnfExpr) {
18        walk_expr(self, expr);
19    }
20    fn visit_let_value(&mut self, val: &LcnfLetValue) {
21        walk_let_value(self, val);
22    }
23    fn visit_arg(&mut self, _arg: &LcnfArg) {}
24    fn visit_type(&mut self, _ty: &LcnfType) {}
25    fn visit_alt(&mut self, alt: &LcnfAlt) {
26        walk_alt(self, alt);
27    }
28    fn visit_fun_decl(&mut self, decl: &LcnfFunDecl) {
29        walk_fun_decl(self, decl);
30    }
31    fn visit_param(&mut self, _param: &LcnfParam) {}
32}
33/// Recursively walk children of an expression.
34pub fn walk_expr<V: LcnfVisitor + ?Sized>(visitor: &mut V, expr: &LcnfExpr) {
35    match expr {
36        LcnfExpr::Let {
37            ty, value, body, ..
38        } => {
39            visitor.visit_type(ty);
40            visitor.visit_let_value(value);
41            visitor.visit_expr(body);
42        }
43        LcnfExpr::Case {
44            scrutinee_ty,
45            alts,
46            default,
47            ..
48        } => {
49            visitor.visit_type(scrutinee_ty);
50            for alt in alts {
51                visitor.visit_alt(alt);
52            }
53            if let Some(def) = default {
54                visitor.visit_expr(def);
55            }
56        }
57        LcnfExpr::Return(arg) => visitor.visit_arg(arg),
58        LcnfExpr::Unreachable => {}
59        LcnfExpr::TailCall(func, args) => {
60            visitor.visit_arg(func);
61            for arg in args {
62                visitor.visit_arg(arg);
63            }
64        }
65    }
66}
67/// Recursively walk children of a let-bound value.
68pub fn walk_let_value<V: LcnfVisitor + ?Sized>(visitor: &mut V, val: &LcnfLetValue) {
69    match val {
70        LcnfLetValue::App(func, args) => {
71            visitor.visit_arg(func);
72            for arg in args {
73                visitor.visit_arg(arg);
74            }
75        }
76        LcnfLetValue::Proj(..) => {}
77        LcnfLetValue::Ctor(_, _, args) => {
78            for arg in args {
79                visitor.visit_arg(arg);
80            }
81        }
82        LcnfLetValue::Lit(_)
83        | LcnfLetValue::Erased
84        | LcnfLetValue::FVar(_)
85        | LcnfLetValue::Reset(_)
86        | LcnfLetValue::Reuse(_, _, _, _) => {}
87    }
88}
89/// Recursively walk children of a case alternative.
90pub fn walk_alt<V: LcnfVisitor + ?Sized>(visitor: &mut V, alt: &LcnfAlt) {
91    for param in &alt.params {
92        visitor.visit_param(param);
93    }
94    visitor.visit_expr(&alt.body);
95}
96/// Recursively walk children of a function declaration.
97pub fn walk_fun_decl<V: LcnfVisitor + ?Sized>(visitor: &mut V, decl: &LcnfFunDecl) {
98    for param in &decl.params {
99        visitor.visit_param(param);
100    }
101    visitor.visit_type(&decl.ret_type);
102    visitor.visit_expr(&decl.body);
103}
104/// Mutable visitor trait for in-place mutation of LCNF expressions.
105pub trait LcnfMutVisitor {
106    fn visit_expr_mut(&mut self, expr: &mut LcnfExpr) {
107        walk_expr_mut(self, expr);
108    }
109    fn visit_let_value_mut(&mut self, val: &mut LcnfLetValue) {
110        walk_let_value_mut(self, val);
111    }
112    fn visit_arg_mut(&mut self, _arg: &mut LcnfArg) {}
113    fn visit_type_mut(&mut self, _ty: &mut LcnfType) {}
114    fn visit_alt_mut(&mut self, alt: &mut LcnfAlt) {
115        walk_alt_mut(self, alt);
116    }
117    fn visit_fun_decl_mut(&mut self, decl: &mut LcnfFunDecl) {
118        walk_fun_decl_mut(self, decl);
119    }
120    fn visit_param_mut(&mut self, _param: &mut LcnfParam) {}
121}
122/// Walk children of an expression mutably.
123pub fn walk_expr_mut<V: LcnfMutVisitor + ?Sized>(visitor: &mut V, expr: &mut LcnfExpr) {
124    match expr {
125        LcnfExpr::Let {
126            ty, value, body, ..
127        } => {
128            visitor.visit_type_mut(ty);
129            visitor.visit_let_value_mut(value);
130            visitor.visit_expr_mut(body);
131        }
132        LcnfExpr::Case {
133            scrutinee_ty,
134            alts,
135            default,
136            ..
137        } => {
138            visitor.visit_type_mut(scrutinee_ty);
139            for alt in alts {
140                visitor.visit_alt_mut(alt);
141            }
142            if let Some(def) = default {
143                visitor.visit_expr_mut(def);
144            }
145        }
146        LcnfExpr::Return(arg) => visitor.visit_arg_mut(arg),
147        LcnfExpr::Unreachable => {}
148        LcnfExpr::TailCall(func, args) => {
149            visitor.visit_arg_mut(func);
150            for arg in args {
151                visitor.visit_arg_mut(arg);
152            }
153        }
154    }
155}
156/// Walk children of a let-bound value mutably.
157pub fn walk_let_value_mut<V: LcnfMutVisitor + ?Sized>(visitor: &mut V, val: &mut LcnfLetValue) {
158    match val {
159        LcnfLetValue::App(func, args) => {
160            visitor.visit_arg_mut(func);
161            for arg in args {
162                visitor.visit_arg_mut(arg);
163            }
164        }
165        LcnfLetValue::Proj(..) => {}
166        LcnfLetValue::Ctor(_, _, args) => {
167            for arg in args {
168                visitor.visit_arg_mut(arg);
169            }
170        }
171        LcnfLetValue::Lit(_)
172        | LcnfLetValue::Erased
173        | LcnfLetValue::FVar(_)
174        | LcnfLetValue::Reset(_)
175        | LcnfLetValue::Reuse(_, _, _, _) => {}
176    }
177}
178/// Walk children of a case alternative mutably.
179pub fn walk_alt_mut<V: LcnfMutVisitor + ?Sized>(visitor: &mut V, alt: &mut LcnfAlt) {
180    for param in &mut alt.params {
181        visitor.visit_param_mut(param);
182    }
183    visitor.visit_expr_mut(&mut alt.body);
184}
185/// Walk children of a function declaration mutably.
186pub fn walk_fun_decl_mut<V: LcnfMutVisitor + ?Sized>(visitor: &mut V, decl: &mut LcnfFunDecl) {
187    for param in &mut decl.params {
188        visitor.visit_param_mut(param);
189    }
190    visitor.visit_type_mut(&mut decl.ret_type);
191    visitor.visit_expr_mut(&mut decl.body);
192}
193/// Bottom-up transformation trait (folder).
194pub trait LcnfFolder {
195    fn fold_expr(&mut self, expr: LcnfExpr) -> LcnfExpr {
196        match expr {
197            LcnfExpr::Let {
198                id,
199                name,
200                ty,
201                value,
202                body,
203            } => {
204                let new_value = self.fold_let_value(value);
205                let new_body = self.fold_expr(*body);
206                LcnfExpr::Let {
207                    id,
208                    name,
209                    ty,
210                    value: new_value,
211                    body: Box::new(new_body),
212                }
213            }
214            LcnfExpr::Case {
215                scrutinee,
216                scrutinee_ty,
217                alts,
218                default,
219            } => {
220                let new_alts = alts
221                    .into_iter()
222                    .map(|alt| {
223                        let new_body = self.fold_expr(alt.body);
224                        LcnfAlt {
225                            ctor_name: alt.ctor_name,
226                            ctor_tag: alt.ctor_tag,
227                            params: alt.params,
228                            body: new_body,
229                        }
230                    })
231                    .collect();
232                let new_default = default.map(|d| Box::new(self.fold_expr(*d)));
233                LcnfExpr::Case {
234                    scrutinee,
235                    scrutinee_ty,
236                    alts: new_alts,
237                    default: new_default,
238                }
239            }
240            other => other,
241        }
242    }
243    fn fold_let_value(&mut self, val: LcnfLetValue) -> LcnfLetValue {
244        val
245    }
246}
247/// Collect all free variable IDs in an expression.
248pub fn free_vars(expr: &LcnfExpr) -> HashSet<LcnfVarId> {
249    let mut collector = FreeVarCollector::new();
250    collector.collect_expr(expr);
251    collector.free
252}
253/// Collect all let-bound variable IDs in an expression.
254pub fn bound_vars(expr: &LcnfExpr) -> HashSet<LcnfVarId> {
255    let mut result = HashSet::new();
256    collect_bound_vars(expr, &mut result);
257    result
258}
259pub(super) fn collect_bound_vars(expr: &LcnfExpr, result: &mut HashSet<LcnfVarId>) {
260    match expr {
261        LcnfExpr::Let { id, body, .. } => {
262            result.insert(*id);
263            collect_bound_vars(body, result);
264        }
265        LcnfExpr::Case { alts, default, .. } => {
266            for alt in alts {
267                for param in &alt.params {
268                    result.insert(param.id);
269                }
270                collect_bound_vars(&alt.body, result);
271            }
272            if let Some(def) = default {
273                collect_bound_vars(def, result);
274            }
275        }
276        LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => {}
277    }
278}
279/// Collect all variable IDs appearing in an expression (free + bound).
280pub fn all_vars(expr: &LcnfExpr) -> HashSet<LcnfVarId> {
281    let mut result = free_vars(expr);
282    result.extend(bound_vars(expr));
283    result
284}
285/// Count how many times each variable is referenced in an expression.
286pub fn usage_counts(expr: &LcnfExpr) -> HashMap<LcnfVarId, usize> {
287    let mut counter = UsageCounter::new();
288    counter.count_expr(expr);
289    counter.counts
290}
291/// Check whether all variables are used at most once.
292pub fn is_linear(expr: &LcnfExpr) -> bool {
293    usage_counts(expr).values().all(|&c| c <= 1)
294}
295/// Collect all definition sites in an expression.
296pub fn definition_sites(expr: &LcnfExpr) -> Vec<DefinitionSite> {
297    let mut sites = Vec::new();
298    collect_definition_sites(expr, 0, &mut sites);
299    sites
300}
301pub(super) fn collect_definition_sites(
302    expr: &LcnfExpr,
303    depth: usize,
304    sites: &mut Vec<DefinitionSite>,
305) {
306    match expr {
307        LcnfExpr::Let {
308            id, name, ty, body, ..
309        } => {
310            sites.push(DefinitionSite {
311                var: *id,
312                name: name.clone(),
313                ty: ty.clone(),
314                depth,
315            });
316            collect_definition_sites(body, depth + 1, sites);
317        }
318        LcnfExpr::Case { alts, default, .. } => {
319            for alt in alts {
320                for param in &alt.params {
321                    sites.push(DefinitionSite {
322                        var: param.id,
323                        name: param.name.clone(),
324                        ty: param.ty.clone(),
325                        depth: depth + 1,
326                    });
327                }
328                collect_definition_sites(&alt.body, depth + 1, sites);
329            }
330            if let Some(def) = default {
331                collect_definition_sites(def, depth + 1, sites);
332            }
333        }
334        LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => {}
335    }
336}
337/// Apply a substitution to an argument.
338pub fn substitute_arg(arg: &LcnfArg, subst: &Substitution) -> LcnfArg {
339    if let LcnfArg::Var(id) = arg {
340        if let Some(replacement) = subst.get(id) {
341            return replacement.clone();
342        }
343    }
344    arg.clone()
345}
346/// Apply a substitution to a let-bound value.
347pub fn substitute_let_value(val: &LcnfLetValue, subst: &Substitution) -> LcnfLetValue {
348    match val {
349        LcnfLetValue::App(func, args) => LcnfLetValue::App(
350            substitute_arg(func, subst),
351            args.iter().map(|a| substitute_arg(a, subst)).collect(),
352        ),
353        LcnfLetValue::Proj(name, idx, var) => {
354            if let Some(LcnfArg::Var(new_var)) = subst.get(var) {
355                LcnfLetValue::Proj(name.clone(), *idx, *new_var)
356            } else {
357                val.clone()
358            }
359        }
360        LcnfLetValue::Ctor(name, tag, args) => LcnfLetValue::Ctor(
361            name.clone(),
362            *tag,
363            args.iter().map(|a| substitute_arg(a, subst)).collect(),
364        ),
365        LcnfLetValue::FVar(id) => {
366            if let Some(LcnfArg::Var(new_id)) = subst.get(id) {
367                LcnfLetValue::FVar(*new_id)
368            } else {
369                val.clone()
370            }
371        }
372        LcnfLetValue::Lit(_)
373        | LcnfLetValue::Erased
374        | LcnfLetValue::Reset(_)
375        | LcnfLetValue::Reuse(_, _, _, _) => val.clone(),
376    }
377}
378/// Apply a substitution to an expression.
379pub fn substitute_expr(expr: &LcnfExpr, subst: &Substitution) -> LcnfExpr {
380    match expr {
381        LcnfExpr::Let {
382            id,
383            name,
384            ty,
385            value,
386            body,
387        } => {
388            let new_value = substitute_let_value(value, subst);
389            let mut inner_subst = subst.clone();
390            inner_subst.0.remove(id);
391            LcnfExpr::Let {
392                id: *id,
393                name: name.clone(),
394                ty: ty.clone(),
395                value: new_value,
396                body: Box::new(substitute_expr(body, &inner_subst)),
397            }
398        }
399        LcnfExpr::Case {
400            scrutinee,
401            scrutinee_ty,
402            alts,
403            default,
404        } => {
405            let new_scrutinee = if let Some(LcnfArg::Var(new_id)) = subst.get(scrutinee) {
406                *new_id
407            } else {
408                *scrutinee
409            };
410            let new_alts = alts
411                .iter()
412                .map(|alt| {
413                    let mut inner_subst = subst.clone();
414                    for param in &alt.params {
415                        inner_subst.0.remove(&param.id);
416                    }
417                    LcnfAlt {
418                        ctor_name: alt.ctor_name.clone(),
419                        ctor_tag: alt.ctor_tag,
420                        params: alt.params.clone(),
421                        body: substitute_expr(&alt.body, &inner_subst),
422                    }
423                })
424                .collect();
425            let new_default = default
426                .as_ref()
427                .map(|d| Box::new(substitute_expr(d, subst)));
428            LcnfExpr::Case {
429                scrutinee: new_scrutinee,
430                scrutinee_ty: scrutinee_ty.clone(),
431                alts: new_alts,
432                default: new_default,
433            }
434        }
435        LcnfExpr::Return(arg) => LcnfExpr::Return(substitute_arg(arg, subst)),
436        LcnfExpr::Unreachable => LcnfExpr::Unreachable,
437        LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(
438            substitute_arg(func, subst),
439            args.iter().map(|a| substitute_arg(a, subst)).collect(),
440        ),
441    }
442}
443/// Rename variables according to the given mapping.
444pub fn rename_vars(expr: &LcnfExpr, rename: &HashMap<LcnfVarId, LcnfVarId>) -> LcnfExpr {
445    let subst = Substitution(
446        rename
447            .iter()
448            .map(|(old, new)| (*old, LcnfArg::Var(*new)))
449            .collect(),
450    );
451    rename_expr_inner(expr, rename, &subst)
452}
453pub(super) fn rename_expr_inner(
454    expr: &LcnfExpr,
455    rename: &HashMap<LcnfVarId, LcnfVarId>,
456    subst: &Substitution,
457) -> LcnfExpr {
458    match expr {
459        LcnfExpr::Let {
460            id,
461            name,
462            ty,
463            value,
464            body,
465        } => {
466            let new_id = rename.get(id).copied().unwrap_or(*id);
467            LcnfExpr::Let {
468                id: new_id,
469                name: name.clone(),
470                ty: ty.clone(),
471                value: substitute_let_value(value, subst),
472                body: Box::new(rename_expr_inner(body, rename, subst)),
473            }
474        }
475        LcnfExpr::Case {
476            scrutinee,
477            scrutinee_ty,
478            alts,
479            default,
480        } => {
481            let new_scrutinee = rename.get(scrutinee).copied().unwrap_or(*scrutinee);
482            let new_alts = alts
483                .iter()
484                .map(|alt| {
485                    let new_params: Vec<LcnfParam> = alt
486                        .params
487                        .iter()
488                        .map(|p| LcnfParam {
489                            id: rename.get(&p.id).copied().unwrap_or(p.id),
490                            name: p.name.clone(),
491                            ty: p.ty.clone(),
492                            erased: p.erased,
493                            borrowed: false,
494                        })
495                        .collect();
496                    LcnfAlt {
497                        ctor_name: alt.ctor_name.clone(),
498                        ctor_tag: alt.ctor_tag,
499                        params: new_params,
500                        body: rename_expr_inner(&alt.body, rename, subst),
501                    }
502                })
503                .collect();
504            let new_default = default
505                .as_ref()
506                .map(|d| Box::new(rename_expr_inner(d, rename, subst)));
507            LcnfExpr::Case {
508                scrutinee: new_scrutinee,
509                scrutinee_ty: scrutinee_ty.clone(),
510                alts: new_alts,
511                default: new_default,
512            }
513        }
514        LcnfExpr::Return(arg) => LcnfExpr::Return(substitute_arg(arg, subst)),
515        LcnfExpr::Unreachable => LcnfExpr::Unreachable,
516        LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(
517            substitute_arg(func, subst),
518            args.iter().map(|a| substitute_arg(a, subst)).collect(),
519        ),
520    }
521}
522/// Check structural equality up to variable renaming (alpha-equivalence).
523pub fn alpha_equiv(e1: &LcnfExpr, e2: &LcnfExpr) -> bool {
524    let mut l2r: HashMap<LcnfVarId, LcnfVarId> = HashMap::new();
525    let mut r2l: HashMap<LcnfVarId, LcnfVarId> = HashMap::new();
526    alpha_equiv_expr(e1, e2, &mut l2r, &mut r2l)
527}
528pub(super) fn alpha_equiv_var(
529    v1: LcnfVarId,
530    v2: LcnfVarId,
531    l2r: &HashMap<LcnfVarId, LcnfVarId>,
532    r2l: &HashMap<LcnfVarId, LcnfVarId>,
533) -> bool {
534    match (l2r.get(&v1), r2l.get(&v2)) {
535        (Some(&mapped), Some(&mapped_back)) => mapped == v2 && mapped_back == v1,
536        (None, None) => v1 == v2,
537        _ => false,
538    }
539}
540pub(super) fn alpha_equiv_arg(
541    a1: &LcnfArg,
542    a2: &LcnfArg,
543    l2r: &HashMap<LcnfVarId, LcnfVarId>,
544    r2l: &HashMap<LcnfVarId, LcnfVarId>,
545) -> bool {
546    match (a1, a2) {
547        (LcnfArg::Var(v1), LcnfArg::Var(v2)) => alpha_equiv_var(*v1, *v2, l2r, r2l),
548        (LcnfArg::Lit(l1), LcnfArg::Lit(l2)) => l1 == l2,
549        (LcnfArg::Erased, LcnfArg::Erased) => true,
550        (LcnfArg::Type(t1), LcnfArg::Type(t2)) => t1 == t2,
551        _ => false,
552    }
553}
554pub(super) fn alpha_equiv_let_value(
555    v1: &LcnfLetValue,
556    v2: &LcnfLetValue,
557    l2r: &HashMap<LcnfVarId, LcnfVarId>,
558    r2l: &HashMap<LcnfVarId, LcnfVarId>,
559) -> bool {
560    match (v1, v2) {
561        (LcnfLetValue::App(f1, a1), LcnfLetValue::App(f2, a2)) => {
562            alpha_equiv_arg(f1, f2, l2r, r2l)
563                && a1.len() == a2.len()
564                && a1
565                    .iter()
566                    .zip(a2.iter())
567                    .all(|(x, y)| alpha_equiv_arg(x, y, l2r, r2l))
568        }
569        (LcnfLetValue::Proj(n1, i1, var1), LcnfLetValue::Proj(n2, i2, var2)) => {
570            n1 == n2 && i1 == i2 && alpha_equiv_var(*var1, *var2, l2r, r2l)
571        }
572        (LcnfLetValue::Ctor(n1, t1, a1), LcnfLetValue::Ctor(n2, t2, a2)) => {
573            n1 == n2
574                && t1 == t2
575                && a1.len() == a2.len()
576                && a1
577                    .iter()
578                    .zip(a2.iter())
579                    .all(|(x, y)| alpha_equiv_arg(x, y, l2r, r2l))
580        }
581        (LcnfLetValue::Lit(l1), LcnfLetValue::Lit(l2)) => l1 == l2,
582        (LcnfLetValue::Erased, LcnfLetValue::Erased) => true,
583        (LcnfLetValue::FVar(id1), LcnfLetValue::FVar(id2)) => alpha_equiv_var(*id1, *id2, l2r, r2l),
584        _ => false,
585    }
586}
587#[allow(clippy::too_many_arguments)]
588pub(super) fn alpha_equiv_expr(
589    e1: &LcnfExpr,
590    e2: &LcnfExpr,
591    l2r: &mut HashMap<LcnfVarId, LcnfVarId>,
592    r2l: &mut HashMap<LcnfVarId, LcnfVarId>,
593) -> bool {
594    match (e1, e2) {
595        (
596            LcnfExpr::Let {
597                id: id1,
598                ty: ty1,
599                value: val1,
600                body: body1,
601                ..
602            },
603            LcnfExpr::Let {
604                id: id2,
605                ty: ty2,
606                value: val2,
607                body: body2,
608                ..
609            },
610        ) => {
611            if ty1 != ty2 || !alpha_equiv_let_value(val1, val2, l2r, r2l) {
612                return false;
613            }
614            l2r.insert(*id1, *id2);
615            r2l.insert(*id2, *id1);
616            let result = alpha_equiv_expr(body1, body2, l2r, r2l);
617            l2r.remove(id1);
618            r2l.remove(id2);
619            result
620        }
621        (
622            LcnfExpr::Case {
623                scrutinee: s1,
624                scrutinee_ty: st1,
625                alts: alts1,
626                default: def1,
627            },
628            LcnfExpr::Case {
629                scrutinee: s2,
630                scrutinee_ty: st2,
631                alts: alts2,
632                default: def2,
633            },
634        ) => {
635            if !alpha_equiv_var(*s1, *s2, l2r, r2l) || st1 != st2 || alts1.len() != alts2.len() {
636                return false;
637            }
638            for (a1, a2) in alts1.iter().zip(alts2.iter()) {
639                if a1.ctor_name != a2.ctor_name
640                    || a1.ctor_tag != a2.ctor_tag
641                    || a1.params.len() != a2.params.len()
642                {
643                    return false;
644                }
645                for (p1, p2) in a1.params.iter().zip(a2.params.iter()) {
646                    l2r.insert(p1.id, p2.id);
647                    r2l.insert(p2.id, p1.id);
648                }
649                let ok = alpha_equiv_expr(&a1.body, &a2.body, l2r, r2l);
650                for (p1, p2) in a1.params.iter().zip(a2.params.iter()) {
651                    l2r.remove(&p1.id);
652                    r2l.remove(&p2.id);
653                }
654                if !ok {
655                    return false;
656                }
657            }
658            match (def1, def2) {
659                (Some(d1), Some(d2)) => alpha_equiv_expr(d1, d2, l2r, r2l),
660                (None, None) => true,
661                _ => false,
662            }
663        }
664        (LcnfExpr::Return(a1), LcnfExpr::Return(a2)) => alpha_equiv_arg(a1, a2, l2r, r2l),
665        (LcnfExpr::Unreachable, LcnfExpr::Unreachable) => true,
666        (LcnfExpr::TailCall(f1, a1), LcnfExpr::TailCall(f2, a2)) => {
667            alpha_equiv_arg(f1, f2, l2r, r2l)
668                && a1.len() == a2.len()
669                && a1
670                    .iter()
671                    .zip(a2.iter())
672                    .all(|(x, y)| alpha_equiv_arg(x, y, l2r, r2l))
673        }
674        _ => false,
675    }
676}
677/// Count the number of AST nodes in an expression.
678pub fn expr_size(expr: &LcnfExpr) -> usize {
679    match expr {
680        LcnfExpr::Let { value, body, .. } => 1 + let_value_size(value) + expr_size(body),
681        LcnfExpr::Case { alts, default, .. } => {
682            let alt_size: usize = alts.iter().map(|a| 1 + expr_size(&a.body)).sum();
683            let def_size = default.as_ref().map_or(0, |d| expr_size(d));
684            1 + alt_size + def_size
685        }
686        LcnfExpr::Return(_) => 1,
687        LcnfExpr::Unreachable => 1,
688        LcnfExpr::TailCall(_, args) => 1 + args.len(),
689    }
690}
691pub(super) fn let_value_size(val: &LcnfLetValue) -> usize {
692    match val {
693        LcnfLetValue::App(_, args) => 1 + args.len(),
694        LcnfLetValue::Proj(..) => 1,
695        LcnfLetValue::Ctor(_, _, args) => 1 + args.len(),
696        LcnfLetValue::Lit(_)
697        | LcnfLetValue::Erased
698        | LcnfLetValue::FVar(_)
699        | LcnfLetValue::Reset(_)
700        | LcnfLetValue::Reuse(_, _, _, _) => 1,
701    }
702}
703/// Compute the maximum nesting depth of an expression.
704pub fn expr_depth(expr: &LcnfExpr) -> usize {
705    match expr {
706        LcnfExpr::Let { body, .. } => 1 + expr_depth(body),
707        LcnfExpr::Case { alts, default, .. } => {
708            let max_alt = alts.iter().map(|a| expr_depth(&a.body)).max().unwrap_or(0);
709            let def_depth = default.as_ref().map_or(0, |d| expr_depth(d));
710            1 + max_alt.max(def_depth)
711        }
712        LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => 1,
713    }
714}
715/// Compute a heuristic inlining cost for a function declaration.
716pub fn compute_inline_cost(decl: &LcnfFunDecl) -> usize {
717    let base = expr_size(&decl.body);
718    let depth_penalty = expr_depth(&decl.body);
719    let branch_penalty = count_branches(&decl.body) * 5;
720    let recursive_penalty = if decl.is_recursive { 100 } else { 0 };
721    let param_bonus = if decl.params.len() <= 2 {
722        0
723    } else {
724        decl.params.len() * 2
725    };
726    base + depth_penalty + branch_penalty + recursive_penalty + param_bonus
727}
728/// Estimate the runtime cost of an expression under the given cost model.
729pub fn estimate_runtime_cost(expr: &LcnfExpr, model: &CostModel) -> u64 {
730    match expr {
731        LcnfExpr::Let { value, body, .. } => {
732            let val_cost = match value {
733                LcnfLetValue::App(..) | LcnfLetValue::Ctor(..) => model.app_cost,
734                LcnfLetValue::Proj(..) | LcnfLetValue::Lit(_) | LcnfLetValue::FVar(_) => {
735                    model.let_cost
736                }
737                LcnfLetValue::Erased | LcnfLetValue::Reset(_) | LcnfLetValue::Reuse(_, _, _, _) => {
738                    0
739                }
740            };
741            model.let_cost + val_cost + estimate_runtime_cost(body, model)
742        }
743        LcnfExpr::Case { alts, default, .. } => {
744            let max_alt_cost = alts
745                .iter()
746                .map(|a| estimate_runtime_cost(&a.body, model))
747                .max()
748                .unwrap_or(0);
749            let def_cost = default
750                .as_ref()
751                .map_or(0, |d| estimate_runtime_cost(d, model));
752            model.case_cost
753                + model.branch_penalty * (alts.len() as u64)
754                + max_alt_cost.max(def_cost)
755        }
756        LcnfExpr::Return(_) => model.return_cost,
757        LcnfExpr::Unreachable => 0,
758        LcnfExpr::TailCall(_, args) => model.app_cost + (args.len() as u64),
759    }
760}
761/// Estimate the number of heap allocations (from constructor applications).
762pub fn count_allocations(expr: &LcnfExpr) -> usize {
763    match expr {
764        LcnfExpr::Let { value, body, .. } => {
765            let alloc = match value {
766                LcnfLetValue::Ctor(_, _, args) if !args.is_empty() => 1,
767                _ => 0,
768            };
769            alloc + count_allocations(body)
770        }
771        LcnfExpr::Case { alts, default, .. } => {
772            let alt_allocs: usize = alts.iter().map(|a| count_allocations(&a.body)).sum();
773            let def_allocs = default.as_ref().map_or(0, |d| count_allocations(d));
774            alt_allocs + def_allocs
775        }
776        LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => 0,
777    }
778}
779/// Count the number of case splits in an expression.
780pub fn count_branches(expr: &LcnfExpr) -> usize {
781    match expr {
782        LcnfExpr::Let { body, .. } => count_branches(body),
783        LcnfExpr::Case { alts, default, .. } => {
784            let inner: usize = alts.iter().map(|a| count_branches(&a.body)).sum();
785            let def_branches = default.as_ref().map_or(0, |d| count_branches(d));
786            1 + inner + def_branches
787        }
788        LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => 0,
789    }
790}
791/// Validate an expression with respect to a set of bound variables.
792pub fn validate_expr(expr: &LcnfExpr, bound: &HashSet<LcnfVarId>) -> Result<(), ValidationError> {
793    match expr {
794        LcnfExpr::Let {
795            id, value, body, ..
796        } => {
797            validate_let_value(value, bound)?;
798            let mut new_bound = bound.clone();
799            if !new_bound.insert(*id) {
800                return Err(ValidationError::DuplicateBinding(*id));
801            }
802            validate_expr(body, &new_bound)
803        }
804        LcnfExpr::Case {
805            scrutinee,
806            alts,
807            default,
808            ..
809        } => {
810            if !bound.contains(scrutinee) {
811                return Err(ValidationError::UnboundVariable(*scrutinee));
812            }
813            if alts.is_empty() && default.is_none() {
814                return Err(ValidationError::EmptyCase);
815            }
816            for alt in alts {
817                let mut alt_bound = bound.clone();
818                for param in &alt.params {
819                    if !alt_bound.insert(param.id) {
820                        return Err(ValidationError::DuplicateBinding(param.id));
821                    }
822                }
823                validate_expr(&alt.body, &alt_bound)?;
824            }
825            if let Some(def) = default {
826                validate_expr(def, bound)?;
827            }
828            Ok(())
829        }
830        LcnfExpr::Return(arg) => validate_arg_bound(arg, bound),
831        LcnfExpr::Unreachable => Ok(()),
832        LcnfExpr::TailCall(func, args) => {
833            validate_arg_bound(func, bound)?;
834            for arg in args {
835                validate_arg_bound(arg, bound)?;
836            }
837            Ok(())
838        }
839    }
840}
841pub(super) fn validate_arg_bound(
842    arg: &LcnfArg,
843    bound: &HashSet<LcnfVarId>,
844) -> Result<(), ValidationError> {
845    if let LcnfArg::Var(id) = arg {
846        if !bound.contains(id) {
847            return Err(ValidationError::UnboundVariable(*id));
848        }
849    }
850    Ok(())
851}
852pub(super) fn validate_let_value(
853    val: &LcnfLetValue,
854    bound: &HashSet<LcnfVarId>,
855) -> Result<(), ValidationError> {
856    match val {
857        LcnfLetValue::App(func, args) => {
858            validate_arg_bound(func, bound)?;
859            for arg in args {
860                validate_arg_bound(arg, bound)?;
861            }
862            Ok(())
863        }
864        LcnfLetValue::Proj(_, _, var) => {
865            if !bound.contains(var) {
866                Err(ValidationError::UnboundVariable(*var))
867            } else {
868                Ok(())
869            }
870        }
871        LcnfLetValue::Ctor(_, _, args) => {
872            for arg in args {
873                validate_arg_bound(arg, bound)?;
874            }
875            Ok(())
876        }
877        LcnfLetValue::FVar(id) => {
878            if !bound.contains(id) {
879                Err(ValidationError::UnboundVariable(*id))
880            } else {
881                Ok(())
882            }
883        }
884        LcnfLetValue::Lit(_)
885        | LcnfLetValue::Erased
886        | LcnfLetValue::Reset(_)
887        | LcnfLetValue::Reuse(_, _, _, _) => Ok(()),
888    }
889}
890/// Validate a function declaration.
891pub fn validate_fun_decl(decl: &LcnfFunDecl) -> Result<(), ValidationError> {
892    let mut bound = HashSet::new();
893    for param in &decl.params {
894        if !bound.insert(param.id) {
895            return Err(ValidationError::DuplicateBinding(param.id));
896        }
897    }
898    validate_expr(&decl.body, &bound)
899}
900/// Validate an entire module, collecting all errors.
901pub fn validate_module(module: &LcnfModule) -> Result<(), Vec<ValidationError>> {
902    let mut errors = Vec::new();
903    for decl in &module.fun_decls {
904        if let Err(e) = validate_fun_decl(decl) {
905            errors.push(e);
906        }
907    }
908    if errors.is_empty() {
909        Ok(())
910    } else {
911        Err(errors)
912    }
913}
914/// Check that the ANF invariant holds: all arguments are atomic.
915pub fn check_anf_invariant(expr: &LcnfExpr) -> bool {
916    match expr {
917        LcnfExpr::Let { value, body, .. } => {
918            check_let_value_anf(value) && check_anf_invariant(body)
919        }
920        LcnfExpr::Case { alts, default, .. } => {
921            alts.iter().all(|a| check_anf_invariant(&a.body))
922                && default.as_ref().is_none_or(|d| check_anf_invariant(d))
923        }
924        LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => true,
925    }
926}
927pub(super) fn check_let_value_anf(val: &LcnfLetValue) -> bool {
928    match val {
929        LcnfLetValue::App(func, args) => is_atomic_arg(func) && args.iter().all(is_atomic_arg),
930        LcnfLetValue::Ctor(_, _, args) => args.iter().all(is_atomic_arg),
931        _ => true,
932    }
933}
934pub(super) fn is_atomic_arg(arg: &LcnfArg) -> bool {
935    matches!(
936        arg,
937        LcnfArg::Var(_) | LcnfArg::Lit(_) | LcnfArg::Erased | LcnfArg::Type(_)
938    )
939}
940/// Pretty-print an expression to a string.
941pub fn pretty_print_expr(expr: &LcnfExpr, config: &PrettyConfig) -> String {
942    let mut output = String::new();
943    pp_expr(&mut output, expr, config, 0);
944    output
945}
946pub(super) fn pp_indent(output: &mut String, config: &PrettyConfig, level: usize) {
947    for _ in 0..level * config.indent {
948        output.push(' ');
949    }
950}
951pub(super) fn pp_arg(output: &mut String, arg: &LcnfArg, config: &PrettyConfig) {
952    match arg {
953        LcnfArg::Var(id) => output.push_str(&id.to_string()),
954        LcnfArg::Lit(lit) => output.push_str(&lit.to_string()),
955        LcnfArg::Erased => {
956            if config.show_erased {
957                output.push('â—»');
958            } else {
959                output.push('_');
960            }
961        }
962        LcnfArg::Type(ty) => {
963            if config.show_types {
964                output.push('@');
965                output.push_str(&ty.to_string());
966            } else {
967                output.push('_');
968            }
969        }
970    }
971}
972pub(super) fn pp_let_value(output: &mut String, val: &LcnfLetValue, config: &PrettyConfig) {
973    match val {
974        LcnfLetValue::App(func, args) => {
975            pp_arg(output, func, config);
976            output.push('(');
977            for (i, a) in args.iter().enumerate() {
978                if i > 0 {
979                    output.push_str(", ");
980                }
981                pp_arg(output, a, config);
982            }
983            output.push(')');
984        }
985        LcnfLetValue::Proj(name, idx, var) => {
986            output.push_str(&format!("{}.{} {}", name, idx, var));
987        }
988        LcnfLetValue::Ctor(name, tag, args) => {
989            output.push_str(&format!("{}#{}", name, tag));
990            if !args.is_empty() {
991                output.push('(');
992                for (i, a) in args.iter().enumerate() {
993                    if i > 0 {
994                        output.push_str(", ");
995                    }
996                    pp_arg(output, a, config);
997                }
998                output.push(')');
999            }
1000        }
1001        LcnfLetValue::Lit(lit) => output.push_str(&lit.to_string()),
1002        LcnfLetValue::Erased => output.push_str("erased"),
1003        LcnfLetValue::FVar(id) => output.push_str(&id.to_string()),
1004        LcnfLetValue::Reset(var) => output.push_str(&format!("reset({})", var)),
1005        LcnfLetValue::Reuse(slot, name, tag, _) => {
1006            output.push_str(&format!("reuse({}, {}#{})", slot, name, tag))
1007        }
1008    }
1009}
1010pub(super) fn pp_expr(output: &mut String, expr: &LcnfExpr, config: &PrettyConfig, level: usize) {
1011    match expr {
1012        LcnfExpr::Let {
1013            id,
1014            name,
1015            ty,
1016            value,
1017            body,
1018        } => {
1019            pp_indent(output, config, level);
1020            output.push_str("let ");
1021            output.push_str(&id.to_string());
1022            if !name.is_empty() {
1023                output.push_str(&format!(" ({})", name));
1024            }
1025            if config.show_types {
1026                output.push_str(&format!(" : {}", ty));
1027            }
1028            output.push_str(" := ");
1029            pp_let_value(output, value, config);
1030            output.push('\n');
1031            pp_expr(output, body, config, level);
1032        }
1033        LcnfExpr::Case {
1034            scrutinee,
1035            scrutinee_ty,
1036            alts,
1037            default,
1038        } => {
1039            pp_indent(output, config, level);
1040            output.push_str(&format!("case {}", scrutinee));
1041            if config.show_types {
1042                output.push_str(&format!(" : {}", scrutinee_ty));
1043            }
1044            output.push_str(" of\n");
1045            for alt in alts {
1046                pp_indent(output, config, level + 1);
1047                output.push_str(&format!("| {}#{}", alt.ctor_name, alt.ctor_tag));
1048                for p in &alt.params {
1049                    output.push_str(&format!(" {}", p.id));
1050                }
1051                output.push_str(" =>\n");
1052                pp_expr(output, &alt.body, config, level + 2);
1053            }
1054            if let Some(def) = default {
1055                pp_indent(output, config, level + 1);
1056                output.push_str("| _ =>\n");
1057                pp_expr(output, def, config, level + 2);
1058            }
1059        }
1060        LcnfExpr::Return(arg) => {
1061            pp_indent(output, config, level);
1062            output.push_str("return ");
1063            pp_arg(output, arg, config);
1064            output.push('\n');
1065        }
1066        LcnfExpr::Unreachable => {
1067            pp_indent(output, config, level);
1068            output.push_str("unreachable\n");
1069        }
1070        LcnfExpr::TailCall(func, args) => {
1071            pp_indent(output, config, level);
1072            output.push_str("tailcall ");
1073            pp_arg(output, func, config);
1074            output.push('(');
1075            for (i, a) in args.iter().enumerate() {
1076                if i > 0 {
1077                    output.push_str(", ");
1078                }
1079                pp_arg(output, a, config);
1080            }
1081            output.push_str(")\n");
1082        }
1083    }
1084}
1085/// Pretty-print a function declaration.
1086pub fn pretty_print_fun_decl(decl: &LcnfFunDecl, config: &PrettyConfig) -> String {
1087    let mut output = String::new();
1088    output.push_str("def ");
1089    output.push_str(&decl.name);
1090    output.push('(');
1091    for (i, param) in decl.params.iter().enumerate() {
1092        if i > 0 {
1093            output.push_str(", ");
1094        }
1095        output.push_str(&format!("{}", param.id));
1096        if !param.name.is_empty() {
1097            output.push_str(&format!(" ({})", param.name));
1098        }
1099        if config.show_types {
1100            output.push_str(&format!(" : {}", param.ty));
1101        }
1102    }
1103    output.push(')');
1104    if config.show_types {
1105        output.push_str(&format!(" : {}", decl.ret_type));
1106    }
1107    if decl.is_recursive {
1108        output.push_str(" [rec]");
1109    }
1110    if decl.is_lifted {
1111        output.push_str(" [lifted]");
1112    }
1113    output.push_str(" :=\n");
1114    pp_expr(&mut output, &decl.body, config, 1);
1115    output
1116}
1117/// Pretty-print an entire module.
1118pub fn pretty_print_module(module: &LcnfModule, config: &PrettyConfig) -> String {
1119    let mut output = String::new();
1120    output.push_str(&format!("-- module {}\n", module.name));
1121    output.push_str(&format!(
1122        "-- {} decls, {} externs\n\n",
1123        module.fun_decls.len(),
1124        module.extern_decls.len()
1125    ));
1126    for decl in &module.extern_decls {
1127        output.push_str("extern ");
1128        output.push_str(&decl.name);
1129        output.push('(');
1130        for (i, param) in decl.params.iter().enumerate() {
1131            if i > 0 {
1132                output.push_str(", ");
1133            }
1134            if config.show_types {
1135                output.push_str(&format!("{} : {}", param.id, param.ty));
1136            } else {
1137                output.push_str(&format!("{}", param.id));
1138            }
1139        }
1140        output.push(')');
1141        if config.show_types {
1142            output.push_str(&format!(" : {}", decl.ret_type));
1143        }
1144        output.push('\n');
1145    }
1146    if !module.extern_decls.is_empty() {
1147        output.push('\n');
1148    }
1149    for decl in &module.fun_decls {
1150        output.push_str(&pretty_print_fun_decl(decl, config));
1151        output.push('\n');
1152    }
1153    output
1154}
1155/// Inline a single let binding by substituting its value into all uses.
1156pub fn inline_let(expr: LcnfExpr, var: LcnfVarId) -> LcnfExpr {
1157    match expr {
1158        LcnfExpr::Let {
1159            id,
1160            name,
1161            ty,
1162            value,
1163            body,
1164        } if id == var => {
1165            if let Some(arg) = let_value_to_arg(&value) {
1166                let mut subst = Substitution::new();
1167                subst.insert(id, arg);
1168                substitute_expr(&body, &subst)
1169            } else {
1170                LcnfExpr::Let {
1171                    id,
1172                    name,
1173                    ty,
1174                    value,
1175                    body,
1176                }
1177            }
1178        }
1179        LcnfExpr::Let {
1180            id,
1181            name,
1182            ty,
1183            value,
1184            body,
1185        } => LcnfExpr::Let {
1186            id,
1187            name,
1188            ty,
1189            value,
1190            body: Box::new(inline_let(*body, var)),
1191        },
1192        LcnfExpr::Case {
1193            scrutinee,
1194            scrutinee_ty,
1195            alts,
1196            default,
1197        } => LcnfExpr::Case {
1198            scrutinee,
1199            scrutinee_ty,
1200            alts: alts
1201                .into_iter()
1202                .map(|a| LcnfAlt {
1203                    ctor_name: a.ctor_name,
1204                    ctor_tag: a.ctor_tag,
1205                    params: a.params,
1206                    body: inline_let(a.body, var),
1207                })
1208                .collect(),
1209            default: default.map(|d| Box::new(inline_let(*d, var))),
1210        },
1211        other => other,
1212    }
1213}
1214/// Try to convert a let value to an atomic argument for inlining.
1215pub(super) fn let_value_to_arg(val: &LcnfLetValue) -> Option<LcnfArg> {
1216    match val {
1217        LcnfLetValue::Lit(lit) => Some(LcnfArg::Lit(lit.clone())),
1218        LcnfLetValue::Erased => Some(LcnfArg::Erased),
1219        LcnfLetValue::FVar(id) => Some(LcnfArg::Var(*id)),
1220        _ => None,
1221    }
1222}
1223/// Flatten nested let chains.
1224pub fn flatten_lets(expr: LcnfExpr) -> LcnfExpr {
1225    let mut bindings: Vec<(LcnfVarId, String, LcnfType, LcnfLetValue)> = Vec::new();
1226    let terminal = collect_lets(expr, &mut bindings);
1227    let mut result = flatten_lets_in_terminal(terminal);
1228    for (id, name, ty, value) in bindings.into_iter().rev() {
1229        result = LcnfExpr::Let {
1230            id,
1231            name,
1232            ty,
1233            value,
1234            body: Box::new(result),
1235        };
1236    }
1237    result
1238}
1239pub(super) fn collect_lets(
1240    expr: LcnfExpr,
1241    bindings: &mut Vec<(LcnfVarId, String, LcnfType, LcnfLetValue)>,
1242) -> LcnfExpr {
1243    match expr {
1244        LcnfExpr::Let {
1245            id,
1246            name,
1247            ty,
1248            value,
1249            body,
1250        } => {
1251            bindings.push((id, name, ty, value));
1252            collect_lets(*body, bindings)
1253        }
1254        other => other,
1255    }
1256}
1257pub(super) fn flatten_lets_in_terminal(expr: LcnfExpr) -> LcnfExpr {
1258    match expr {
1259        LcnfExpr::Case {
1260            scrutinee,
1261            scrutinee_ty,
1262            alts,
1263            default,
1264        } => LcnfExpr::Case {
1265            scrutinee,
1266            scrutinee_ty,
1267            alts: alts
1268                .into_iter()
1269                .map(|a| LcnfAlt {
1270                    ctor_name: a.ctor_name,
1271                    ctor_tag: a.ctor_tag,
1272                    params: a.params,
1273                    body: flatten_lets(a.body),
1274                })
1275                .collect(),
1276            default: default.map(|d| Box::new(flatten_lets(*d))),
1277        },
1278        other => other,
1279    }
1280}
1281/// Simplify a case with a single alternative into a let chain.
1282pub fn simplify_trivial_case(expr: LcnfExpr) -> LcnfExpr {
1283    match expr {
1284        LcnfExpr::Case {
1285            scrutinee,
1286            alts,
1287            default: None,
1288            ..
1289        } if alts.len() == 1 => {
1290            let alt = alts.into_iter().next().expect(
1291                "alts has exactly one element; guaranteed by pattern guard alts.len() == 1",
1292            );
1293            let mut result = simplify_trivial_case(alt.body);
1294            for (idx, param) in alt.params.iter().enumerate().rev() {
1295                result = LcnfExpr::Let {
1296                    id: param.id,
1297                    name: param.name.clone(),
1298                    ty: param.ty.clone(),
1299                    value: LcnfLetValue::Proj(alt.ctor_name.clone(), idx as u32, scrutinee),
1300                    body: Box::new(result),
1301                };
1302            }
1303            result
1304        }
1305        LcnfExpr::Let {
1306            id,
1307            name,
1308            ty,
1309            value,
1310            body,
1311        } => LcnfExpr::Let {
1312            id,
1313            name,
1314            ty,
1315            value,
1316            body: Box::new(simplify_trivial_case(*body)),
1317        },
1318        LcnfExpr::Case {
1319            scrutinee,
1320            scrutinee_ty,
1321            alts,
1322            default,
1323        } => LcnfExpr::Case {
1324            scrutinee,
1325            scrutinee_ty,
1326            alts: alts
1327                .into_iter()
1328                .map(|a| LcnfAlt {
1329                    ctor_name: a.ctor_name,
1330                    ctor_tag: a.ctor_tag,
1331                    params: a.params,
1332                    body: simplify_trivial_case(a.body),
1333                })
1334                .collect(),
1335            default: default.map(|d| Box::new(simplify_trivial_case(*d))),
1336        },
1337        other => other,
1338    }
1339}
1340/// Remove unused let bindings (dead code elimination).
1341pub fn remove_unused_lets(expr: LcnfExpr) -> LcnfExpr {
1342    match expr {
1343        LcnfExpr::Let {
1344            id,
1345            name,
1346            ty,
1347            value,
1348            body,
1349        } => {
1350            let new_body = remove_unused_lets(*body);
1351            let counts = usage_counts(&new_body);
1352            if counts.get(&id).copied().unwrap_or(0) == 0 {
1353                new_body
1354            } else {
1355                LcnfExpr::Let {
1356                    id,
1357                    name,
1358                    ty,
1359                    value,
1360                    body: Box::new(new_body),
1361                }
1362            }
1363        }
1364        LcnfExpr::Case {
1365            scrutinee,
1366            scrutinee_ty,
1367            alts,
1368            default,
1369        } => LcnfExpr::Case {
1370            scrutinee,
1371            scrutinee_ty,
1372            alts: alts
1373                .into_iter()
1374                .map(|a| LcnfAlt {
1375                    ctor_name: a.ctor_name,
1376                    ctor_tag: a.ctor_tag,
1377                    params: a.params,
1378                    body: remove_unused_lets(a.body),
1379                })
1380                .collect(),
1381            default: default.map(|d| Box::new(remove_unused_lets(*d))),
1382        },
1383        other => other,
1384    }
1385}
1386/// Hoist let bindings out of case branches when the same binding
1387/// appears in all branches.
1388pub fn hoist_lets(expr: LcnfExpr) -> LcnfExpr {
1389    match expr {
1390        LcnfExpr::Let {
1391            id,
1392            name,
1393            ty,
1394            value,
1395            body,
1396        } => LcnfExpr::Let {
1397            id,
1398            name,
1399            ty,
1400            value,
1401            body: Box::new(hoist_lets(*body)),
1402        },
1403        LcnfExpr::Case {
1404            scrutinee,
1405            scrutinee_ty,
1406            alts,
1407            default,
1408        } => {
1409            let alts: Vec<LcnfAlt> = alts
1410                .into_iter()
1411                .map(|a| LcnfAlt {
1412                    ctor_name: a.ctor_name,
1413                    ctor_tag: a.ctor_tag,
1414                    params: a.params,
1415                    body: hoist_lets(a.body),
1416                })
1417                .collect();
1418            let default = default.map(|d| Box::new(hoist_lets(*d)));
1419            if alts.len() < 2 || default.is_some() {
1420                return LcnfExpr::Case {
1421                    scrutinee,
1422                    scrutinee_ty,
1423                    alts,
1424                    default,
1425                };
1426            }
1427            let first_let = match &alts[0].body {
1428                LcnfExpr::Let {
1429                    name, ty, value, ..
1430                } => Some((name.clone(), ty.clone(), value.clone())),
1431                _ => None,
1432            };
1433            if let Some((common_name, common_ty, common_value)) = first_let {
1434                let all_same = alts.iter().all(|a| {
1435                    matches!(
1436                        & a.body, LcnfExpr::Let { name, ty, value, .. } if * name ==
1437                        common_name && * ty == common_ty && * value == common_value
1438                    )
1439                });
1440                if all_same {
1441                    let hoisted_id = match &alts[0].body {
1442                        LcnfExpr::Let { id, .. } => *id,
1443                        _ => unreachable!(),
1444                    };
1445                    let new_alts: Vec<LcnfAlt> = alts
1446                        .into_iter()
1447                        .map(|a| {
1448                            let inner_body = match a.body {
1449                                LcnfExpr::Let { id, body, .. } => {
1450                                    if id != hoisted_id {
1451                                        let mut subst = Substitution::new();
1452                                        subst.insert(id, LcnfArg::Var(hoisted_id));
1453                                        substitute_expr(&body, &subst)
1454                                    } else {
1455                                        *body
1456                                    }
1457                                }
1458                                other => other,
1459                            };
1460                            LcnfAlt {
1461                                ctor_name: a.ctor_name,
1462                                ctor_tag: a.ctor_tag,
1463                                params: a.params,
1464                                body: inner_body,
1465                            }
1466                        })
1467                        .collect();
1468                    return LcnfExpr::Let {
1469                        id: hoisted_id,
1470                        name: common_name,
1471                        ty: common_ty,
1472                        value: common_value,
1473                        body: Box::new(LcnfExpr::Case {
1474                            scrutinee,
1475                            scrutinee_ty,
1476                            alts: new_alts,
1477                            default: None,
1478                        }),
1479                    };
1480                }
1481            }
1482            LcnfExpr::Case {
1483                scrutinee,
1484                scrutinee_ty,
1485                alts,
1486                default,
1487            }
1488        }
1489        other => other,
1490    }
1491}