1use crate::lcnf::*;
6
7use super::types::{
8 ConstantFoldReport, CopyProp, CopyPropConfig, DeadBindingElim, DeadBindingReport, InlineReport,
9 InliningPass, InterferenceGraph, OptPipeline, PassKind, RegisterCoalescingHint, UsedVars,
10};
11
12#[allow(dead_code)]
17pub fn collect_coalescing_hints(
18 copies: &[(LcnfVarId, LcnfVarId)],
19 ig: &InterferenceGraph,
20) -> Vec<RegisterCoalescingHint> {
21 let mut hints = Vec::new();
22 for &(src, dst) in copies {
23 let is_safe = !ig.interfere(src, dst);
24 let benefit = if is_safe { 10 } else { 1 };
25 hints.push(RegisterCoalescingHint::new(src, dst, is_safe, benefit));
26 }
27 hints.sort_by(|a, b| b.benefit.cmp(&a.benefit));
28 hints
29}
30#[cfg(test)]
31mod tests {
32 use super::*;
33 use crate::lcnf::{
34 LcnfAlt, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfParam, LcnfType, LcnfVarId,
35 };
36 pub(super) fn make_decl(body: LcnfExpr) -> LcnfFunDecl {
37 LcnfFunDecl {
38 name: "test_fn".to_string(),
39 original_name: None,
40 params: vec![],
41 ret_type: LcnfType::Nat,
42 body,
43 is_recursive: false,
44 is_lifted: false,
45 inline_cost: 1,
46 }
47 }
48 #[test]
50 pub(super) fn test_simple_fvar_copy() {
51 let body = LcnfExpr::Let {
52 id: LcnfVarId(1),
53 name: "x".to_string(),
54 ty: LcnfType::Nat,
55 value: LcnfLetValue::FVar(LcnfVarId(0)),
56 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1)))),
57 };
58 let mut decl = make_decl(body);
59 let mut pass = CopyProp::new(CopyPropConfig::default());
60 pass.run(&mut decl);
61 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))));
62 assert_eq!(pass.report().copies_eliminated, 1);
63 }
64 #[test]
66 pub(super) fn test_literal_fold_enabled() {
67 let body = LcnfExpr::Let {
68 id: LcnfVarId(0),
69 name: "x".to_string(),
70 ty: LcnfType::Nat,
71 value: LcnfLetValue::Lit(LcnfLit::Nat(42)),
72 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
73 };
74 let mut decl = make_decl(body);
75 let mut pass = CopyProp::new(CopyPropConfig {
76 fold_literals: true,
77 ..Default::default()
78 });
79 pass.run(&mut decl);
80 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(42))));
81 assert_eq!(pass.report().copies_eliminated, 1);
82 }
83 #[test]
85 pub(super) fn test_literal_fold_disabled() {
86 let body = LcnfExpr::Let {
87 id: LcnfVarId(0),
88 name: "x".to_string(),
89 ty: LcnfType::Nat,
90 value: LcnfLetValue::Lit(LcnfLit::Nat(7)),
91 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
92 };
93 let mut decl = make_decl(body);
94 let mut pass = CopyProp::new(CopyPropConfig {
95 fold_literals: false,
96 ..Default::default()
97 });
98 pass.run(&mut decl);
99 assert!(matches!(decl.body, LcnfExpr::Let { .. }));
100 assert_eq!(pass.report().copies_eliminated, 0);
101 }
102 #[test]
104 pub(super) fn test_transitive_chain() {
105 let body = LcnfExpr::Let {
106 id: LcnfVarId(1),
107 name: "b".to_string(),
108 ty: LcnfType::Nat,
109 value: LcnfLetValue::FVar(LcnfVarId(0)),
110 body: Box::new(LcnfExpr::Let {
111 id: LcnfVarId(2),
112 name: "a".to_string(),
113 ty: LcnfType::Nat,
114 value: LcnfLetValue::FVar(LcnfVarId(1)),
115 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(2)))),
116 }),
117 };
118 let mut decl = make_decl(body);
119 let mut pass = CopyProp::new(CopyPropConfig::default());
120 pass.run(&mut decl);
121 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))));
122 assert_eq!(pass.report().copies_eliminated, 2);
123 assert_eq!(pass.report().chains_followed, 1);
124 }
125 #[test]
127 pub(super) fn test_chain_depth_limit() {
128 let body = LcnfExpr::Let {
129 id: LcnfVarId(1),
130 name: "b".to_string(),
131 ty: LcnfType::Nat,
132 value: LcnfLetValue::FVar(LcnfVarId(0)),
133 body: Box::new(LcnfExpr::Let {
134 id: LcnfVarId(2),
135 name: "a".to_string(),
136 ty: LcnfType::Nat,
137 value: LcnfLetValue::FVar(LcnfVarId(1)),
138 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(2)))),
139 }),
140 };
141 let mut decl = make_decl(body);
142 let mut pass = CopyProp::new(CopyPropConfig {
143 max_chain_depth: 1,
144 fold_literals: true,
145 });
146 pass.run(&mut decl);
147 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))));
148 }
149 #[test]
151 pub(super) fn test_app_not_propagated() {
152 let body = LcnfExpr::Let {
153 id: LcnfVarId(1),
154 name: "r".to_string(),
155 ty: LcnfType::Nat,
156 value: LcnfLetValue::App(LcnfArg::Var(LcnfVarId(0)), vec![]),
157 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1)))),
158 };
159 let mut decl = make_decl(body);
160 let mut pass = CopyProp::default_pass();
161 pass.run(&mut decl);
162 assert!(matches!(decl.body, LcnfExpr::Let { .. }));
163 assert_eq!(pass.report().copies_eliminated, 0);
164 }
165 #[test]
167 pub(super) fn test_copy_in_case_branches() {
168 let case_expr = LcnfExpr::Case {
169 scrutinee: LcnfVarId(0),
170 scrutinee_ty: LcnfType::Object,
171 alts: vec![LcnfAlt {
172 ctor_name: "A".to_string(),
173 ctor_tag: 0,
174 params: vec![LcnfParam {
175 id: LcnfVarId(1),
176 name: "p".to_string(),
177 ty: LcnfType::Nat,
178 erased: false,
179 borrowed: false,
180 }],
181 body: LcnfExpr::Let {
182 id: LcnfVarId(2),
183 name: "q".to_string(),
184 ty: LcnfType::Nat,
185 value: LcnfLetValue::FVar(LcnfVarId(1)),
186 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(2)))),
187 },
188 }],
189 default: Some(Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0))))),
190 };
191 let mut decl = make_decl(case_expr);
192 let mut pass = CopyProp::default_pass();
193 pass.run(&mut decl);
194 match &decl.body {
195 LcnfExpr::Case { alts, .. } => {
196 assert_eq!(alts.len(), 1);
197 assert_eq!(alts[0].body, LcnfExpr::Return(LcnfArg::Var(LcnfVarId(1))));
198 }
199 _ => panic!("Expected Case"),
200 }
201 assert!(pass.report().copies_eliminated >= 1);
202 }
203 #[test]
205 pub(super) fn test_erased_copy_propagated() {
206 let body = LcnfExpr::Let {
207 id: LcnfVarId(0),
208 name: "e".to_string(),
209 ty: LcnfType::Erased,
210 value: LcnfLetValue::Erased,
211 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
212 };
213 let mut decl = make_decl(body);
214 let mut pass = CopyProp::default_pass();
215 pass.run(&mut decl);
216 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Erased));
217 assert_eq!(pass.report().copies_eliminated, 1);
218 }
219}
220#[allow(dead_code)]
221pub(super) fn collect_used(expr: &LcnfExpr, used: &mut UsedVars) {
222 match expr {
223 LcnfExpr::Return(arg) => collect_used_arg(arg, used),
224 LcnfExpr::Let { value, body, .. } => {
225 collect_used_value(value, used);
226 collect_used(body, used);
227 }
228 LcnfExpr::Case {
229 scrutinee,
230 scrutinee_ty: _,
231 alts,
232 default,
233 ..
234 } => {
235 used.vars.insert(*scrutinee);
236 for alt in alts {
237 collect_used(&alt.body, used);
238 }
239 if let Some(d) = default {
240 collect_used(d, used);
241 }
242 }
243 LcnfExpr::TailCall(callee, args) => {
244 collect_used_arg(callee, used);
245 for a in args {
246 collect_used_arg(a, used);
247 }
248 }
249 LcnfExpr::Unreachable => {}
250 }
251}
252#[allow(dead_code)]
253pub(super) fn collect_used_arg(arg: &LcnfArg, used: &mut UsedVars) {
254 if let LcnfArg::Var(id) = arg {
255 used.vars.insert(*id);
256 }
257}
258#[allow(dead_code)]
259pub(super) fn collect_used_value(val: &LcnfLetValue, used: &mut UsedVars) {
260 match val {
261 LcnfLetValue::FVar(id) => {
262 used.vars.insert(*id);
263 }
264 LcnfLetValue::App(callee, args) => {
265 collect_used_arg(callee, used);
266 for a in args {
267 collect_used_arg(a, used);
268 }
269 }
270 LcnfLetValue::Ctor(_, _, args) => {
271 for a in args {
272 collect_used_arg(a, used);
273 }
274 }
275 LcnfLetValue::Proj(_, _, id) => {
276 used.vars.insert(*id);
277 }
278 _ => {}
279 }
280}
281#[allow(dead_code)]
283pub const DEFAULT_INLINE_THRESHOLD: u32 = 5;
284#[allow(dead_code)]
286pub fn count_let_bindings(expr: &LcnfExpr) -> usize {
287 match expr {
288 LcnfExpr::Let { body, .. } => 1 + count_let_bindings(body),
289 LcnfExpr::Case { alts, default, .. } => {
290 let alt_sum: usize = alts.iter().map(|a| count_let_bindings(&a.body)).sum();
291 let def_sum = default.as_ref().map_or(0, |d| count_let_bindings(d));
292 alt_sum + def_sum
293 }
294 _ => 0,
295 }
296}
297#[allow(dead_code)]
299pub fn expr_depth(expr: &LcnfExpr) -> usize {
300 match expr {
301 LcnfExpr::Let { body, .. } => 1 + expr_depth(body),
302 LcnfExpr::Case { alts, default, .. } => {
303 let alt_max = alts.iter().map(|a| expr_depth(&a.body)).max().unwrap_or(0);
304 let def_max = default.as_ref().map_or(0, |d| expr_depth(d));
305 1 + alt_max.max(def_max)
306 }
307 LcnfExpr::TailCall(_, args) => args.len(),
308 _ => 0,
309 }
310}
311#[allow(dead_code)]
313pub fn has_tail_call_to(expr: &LcnfExpr, target: LcnfVarId) -> bool {
314 match expr {
315 LcnfExpr::TailCall(LcnfArg::Var(id), _) => *id == target,
316 LcnfExpr::Let { body, .. } => has_tail_call_to(body, target),
317 LcnfExpr::Case { alts, default, .. } => {
318 alts.iter().any(|a| has_tail_call_to(&a.body, target))
319 || default
320 .as_ref()
321 .is_some_and(|d| has_tail_call_to(d, target))
322 }
323 _ => false,
324 }
325}
326#[allow(dead_code)]
328pub fn collect_bound_vars(expr: &LcnfExpr, out: &mut Vec<LcnfVarId>) {
329 match expr {
330 LcnfExpr::Let { id, body, .. } => {
331 out.push(*id);
332 collect_bound_vars(body, out);
333 }
334 LcnfExpr::Case { alts, default, .. } => {
335 for alt in alts {
336 collect_bound_vars(&alt.body, out);
337 }
338 if let Some(d) = default {
339 collect_bound_vars(d, out);
340 }
341 }
342 _ => {}
343 }
344}
345#[cfg(test)]
346mod tests_extended {
347 use super::*;
348 pub(super) fn make_var(n: u32) -> LcnfVarId {
349 LcnfVarId(u64::from(n))
350 }
351 pub(super) fn make_simple_decl(body: LcnfExpr) -> LcnfFunDecl {
352 LcnfFunDecl {
353 name: "test_fn".to_string(),
354 original_name: None,
355 params: vec![],
356 ret_type: LcnfType::Nat,
357 body,
358 is_recursive: false,
359 is_lifted: false,
360 inline_cost: 1,
361 }
362 }
363 #[test]
364 pub(super) fn test_dead_binding_removal() {
365 let body = LcnfExpr::Let {
366 id: LcnfVarId(99),
367 name: "x".to_string(),
368 ty: LcnfType::Nat,
369 value: LcnfLetValue::Lit(LcnfLit::Nat(42)),
370 body: Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
371 };
372 let mut decl = make_simple_decl(body);
373 let mut pass = DeadBindingElim::default_pass();
374 pass.run(&mut decl);
375 assert_eq!(decl.body, LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))));
376 assert!(pass.report().bindings_removed >= 0);
377 }
378 #[test]
379 pub(super) fn test_count_let_bindings() {
380 let body = LcnfExpr::Let {
381 id: LcnfVarId(0),
382 name: "a".to_string(),
383 ty: LcnfType::Nat,
384 value: LcnfLetValue::Lit(LcnfLit::Nat(1)),
385 body: Box::new(LcnfExpr::Let {
386 id: LcnfVarId(1),
387 name: "b".to_string(),
388 ty: LcnfType::Nat,
389 value: LcnfLetValue::Lit(LcnfLit::Nat(2)),
390 body: Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
391 }),
392 };
393 assert_eq!(count_let_bindings(&body), 2);
394 }
395 #[test]
396 pub(super) fn test_expr_depth() {
397 let body = LcnfExpr::Let {
398 id: LcnfVarId(0),
399 name: "a".to_string(),
400 ty: LcnfType::Nat,
401 value: LcnfLetValue::Lit(LcnfLit::Nat(0)),
402 body: Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)))),
403 };
404 assert_eq!(expr_depth(&body), 1);
405 }
406 #[test]
407 pub(super) fn test_has_tail_call_to() {
408 let target = make_var(7);
409 let body = LcnfExpr::TailCall(LcnfArg::Var(target), vec![]);
410 assert!(has_tail_call_to(&body, target));
411 assert!(!has_tail_call_to(&body, make_var(8)));
412 }
413 #[test]
414 pub(super) fn test_collect_bound_vars() {
415 let body = LcnfExpr::Let {
416 id: LcnfVarId(5),
417 name: "x".to_string(),
418 ty: LcnfType::Nat,
419 value: LcnfLetValue::Lit(LcnfLit::Nat(0)),
420 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(5)))),
421 };
422 let mut bound = vec![];
423 collect_bound_vars(&body, &mut bound);
424 assert_eq!(bound, vec![LcnfVarId(5)]);
425 }
426 #[test]
427 pub(super) fn test_opt_pipeline_default() {
428 let body = LcnfExpr::Let {
429 id: LcnfVarId(0),
430 name: "x".to_string(),
431 ty: LcnfType::Nat,
432 value: LcnfLetValue::FVar(LcnfVarId(1)),
433 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(0)))),
434 };
435 let mut decl = make_simple_decl(body);
436 decl.params.push(LcnfParam {
437 id: LcnfVarId(1),
438 name: "n".to_string(),
439 ty: LcnfType::Nat,
440 erased: false,
441 borrowed: false,
442 });
443 let mut pipeline = OptPipeline::new();
444 let result = pipeline.run(&mut decl);
445 assert!(result.copy_prop.copies_eliminated >= 1);
446 }
447 #[test]
448 pub(super) fn test_pass_kind_display() {
449 assert_eq!(PassKind::CopyProp.to_string(), "CopyProp");
450 assert_eq!(PassKind::DeadBinding.to_string(), "DeadBinding");
451 assert_eq!(PassKind::ConstantFold.to_string(), "ConstantFold");
452 assert_eq!(PassKind::Inlining.to_string(), "Inlining");
453 }
454 #[test]
455 pub(super) fn test_inline_candidate() {
456 let pass = InliningPass::default_pass();
457 let cheap = LcnfFunDecl {
458 name: "cheap".to_string(),
459 original_name: None,
460 params: vec![],
461 ret_type: LcnfType::Nat,
462 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
463 is_recursive: false,
464 is_lifted: false,
465 inline_cost: 1,
466 };
467 let expensive = LcnfFunDecl {
468 inline_cost: 100,
469 name: "expensive".to_string(),
470 ..cheap.clone()
471 };
472 assert!(pass.is_inline_candidate(&cheap));
473 assert!(!pass.is_inline_candidate(&expensive));
474 }
475 #[test]
476 pub(super) fn test_dead_binding_config_display() {
477 let cfg = CopyPropConfig::default();
478 let s = format!("{}", cfg);
479 assert!(s.contains("CopyPropConfig"));
480 }
481 #[test]
482 pub(super) fn test_dead_binding_report_display() {
483 let r = DeadBindingReport {
484 bindings_removed: 3,
485 passes_run: 2,
486 };
487 let s = format!("{}", r);
488 assert!(s.contains("removed=3"));
489 assert!(s.contains("passes=2"));
490 }
491 #[test]
492 pub(super) fn test_constant_fold_report_display() {
493 let r = ConstantFoldReport { folds_performed: 7 };
494 let s = format!("{}", r);
495 assert!(s.contains("folds=7"));
496 }
497 #[test]
498 pub(super) fn test_inline_report_display() {
499 let r = InlineReport {
500 inlines_performed: 2,
501 functions_considered: 10,
502 };
503 let s = format!("{}", r);
504 assert!(s.contains("inlined=2"));
505 assert!(s.contains("considered=10"));
506 }
507}