1use crate::lcnf::*;
6use std::collections::HashMap;
7
8use super::types::{
9 ConstantFoldReport, CopyProp, CopyPropConfig, DeadBindingElim, DeadBindingReport, InlineConfig,
10 InlineReport, InliningPass, InterferenceGraph, OptPipeline, PassKind, RegisterCoalescingHint,
11 UsedVars,
12};
13
14#[allow(dead_code)]
19pub fn collect_coalescing_hints(
20 copies: &[(LcnfVarId, LcnfVarId)],
21 ig: &InterferenceGraph,
22) -> Vec<RegisterCoalescingHint> {
23 let mut hints = Vec::new();
24 for &(src, dst) in copies {
25 let is_safe = !ig.interfere(src, dst);
26 let benefit = if is_safe { 10 } else { 1 };
27 hints.push(RegisterCoalescingHint::new(src, dst, is_safe, benefit));
28 }
29 hints.sort_by_key(|b| std::cmp::Reverse(b.benefit));
30 hints
31}
32#[cfg(test)]
33mod tests {
34 use super::*;
35 use crate::lcnf::{
36 LcnfAlt, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfParam, LcnfType, LcnfVarId,
37 };
38 pub(super) fn make_decl(body: LcnfExpr) -> LcnfFunDecl {
39 LcnfFunDecl {
40 name: "test_fn".to_string(),
41 original_name: None,
42 params: vec![],
43 ret_type: LcnfType::Nat,
44 body,
45 is_recursive: false,
46 is_lifted: false,
47 inline_cost: 1,
48 }
49 }
50 #[test]
52 pub(super) fn test_simple_fvar_copy() {
53 let body = LcnfExpr::Let {
54 id: LcnfVarId(1),
55 name: "x".to_string(),
56 ty: LcnfType::Nat,
57 value: LcnfLetValue::FVar(LcnfVarId(0)),
58 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1)))),
59 };
60 let mut decl = make_decl(body);
61 let mut pass = CopyProp::new(CopyPropConfig::default());
62 pass.run(&mut decl);
63 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))));
64 assert_eq!(pass.report().copies_eliminated, 1);
65 }
66 #[test]
68 pub(super) fn test_literal_fold_enabled() {
69 let body = LcnfExpr::Let {
70 id: LcnfVarId(0),
71 name: "x".to_string(),
72 ty: LcnfType::Nat,
73 value: LcnfLetValue::Lit(LcnfLit::Nat(42)),
74 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
75 };
76 let mut decl = make_decl(body);
77 let mut pass = CopyProp::new(CopyPropConfig {
78 fold_literals: true,
79 ..Default::default()
80 });
81 pass.run(&mut decl);
82 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42))));
83 assert_eq!(pass.report().copies_eliminated, 1);
84 }
85 #[test]
87 pub(super) fn test_literal_fold_disabled() {
88 let body = LcnfExpr::Let {
89 id: LcnfVarId(0),
90 name: "x".to_string(),
91 ty: LcnfType::Nat,
92 value: LcnfLetValue::Lit(LcnfLit::Nat(7)),
93 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
94 };
95 let mut decl = make_decl(body);
96 let mut pass = CopyProp::new(CopyPropConfig {
97 fold_literals: false,
98 ..Default::default()
99 });
100 pass.run(&mut decl);
101 assert!(matches!(decl.body, LcnfExpr::Let { .. }));
102 assert_eq!(pass.report().copies_eliminated, 0);
103 }
104 #[test]
106 pub(super) fn test_transitive_chain() {
107 let body = LcnfExpr::Let {
108 id: LcnfVarId(1),
109 name: "b".to_string(),
110 ty: LcnfType::Nat,
111 value: LcnfLetValue::FVar(LcnfVarId(0)),
112 body: Box::new(LcnfExpr::Let {
113 id: LcnfVarId(2),
114 name: "a".to_string(),
115 ty: LcnfType::Nat,
116 value: LcnfLetValue::FVar(LcnfVarId(1)),
117 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(2)))),
118 }),
119 };
120 let mut decl = make_decl(body);
121 let mut pass = CopyProp::new(CopyPropConfig::default());
122 pass.run(&mut decl);
123 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))));
124 assert_eq!(pass.report().copies_eliminated, 2);
125 assert_eq!(pass.report().chains_followed, 1);
126 }
127 #[test]
129 pub(super) fn test_chain_depth_limit() {
130 let body = LcnfExpr::Let {
131 id: LcnfVarId(1),
132 name: "b".to_string(),
133 ty: LcnfType::Nat,
134 value: LcnfLetValue::FVar(LcnfVarId(0)),
135 body: Box::new(LcnfExpr::Let {
136 id: LcnfVarId(2),
137 name: "a".to_string(),
138 ty: LcnfType::Nat,
139 value: LcnfLetValue::FVar(LcnfVarId(1)),
140 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(2)))),
141 }),
142 };
143 let mut decl = make_decl(body);
144 let mut pass = CopyProp::new(CopyPropConfig {
145 max_chain_depth: 1,
146 fold_literals: true,
147 });
148 pass.run(&mut decl);
149 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))));
150 }
151 #[test]
153 pub(super) fn test_app_not_propagated() {
154 let body = LcnfExpr::Let {
155 id: LcnfVarId(1),
156 name: "r".to_string(),
157 ty: LcnfType::Nat,
158 value: LcnfLetValue::App(LcnfArg::Var(LcnfVarId(0)), vec![]),
159 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1)))),
160 };
161 let mut decl = make_decl(body);
162 let mut pass = CopyProp::default_pass();
163 pass.run(&mut decl);
164 assert!(matches!(decl.body, LcnfExpr::Let { .. }));
165 assert_eq!(pass.report().copies_eliminated, 0);
166 }
167 #[test]
169 pub(super) fn test_copy_in_case_branches() {
170 let case_expr = LcnfExpr::Case {
171 scrutinee: LcnfVarId(0),
172 scrutinee_ty: LcnfType::Object,
173 alts: vec![LcnfAlt {
174 ctor_name: "A".to_string(),
175 ctor_tag: 0,
176 params: vec![LcnfParam {
177 id: LcnfVarId(1),
178 name: "p".to_string(),
179 ty: LcnfType::Nat,
180 erased: false,
181 borrowed: false,
182 }],
183 body: LcnfExpr::Let {
184 id: LcnfVarId(2),
185 name: "q".to_string(),
186 ty: LcnfType::Nat,
187 value: LcnfLetValue::FVar(LcnfVarId(1)),
188 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(2)))),
189 },
190 }],
191 default: Some(Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))))),
192 };
193 let mut decl = make_decl(case_expr);
194 let mut pass = CopyProp::default_pass();
195 pass.run(&mut decl);
196 match &decl.body {
197 LcnfExpr::Case { alts, .. } => {
198 assert_eq!(alts.len(), 1);
199 assert_eq!(alts[0].body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1))));
200 }
201 _ => panic!("Expected Case"),
202 }
203 assert!(pass.report().copies_eliminated >= 1);
204 }
205 #[test]
207 pub(super) fn test_erased_copy_propagated() {
208 let body = LcnfExpr::Let {
209 id: LcnfVarId(0),
210 name: "e".to_string(),
211 ty: LcnfType::Erased,
212 value: LcnfLetValue::Erased,
213 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
214 };
215 let mut decl = make_decl(body);
216 let mut pass = CopyProp::default_pass();
217 pass.run(&mut decl);
218 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Erased));
219 assert_eq!(pass.report().copies_eliminated, 1);
220 }
221}
222#[allow(dead_code)]
223pub(super) fn collect_used(expr: &LcnfExpr, used: &mut UsedVars) {
224 match expr {
225 LcnfExpr::Return(arg) => collect_used_arg(arg, used),
226 LcnfExpr::Let { value, body, .. } => {
227 collect_used_value(value, used);
228 collect_used(body, used);
229 }
230 LcnfExpr::Case {
231 scrutinee,
232 scrutinee_ty: _,
233 alts,
234 default,
235 ..
236 } => {
237 used.vars.insert(*scrutinee);
238 for alt in alts {
239 collect_used(&alt.body, used);
240 }
241 if let Some(d) = default {
242 collect_used(d, used);
243 }
244 }
245 LcnfExpr::TailCall(callee, args) => {
246 collect_used_arg(callee, used);
247 for a in args {
248 collect_used_arg(a, used);
249 }
250 }
251 LcnfExpr::Unreachable => {}
252 }
253}
254#[allow(dead_code)]
255pub(super) fn collect_used_arg(arg: &LcnfArg, used: &mut UsedVars) {
256 if let LcnfArg::Var(id) = arg {
257 used.vars.insert(*id);
258 }
259}
260#[allow(dead_code)]
261pub(super) fn collect_used_value(val: &LcnfLetValue, used: &mut UsedVars) {
262 match val {
263 LcnfLetValue::FVar(id) => {
264 used.vars.insert(*id);
265 }
266 LcnfLetValue::App(callee, args) => {
267 collect_used_arg(callee, used);
268 for a in args {
269 collect_used_arg(a, used);
270 }
271 }
272 LcnfLetValue::Ctor(_, _, args) => {
273 for a in args {
274 collect_used_arg(a, used);
275 }
276 }
277 LcnfLetValue::Proj(_, _, id) => {
278 used.vars.insert(*id);
279 }
280 _ => {}
281 }
282}
283#[allow(dead_code)]
285pub const DEFAULT_INLINE_THRESHOLD: u32 = 5;
286#[allow(dead_code)]
288pub fn count_let_bindings(expr: &LcnfExpr) -> usize {
289 match expr {
290 LcnfExpr::Let { body, .. } => 1 + count_let_bindings(body),
291 LcnfExpr::Case { alts, default, .. } => {
292 let alt_sum: usize = alts.iter().map(|a| count_let_bindings(&a.body)).sum();
293 let def_sum = default.as_ref().map_or(0, |d| count_let_bindings(d));
294 alt_sum + def_sum
295 }
296 _ => 0,
297 }
298}
299#[allow(dead_code)]
301pub fn expr_depth(expr: &LcnfExpr) -> usize {
302 match expr {
303 LcnfExpr::Let { body, .. } => 1 + expr_depth(body),
304 LcnfExpr::Case { alts, default, .. } => {
305 let alt_max = alts.iter().map(|a| expr_depth(&a.body)).max().unwrap_or(0);
306 let def_max = default.as_ref().map_or(0, |d| expr_depth(d));
307 1 + alt_max.max(def_max)
308 }
309 LcnfExpr::TailCall(_, args) => args.len(),
310 _ => 0,
311 }
312}
313#[allow(dead_code)]
315pub fn has_tail_call_to(expr: &LcnfExpr, target: LcnfVarId) -> bool {
316 match expr {
317 LcnfExpr::TailCall(LcnfArg::Var(id), _) => *id == target,
318 LcnfExpr::Let { body, .. } => has_tail_call_to(body, target),
319 LcnfExpr::Case { alts, default, .. } => {
320 alts.iter().any(|a| has_tail_call_to(&a.body, target))
321 || default
322 .as_ref()
323 .is_some_and(|d| has_tail_call_to(d, target))
324 }
325 _ => false,
326 }
327}
328#[allow(dead_code)]
330pub fn collect_bound_vars(expr: &LcnfExpr, out: &mut Vec<LcnfVarId>) {
331 match expr {
332 LcnfExpr::Let { id, body, .. } => {
333 out.push(*id);
334 collect_bound_vars(body, out);
335 }
336 LcnfExpr::Case { alts, default, .. } => {
337 for alt in alts {
338 collect_bound_vars(&alt.body, out);
339 }
340 if let Some(d) = default {
341 collect_bound_vars(d, out);
342 }
343 }
344 _ => {}
345 }
346}
347
348pub(super) fn max_var_id_in_expr(expr: &LcnfExpr) -> u64 {
350 match expr {
351 LcnfExpr::Let {
352 id, value, body, ..
353 } => {
354 let mut m = id.0;
355 m = m.max(max_var_id_in_let_value(value));
356 m = m.max(max_var_id_in_expr(body));
357 m
358 }
359 LcnfExpr::Case {
360 scrutinee,
361 alts,
362 default,
363 ..
364 } => {
365 let mut m = scrutinee.0;
366 for alt in alts {
367 for p in &alt.params {
368 m = m.max(p.id.0);
369 }
370 m = m.max(max_var_id_in_expr(&alt.body));
371 }
372 if let Some(d) = default {
373 m = m.max(max_var_id_in_expr(d));
374 }
375 m
376 }
377 LcnfExpr::Return(arg) => max_var_id_in_arg(arg),
378 LcnfExpr::TailCall(func, args) => {
379 let mut m = max_var_id_in_arg(func);
380 for a in args {
381 m = m.max(max_var_id_in_arg(a));
382 }
383 m
384 }
385 LcnfExpr::Unreachable => 0,
386 }
387}
388
389pub(super) fn max_var_id_in_arg(arg: &LcnfArg) -> u64 {
390 if let LcnfArg::Var(id) = arg {
391 id.0
392 } else {
393 0
394 }
395}
396
397pub(super) fn max_var_id_in_let_value(val: &LcnfLetValue) -> u64 {
398 match val {
399 LcnfLetValue::App(func, args) => {
400 let mut m = max_var_id_in_arg(func);
401 for a in args {
402 m = m.max(max_var_id_in_arg(a));
403 }
404 m
405 }
406 LcnfLetValue::Ctor(_, _, args) | LcnfLetValue::Reuse(_, _, _, args) => {
407 args.iter().map(max_var_id_in_arg).max().unwrap_or(0)
408 }
409 LcnfLetValue::Proj(_, _, id) | LcnfLetValue::Reset(id) | LcnfLetValue::FVar(id) => id.0,
410 LcnfLetValue::Lit(_) | LcnfLetValue::Erased => 0,
411 }
412}
413
414pub(super) fn offset_var_ids(expr: LcnfExpr, offset: u64) -> LcnfExpr {
416 match expr {
417 LcnfExpr::Let {
418 id,
419 name,
420 ty,
421 value,
422 body,
423 } => LcnfExpr::Let {
424 id: LcnfVarId(id.0 + offset),
425 name,
426 ty,
427 value: offset_var_ids_in_let_value(value, offset),
428 body: Box::new(offset_var_ids(*body, offset)),
429 },
430 LcnfExpr::Case {
431 scrutinee,
432 scrutinee_ty,
433 alts,
434 default,
435 } => LcnfExpr::Case {
436 scrutinee: LcnfVarId(scrutinee.0 + offset),
437 scrutinee_ty,
438 alts: alts
439 .into_iter()
440 .map(|alt| LcnfAlt {
441 ctor_name: alt.ctor_name,
442 ctor_tag: alt.ctor_tag,
443 params: alt
444 .params
445 .into_iter()
446 .map(|p| LcnfParam {
447 id: LcnfVarId(p.id.0 + offset),
448 ..p
449 })
450 .collect(),
451 body: offset_var_ids(alt.body, offset),
452 })
453 .collect(),
454 default: default.map(|d| Box::new(offset_var_ids(*d, offset))),
455 },
456 LcnfExpr::Return(arg) => LcnfExpr::Return(offset_var_ids_in_arg(arg, offset)),
457 LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(
458 offset_var_ids_in_arg(func, offset),
459 args.into_iter()
460 .map(|a| offset_var_ids_in_arg(a, offset))
461 .collect(),
462 ),
463 LcnfExpr::Unreachable => LcnfExpr::Unreachable,
464 }
465}
466
467pub(super) fn offset_var_ids_in_arg(arg: LcnfArg, offset: u64) -> LcnfArg {
468 match arg {
469 LcnfArg::Var(id) => LcnfArg::Var(LcnfVarId(id.0 + offset)),
470 other => other,
471 }
472}
473
474pub(super) fn offset_var_ids_in_let_value(val: LcnfLetValue, offset: u64) -> LcnfLetValue {
475 match val {
476 LcnfLetValue::App(func, args) => LcnfLetValue::App(
477 offset_var_ids_in_arg(func, offset),
478 args.into_iter()
479 .map(|a| offset_var_ids_in_arg(a, offset))
480 .collect(),
481 ),
482 LcnfLetValue::Ctor(name, tag, args) => LcnfLetValue::Ctor(
483 name,
484 tag,
485 args.into_iter()
486 .map(|a| offset_var_ids_in_arg(a, offset))
487 .collect(),
488 ),
489 LcnfLetValue::Reuse(slot, name, tag, args) => LcnfLetValue::Reuse(
490 LcnfVarId(slot.0 + offset),
491 name,
492 tag,
493 args.into_iter()
494 .map(|a| offset_var_ids_in_arg(a, offset))
495 .collect(),
496 ),
497 LcnfLetValue::Proj(name, idx, var) => {
498 LcnfLetValue::Proj(name, idx, LcnfVarId(var.0 + offset))
499 }
500 LcnfLetValue::Reset(var) => LcnfLetValue::Reset(LcnfVarId(var.0 + offset)),
501 LcnfLetValue::FVar(id) => LcnfLetValue::FVar(LcnfVarId(id.0 + offset)),
502 other @ (LcnfLetValue::Lit(_) | LcnfLetValue::Erased) => other,
503 }
504}
505
506pub(super) fn inline_substitute_params(
511 body: &LcnfExpr,
512 params: &[LcnfParam],
513 args: &[LcnfArg],
514) -> LcnfExpr {
515 let subst: HashMap<LcnfVarId, LcnfArg> = params
516 .iter()
517 .zip(args.iter())
518 .map(|(p, a)| (p.id, a.clone()))
519 .collect();
520 inline_subst_expr(body, &subst)
521}
522
523pub(super) fn inline_subst_expr(expr: &LcnfExpr, subst: &HashMap<LcnfVarId, LcnfArg>) -> LcnfExpr {
524 match expr {
525 LcnfExpr::Let {
526 id,
527 name,
528 ty,
529 value,
530 body,
531 } => LcnfExpr::Let {
532 id: *id,
533 name: name.clone(),
534 ty: ty.clone(),
535 value: inline_subst_let_value(value, subst),
536 body: Box::new(inline_subst_expr(body, subst)),
537 },
538 LcnfExpr::Case {
539 scrutinee,
540 scrutinee_ty,
541 alts,
542 default,
543 } => {
544 let new_scrutinee = match subst.get(scrutinee) {
545 Some(LcnfArg::Var(v)) => *v,
546 _ => *scrutinee,
547 };
548 LcnfExpr::Case {
549 scrutinee: new_scrutinee,
550 scrutinee_ty: scrutinee_ty.clone(),
551 alts: alts
552 .iter()
553 .map(|alt| LcnfAlt {
554 ctor_name: alt.ctor_name.clone(),
555 ctor_tag: alt.ctor_tag,
556 params: alt.params.clone(),
557 body: inline_subst_expr(&alt.body, subst),
558 })
559 .collect(),
560 default: default
561 .as_ref()
562 .map(|d| Box::new(inline_subst_expr(d, subst))),
563 }
564 }
565 LcnfExpr::Return(arg) => LcnfExpr::Return(inline_subst_arg(arg, subst)),
566 LcnfExpr::TailCall(func, args) => LcnfExpr::TailCall(
567 inline_subst_arg(func, subst),
568 args.iter().map(|a| inline_subst_arg(a, subst)).collect(),
569 ),
570 LcnfExpr::Unreachable => LcnfExpr::Unreachable,
571 }
572}
573
574pub(super) fn inline_subst_let_value(
575 val: &LcnfLetValue,
576 subst: &HashMap<LcnfVarId, LcnfArg>,
577) -> LcnfLetValue {
578 match val {
579 LcnfLetValue::App(func, args) => LcnfLetValue::App(
580 inline_subst_arg(func, subst),
581 args.iter().map(|a| inline_subst_arg(a, subst)).collect(),
582 ),
583 LcnfLetValue::Ctor(name, tag, args) => LcnfLetValue::Ctor(
584 name.clone(),
585 *tag,
586 args.iter().map(|a| inline_subst_arg(a, subst)).collect(),
587 ),
588 LcnfLetValue::Reuse(slot, name, tag, args) => LcnfLetValue::Reuse(
589 *slot,
590 name.clone(),
591 *tag,
592 args.iter().map(|a| inline_subst_arg(a, subst)).collect(),
593 ),
594 LcnfLetValue::Proj(name, idx, var) => LcnfLetValue::Proj(name.clone(), *idx, *var),
595 LcnfLetValue::Reset(var) => LcnfLetValue::Reset(*var),
596 LcnfLetValue::Lit(lit) => LcnfLetValue::Lit(lit.clone()),
597 LcnfLetValue::Erased => LcnfLetValue::Erased,
598 LcnfLetValue::FVar(id) => LcnfLetValue::FVar(*id),
599 }
600}
601
602pub(super) fn inline_subst_arg(arg: &LcnfArg, subst: &HashMap<LcnfVarId, LcnfArg>) -> LcnfArg {
603 match arg {
604 LcnfArg::Var(id) => subst.get(id).cloned().unwrap_or(LcnfArg::Var(*id)),
605 LcnfArg::Lit(lit) => LcnfArg::Lit(lit.clone()),
606 LcnfArg::Erased => LcnfArg::Erased,
607 LcnfArg::Type(ty) => LcnfArg::Type(ty.clone()),
608 }
609}
610
611pub(super) fn splice_inline_result(
618 inlined: LcnfExpr,
619 outer_id: LcnfVarId,
620 outer_name: &str,
621 outer_ty: &LcnfType,
622 continuation: LcnfExpr,
623) -> LcnfExpr {
624 splice_inline_result_inner(inlined, outer_id, outer_name, outer_ty, &continuation)
625}
626
627fn splice_inline_result_inner(
628 inlined: LcnfExpr,
629 outer_id: LcnfVarId,
630 outer_name: &str,
631 outer_ty: &LcnfType,
632 continuation: &LcnfExpr,
633) -> LcnfExpr {
634 match inlined {
635 LcnfExpr::Return(val) => {
636 let let_val = match val {
638 LcnfArg::Var(id) => LcnfLetValue::FVar(id),
639 LcnfArg::Lit(lit) => LcnfLetValue::Lit(lit),
640 LcnfArg::Erased => LcnfLetValue::Erased,
641 LcnfArg::Type(_) => LcnfLetValue::Erased,
642 };
643 LcnfExpr::Let {
644 id: outer_id,
645 name: outer_name.to_string(),
646 ty: outer_ty.clone(),
647 value: let_val,
648 body: Box::new(continuation.clone()),
649 }
650 }
651 LcnfExpr::TailCall(_, _) => inlined,
652 LcnfExpr::Unreachable => LcnfExpr::Unreachable,
653 LcnfExpr::Let {
654 id,
655 name,
656 ty,
657 value,
658 body,
659 } => LcnfExpr::Let {
660 id,
661 name,
662 ty,
663 value,
664 body: Box::new(splice_inline_result_inner(
665 *body,
666 outer_id,
667 outer_name,
668 outer_ty,
669 continuation,
670 )),
671 },
672 LcnfExpr::Case {
673 scrutinee,
674 scrutinee_ty,
675 alts,
676 default,
677 } => LcnfExpr::Case {
678 scrutinee,
679 scrutinee_ty,
680 alts: alts
681 .into_iter()
682 .map(|alt| LcnfAlt {
683 ctor_name: alt.ctor_name,
684 ctor_tag: alt.ctor_tag,
685 params: alt.params,
686 body: splice_inline_result_inner(
687 alt.body,
688 outer_id,
689 outer_name,
690 outer_ty,
691 continuation,
692 ),
693 })
694 .collect(),
695 default: default.map(|d| {
696 Box::new(splice_inline_result_inner(
697 *d,
698 outer_id,
699 outer_name,
700 outer_ty,
701 continuation,
702 ))
703 }),
704 },
705 }
706}
707
708pub(super) fn inline_expr_walk(
715 expr: LcnfExpr,
716 fn_map: &HashMap<String, LcnfFunDecl>,
717 config: &InlineConfig,
718 caller_max_id: u64,
719 id_counter: &mut u64,
720 inlines_performed: &mut usize,
721) -> LcnfExpr {
722 match expr {
723 LcnfExpr::Let {
724 id,
725 name,
726 ty,
727 value,
728 body,
729 } => {
730 if let LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str(ref callee_name)), ref args) = value
732 {
733 if let Some(callee_decl) = fn_map.get(callee_name) {
734 let should_inline = {
735 if callee_decl.is_recursive && !config.inline_recursive {
736 false
737 } else {
738 callee_decl.inline_cost <= config.threshold as usize
739 }
740 };
741 if should_inline && callee_decl.params.len() == args.len() {
742 let callee_max = max_var_id_in_expr(&callee_decl.body);
744 let offset = caller_max_id + *id_counter * (callee_max + 1) + 1;
745 *id_counter += 1;
746
747 let fresh_body = offset_var_ids(callee_decl.body.clone(), offset);
749 let fresh_params: Vec<LcnfParam> = callee_decl
750 .params
751 .iter()
752 .map(|p| LcnfParam {
753 id: LcnfVarId(p.id.0 + offset),
754 ..p.clone()
755 })
756 .collect();
757
758 let substituted =
760 inline_substitute_params(&fresh_body, &fresh_params, args);
761
762 let new_body = inline_expr_walk(
764 *body,
765 fn_map,
766 config,
767 caller_max_id,
768 id_counter,
769 inlines_performed,
770 );
771
772 let spliced = splice_inline_result(substituted, id, &name, &ty, new_body);
774
775 *inlines_performed += 1;
776 return spliced;
777 }
778 }
779 }
780
781 let new_body = inline_expr_walk(
784 *body,
785 fn_map,
786 config,
787 caller_max_id,
788 id_counter,
789 inlines_performed,
790 );
791 LcnfExpr::Let {
792 id,
793 name,
794 ty,
795 value,
796 body: Box::new(new_body),
797 }
798 }
799 LcnfExpr::Case {
800 scrutinee,
801 scrutinee_ty,
802 alts,
803 default,
804 } => {
805 let new_alts = alts
806 .into_iter()
807 .map(|alt| LcnfAlt {
808 ctor_name: alt.ctor_name,
809 ctor_tag: alt.ctor_tag,
810 params: alt.params,
811 body: inline_expr_walk(
812 alt.body,
813 fn_map,
814 config,
815 caller_max_id,
816 id_counter,
817 inlines_performed,
818 ),
819 })
820 .collect();
821 let new_default = default.map(|d| {
822 Box::new(inline_expr_walk(
823 *d,
824 fn_map,
825 config,
826 caller_max_id,
827 id_counter,
828 inlines_performed,
829 ))
830 });
831 LcnfExpr::Case {
832 scrutinee,
833 scrutinee_ty,
834 alts: new_alts,
835 default: new_default,
836 }
837 }
838 terminal @ (LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable) => {
840 terminal
841 }
842 }
843}
844
845#[cfg(test)]
846mod tests_extended {
847 use super::*;
848 pub(super) fn make_var(n: u32) -> LcnfVarId {
849 LcnfVarId(u64::from(n))
850 }
851 pub(super) fn make_simple_decl(body: LcnfExpr) -> LcnfFunDecl {
852 LcnfFunDecl {
853 name: "test_fn".to_string(),
854 original_name: None,
855 params: vec![],
856 ret_type: LcnfType::Nat,
857 body,
858 is_recursive: false,
859 is_lifted: false,
860 inline_cost: 1,
861 }
862 }
863 #[test]
864 pub(super) fn test_dead_binding_removal() {
865 let body = LcnfExpr::Let {
866 id: LcnfVarId(99),
867 name: "x".to_string(),
868 ty: LcnfType::Nat,
869 value: LcnfLetValue::Lit(LcnfLit::Nat(42)),
870 body: Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
871 };
872 let mut decl = make_simple_decl(body);
873 let mut pass = DeadBindingElim::default_pass();
874 pass.run(&mut decl);
875 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))));
876 assert!(pass.report().bindings_removed >= 0);
877 }
878 #[test]
879 pub(super) fn test_count_let_bindings() {
880 let body = LcnfExpr::Let {
881 id: LcnfVarId(0),
882 name: "a".to_string(),
883 ty: LcnfType::Nat,
884 value: LcnfLetValue::Lit(LcnfLit::Nat(1)),
885 body: Box::new(LcnfExpr::Let {
886 id: LcnfVarId(1),
887 name: "b".to_string(),
888 ty: LcnfType::Nat,
889 value: LcnfLetValue::Lit(LcnfLit::Nat(2)),
890 body: Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
891 }),
892 };
893 assert_eq!(count_let_bindings(&body), 2);
894 }
895 #[test]
896 pub(super) fn test_expr_depth() {
897 let body = LcnfExpr::Let {
898 id: LcnfVarId(0),
899 name: "a".to_string(),
900 ty: LcnfType::Nat,
901 value: LcnfLetValue::Lit(LcnfLit::Nat(0)),
902 body: Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
903 };
904 assert_eq!(expr_depth(&body), 1);
905 }
906 #[test]
907 pub(super) fn test_has_tail_call_to() {
908 let target = make_var(7);
909 let body = LcnfExpr::TailCall(LcnfArg::Var(target), vec![]);
910 assert!(has_tail_call_to(&body, target));
911 assert!(!has_tail_call_to(&body, make_var(8)));
912 }
913 #[test]
914 pub(super) fn test_collect_bound_vars() {
915 let body = LcnfExpr::Let {
916 id: LcnfVarId(5),
917 name: "x".to_string(),
918 ty: LcnfType::Nat,
919 value: LcnfLetValue::Lit(LcnfLit::Nat(0)),
920 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(5)))),
921 };
922 let mut bound = vec![];
923 collect_bound_vars(&body, &mut bound);
924 assert_eq!(bound, vec![LcnfVarId(5)]);
925 }
926 #[test]
927 pub(super) fn test_opt_pipeline_default() {
928 let body = LcnfExpr::Let {
929 id: LcnfVarId(0),
930 name: "x".to_string(),
931 ty: LcnfType::Nat,
932 value: LcnfLetValue::FVar(LcnfVarId(1)),
933 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
934 };
935 let mut decl = make_simple_decl(body);
936 decl.params.push(LcnfParam {
937 id: LcnfVarId(1),
938 name: "n".to_string(),
939 ty: LcnfType::Nat,
940 erased: false,
941 borrowed: false,
942 });
943 let mut pipeline = OptPipeline::new();
944 let result = pipeline.run(&mut decl);
945 assert!(result.copy_prop.copies_eliminated >= 1);
946 }
947 #[test]
948 pub(super) fn test_pass_kind_display() {
949 assert_eq!(PassKind::CopyProp.to_string(), "CopyProp");
950 assert_eq!(PassKind::DeadBinding.to_string(), "DeadBinding");
951 assert_eq!(PassKind::ConstantFold.to_string(), "ConstantFold");
952 assert_eq!(PassKind::Inlining.to_string(), "Inlining");
953 }
954 #[test]
955 pub(super) fn test_inline_candidate() {
956 let pass = InliningPass::default_pass();
957 let cheap = LcnfFunDecl {
958 name: "cheap".to_string(),
959 original_name: None,
960 params: vec![],
961 ret_type: LcnfType::Nat,
962 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
963 is_recursive: false,
964 is_lifted: false,
965 inline_cost: 1,
966 };
967 let expensive = LcnfFunDecl {
968 inline_cost: 100,
969 name: "expensive".to_string(),
970 ..cheap.clone()
971 };
972 assert!(pass.is_inline_candidate(&cheap));
973 assert!(!pass.is_inline_candidate(&expensive));
974 }
975 #[test]
976 pub(super) fn test_dead_binding_config_display() {
977 let cfg = CopyPropConfig::default();
978 let s = format!("{}", cfg);
979 assert!(s.contains("CopyPropConfig"));
980 }
981 #[test]
982 pub(super) fn test_dead_binding_report_display() {
983 let r = DeadBindingReport {
984 bindings_removed: 3,
985 passes_run: 2,
986 };
987 let s = format!("{}", r);
988 assert!(s.contains("removed=3"));
989 assert!(s.contains("passes=2"));
990 }
991 #[test]
992 pub(super) fn test_constant_fold_report_display() {
993 let r = ConstantFoldReport { folds_performed: 7 };
994 let s = format!("{}", r);
995 assert!(s.contains("folds=7"));
996 }
997 #[test]
998 pub(super) fn test_inline_report_display() {
999 let r = InlineReport {
1000 inlines_performed: 2,
1001 functions_considered: 10,
1002 };
1003 let s = format!("{}", r);
1004 assert!(s.contains("inlined=2"));
1005 assert!(s.contains("considered=10"));
1006 }
1007}
1008
1009#[cfg(test)]
1010mod tests_inlining_pass {
1011 use super::super::types::{FnMap, InlineConfig, InlineReport, InliningPass};
1012 use super::*;
1013 use crate::lcnf::{
1014 LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfParam, LcnfType, LcnfVarId,
1015 };
1016
1017 fn make_decl(
1019 name: &str,
1020 params: Vec<LcnfParam>,
1021 body: LcnfExpr,
1022 inline_cost: usize,
1023 is_recursive: bool,
1024 ) -> LcnfFunDecl {
1025 LcnfFunDecl {
1026 name: name.to_string(),
1027 original_name: None,
1028 params,
1029 ret_type: LcnfType::Nat,
1030 body,
1031 is_recursive,
1032 is_lifted: false,
1033 inline_cost,
1034 }
1035 }
1036
1037 fn make_param(id: u64, name: &str) -> LcnfParam {
1038 LcnfParam {
1039 id: LcnfVarId(id),
1040 name: name.to_string(),
1041 ty: LcnfType::Nat,
1042 erased: false,
1043 borrowed: false,
1044 }
1045 }
1046
1047 #[test]
1052 fn inline_pass_simple_call() {
1053 let callee = make_decl(
1055 "trivial",
1056 vec![],
1057 LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42))),
1058 1,
1059 false,
1060 );
1061
1062 let caller_body = LcnfExpr::Let {
1064 id: LcnfVarId(100),
1065 name: "x".to_string(),
1066 ty: LcnfType::Nat,
1067 value: LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str("trivial".to_string())), vec![]),
1068 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(100)))),
1069 };
1070 let mut caller = make_decl("caller", vec![], caller_body, 5, false);
1071
1072 let mut fn_map: FnMap = FnMap::new();
1073 fn_map.insert("trivial".to_string(), callee);
1074
1075 let config = InlineConfig {
1076 threshold: 5,
1077 inline_recursive: false,
1078 };
1079 let mut pass = InliningPass::new(config);
1080 pass.run_with_context(&mut caller, &fn_map);
1081
1082 assert_eq!(pass.report().inlines_performed, 1);
1083 assert_eq!(pass.report().functions_considered, 1);
1084
1085 let expected = LcnfExpr::Let {
1088 id: LcnfVarId(100),
1089 name: "x".to_string(),
1090 ty: LcnfType::Nat,
1091 value: LcnfLetValue::Lit(LcnfLit::Nat(42)),
1092 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(100)))),
1093 };
1094 assert_eq!(caller.body, expected);
1095 }
1096
1097 #[test]
1100 fn inline_pass_above_threshold_not_inlined() {
1101 let callee = make_decl(
1102 "big_fn",
1103 vec![],
1104 LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
1105 100, false,
1107 );
1108
1109 let caller_body = LcnfExpr::Let {
1110 id: LcnfVarId(0),
1111 name: "r".to_string(),
1112 ty: LcnfType::Nat,
1113 value: LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str("big_fn".to_string())), vec![]),
1114 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
1115 };
1116 let mut caller = make_decl("caller", vec![], caller_body, 5, false);
1117
1118 let mut fn_map: FnMap = FnMap::new();
1119 fn_map.insert("big_fn".to_string(), callee);
1120
1121 let config = InlineConfig {
1122 threshold: 5,
1123 inline_recursive: false,
1124 };
1125 let mut pass = InliningPass::new(config);
1126 pass.run_with_context(&mut caller, &fn_map);
1127
1128 assert_eq!(pass.report().inlines_performed, 0);
1129 assert!(matches!(caller.body, LcnfExpr::Let { .. }));
1131 }
1132
1133 #[test]
1135 fn inline_pass_recursive_not_inlined_when_disabled() {
1136 let rec_callee = make_decl(
1137 "rec_fn",
1138 vec![],
1139 LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))),
1140 1,
1141 true, );
1143
1144 let caller_body = LcnfExpr::Let {
1145 id: LcnfVarId(0),
1146 name: "r".to_string(),
1147 ty: LcnfType::Nat,
1148 value: LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str("rec_fn".to_string())), vec![]),
1149 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
1150 };
1151 let mut caller = make_decl("caller", vec![], caller_body, 5, false);
1152
1153 let mut fn_map: FnMap = FnMap::new();
1154 fn_map.insert("rec_fn".to_string(), rec_callee);
1155
1156 let config = InlineConfig {
1157 threshold: 5,
1158 inline_recursive: false, };
1160 let mut pass = InliningPass::new(config);
1161 pass.run_with_context(&mut caller, &fn_map);
1162
1163 assert_eq!(pass.report().inlines_performed, 0);
1164 }
1165
1166 #[test]
1174 fn inline_pass_run_all_fixpoint() {
1175 let f = make_decl(
1176 "f",
1177 vec![],
1178 LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(7))),
1179 1,
1180 false,
1181 );
1182 let g_body = LcnfExpr::Let {
1183 id: LcnfVarId(50),
1184 name: "r".to_string(),
1185 ty: LcnfType::Nat,
1186 value: LcnfLetValue::App(LcnfArg::Lit(LcnfLit::Str("f".to_string())), vec![]),
1187 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(50)))),
1188 };
1189 let g = make_decl("g", vec![], g_body, 5, false);
1190
1191 let mut decls = vec![f, g];
1192 let config = InlineConfig {
1193 threshold: 5,
1194 inline_recursive: false,
1195 };
1196 let mut pass = InliningPass::new(config);
1197 pass.run_all(&mut decls);
1198
1199 assert_eq!(pass.report().inlines_performed, 1);
1201
1202 let g_decl = decls.iter().find(|d| d.name == "g").expect("g not found");
1204 let expected_g_body = LcnfExpr::Let {
1205 id: LcnfVarId(50),
1206 name: "r".to_string(),
1207 ty: LcnfType::Nat,
1208 value: LcnfLetValue::Lit(LcnfLit::Nat(7)),
1209 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(50)))),
1210 };
1211 assert_eq!(g_decl.body, expected_g_body);
1212 }
1213
1214 #[test]
1227 fn inline_pass_freshen_var_ids_no_collision() {
1228 let param_p1 = make_param(1, "p1");
1231 let callee_body = LcnfExpr::Let {
1232 id: LcnfVarId(0), name: "t".to_string(),
1234 ty: LcnfType::Object,
1235 value: LcnfLetValue::Ctor(
1236 "Pair".to_string(),
1237 0,
1238 vec![LcnfArg::Var(LcnfVarId(1))], ),
1240 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
1241 };
1242 let callee = make_decl("wrap", vec![param_p1.clone()], callee_body, 2, false);
1243
1244 let caller_body = LcnfExpr::Let {
1246 id: LcnfVarId(0), name: "x".to_string(),
1248 ty: LcnfType::Object,
1249 value: LcnfLetValue::App(
1250 LcnfArg::Lit(LcnfLit::Str("wrap".to_string())),
1251 vec![LcnfArg::Lit(LcnfLit::Nat(5))],
1252 ),
1253 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
1254 };
1255 let mut caller = make_decl("caller", vec![], caller_body, 5, false);
1256
1257 let mut fn_map: FnMap = FnMap::new();
1258 fn_map.insert("wrap".to_string(), callee);
1259
1260 let config = InlineConfig {
1261 threshold: 5,
1262 inline_recursive: false,
1263 };
1264 let mut pass = InliningPass::new(config);
1265 pass.run_with_context(&mut caller, &fn_map);
1266
1267 assert_eq!(pass.report().inlines_performed, 1, "expected one inline");
1268
1269 let mut bound = Vec::new();
1272 collect_bound_vars(&caller.body, &mut bound);
1273
1274 assert_eq!(
1278 bound.len(),
1279 2,
1280 "expected 2 bound vars after inlining, got {:?}",
1281 bound
1282 );
1283
1284 let ids: Vec<u64> = bound.iter().map(|v| v.0).collect();
1286 let unique_count = {
1287 let mut seen = std::collections::HashSet::new();
1288 ids.iter().filter(|&&id| seen.insert(id)).count()
1289 };
1290 assert_eq!(
1291 unique_count,
1292 bound.len(),
1293 "collision: bound var ids are not unique: {:?}",
1294 ids
1295 );
1296 }
1297
1298 #[test]
1305 fn inline_pass_with_param_substitution() {
1306 let param = make_param(10, "p0");
1307 let callee = make_decl(
1308 "identity",
1309 vec![param.clone()],
1310 LcnfExpr::Return(LcnfArg::Var(param.id)),
1311 1,
1312 false,
1313 );
1314
1315 let caller_body = LcnfExpr::Let {
1316 id: LcnfVarId(200),
1317 name: "x".to_string(),
1318 ty: LcnfType::Nat,
1319 value: LcnfLetValue::App(
1320 LcnfArg::Lit(LcnfLit::Str("identity".to_string())),
1321 vec![LcnfArg::Lit(LcnfLit::Nat(99))],
1322 ),
1323 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(200)))),
1324 };
1325 let mut caller = make_decl("caller", vec![], caller_body, 5, false);
1326
1327 let mut fn_map: FnMap = FnMap::new();
1328 fn_map.insert("identity".to_string(), callee);
1329
1330 let config = InlineConfig {
1331 threshold: 5,
1332 inline_recursive: false,
1333 };
1334 let mut pass = InliningPass::new(config);
1335 pass.run_with_context(&mut caller, &fn_map);
1336
1337 assert_eq!(pass.report().inlines_performed, 1);
1338
1339 let expected = LcnfExpr::Let {
1341 id: LcnfVarId(200),
1342 name: "x".to_string(),
1343 ty: LcnfType::Nat,
1344 value: LcnfLetValue::Lit(LcnfLit::Nat(99)),
1345 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(200)))),
1346 };
1347 assert_eq!(caller.body, expected);
1348 }
1349}