1use std::collections::{HashMap, HashSet};
6
7use super::types::{
8 CostModel, DefinitionSite, FreeVarCollector, LcnfAlt, LcnfArg, LcnfExpr, LcnfFunDecl,
9 LcnfLetValue, LcnfModule, LcnfParam, LcnfType, LcnfVarId, PrettyConfig, Substitution,
10 UsageCounter, ValidationError,
11};
12
13pub type NameMap = HashMap<String, String>;
15pub trait LcnfVisitor {
17 fn visit_expr(&mut self, expr: &LcnfExpr) {
18 walk_expr(self, expr);
19 }
20 fn visit_let_value(&mut self, val: &LcnfLetValue) {
21 walk_let_value(self, val);
22 }
23 fn visit_arg(&mut self, _arg: &LcnfArg) {}
24 fn visit_type(&mut self, _ty: &LcnfType) {}
25 fn visit_alt(&mut self, alt: &LcnfAlt) {
26 walk_alt(self, alt);
27 }
28 fn visit_fun_decl(&mut self, decl: &LcnfFunDecl) {
29 walk_fun_decl(self, decl);
30 }
31 fn visit_param(&mut self, _param: &LcnfParam) {}
32}
33pub fn walk_expr<V: LcnfVisitor + ?Sized>(visitor: &mut V, expr: &LcnfExpr) {
35 match expr {
36 LcnfExpr::Let {
37 ty, value, body, ..
38 } => {
39 visitor.visit_type(ty);
40 visitor.visit_let_value(value);
41 visitor.visit_expr(body);
42 }
43 LcnfExpr::Case {
44 scrutinee_ty,
45 alts,
46 default,
47 ..
48 } => {
49 visitor.visit_type(scrutinee_ty);
50 for alt in alts {
51 visitor.visit_alt(alt);
52 }
53 if let Some(def) = default {
54 visitor.visit_expr(def);
55 }
56 }
57 LcnfExpr::Return(arg) => visitor.visit_arg(arg),
58 LcnfExpr::Unreachable => {}
59 LcnfExpr::TailCall(func, args) => {
60 visitor.visit_arg(func);
61 for arg in args {
62 visitor.visit_arg(arg);
63 }
64 }
65 }
66}
67pub fn walk_let_value<V: LcnfVisitor + ?Sized>(visitor: &mut V, val: &LcnfLetValue) {
69 match val {
70 LcnfLetValue::App(func, args) => {
71 visitor.visit_arg(func);
72 for arg in args {
73 visitor.visit_arg(arg);
74 }
75 }
76 LcnfLetValue::Proj(..) => {}
77 LcnfLetValue::Ctor(_, _, args) => {
78 for arg in args {
79 visitor.visit_arg(arg);
80 }
81 }
82 LcnfLetValue::Lit(_)
83 | LcnfLetValue::Erased
84 | LcnfLetValue::FVar(_)
85 | LcnfLetValue::Reset(_)
86 | LcnfLetValue::Reuse(_, _, _, _) => {}
87 }
88}
89pub fn walk_alt<V: LcnfVisitor + ?Sized>(visitor: &mut V, alt: &LcnfAlt) {
91 for param in &alt.params {
92 visitor.visit_param(param);
93 }
94 visitor.visit_expr(&alt.body);
95}
96pub fn walk_fun_decl<V: LcnfVisitor + ?Sized>(visitor: &mut V, decl: &LcnfFunDecl) {
98 for param in &decl.params {
99 visitor.visit_param(param);
100 }
101 visitor.visit_type(&decl.ret_type);
102 visitor.visit_expr(&decl.body);
103}
104pub trait LcnfMutVisitor {
106 fn visit_expr_mut(&mut self, expr: &mut LcnfExpr) {
107 walk_expr_mut(self, expr);
108 }
109 fn visit_let_value_mut(&mut self, val: &mut LcnfLetValue) {
110 walk_let_value_mut(self, val);
111 }
112 fn visit_arg_mut(&mut self, _arg: &mut LcnfArg) {}
113 fn visit_type_mut(&mut self, _ty: &mut LcnfType) {}
114 fn visit_alt_mut(&mut self, alt: &mut LcnfAlt) {
115 walk_alt_mut(self, alt);
116 }
117 fn visit_fun_decl_mut(&mut self, decl: &mut LcnfFunDecl) {
118 walk_fun_decl_mut(self, decl);
119 }
120 fn visit_param_mut(&mut self, _param: &mut LcnfParam) {}
121}
122pub fn walk_expr_mut<V: LcnfMutVisitor + ?Sized>(visitor: &mut V, expr: &mut LcnfExpr) {
124 match expr {
125 LcnfExpr::Let {
126 ty, value, body, ..
127 } => {
128 visitor.visit_type_mut(ty);
129 visitor.visit_let_value_mut(value);
130 visitor.visit_expr_mut(body);
131 }
132 LcnfExpr::Case {
133 scrutinee_ty,
134 alts,
135 default,
136 ..
137 } => {
138 visitor.visit_type_mut(scrutinee_ty);
139 for alt in alts {
140 visitor.visit_alt_mut(alt);
141 }
142 if let Some(def) = default {
143 visitor.visit_expr_mut(def);
144 }
145 }
146 LcnfExpr::Return(arg) => visitor.visit_arg_mut(arg),
147 LcnfExpr::Unreachable => {}
148 LcnfExpr::TailCall(func, args) => {
149 visitor.visit_arg_mut(func);
150 for arg in args {
151 visitor.visit_arg_mut(arg);
152 }
153 }
154 }
155}
156pub fn walk_let_value_mut<V: LcnfMutVisitor + ?Sized>(visitor: &mut V, val: &mut LcnfLetValue) {
158 match val {
159 LcnfLetValue::App(func, args) => {
160 visitor.visit_arg_mut(func);
161 for arg in args {
162 visitor.visit_arg_mut(arg);
163 }
164 }
165 LcnfLetValue::Proj(..) => {}
166 LcnfLetValue::Ctor(_, _, args) => {
167 for arg in args {
168 visitor.visit_arg_mut(arg);
169 }
170 }
171 LcnfLetValue::Lit(_)
172 | LcnfLetValue::Erased
173 | LcnfLetValue::FVar(_)
174 | LcnfLetValue::Reset(_)
175 | LcnfLetValue::Reuse(_, _, _, _) => {}
176 }
177}
178pub fn walk_alt_mut<V: LcnfMutVisitor + ?Sized>(visitor: &mut V, alt: &mut LcnfAlt) {
180 for param in &mut alt.params {
181 visitor.visit_param_mut(param);
182 }
183 visitor.visit_expr_mut(&mut alt.body);
184}
185pub fn walk_fun_decl_mut<V: LcnfMutVisitor + ?Sized>(visitor: &mut V, decl: &mut LcnfFunDecl) {
187 for param in &mut decl.params {
188 visitor.visit_param_mut(param);
189 }
190 visitor.visit_type_mut(&mut decl.ret_type);
191 visitor.visit_expr_mut(&mut decl.body);
192}
193pub trait LcnfFolder {
195 fn fold_expr(&mut self, expr: LcnfExpr) -> LcnfExpr {
196 match expr {
197 LcnfExpr::Let {
198 id,
199 name,
200 ty,
201 value,
202 body,
203 } => {
204 let new_value = self.fold_let_value(value);
205 let new_body = self.fold_expr(*body);
206 LcnfExpr::Let {
207 id,
208 name,
209 ty,
210 value: new_value,
211 body: Box::new(new_body),
212 }
213 }
214 LcnfExpr::Case {
215 scrutinee,
216 scrutinee_ty,
217 alts,
218 default,
219 } => {
220 let new_alts = alts
221 .into_iter()
222 .map(|alt| {
223 let new_body = self.fold_expr(alt.body);
224 LcnfAlt {
225 ctor_name: alt.ctor_name,
226 ctor_tag: alt.ctor_tag,
227 params: alt.params,
228 body: new_body,
229 }
230 })
231 .collect();
232 let new_default = default.map(|d| Box::new(self.fold_expr(*d)));
233 LcnfExpr::Case {
234 scrutinee,
235 scrutinee_ty,
236 alts: new_alts,
237 default: new_default,
238 }
239 }
240 other => other,
241 }
242 }
243 fn fold_let_value(&mut self, val: LcnfLetValue) -> LcnfLetValue {
244 val
245 }
246}
247pub fn free_vars(expr: &LcnfExpr) -> HashSet<LcnfVarId> {
249 let mut collector = FreeVarCollector::new();
250 collector.collect_expr(expr);
251 collector.free
252}
253pub fn bound_vars(expr: &LcnfExpr) -> HashSet<LcnfVarId> {
255 let mut result = HashSet::new();
256 collect_bound_vars(expr, &mut result);
257 result
258}
259pub(super) fn collect_bound_vars(expr: &LcnfExpr, result: &mut HashSet<LcnfVarId>) {
260 match expr {
261 LcnfExpr::Let { id, body, .. } => {
262 result.insert(*id);
263 collect_bound_vars(body, result);
264 }
265 LcnfExpr::Case { alts, default, .. } => {
266 for alt in alts {
267 for param in &alt.params {
268 result.insert(param.id);
269 }
270 collect_bound_vars(&alt.body, result);
271 }
272 if let Some(def) = default {
273 collect_bound_vars(def, result);
274 }
275 }
276 LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => {}
277 }
278}
279pub fn all_vars(expr: &LcnfExpr) -> HashSet<LcnfVarId> {
281 let mut result = free_vars(expr);
282 result.extend(bound_vars(expr));
283 result
284}
285pub fn usage_counts(expr: &LcnfExpr) -> HashMap<LcnfVarId, usize> {
287 let mut counter = UsageCounter::new();
288 counter.count_expr(expr);
289 counter.counts
290}
291pub fn is_linear(expr: &LcnfExpr) -> bool {
293 usage_counts(expr).values().all(|&c| c <= 1)
294}
295pub fn definition_sites(expr: &LcnfExpr) -> Vec<DefinitionSite> {
297 let mut sites = Vec::new();
298 collect_definition_sites(expr, 0, &mut sites);
299 sites
300}
301pub(super) fn collect_definition_sites(
302 expr: &LcnfExpr,
303 depth: usize,
304 sites: &mut Vec<DefinitionSite>,
305) {
306 match expr {
307 LcnfExpr::Let {
308 id, name, ty, body, ..
309 } => {
310 sites.push(DefinitionSite {
311 var: *id,
312 name: name.clone(),
313 ty: ty.clone(),
314 depth,
315 });
316 collect_definition_sites(body, depth + 1, sites);
317 }
318 LcnfExpr::Case { alts, default, .. } => {
319 for alt in alts {
320 for param in &alt.params {
321 sites.push(DefinitionSite {
322 var: param.id,
323 name: param.name.clone(),
324 ty: param.ty.clone(),
325 depth: depth + 1,
326 });
327 }
328 collect_definition_sites(&alt.body, depth + 1, sites);
329 }
330 if let Some(def) = default {
331 collect_definition_sites(def, depth + 1, sites);
332 }
333 }
334 LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => {}
335 }
336}
337pub fn substitute_arg(arg: &LcnfArg, subst: &Substitution) -> LcnfArg {
339 if let LcnfArg::Var(id) = arg {
340 if let Some(replacement) = subst.get(id) {
341 return replacement.clone();
342 }
343 }
344 arg.clone()
345}
346pub fn substitute_let_value(val: &LcnfLetValue, subst: &Substitution) -> LcnfLetValue {
348 match val {
349 LcnfLetValue::App(func, args) => LcnfLetValue::App(
350 substitute_arg(func, subst),
351 args.iter().map(|a| substitute_arg(a, subst)).collect(),
352 ),
353 LcnfLetValue::Proj(name, idx, var) => {
354 if let Some(LcnfArg::Var(new_var)) = subst.get(var) {
355 LcnfLetValue::Proj(name.clone(), *idx, *new_var)
356 } else {
357 val.clone()
358 }
359 }
360 LcnfLetValue::Ctor(name, tag, args) => LcnfLetValue::Ctor(
361 name.clone(),
362 *tag,
363 args.iter().map(|a| substitute_arg(a, subst)).collect(),
364 ),
365 LcnfLetValue::FVar(id) => {
366 if let Some(LcnfArg::Var(new_id)) = subst.get(id) {
367 LcnfLetValue::FVar(*new_id)
368 } else {
369 val.clone()
370 }
371 }
372 LcnfLetValue::Lit(_)
373 | LcnfLetValue::Erased
374 | LcnfLetValue::Reset(_)
375 | LcnfLetValue::Reuse(_, _, _, _) => val.clone(),
376 }
377}
378pub fn substitute_expr(expr: &LcnfExpr, subst: &Substitution) -> LcnfExpr {
380 match expr {
381 LcnfExpr::Let {
382 id,
383 name,
384 ty,
385 value,
386 body,
387 } => {
388 let new_value = substitute_let_value(value, subst);
389 let mut inner_subst = subst.clone();
390 inner_subst.0.remove(id);
391 LcnfExpr::Let {
392 id: *id,
393 name: name.clone(),
394 ty: ty.clone(),
395 value: new_value,
396 body: Box::new(substitute_expr(body, &inner_subst)),
397 }
398 }
399 LcnfExpr::Case {
400 scrutinee,
401 scrutinee_ty,
402 alts,
403 default,
404 } => {
405 let new_scrutinee = if let Some(LcnfArg::Var(new_id)) = subst.get(scrutinee) {
406 *new_id
407 } else {
408 *scrutinee
409 };
410 let new_alts = alts
411 .iter()
412 .map(|alt| {
413 let mut inner_subst = subst.clone();
414 for param in &alt.params {
415 inner_subst.0.remove(¶m.id);
416 }
417 LcnfAlt {
418 ctor_name: alt.ctor_name.clone(),
419 ctor_tag: alt.ctor_tag,
420 params: alt.params.clone(),
421 body: substitute_expr(&alt.body, &inner_subst),
422 }
423 })
424 .collect();
425 let new_default = default
426 .as_ref()
427 .map(|d| Box::new(substitute_expr(d, subst)));
428 LcnfExpr::Case {
429 scrutinee: new_scrutinee,
430 scrutinee_ty: scrutinee_ty.clone(),
431 alts: new_alts,
432 default: new_default,
433 }
434 }
435 LcnfExpr::Return(arg) => LcnfExpr::Return(substitute_arg(arg, subst)),
436 LcnfExpr::Unreachable => LcnfExpr::Unreachable,
437 LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(
438 substitute_arg(func, subst),
439 args.iter().map(|a| substitute_arg(a, subst)).collect(),
440 ),
441 }
442}
443pub fn rename_vars(expr: &LcnfExpr, rename: &HashMap<LcnfVarId, LcnfVarId>) -> LcnfExpr {
445 let subst = Substitution(
446 rename
447 .iter()
448 .map(|(old, new)| (*old, LcnfArg::Var(*new)))
449 .collect(),
450 );
451 rename_expr_inner(expr, rename, &subst)
452}
453pub(super) fn rename_expr_inner(
454 expr: &LcnfExpr,
455 rename: &HashMap<LcnfVarId, LcnfVarId>,
456 subst: &Substitution,
457) -> LcnfExpr {
458 match expr {
459 LcnfExpr::Let {
460 id,
461 name,
462 ty,
463 value,
464 body,
465 } => {
466 let new_id = rename.get(id).copied().unwrap_or(*id);
467 LcnfExpr::Let {
468 id: new_id,
469 name: name.clone(),
470 ty: ty.clone(),
471 value: substitute_let_value(value, subst),
472 body: Box::new(rename_expr_inner(body, rename, subst)),
473 }
474 }
475 LcnfExpr::Case {
476 scrutinee,
477 scrutinee_ty,
478 alts,
479 default,
480 } => {
481 let new_scrutinee = rename.get(scrutinee).copied().unwrap_or(*scrutinee);
482 let new_alts = alts
483 .iter()
484 .map(|alt| {
485 let new_params: Vec<LcnfParam> = alt
486 .params
487 .iter()
488 .map(|p| LcnfParam {
489 id: rename.get(&p.id).copied().unwrap_or(p.id),
490 name: p.name.clone(),
491 ty: p.ty.clone(),
492 erased: p.erased,
493 borrowed: false,
494 })
495 .collect();
496 LcnfAlt {
497 ctor_name: alt.ctor_name.clone(),
498 ctor_tag: alt.ctor_tag,
499 params: new_params,
500 body: rename_expr_inner(&alt.body, rename, subst),
501 }
502 })
503 .collect();
504 let new_default = default
505 .as_ref()
506 .map(|d| Box::new(rename_expr_inner(d, rename, subst)));
507 LcnfExpr::Case {
508 scrutinee: new_scrutinee,
509 scrutinee_ty: scrutinee_ty.clone(),
510 alts: new_alts,
511 default: new_default,
512 }
513 }
514 LcnfExpr::Return(arg) => LcnfExpr::Return(substitute_arg(arg, subst)),
515 LcnfExpr::Unreachable => LcnfExpr::Unreachable,
516 LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(
517 substitute_arg(func, subst),
518 args.iter().map(|a| substitute_arg(a, subst)).collect(),
519 ),
520 }
521}
522pub fn alpha_equiv(e1: &LcnfExpr, e2: &LcnfExpr) -> bool {
524 let mut l2r: HashMap<LcnfVarId, LcnfVarId> = HashMap::new();
525 let mut r2l: HashMap<LcnfVarId, LcnfVarId> = HashMap::new();
526 alpha_equiv_expr(e1, e2, &mut l2r, &mut r2l)
527}
528pub(super) fn alpha_equiv_var(
529 v1: LcnfVarId,
530 v2: LcnfVarId,
531 l2r: &HashMap<LcnfVarId, LcnfVarId>,
532 r2l: &HashMap<LcnfVarId, LcnfVarId>,
533) -> bool {
534 match (l2r.get(&v1), r2l.get(&v2)) {
535 (Some(&mapped), Some(&mapped_back)) => mapped == v2 && mapped_back == v1,
536 (None, None) => v1 == v2,
537 _ => false,
538 }
539}
540pub(super) fn alpha_equiv_arg(
541 a1: &LcnfArg,
542 a2: &LcnfArg,
543 l2r: &HashMap<LcnfVarId, LcnfVarId>,
544 r2l: &HashMap<LcnfVarId, LcnfVarId>,
545) -> bool {
546 match (a1, a2) {
547 (LcnfArg::Var(v1), LcnfArg::Var(v2)) => alpha_equiv_var(*v1, *v2, l2r, r2l),
548 (LcnfArg::Lit(l1), LcnfArg::Lit(l2)) => l1 == l2,
549 (LcnfArg::Erased, LcnfArg::Erased) => true,
550 (LcnfArg::Type(t1), LcnfArg::Type(t2)) => t1 == t2,
551 _ => false,
552 }
553}
554pub(super) fn alpha_equiv_let_value(
555 v1: &LcnfLetValue,
556 v2: &LcnfLetValue,
557 l2r: &HashMap<LcnfVarId, LcnfVarId>,
558 r2l: &HashMap<LcnfVarId, LcnfVarId>,
559) -> bool {
560 match (v1, v2) {
561 (LcnfLetValue::App(f1, a1), LcnfLetValue::App(f2, a2)) => {
562 alpha_equiv_arg(f1, f2, l2r, r2l)
563 && a1.len() == a2.len()
564 && a1
565 .iter()
566 .zip(a2.iter())
567 .all(|(x, y)| alpha_equiv_arg(x, y, l2r, r2l))
568 }
569 (LcnfLetValue::Proj(n1, i1, var1), LcnfLetValue::Proj(n2, i2, var2)) => {
570 n1 == n2 && i1 == i2 && alpha_equiv_var(*var1, *var2, l2r, r2l)
571 }
572 (LcnfLetValue::Ctor(n1, t1, a1), LcnfLetValue::Ctor(n2, t2, a2)) => {
573 n1 == n2
574 && t1 == t2
575 && a1.len() == a2.len()
576 && a1
577 .iter()
578 .zip(a2.iter())
579 .all(|(x, y)| alpha_equiv_arg(x, y, l2r, r2l))
580 }
581 (LcnfLetValue::Lit(l1), LcnfLetValue::Lit(l2)) => l1 == l2,
582 (LcnfLetValue::Erased, LcnfLetValue::Erased) => true,
583 (LcnfLetValue::FVar(id1), LcnfLetValue::FVar(id2)) => alpha_equiv_var(*id1, *id2, l2r, r2l),
584 _ => false,
585 }
586}
587#[allow(clippy::too_many_arguments)]
588pub(super) fn alpha_equiv_expr(
589 e1: &LcnfExpr,
590 e2: &LcnfExpr,
591 l2r: &mut HashMap<LcnfVarId, LcnfVarId>,
592 r2l: &mut HashMap<LcnfVarId, LcnfVarId>,
593) -> bool {
594 match (e1, e2) {
595 (
596 LcnfExpr::Let {
597 id: id1,
598 ty: ty1,
599 value: val1,
600 body: body1,
601 ..
602 },
603 LcnfExpr::Let {
604 id: id2,
605 ty: ty2,
606 value: val2,
607 body: body2,
608 ..
609 },
610 ) => {
611 if ty1 != ty2 || !alpha_equiv_let_value(val1, val2, l2r, r2l) {
612 return false;
613 }
614 l2r.insert(*id1, *id2);
615 r2l.insert(*id2, *id1);
616 let result = alpha_equiv_expr(body1, body2, l2r, r2l);
617 l2r.remove(id1);
618 r2l.remove(id2);
619 result
620 }
621 (
622 LcnfExpr::Case {
623 scrutinee: s1,
624 scrutinee_ty: st1,
625 alts: alts1,
626 default: def1,
627 },
628 LcnfExpr::Case {
629 scrutinee: s2,
630 scrutinee_ty: st2,
631 alts: alts2,
632 default: def2,
633 },
634 ) => {
635 if !alpha_equiv_var(*s1, *s2, l2r, r2l) || st1 != st2 || alts1.len() != alts2.len() {
636 return false;
637 }
638 for (a1, a2) in alts1.iter().zip(alts2.iter()) {
639 if a1.ctor_name != a2.ctor_name
640 || a1.ctor_tag != a2.ctor_tag
641 || a1.params.len() != a2.params.len()
642 {
643 return false;
644 }
645 for (p1, p2) in a1.params.iter().zip(a2.params.iter()) {
646 l2r.insert(p1.id, p2.id);
647 r2l.insert(p2.id, p1.id);
648 }
649 let ok = alpha_equiv_expr(&a1.body, &a2.body, l2r, r2l);
650 for (p1, p2) in a1.params.iter().zip(a2.params.iter()) {
651 l2r.remove(&p1.id);
652 r2l.remove(&p2.id);
653 }
654 if !ok {
655 return false;
656 }
657 }
658 match (def1, def2) {
659 (Some(d1), Some(d2)) => alpha_equiv_expr(d1, d2, l2r, r2l),
660 (None, None) => true,
661 _ => false,
662 }
663 }
664 (LcnfExpr::Return(a1), LcnfExpr::Return(a2)) => alpha_equiv_arg(a1, a2, l2r, r2l),
665 (LcnfExpr::Unreachable, LcnfExpr::Unreachable) => true,
666 (LcnfExpr::TailCall(f1, a1), LcnfExpr::TailCall(f2, a2)) => {
667 alpha_equiv_arg(f1, f2, l2r, r2l)
668 && a1.len() == a2.len()
669 && a1
670 .iter()
671 .zip(a2.iter())
672 .all(|(x, y)| alpha_equiv_arg(x, y, l2r, r2l))
673 }
674 _ => false,
675 }
676}
677pub fn expr_size(expr: &LcnfExpr) -> usize {
679 match expr {
680 LcnfExpr::Let { value, body, .. } => 1 + let_value_size(value) + expr_size(body),
681 LcnfExpr::Case { alts, default, .. } => {
682 let alt_size: usize = alts.iter().map(|a| 1 + expr_size(&a.body)).sum();
683 let def_size = default.as_ref().map_or(0, |d| expr_size(d));
684 1 + alt_size + def_size
685 }
686 LcnfExpr::Return(_) => 1,
687 LcnfExpr::Unreachable => 1,
688 LcnfExpr::TailCall(_, args) => 1 + args.len(),
689 }
690}
691pub(super) fn let_value_size(val: &LcnfLetValue) -> usize {
692 match val {
693 LcnfLetValue::App(_, args) => 1 + args.len(),
694 LcnfLetValue::Proj(..) => 1,
695 LcnfLetValue::Ctor(_, _, args) => 1 + args.len(),
696 LcnfLetValue::Lit(_)
697 | LcnfLetValue::Erased
698 | LcnfLetValue::FVar(_)
699 | LcnfLetValue::Reset(_)
700 | LcnfLetValue::Reuse(_, _, _, _) => 1,
701 }
702}
703pub fn expr_depth(expr: &LcnfExpr) -> usize {
705 match expr {
706 LcnfExpr::Let { body, .. } => 1 + expr_depth(body),
707 LcnfExpr::Case { alts, default, .. } => {
708 let max_alt = alts.iter().map(|a| expr_depth(&a.body)).max().unwrap_or(0);
709 let def_depth = default.as_ref().map_or(0, |d| expr_depth(d));
710 1 + max_alt.max(def_depth)
711 }
712 LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => 1,
713 }
714}
715pub fn compute_inline_cost(decl: &LcnfFunDecl) -> usize {
717 let base = expr_size(&decl.body);
718 let depth_penalty = expr_depth(&decl.body);
719 let branch_penalty = count_branches(&decl.body) * 5;
720 let recursive_penalty = if decl.is_recursive { 100 } else { 0 };
721 let param_bonus = if decl.params.len() <= 2 {
722 0
723 } else {
724 decl.params.len() * 2
725 };
726 base + depth_penalty + branch_penalty + recursive_penalty + param_bonus
727}
728pub fn estimate_runtime_cost(expr: &LcnfExpr, model: &CostModel) -> u64 {
730 match expr {
731 LcnfExpr::Let { value, body, .. } => {
732 let val_cost = match value {
733 LcnfLetValue::App(..) | LcnfLetValue::Ctor(..) => model.app_cost,
734 LcnfLetValue::Proj(..) | LcnfLetValue::Lit(_) | LcnfLetValue::FVar(_) => {
735 model.let_cost
736 }
737 LcnfLetValue::Erased | LcnfLetValue::Reset(_) | LcnfLetValue::Reuse(_, _, _, _) => {
738 0
739 }
740 };
741 model.let_cost + val_cost + estimate_runtime_cost(body, model)
742 }
743 LcnfExpr::Case { alts, default, .. } => {
744 let max_alt_cost = alts
745 .iter()
746 .map(|a| estimate_runtime_cost(&a.body, model))
747 .max()
748 .unwrap_or(0);
749 let def_cost = default
750 .as_ref()
751 .map_or(0, |d| estimate_runtime_cost(d, model));
752 model.case_cost
753 + model.branch_penalty * (alts.len() as u64)
754 + max_alt_cost.max(def_cost)
755 }
756 LcnfExpr::Return(_) => model.return_cost,
757 LcnfExpr::Unreachable => 0,
758 LcnfExpr::TailCall(_, args) => model.app_cost + (args.len() as u64),
759 }
760}
761pub fn count_allocations(expr: &LcnfExpr) -> usize {
763 match expr {
764 LcnfExpr::Let { value, body, .. } => {
765 let alloc = match value {
766 LcnfLetValue::Ctor(_, _, args) if !args.is_empty() => 1,
767 _ => 0,
768 };
769 alloc + count_allocations(body)
770 }
771 LcnfExpr::Case { alts, default, .. } => {
772 let alt_allocs: usize = alts.iter().map(|a| count_allocations(&a.body)).sum();
773 let def_allocs = default.as_ref().map_or(0, |d| count_allocations(d));
774 alt_allocs + def_allocs
775 }
776 LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => 0,
777 }
778}
779pub fn count_branches(expr: &LcnfExpr) -> usize {
781 match expr {
782 LcnfExpr::Let { body, .. } => count_branches(body),
783 LcnfExpr::Case { alts, default, .. } => {
784 let inner: usize = alts.iter().map(|a| count_branches(&a.body)).sum();
785 let def_branches = default.as_ref().map_or(0, |d| count_branches(d));
786 1 + inner + def_branches
787 }
788 LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => 0,
789 }
790}
791pub fn validate_expr(expr: &LcnfExpr, bound: &HashSet<LcnfVarId>) -> Result<(), ValidationError> {
793 match expr {
794 LcnfExpr::Let {
795 id, value, body, ..
796 } => {
797 validate_let_value(value, bound)?;
798 let mut new_bound = bound.clone();
799 if !new_bound.insert(*id) {
800 return Err(ValidationError::DuplicateBinding(*id));
801 }
802 validate_expr(body, &new_bound)
803 }
804 LcnfExpr::Case {
805 scrutinee,
806 alts,
807 default,
808 ..
809 } => {
810 if !bound.contains(scrutinee) {
811 return Err(ValidationError::UnboundVariable(*scrutinee));
812 }
813 if alts.is_empty() && default.is_none() {
814 return Err(ValidationError::EmptyCase);
815 }
816 for alt in alts {
817 let mut alt_bound = bound.clone();
818 for param in &alt.params {
819 if !alt_bound.insert(param.id) {
820 return Err(ValidationError::DuplicateBinding(param.id));
821 }
822 }
823 validate_expr(&alt.body, &alt_bound)?;
824 }
825 if let Some(def) = default {
826 validate_expr(def, bound)?;
827 }
828 Ok(())
829 }
830 LcnfExpr::Return(arg) => validate_arg_bound(arg, bound),
831 LcnfExpr::Unreachable => Ok(()),
832 LcnfExpr::TailCall(func, args) => {
833 validate_arg_bound(func, bound)?;
834 for arg in args {
835 validate_arg_bound(arg, bound)?;
836 }
837 Ok(())
838 }
839 }
840}
841pub(super) fn validate_arg_bound(
842 arg: &LcnfArg,
843 bound: &HashSet<LcnfVarId>,
844) -> Result<(), ValidationError> {
845 if let LcnfArg::Var(id) = arg {
846 if !bound.contains(id) {
847 return Err(ValidationError::UnboundVariable(*id));
848 }
849 }
850 Ok(())
851}
852pub(super) fn validate_let_value(
853 val: &LcnfLetValue,
854 bound: &HashSet<LcnfVarId>,
855) -> Result<(), ValidationError> {
856 match val {
857 LcnfLetValue::App(func, args) => {
858 validate_arg_bound(func, bound)?;
859 for arg in args {
860 validate_arg_bound(arg, bound)?;
861 }
862 Ok(())
863 }
864 LcnfLetValue::Proj(_, _, var) => {
865 if !bound.contains(var) {
866 Err(ValidationError::UnboundVariable(*var))
867 } else {
868 Ok(())
869 }
870 }
871 LcnfLetValue::Ctor(_, _, args) => {
872 for arg in args {
873 validate_arg_bound(arg, bound)?;
874 }
875 Ok(())
876 }
877 LcnfLetValue::FVar(id) => {
878 if !bound.contains(id) {
879 Err(ValidationError::UnboundVariable(*id))
880 } else {
881 Ok(())
882 }
883 }
884 LcnfLetValue::Lit(_)
885 | LcnfLetValue::Erased
886 | LcnfLetValue::Reset(_)
887 | LcnfLetValue::Reuse(_, _, _, _) => Ok(()),
888 }
889}
890pub fn validate_fun_decl(decl: &LcnfFunDecl) -> Result<(), ValidationError> {
892 let mut bound = HashSet::new();
893 for param in &decl.params {
894 if !bound.insert(param.id) {
895 return Err(ValidationError::DuplicateBinding(param.id));
896 }
897 }
898 validate_expr(&decl.body, &bound)
899}
900pub fn validate_module(module: &LcnfModule) -> Result<(), Vec<ValidationError>> {
902 let mut errors = Vec::new();
903 for decl in &module.fun_decls {
904 if let Err(e) = validate_fun_decl(decl) {
905 errors.push(e);
906 }
907 }
908 if errors.is_empty() {
909 Ok(())
910 } else {
911 Err(errors)
912 }
913}
914pub fn check_anf_invariant(expr: &LcnfExpr) -> bool {
916 match expr {
917 LcnfExpr::Let { value, body, .. } => {
918 check_let_value_anf(value) && check_anf_invariant(body)
919 }
920 LcnfExpr::Case { alts, default, .. } => {
921 alts.iter().all(|a| check_anf_invariant(&a.body))
922 && default.as_ref().is_none_or(|d| check_anf_invariant(d))
923 }
924 LcnfExpr::Return(_) | LcnfExpr::Unreachable | LcnfExpr::TailCall(..) => true,
925 }
926}
927pub(super) fn check_let_value_anf(val: &LcnfLetValue) -> bool {
928 match val {
929 LcnfLetValue::App(func, args) => is_atomic_arg(func) && args.iter().all(is_atomic_arg),
930 LcnfLetValue::Ctor(_, _, args) => args.iter().all(is_atomic_arg),
931 _ => true,
932 }
933}
934pub(super) fn is_atomic_arg(arg: &LcnfArg) -> bool {
935 matches!(
936 arg,
937 LcnfArg::Var(_) | LcnfArg::Lit(_) | LcnfArg::Erased | LcnfArg::Type(_)
938 )
939}
940pub fn pretty_print_expr(expr: &LcnfExpr, config: &PrettyConfig) -> String {
942 let mut output = String::new();
943 pp_expr(&mut output, expr, config, 0);
944 output
945}
946pub(super) fn pp_indent(output: &mut String, config: &PrettyConfig, level: usize) {
947 for _ in 0..level * config.indent {
948 output.push(' ');
949 }
950}
951pub(super) fn pp_arg(output: &mut String, arg: &LcnfArg, config: &PrettyConfig) {
952 match arg {
953 LcnfArg::Var(id) => output.push_str(&id.to_string()),
954 LcnfArg::Lit(lit) => output.push_str(&lit.to_string()),
955 LcnfArg::Erased => {
956 if config.show_erased {
957 output.push('â—»');
958 } else {
959 output.push('_');
960 }
961 }
962 LcnfArg::Type(ty) => {
963 if config.show_types {
964 output.push('@');
965 output.push_str(&ty.to_string());
966 } else {
967 output.push('_');
968 }
969 }
970 }
971}
972pub(super) fn pp_let_value(output: &mut String, val: &LcnfLetValue, config: &PrettyConfig) {
973 match val {
974 LcnfLetValue::App(func, args) => {
975 pp_arg(output, func, config);
976 output.push('(');
977 for (i, a) in args.iter().enumerate() {
978 if i > 0 {
979 output.push_str(", ");
980 }
981 pp_arg(output, a, config);
982 }
983 output.push(')');
984 }
985 LcnfLetValue::Proj(name, idx, var) => {
986 output.push_str(&format!("{}.{} {}", name, idx, var));
987 }
988 LcnfLetValue::Ctor(name, tag, args) => {
989 output.push_str(&format!("{}#{}", name, tag));
990 if !args.is_empty() {
991 output.push('(');
992 for (i, a) in args.iter().enumerate() {
993 if i > 0 {
994 output.push_str(", ");
995 }
996 pp_arg(output, a, config);
997 }
998 output.push(')');
999 }
1000 }
1001 LcnfLetValue::Lit(lit) => output.push_str(&lit.to_string()),
1002 LcnfLetValue::Erased => output.push_str("erased"),
1003 LcnfLetValue::FVar(id) => output.push_str(&id.to_string()),
1004 LcnfLetValue::Reset(var) => output.push_str(&format!("reset({})", var)),
1005 LcnfLetValue::Reuse(slot, name, tag, _) => {
1006 output.push_str(&format!("reuse({}, {}#{})", slot, name, tag))
1007 }
1008 }
1009}
1010pub(super) fn pp_expr(output: &mut String, expr: &LcnfExpr, config: &PrettyConfig, level: usize) {
1011 match expr {
1012 LcnfExpr::Let {
1013 id,
1014 name,
1015 ty,
1016 value,
1017 body,
1018 } => {
1019 pp_indent(output, config, level);
1020 output.push_str("let ");
1021 output.push_str(&id.to_string());
1022 if !name.is_empty() {
1023 output.push_str(&format!(" ({})", name));
1024 }
1025 if config.show_types {
1026 output.push_str(&format!(" : {}", ty));
1027 }
1028 output.push_str(" := ");
1029 pp_let_value(output, value, config);
1030 output.push('\n');
1031 pp_expr(output, body, config, level);
1032 }
1033 LcnfExpr::Case {
1034 scrutinee,
1035 scrutinee_ty,
1036 alts,
1037 default,
1038 } => {
1039 pp_indent(output, config, level);
1040 output.push_str(&format!("case {}", scrutinee));
1041 if config.show_types {
1042 output.push_str(&format!(" : {}", scrutinee_ty));
1043 }
1044 output.push_str(" of\n");
1045 for alt in alts {
1046 pp_indent(output, config, level + 1);
1047 output.push_str(&format!("| {}#{}", alt.ctor_name, alt.ctor_tag));
1048 for p in &alt.params {
1049 output.push_str(&format!(" {}", p.id));
1050 }
1051 output.push_str(" =>\n");
1052 pp_expr(output, &alt.body, config, level + 2);
1053 }
1054 if let Some(def) = default {
1055 pp_indent(output, config, level + 1);
1056 output.push_str("| _ =>\n");
1057 pp_expr(output, def, config, level + 2);
1058 }
1059 }
1060 LcnfExpr::Return(arg) => {
1061 pp_indent(output, config, level);
1062 output.push_str("return ");
1063 pp_arg(output, arg, config);
1064 output.push('\n');
1065 }
1066 LcnfExpr::Unreachable => {
1067 pp_indent(output, config, level);
1068 output.push_str("unreachable\n");
1069 }
1070 LcnfExpr::TailCall(func, args) => {
1071 pp_indent(output, config, level);
1072 output.push_str("tailcall ");
1073 pp_arg(output, func, config);
1074 output.push('(');
1075 for (i, a) in args.iter().enumerate() {
1076 if i > 0 {
1077 output.push_str(", ");
1078 }
1079 pp_arg(output, a, config);
1080 }
1081 output.push_str(")\n");
1082 }
1083 }
1084}
1085pub fn pretty_print_fun_decl(decl: &LcnfFunDecl, config: &PrettyConfig) -> String {
1087 let mut output = String::new();
1088 output.push_str("def ");
1089 output.push_str(&decl.name);
1090 output.push('(');
1091 for (i, param) in decl.params.iter().enumerate() {
1092 if i > 0 {
1093 output.push_str(", ");
1094 }
1095 output.push_str(&format!("{}", param.id));
1096 if !param.name.is_empty() {
1097 output.push_str(&format!(" ({})", param.name));
1098 }
1099 if config.show_types {
1100 output.push_str(&format!(" : {}", param.ty));
1101 }
1102 }
1103 output.push(')');
1104 if config.show_types {
1105 output.push_str(&format!(" : {}", decl.ret_type));
1106 }
1107 if decl.is_recursive {
1108 output.push_str(" [rec]");
1109 }
1110 if decl.is_lifted {
1111 output.push_str(" [lifted]");
1112 }
1113 output.push_str(" :=\n");
1114 pp_expr(&mut output, &decl.body, config, 1);
1115 output
1116}
1117pub fn pretty_print_module(module: &LcnfModule, config: &PrettyConfig) -> String {
1119 let mut output = String::new();
1120 output.push_str(&format!("-- module {}\n", module.name));
1121 output.push_str(&format!(
1122 "-- {} decls, {} externs\n\n",
1123 module.fun_decls.len(),
1124 module.extern_decls.len()
1125 ));
1126 for decl in &module.extern_decls {
1127 output.push_str("extern ");
1128 output.push_str(&decl.name);
1129 output.push('(');
1130 for (i, param) in decl.params.iter().enumerate() {
1131 if i > 0 {
1132 output.push_str(", ");
1133 }
1134 if config.show_types {
1135 output.push_str(&format!("{} : {}", param.id, param.ty));
1136 } else {
1137 output.push_str(&format!("{}", param.id));
1138 }
1139 }
1140 output.push(')');
1141 if config.show_types {
1142 output.push_str(&format!(" : {}", decl.ret_type));
1143 }
1144 output.push('\n');
1145 }
1146 if !module.extern_decls.is_empty() {
1147 output.push('\n');
1148 }
1149 for decl in &module.fun_decls {
1150 output.push_str(&pretty_print_fun_decl(decl, config));
1151 output.push('\n');
1152 }
1153 output
1154}
1155pub fn inline_let(expr: LcnfExpr, var: LcnfVarId) -> LcnfExpr {
1157 match expr {
1158 LcnfExpr::Let {
1159 id,
1160 name,
1161 ty,
1162 value,
1163 body,
1164 } if id == var => {
1165 if let Some(arg) = let_value_to_arg(&value) {
1166 let mut subst = Substitution::new();
1167 subst.insert(id, arg);
1168 substitute_expr(&body, &subst)
1169 } else {
1170 LcnfExpr::Let {
1171 id,
1172 name,
1173 ty,
1174 value,
1175 body,
1176 }
1177 }
1178 }
1179 LcnfExpr::Let {
1180 id,
1181 name,
1182 ty,
1183 value,
1184 body,
1185 } => LcnfExpr::Let {
1186 id,
1187 name,
1188 ty,
1189 value,
1190 body: Box::new(inline_let(*body, var)),
1191 },
1192 LcnfExpr::Case {
1193 scrutinee,
1194 scrutinee_ty,
1195 alts,
1196 default,
1197 } => LcnfExpr::Case {
1198 scrutinee,
1199 scrutinee_ty,
1200 alts: alts
1201 .into_iter()
1202 .map(|a| LcnfAlt {
1203 ctor_name: a.ctor_name,
1204 ctor_tag: a.ctor_tag,
1205 params: a.params,
1206 body: inline_let(a.body, var),
1207 })
1208 .collect(),
1209 default: default.map(|d| Box::new(inline_let(*d, var))),
1210 },
1211 other => other,
1212 }
1213}
1214pub(super) fn let_value_to_arg(val: &LcnfLetValue) -> Option<LcnfArg> {
1216 match val {
1217 LcnfLetValue::Lit(lit) => Some(LcnfArg::Lit(lit.clone())),
1218 LcnfLetValue::Erased => Some(LcnfArg::Erased),
1219 LcnfLetValue::FVar(id) => Some(LcnfArg::Var(*id)),
1220 _ => None,
1221 }
1222}
1223pub fn flatten_lets(expr: LcnfExpr) -> LcnfExpr {
1225 let mut bindings: Vec<(LcnfVarId, String, LcnfType, LcnfLetValue)> = Vec::new();
1226 let terminal = collect_lets(expr, &mut bindings);
1227 let mut result = flatten_lets_in_terminal(terminal);
1228 for (id, name, ty, value) in bindings.into_iter().rev() {
1229 result = LcnfExpr::Let {
1230 id,
1231 name,
1232 ty,
1233 value,
1234 body: Box::new(result),
1235 };
1236 }
1237 result
1238}
1239pub(super) fn collect_lets(
1240 expr: LcnfExpr,
1241 bindings: &mut Vec<(LcnfVarId, String, LcnfType, LcnfLetValue)>,
1242) -> LcnfExpr {
1243 match expr {
1244 LcnfExpr::Let {
1245 id,
1246 name,
1247 ty,
1248 value,
1249 body,
1250 } => {
1251 bindings.push((id, name, ty, value));
1252 collect_lets(*body, bindings)
1253 }
1254 other => other,
1255 }
1256}
1257pub(super) fn flatten_lets_in_terminal(expr: LcnfExpr) -> LcnfExpr {
1258 match expr {
1259 LcnfExpr::Case {
1260 scrutinee,
1261 scrutinee_ty,
1262 alts,
1263 default,
1264 } => LcnfExpr::Case {
1265 scrutinee,
1266 scrutinee_ty,
1267 alts: alts
1268 .into_iter()
1269 .map(|a| LcnfAlt {
1270 ctor_name: a.ctor_name,
1271 ctor_tag: a.ctor_tag,
1272 params: a.params,
1273 body: flatten_lets(a.body),
1274 })
1275 .collect(),
1276 default: default.map(|d| Box::new(flatten_lets(*d))),
1277 },
1278 other => other,
1279 }
1280}
1281pub fn simplify_trivial_case(expr: LcnfExpr) -> LcnfExpr {
1283 match expr {
1284 LcnfExpr::Case {
1285 scrutinee,
1286 alts,
1287 default: None,
1288 ..
1289 } if alts.len() == 1 => {
1290 let alt = alts.into_iter().next().expect(
1291 "alts has exactly one element; guaranteed by pattern guard alts.len() == 1",
1292 );
1293 let mut result = simplify_trivial_case(alt.body);
1294 for (idx, param) in alt.params.iter().enumerate().rev() {
1295 result = LcnfExpr::Let {
1296 id: param.id,
1297 name: param.name.clone(),
1298 ty: param.ty.clone(),
1299 value: LcnfLetValue::Proj(alt.ctor_name.clone(), idx as u32, scrutinee),
1300 body: Box::new(result),
1301 };
1302 }
1303 result
1304 }
1305 LcnfExpr::Let {
1306 id,
1307 name,
1308 ty,
1309 value,
1310 body,
1311 } => LcnfExpr::Let {
1312 id,
1313 name,
1314 ty,
1315 value,
1316 body: Box::new(simplify_trivial_case(*body)),
1317 },
1318 LcnfExpr::Case {
1319 scrutinee,
1320 scrutinee_ty,
1321 alts,
1322 default,
1323 } => LcnfExpr::Case {
1324 scrutinee,
1325 scrutinee_ty,
1326 alts: alts
1327 .into_iter()
1328 .map(|a| LcnfAlt {
1329 ctor_name: a.ctor_name,
1330 ctor_tag: a.ctor_tag,
1331 params: a.params,
1332 body: simplify_trivial_case(a.body),
1333 })
1334 .collect(),
1335 default: default.map(|d| Box::new(simplify_trivial_case(*d))),
1336 },
1337 other => other,
1338 }
1339}
1340pub fn remove_unused_lets(expr: LcnfExpr) -> LcnfExpr {
1342 match expr {
1343 LcnfExpr::Let {
1344 id,
1345 name,
1346 ty,
1347 value,
1348 body,
1349 } => {
1350 let new_body = remove_unused_lets(*body);
1351 let counts = usage_counts(&new_body);
1352 if counts.get(&id).copied().unwrap_or(0) == 0 {
1353 new_body
1354 } else {
1355 LcnfExpr::Let {
1356 id,
1357 name,
1358 ty,
1359 value,
1360 body: Box::new(new_body),
1361 }
1362 }
1363 }
1364 LcnfExpr::Case {
1365 scrutinee,
1366 scrutinee_ty,
1367 alts,
1368 default,
1369 } => LcnfExpr::Case {
1370 scrutinee,
1371 scrutinee_ty,
1372 alts: alts
1373 .into_iter()
1374 .map(|a| LcnfAlt {
1375 ctor_name: a.ctor_name,
1376 ctor_tag: a.ctor_tag,
1377 params: a.params,
1378 body: remove_unused_lets(a.body),
1379 })
1380 .collect(),
1381 default: default.map(|d| Box::new(remove_unused_lets(*d))),
1382 },
1383 other => other,
1384 }
1385}
1386pub fn hoist_lets(expr: LcnfExpr) -> LcnfExpr {
1389 match expr {
1390 LcnfExpr::Let {
1391 id,
1392 name,
1393 ty,
1394 value,
1395 body,
1396 } => LcnfExpr::Let {
1397 id,
1398 name,
1399 ty,
1400 value,
1401 body: Box::new(hoist_lets(*body)),
1402 },
1403 LcnfExpr::Case {
1404 scrutinee,
1405 scrutinee_ty,
1406 alts,
1407 default,
1408 } => {
1409 let alts: Vec<LcnfAlt> = alts
1410 .into_iter()
1411 .map(|a| LcnfAlt {
1412 ctor_name: a.ctor_name,
1413 ctor_tag: a.ctor_tag,
1414 params: a.params,
1415 body: hoist_lets(a.body),
1416 })
1417 .collect();
1418 let default = default.map(|d| Box::new(hoist_lets(*d)));
1419 if alts.len() < 2 || default.is_some() {
1420 return LcnfExpr::Case {
1421 scrutinee,
1422 scrutinee_ty,
1423 alts,
1424 default,
1425 };
1426 }
1427 let first_let = match &alts[0].body {
1428 LcnfExpr::Let {
1429 name, ty, value, ..
1430 } => Some((name.clone(), ty.clone(), value.clone())),
1431 _ => None,
1432 };
1433 if let Some((common_name, common_ty, common_value)) = first_let {
1434 let all_same = alts.iter().all(|a| {
1435 matches!(
1436 & a.body, LcnfExpr::Let { name, ty, value, .. } if * name ==
1437 common_name && * ty == common_ty && * value == common_value
1438 )
1439 });
1440 if all_same {
1441 let hoisted_id = match &alts[0].body {
1442 LcnfExpr::Let { id, .. } => *id,
1443 _ => unreachable!(),
1444 };
1445 let new_alts: Vec<LcnfAlt> = alts
1446 .into_iter()
1447 .map(|a| {
1448 let inner_body = match a.body {
1449 LcnfExpr::Let { id, body, .. } => {
1450 if id != hoisted_id {
1451 let mut subst = Substitution::new();
1452 subst.insert(id, LcnfArg::Var(hoisted_id));
1453 substitute_expr(&body, &subst)
1454 } else {
1455 *body
1456 }
1457 }
1458 other => other,
1459 };
1460 LcnfAlt {
1461 ctor_name: a.ctor_name,
1462 ctor_tag: a.ctor_tag,
1463 params: a.params,
1464 body: inner_body,
1465 }
1466 })
1467 .collect();
1468 return LcnfExpr::Let {
1469 id: hoisted_id,
1470 name: common_name,
1471 ty: common_ty,
1472 value: common_value,
1473 body: Box::new(LcnfExpr::Case {
1474 scrutinee,
1475 scrutinee_ty,
1476 alts: new_alts,
1477 default: None,
1478 }),
1479 };
1480 }
1481 }
1482 LcnfExpr::Case {
1483 scrutinee,
1484 scrutinee_ty,
1485 alts,
1486 default,
1487 }
1488 }
1489 other => other,
1490 }
1491}