1use crate::lcnf::{LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfVarId};
6use std::collections::HashMap;
7
8use super::types::{
9 LUAnalysisCache, LUConstantFoldingHelper, LUDepGraph, LUDominatorTree, LULivenessInfo,
10 LUPassConfig, LUPassPhase, LUPassRegistry, LUPassStats, LUWorklist, LoopInfo, LoopUnrollPass,
11 UnrollCandidate, UnrollConfig, UnrollFactor, UnrollReport,
12};
13
14pub fn count_var_refs(expr: &LcnfExpr, target: LcnfVarId) -> usize {
16 match expr {
17 LcnfExpr::Let {
18 id, value, body, ..
19 } => {
20 let in_value = count_var_refs_in_value(value, target);
21 let in_body = if *id == target {
22 0
23 } else {
24 count_var_refs(body, target)
25 };
26 in_value + in_body
27 }
28 LcnfExpr::Case {
29 scrutinee,
30 alts,
31 default,
32 ..
33 } => {
34 let scrutinee_count = if *scrutinee == target { 1 } else { 0 };
35 let alt_count: usize = alts
36 .iter()
37 .filter(|a| a.params.iter().all(|p| p.id != target))
38 .map(|a| count_var_refs(&a.body, target))
39 .sum();
40 let default_count = default
41 .as_ref()
42 .map(|d| count_var_refs(d, target))
43 .unwrap_or(0);
44 scrutinee_count + alt_count + default_count
45 }
46 LcnfExpr::Return(arg) | LcnfExpr::TailCall(arg, _) => {
47 if let crate::lcnf::LcnfArg::Var(id) = arg {
48 if *id == target {
49 1
50 } else {
51 0
52 }
53 } else {
54 0
55 }
56 }
57 LcnfExpr::Unreachable => 0,
58 }
59}
60pub(super) fn count_var_refs_in_value(value: &LcnfLetValue, target: LcnfVarId) -> usize {
61 let count_arg = |a: &crate::lcnf::LcnfArg| {
62 matches!(a, crate ::lcnf::LcnfArg::Var(id) if * id == target) as usize
63 };
64 match value {
65 LcnfLetValue::App(f, args) => count_arg(f) + args.iter().map(count_arg).sum::<usize>(),
66 LcnfLetValue::Proj(_, _, v) => {
67 if *v == target {
68 1
69 } else {
70 0
71 }
72 }
73 LcnfLetValue::Ctor(_, _, args) => args.iter().map(count_arg).sum(),
74 LcnfLetValue::FVar(id) => {
75 if *id == target {
76 1
77 } else {
78 0
79 }
80 }
81 LcnfLetValue::Reset(v) => {
82 if *v == target {
83 1
84 } else {
85 0
86 }
87 }
88 LcnfLetValue::Reuse(slot, _, _, args) => {
89 let s = if *slot == target { 1 } else { 0 };
90 s + args.iter().map(count_arg).sum::<usize>()
91 }
92 LcnfLetValue::Lit(_) | LcnfLetValue::Erased => 0,
93 }
94}
95pub fn estimate_expr_size(expr: &LcnfExpr) -> u64 {
97 match expr {
98 LcnfExpr::Let { body, .. } => 1 + estimate_expr_size(body),
99 LcnfExpr::Case { alts, default, .. } => {
100 let alt_sizes: u64 = alts.iter().map(|a| estimate_expr_size(&a.body)).sum();
101 let def_size = default.as_ref().map(|d| estimate_expr_size(d)).unwrap_or(0);
102 1 + alt_sizes + def_size
103 }
104 LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable => 1,
105 }
106}
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::lcnf::{
111 LcnfAlt, LcnfArg, LcnfExpr, LcnfFunDecl, LcnfLetValue, LcnfLit, LcnfParam, LcnfType,
112 LcnfVarId,
113 };
114 pub(super) fn make_nat_lit(id: u64, n: u64, body: LcnfExpr) -> LcnfExpr {
115 LcnfExpr::Let {
116 id: LcnfVarId(id),
117 name: format!("v{}", id),
118 ty: LcnfType::Nat,
119 value: LcnfLetValue::Lit(LcnfLit::Nat(n)),
120 body: Box::new(body),
121 }
122 }
123 pub(super) fn make_return_nat(n: u64) -> LcnfExpr {
124 LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(n)))
125 }
126 pub(super) fn make_decl(name: &str, body: LcnfExpr) -> LcnfFunDecl {
127 LcnfFunDecl {
128 name: name.to_string(),
129 original_name: None,
130 params: vec![],
131 ret_type: LcnfType::Nat,
132 body,
133 is_recursive: false,
134 is_lifted: false,
135 inline_cost: 0,
136 }
137 }
138 #[test]
139 pub(super) fn test_unroll_factor_full_has_no_numeric_factor() {
140 assert_eq!(UnrollFactor::Full.factor(), None);
141 }
142 #[test]
143 pub(super) fn test_unroll_factor_partial_returns_factor() {
144 assert_eq!(UnrollFactor::Partial(4).factor(), Some(4));
145 }
146 #[test]
147 pub(super) fn test_unroll_factor_jamming_has_no_numeric_factor() {
148 assert_eq!(UnrollFactor::Jamming.factor(), None);
149 }
150 #[test]
151 pub(super) fn test_unroll_factor_vectorizable_returns_factor() {
152 assert_eq!(UnrollFactor::Vectorizable(8).factor(), Some(8));
153 }
154 #[test]
155 pub(super) fn test_unroll_factor_names() {
156 assert_eq!(UnrollFactor::Full.name(), "full");
157 assert_eq!(UnrollFactor::Partial(2).name(), "partial");
158 assert_eq!(UnrollFactor::Jamming.name(), "jamming");
159 assert_eq!(UnrollFactor::Vectorizable(4).name(), "vectorizable");
160 }
161 #[test]
162 pub(super) fn test_loop_info_trip_count_basic() {
163 let info = LoopInfo::new(LcnfVarId(0), 0, 8, 1, vec![]);
164 assert_eq!(info.trip_count, Some(8));
165 }
166 #[test]
167 pub(super) fn test_loop_info_trip_count_step2() {
168 let info = LoopInfo::new(LcnfVarId(0), 0, 8, 2, vec![]);
169 assert_eq!(info.trip_count, Some(4));
170 }
171 #[test]
172 pub(super) fn test_loop_info_trip_count_non_zero_start() {
173 let info = LoopInfo::new(LcnfVarId(0), 3, 15, 3, vec![]);
174 assert_eq!(info.trip_count, Some(4));
175 }
176 #[test]
177 pub(super) fn test_loop_info_is_counted_when_trip_known() {
178 let info = LoopInfo::new(LcnfVarId(0), 0, 4, 1, vec![]);
179 assert!(info.is_counted);
180 }
181 #[test]
182 pub(super) fn test_loop_info_priority_score_innermost_bonus() {
183 let mut info = LoopInfo::new(LcnfVarId(0), 0, 8, 1, vec![]);
184 info.is_innermost = true;
185 let score_inner = info.priority_score();
186 let mut info2 = LoopInfo::new(LcnfVarId(0), 0, 8, 1, vec![]);
187 info2.is_innermost = false;
188 let score_outer = info2.priority_score();
189 assert!(score_inner > score_outer);
190 }
191 #[test]
192 pub(super) fn test_default_config_values() {
193 let cfg = UnrollConfig::default();
194 assert_eq!(cfg.max_unroll_factor, 8);
195 assert_eq!(cfg.max_unrolled_size, 256);
196 assert_eq!(cfg.unroll_full_threshold, 16);
197 assert!(cfg.enable_vectorizable);
198 }
199 #[test]
200 pub(super) fn test_aggressive_config_larger_limits() {
201 let agg = UnrollConfig::aggressive();
202 let def = UnrollConfig::default();
203 assert!(agg.max_unroll_factor >= def.max_unroll_factor);
204 assert!(agg.max_unrolled_size >= def.max_unrolled_size);
205 }
206 #[test]
207 pub(super) fn test_conservative_config_smaller_limits() {
208 let con = UnrollConfig::conservative();
209 let def = UnrollConfig::default();
210 assert!(con.max_unroll_factor <= def.max_unroll_factor);
211 assert!(!con.enable_vectorizable);
212 }
213 #[test]
214 pub(super) fn test_compute_factor_full_for_small_trip() {
215 let pass = LoopUnrollPass::default_pass();
216 let info = LoopInfo::new(LcnfVarId(0), 0, 4, 1, vec![]);
217 assert_eq!(pass.compute_unroll_factor(&info), UnrollFactor::Full);
218 }
219 #[test]
220 pub(super) fn test_compute_factor_partial_for_medium_trip() {
221 let pass = LoopUnrollPass::default_pass();
222 let mut info = LoopInfo::new(LcnfVarId(0), 0, 32, 1, vec![]);
223 info.estimated_size = 10;
224 let factor = pass.compute_unroll_factor(&info);
225 assert_ne!(factor, UnrollFactor::Full);
226 }
227 #[test]
228 pub(super) fn test_compute_factor_vectorizable_for_divisible_trip() {
229 let mut cfg = UnrollConfig::default();
230 cfg.enable_vectorizable = true;
231 let pass = LoopUnrollPass::new(cfg);
232 let mut info = LoopInfo::new(LcnfVarId(0), 0, 32, 1, vec![]);
233 info.estimated_size = 5;
234 info.is_innermost = true;
235 let factor = pass.compute_unroll_factor(&info);
236 assert!(matches!(factor, UnrollFactor::Vectorizable(_)));
237 }
238 #[test]
239 pub(super) fn test_compute_factor_unknown_trip_gives_partial2() {
240 let pass = LoopUnrollPass::default_pass();
241 let info = LoopInfo {
242 loop_var: LcnfVarId(0),
243 start: 0,
244 end: 0,
245 step: 0,
246 body: vec![],
247 trip_count: None,
248 is_innermost: true,
249 is_counted: false,
250 estimated_size: 10,
251 };
252 assert_eq!(pass.compute_unroll_factor(&info), UnrollFactor::Partial(2));
253 }
254 #[test]
255 pub(super) fn test_unroll_loop_partial_2_doubles_body() {
256 let mut pass = LoopUnrollPass::default_pass();
257 let body = vec![make_return_nat(0), make_return_nat(1)];
258 let result = pass.unroll_loop(&body, &UnrollFactor::Partial(2));
259 assert_eq!(result.len(), body.len() * 2);
260 }
261 #[test]
262 pub(super) fn test_unroll_loop_partial_4_quadruples_body() {
263 let mut pass = LoopUnrollPass::default_pass();
264 let body = vec![make_return_nat(42)];
265 let result = pass.unroll_loop(&body, &UnrollFactor::Partial(4));
266 assert_eq!(result.len(), 4);
267 }
268 #[test]
269 pub(super) fn test_unroll_loop_jamming_returns_unchanged() {
270 let mut pass = LoopUnrollPass::default_pass();
271 let body = vec![make_return_nat(7)];
272 let result = pass.unroll_loop(&body, &UnrollFactor::Jamming);
273 assert_eq!(result.len(), body.len());
274 }
275 #[test]
276 pub(super) fn test_unroll_loop_vectorizable_replicates() {
277 let mut pass = LoopUnrollPass::default_pass();
278 let body = vec![make_return_nat(0)];
279 let result = pass.unroll_loop(&body, &UnrollFactor::Vectorizable(4));
280 assert_eq!(result.len(), 4);
281 }
282 #[test]
283 pub(super) fn test_run_empty_decls() {
284 let mut pass = LoopUnrollPass::default_pass();
285 let mut decls: Vec<LcnfFunDecl> = vec![];
286 pass.run(&mut decls);
287 assert_eq!(pass.report().loops_analyzed, 0);
288 }
289 #[test]
290 pub(super) fn test_run_simple_decl_no_loops() {
291 let mut pass = LoopUnrollPass::default_pass();
292 let decl = make_decl("foo", make_return_nat(0));
293 let mut decls = vec![decl];
294 pass.run(&mut decls);
295 assert_eq!(pass.report().loops_analyzed, 0);
296 }
297 #[test]
298 pub(super) fn test_run_preserves_decl_count() {
299 let mut pass = LoopUnrollPass::default_pass();
300 let d1 = make_decl("f1", make_return_nat(1));
301 let d2 = make_decl("f2", make_return_nat(2));
302 let mut decls = vec![d1, d2];
303 pass.run(&mut decls);
304 assert_eq!(decls.len(), 2);
305 }
306 #[test]
307 pub(super) fn test_report_merge() {
308 let mut r1 = UnrollReport {
309 loops_analyzed: 3,
310 loops_unrolled: 2,
311 full_unrolls: 1,
312 partial_unrolls: 1,
313 jammed_loops: 0,
314 vectorizable_loops: 0,
315 estimated_speedup: 1.5,
316 };
317 let r2 = UnrollReport {
318 loops_analyzed: 7,
319 loops_unrolled: 4,
320 full_unrolls: 2,
321 partial_unrolls: 2,
322 jammed_loops: 2,
323 vectorizable_loops: 0,
324 estimated_speedup: 2.0,
325 };
326 r1.merge(&r2);
327 assert_eq!(r1.loops_analyzed, 10);
328 assert_eq!(r1.loops_unrolled, 6);
329 assert_eq!(r1.jammed_loops, 2);
330 }
331 #[test]
332 pub(super) fn test_report_summary_contains_key_fields() {
333 let r = UnrollReport {
334 loops_analyzed: 5,
335 loops_unrolled: 3,
336 full_unrolls: 1,
337 partial_unrolls: 2,
338 jammed_loops: 0,
339 vectorizable_loops: 0,
340 estimated_speedup: 1.8,
341 };
342 let s = r.summary();
343 assert!(s.contains("analyzed=5"));
344 assert!(s.contains("unrolled=3"));
345 }
346 #[test]
347 pub(super) fn test_estimate_size_return_is_1() {
348 assert_eq!(estimate_expr_size(&make_return_nat(0)), 1);
349 }
350 #[test]
351 pub(super) fn test_estimate_size_let_adds_1() {
352 let e = make_nat_lit(0, 42, make_return_nat(0));
353 assert_eq!(estimate_expr_size(&e), 2);
354 }
355 #[test]
356 pub(super) fn test_estimate_size_chain() {
357 let e = make_nat_lit(0, 1, make_nat_lit(1, 2, make_return_nat(0)));
358 assert_eq!(estimate_expr_size(&e), 3);
359 }
360 #[test]
361 pub(super) fn test_count_var_refs_return() {
362 let e = LcnfExpr::Return(LcnfArg::Var(LcnfVarId(5)));
363 assert_eq!(count_var_refs(&e, LcnfVarId(5)), 1);
364 assert_eq!(count_var_refs(&e, LcnfVarId(6)), 0);
365 }
366 #[test]
367 pub(super) fn test_count_var_refs_in_let_value() {
368 let e = LcnfExpr::Let {
369 id: LcnfVarId(1),
370 name: "x".to_string(),
371 ty: LcnfType::Nat,
372 value: LcnfLetValue::FVar(LcnfVarId(5)),
373 body: Box::new(make_return_nat(0)),
374 };
375 assert_eq!(count_var_refs(&e, LcnfVarId(5)), 1);
376 }
377 #[test]
378 pub(super) fn test_count_var_refs_shadowed() {
379 let e = LcnfExpr::Let {
380 id: LcnfVarId(5),
381 name: "x".to_string(),
382 ty: LcnfType::Nat,
383 value: LcnfLetValue::Erased,
384 body: Box::new(LcnfExpr::Return(LcnfArg::Var(LcnfVarId(5)))),
385 };
386 assert_eq!(count_var_refs(&e, LcnfVarId(5)), 0);
387 }
388 #[test]
389 pub(super) fn test_candidate_is_profitable_positive_savings() {
390 let info = LoopInfo::new(LcnfVarId(0), 0, 4, 1, vec![]);
391 let c = UnrollCandidate::new("f", info, UnrollFactor::Full, 10);
392 assert!(c.is_profitable());
393 }
394 #[test]
395 pub(super) fn test_candidate_is_not_profitable_negative_savings() {
396 let info = LoopInfo::new(LcnfVarId(0), 0, 4, 1, vec![]);
397 let c = UnrollCandidate::new("f", info, UnrollFactor::Full, -5);
398 assert!(!c.is_profitable());
399 }
400 #[test]
401 pub(super) fn test_case_expr_size() {
402 let case_expr = LcnfExpr::Case {
403 scrutinee: LcnfVarId(0),
404 scrutinee_ty: LcnfType::Nat,
405 alts: vec![
406 LcnfAlt {
407 ctor_name: "zero".to_string(),
408 ctor_tag: 0,
409 params: vec![],
410 body: make_return_nat(0),
411 },
412 LcnfAlt {
413 ctor_name: "succ".to_string(),
414 ctor_tag: 1,
415 params: vec![LcnfParam {
416 id: LcnfVarId(1),
417 name: "n".to_string(),
418 ty: LcnfType::Nat,
419 erased: false,
420 borrowed: false,
421 }],
422 body: make_return_nat(1),
423 },
424 ],
425 default: None,
426 };
427 assert_eq!(estimate_expr_size(&case_expr), 3);
428 }
429}
430#[allow(dead_code)]
432pub trait LoopOptPass {
433 fn name(&self) -> &str;
435 fn run_pass(&mut self, decls: &mut [LcnfFunDecl]) -> UnrollReport;
437}
438#[cfg(test)]
439mod LU_infra_tests {
440 use super::*;
441 #[test]
442 pub(super) fn test_pass_config() {
443 let config = LUPassConfig::new("test_pass", LUPassPhase::Transformation);
444 assert!(config.enabled);
445 assert!(config.phase.is_modifying());
446 assert_eq!(config.phase.name(), "transformation");
447 }
448 #[test]
449 pub(super) fn test_pass_stats() {
450 let mut stats = LUPassStats::new();
451 stats.record_run(10, 100, 3);
452 stats.record_run(20, 200, 5);
453 assert_eq!(stats.total_runs, 2);
454 assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
455 assert!((stats.success_rate() - 1.0).abs() < 0.01);
456 let s = stats.format_summary();
457 assert!(s.contains("Runs: 2/2"));
458 }
459 #[test]
460 pub(super) fn test_pass_registry() {
461 let mut reg = LUPassRegistry::new();
462 reg.register(LUPassConfig::new("pass_a", LUPassPhase::Analysis));
463 reg.register(LUPassConfig::new("pass_b", LUPassPhase::Transformation).disabled());
464 assert_eq!(reg.total_passes(), 2);
465 assert_eq!(reg.enabled_count(), 1);
466 reg.update_stats("pass_a", 5, 50, 2);
467 let stats = reg.get_stats("pass_a").expect("stats should exist");
468 assert_eq!(stats.total_changes, 5);
469 }
470 #[test]
471 pub(super) fn test_analysis_cache() {
472 let mut cache = LUAnalysisCache::new(10);
473 cache.insert("key1".to_string(), vec![1, 2, 3]);
474 assert!(cache.get("key1").is_some());
475 assert!(cache.get("key2").is_none());
476 assert!((cache.hit_rate() - 0.5).abs() < 0.01);
477 cache.invalidate("key1");
478 assert!(!cache.entries["key1"].valid);
479 assert_eq!(cache.size(), 1);
480 }
481 #[test]
482 pub(super) fn test_worklist() {
483 let mut wl = LUWorklist::new();
484 assert!(wl.push(1));
485 assert!(wl.push(2));
486 assert!(!wl.push(1));
487 assert_eq!(wl.len(), 2);
488 assert_eq!(wl.pop(), Some(1));
489 assert!(!wl.contains(1));
490 assert!(wl.contains(2));
491 }
492 #[test]
493 pub(super) fn test_dominator_tree() {
494 let mut dt = LUDominatorTree::new(5);
495 dt.set_idom(1, 0);
496 dt.set_idom(2, 0);
497 dt.set_idom(3, 1);
498 assert!(dt.dominates(0, 3));
499 assert!(dt.dominates(1, 3));
500 assert!(!dt.dominates(2, 3));
501 assert!(dt.dominates(3, 3));
502 }
503 #[test]
504 pub(super) fn test_liveness() {
505 let mut liveness = LULivenessInfo::new(3);
506 liveness.add_def(0, 1);
507 liveness.add_use(1, 1);
508 assert!(liveness.defs[0].contains(&1));
509 assert!(liveness.uses[1].contains(&1));
510 }
511 #[test]
512 pub(super) fn test_constant_folding() {
513 assert_eq!(LUConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
514 assert_eq!(LUConstantFoldingHelper::fold_div_i64(10, 0), None);
515 assert_eq!(LUConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
516 assert_eq!(
517 LUConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
518 0b1000
519 );
520 assert_eq!(LUConstantFoldingHelper::fold_bitnot_i64(0), -1);
521 }
522 #[test]
523 pub(super) fn test_dep_graph() {
524 let mut g = LUDepGraph::new();
525 g.add_dep(1, 2);
526 g.add_dep(2, 3);
527 g.add_dep(1, 3);
528 assert_eq!(g.dependencies_of(2), vec![1]);
529 let topo = g.topological_sort();
530 assert_eq!(topo.len(), 3);
531 assert!(!g.has_cycle());
532 let pos: std::collections::HashMap<u32, usize> =
533 topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
534 assert!(pos[&1] < pos[&2]);
535 assert!(pos[&1] < pos[&3]);
536 assert!(pos[&2] < pos[&3]);
537 }
538}