1use crate::lcnf::*;
6use std::collections::{HashMap, HashSet};
7
8use super::types::{
9 CallSiteInfo, JoinPointConfig, JoinPointOptimizer, JoinPointStats, OJAnalysisCache,
10 OJConstantFoldingHelper, OJDepGraph, OJDominatorTree, OJLivenessInfo, OJPassConfig,
11 OJPassPhase, OJPassRegistry, OJPassStats, OJWorklist, OJoinConfig, OJoinDiagCollector,
12 OJoinDiagMsg, OJoinEmitStats, OJoinEventLog, OJoinFeatures, OJoinIdGen, OJoinIncrKey,
13 OJoinNameScope, OJoinPassTiming, OJoinProfiler, OJoinSourceBuffer, OJoinVersion, TailUse,
14};
15
16pub(super) fn analyze_tail_uses(expr: &LcnfExpr, tail: bool) -> HashMap<LcnfVarId, TailUse> {
18 let mut uses: HashMap<LcnfVarId, TailUse> = HashMap::new();
19 match expr {
20 LcnfExpr::Let {
21 value, body, id, ..
22 } => {
23 collect_value_uses(value, &mut uses, false);
24 let body_uses = analyze_tail_uses(body, tail);
25 for (var, use_kind) in body_uses {
26 if var != *id {
27 let current = uses.entry(var).or_insert(TailUse::Unused);
28 *current = current.merge(&use_kind);
29 }
30 }
31 }
32 LcnfExpr::Case {
33 scrutinee,
34 alts,
35 default,
36 ..
37 } => {
38 let current = uses.entry(*scrutinee).or_insert(TailUse::Unused);
39 *current = current.merge(&TailUse::NonTail);
40 for alt in alts {
41 let alt_uses = analyze_tail_uses(&alt.body, tail);
42 for (var, use_kind) in alt_uses {
43 let current = uses.entry(var).or_insert(TailUse::Unused);
44 *current = current.merge(&use_kind);
45 }
46 }
47 if let Some(def) = default {
48 let def_uses = analyze_tail_uses(def, tail);
49 for (var, use_kind) in def_uses {
50 let current = uses.entry(var).or_insert(TailUse::Unused);
51 *current = current.merge(&use_kind);
52 }
53 }
54 }
55 LcnfExpr::Return(arg) => {
56 if let LcnfArg::Var(v) = arg {
57 let use_kind = if tail {
58 TailUse::TailOnly
59 } else {
60 TailUse::NonTail
61 };
62 let current = uses.entry(*v).or_insert(TailUse::Unused);
63 *current = current.merge(&use_kind);
64 }
65 }
66 LcnfExpr::TailCall(func, args) => {
67 if let LcnfArg::Var(v) = func {
68 let use_kind = if tail {
69 TailUse::TailOnly
70 } else {
71 TailUse::NonTail
72 };
73 let current = uses.entry(*v).or_insert(TailUse::Unused);
74 *current = current.merge(&use_kind);
75 }
76 for arg in args {
77 if let LcnfArg::Var(v) = arg {
78 let current = uses.entry(*v).or_insert(TailUse::Unused);
79 *current = current.merge(&TailUse::NonTail);
80 }
81 }
82 }
83 LcnfExpr::Unreachable => {}
84 }
85 uses
86}
87pub(super) fn collect_value_uses(
89 value: &LcnfLetValue,
90 uses: &mut HashMap<LcnfVarId, TailUse>,
91 _tail: bool,
92) {
93 let vars = extract_value_vars(value);
94 for v in vars {
95 let current = uses.entry(v).or_insert(TailUse::Unused);
96 *current = current.merge(&TailUse::NonTail);
97 }
98}
99pub(super) fn extract_value_vars(value: &LcnfLetValue) -> Vec<LcnfVarId> {
101 let mut vars = Vec::new();
102 match value {
103 LcnfLetValue::App(func, args) => {
104 if let LcnfArg::Var(v) = func {
105 vars.push(*v);
106 }
107 for a in args {
108 if let LcnfArg::Var(v) = a {
109 vars.push(*v);
110 }
111 }
112 }
113 LcnfLetValue::Proj(_, _, v) => {
114 vars.push(*v);
115 }
116 LcnfLetValue::Ctor(_, _, args) => {
117 for a in args {
118 if let LcnfArg::Var(v) = a {
119 vars.push(*v);
120 }
121 }
122 }
123 LcnfLetValue::FVar(v) => {
124 vars.push(*v);
125 }
126 LcnfLetValue::Lit(_)
127 | LcnfLetValue::Erased
128 | LcnfLetValue::Reset(_)
129 | LcnfLetValue::Reuse(_, _, _, _) => {}
130 }
131 vars
132}
133pub(super) fn analyze_call_sites(
135 expr: &LcnfExpr,
136 caller: &str,
137 in_tail: bool,
138) -> Vec<CallSiteInfo> {
139 let mut sites = Vec::new();
140 match expr {
141 LcnfExpr::Let { value, body, .. } => {
142 if let LcnfLetValue::App(func, args) = value {
143 let callee_var = if let LcnfArg::Var(v) = func {
144 Some(*v)
145 } else {
146 None
147 };
148 sites.push(CallSiteInfo {
149 caller: caller.to_string(),
150 is_tail: false,
151 arg_count: args.len(),
152 callee_var,
153 });
154 }
155 sites.extend(analyze_call_sites(body, caller, in_tail));
156 }
157 LcnfExpr::Case { alts, default, .. } => {
158 for alt in alts {
159 sites.extend(analyze_call_sites(&alt.body, caller, in_tail));
160 }
161 if let Some(def) = default {
162 sites.extend(analyze_call_sites(def, caller, in_tail));
163 }
164 }
165 LcnfExpr::TailCall(func, args) => {
166 let callee_var = if let LcnfArg::Var(v) = func {
167 Some(*v)
168 } else {
169 None
170 };
171 sites.push(CallSiteInfo {
172 caller: caller.to_string(),
173 is_tail: in_tail,
174 arg_count: args.len(),
175 callee_var,
176 });
177 }
178 LcnfExpr::Return(_) | LcnfExpr::Unreachable => {}
179 }
180 sites
181}
182pub(super) fn collect_used_vars(expr: &LcnfExpr) -> HashSet<LcnfVarId> {
184 let mut used = HashSet::new();
185 collect_used_vars_inner(expr, &mut used);
186 used
187}
188pub(super) fn collect_used_vars_inner(expr: &LcnfExpr, used: &mut HashSet<LcnfVarId>) {
189 match expr {
190 LcnfExpr::Let { value, body, .. } => {
191 collect_value_used_vars(value, used);
192 collect_used_vars_inner(body, used);
193 }
194 LcnfExpr::Case {
195 scrutinee,
196 alts,
197 default,
198 ..
199 } => {
200 used.insert(*scrutinee);
201 for alt in alts {
202 collect_used_vars_inner(&alt.body, used);
203 }
204 if let Some(def) = default {
205 collect_used_vars_inner(def, used);
206 }
207 }
208 LcnfExpr::Return(arg) => {
209 if let LcnfArg::Var(v) = arg {
210 used.insert(*v);
211 }
212 }
213 LcnfExpr::TailCall(func, args) => {
214 if let LcnfArg::Var(v) = func {
215 used.insert(*v);
216 }
217 for a in args {
218 if let LcnfArg::Var(v) = a {
219 used.insert(*v);
220 }
221 }
222 }
223 LcnfExpr::Unreachable => {}
224 }
225}
226pub(super) fn collect_value_used_vars(value: &LcnfLetValue, used: &mut HashSet<LcnfVarId>) {
227 match value {
228 LcnfLetValue::App(func, args) => {
229 if let LcnfArg::Var(v) = func {
230 used.insert(*v);
231 }
232 for a in args {
233 if let LcnfArg::Var(v) = a {
234 used.insert(*v);
235 }
236 }
237 }
238 LcnfLetValue::Proj(_, _, v) => {
239 used.insert(*v);
240 }
241 LcnfLetValue::Ctor(_, _, args) => {
242 for a in args {
243 if let LcnfArg::Var(v) = a {
244 used.insert(*v);
245 }
246 }
247 }
248 LcnfLetValue::FVar(v) => {
249 used.insert(*v);
250 }
251 LcnfLetValue::Lit(_)
252 | LcnfLetValue::Erased
253 | LcnfLetValue::Reset(_)
254 | LcnfLetValue::Reuse(_, _, _, _) => {}
255 }
256}
257pub(super) fn is_pure_value(value: &LcnfLetValue) -> bool {
259 match value {
260 LcnfLetValue::Lit(_)
261 | LcnfLetValue::Erased
262 | LcnfLetValue::FVar(_)
263 | LcnfLetValue::Proj(_, _, _)
264 | LcnfLetValue::Ctor(_, _, _) => true,
265 LcnfLetValue::App(_, _) | LcnfLetValue::Reset(_) | LcnfLetValue::Reuse(_, _, _, _) => false,
266 }
267}
268pub(super) fn expr_uses_var(expr: &LcnfExpr, var: LcnfVarId) -> bool {
270 match expr {
271 LcnfExpr::Let {
272 id, value, body, ..
273 } => value_uses_var(value, var) || (*id != var && expr_uses_var(body, var)),
274 LcnfExpr::Case {
275 scrutinee,
276 alts,
277 default,
278 ..
279 } => {
280 *scrutinee == var
281 || alts.iter().any(|alt| expr_uses_var(&alt.body, var))
282 || default.as_ref().is_some_and(|d| expr_uses_var(d, var))
283 }
284 LcnfExpr::Return(arg) => matches!(arg, LcnfArg::Var(v) if * v == var),
285 LcnfExpr::TailCall(func, args) => {
286 matches!(func, LcnfArg::Var(v) if * v == var)
287 || args
288 .iter()
289 .any(|a| matches!(a, LcnfArg::Var(v) if * v == var))
290 }
291 LcnfExpr::Unreachable => false,
292 }
293}
294pub(super) fn value_uses_var(value: &LcnfLetValue, var: LcnfVarId) -> bool {
296 match value {
297 LcnfLetValue::App(func, args) => {
298 matches!(func, LcnfArg::Var(v) if * v == var)
299 || args
300 .iter()
301 .any(|a| matches!(a, LcnfArg::Var(v) if * v == var))
302 }
303 LcnfLetValue::Proj(_, _, v) => *v == var,
304 LcnfLetValue::Ctor(_, _, args) => args
305 .iter()
306 .any(|a| matches!(a, LcnfArg::Var(v) if * v == var)),
307 LcnfLetValue::FVar(v) => *v == var,
308 LcnfLetValue::Lit(_)
309 | LcnfLetValue::Erased
310 | LcnfLetValue::Reset(_)
311 | LcnfLetValue::Reuse(_, _, _, _) => false,
312 }
313}
314pub(super) fn count_instructions(expr: &LcnfExpr) -> usize {
316 match expr {
317 LcnfExpr::Let { body, .. } => 1 + count_instructions(body),
318 LcnfExpr::Case { alts, default, .. } => {
319 let alts_size: usize = alts.iter().map(|a| count_instructions(&a.body)).sum();
320 let def_size = default.as_ref().map(|d| count_instructions(d)).unwrap_or(0);
321 1 + alts_size + def_size
322 }
323 LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable => 1,
324 }
325}
326pub(super) fn compute_call_graph(decls: &[LcnfFunDecl]) -> HashMap<String, HashSet<String>> {
328 let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
329 let decl_names: HashSet<&str> = decls.iter().map(|d| d.name.as_str()).collect();
330 for decl in decls {
331 let mut callees = HashSet::new();
332 collect_callees(&decl.body, &decl_names, &mut callees);
333 graph.insert(decl.name.clone(), callees);
334 }
335 graph
336}
337pub(super) fn collect_callees(
342 expr: &LcnfExpr,
343 known_fns: &HashSet<&str>,
344 callees: &mut HashSet<String>,
345) {
346 let ctx: HashMap<LcnfVarId, String> = HashMap::new();
347 collect_callees_ctx(expr, known_fns, callees, &ctx);
348}
349pub(super) fn collect_callees_ctx(
351 expr: &LcnfExpr,
352 known_fns: &HashSet<&str>,
353 callees: &mut HashSet<String>,
354 ctx: &HashMap<LcnfVarId, String>,
355) {
356 match expr {
357 LcnfExpr::Let {
358 id, value, body, ..
359 } => {
360 if let LcnfLetValue::App(LcnfArg::Var(v), _) = value {
361 if let Some(name) = ctx.get(v) {
362 if known_fns.contains(name.as_str()) {
363 callees.insert(name.clone());
364 }
365 }
366 }
367 if let LcnfLetValue::FVar(v) = value {
368 if let Some(name) = ctx.get(v).cloned() {
369 let mut extended = ctx.clone();
370 extended.insert(*id, name);
371 collect_callees_ctx(body, known_fns, callees, &extended);
372 return;
373 }
374 }
375 collect_callees_ctx(body, known_fns, callees, ctx);
376 }
377 LcnfExpr::Case { alts, default, .. } => {
378 for alt in alts {
379 collect_callees_ctx(&alt.body, known_fns, callees, ctx);
380 }
381 if let Some(def) = default {
382 collect_callees_ctx(def, known_fns, callees, ctx);
383 }
384 }
385 LcnfExpr::TailCall(LcnfArg::Var(v), _) => {
386 if let Some(name) = ctx.get(v) {
387 if known_fns.contains(name.as_str()) {
388 callees.insert(name.clone());
389 }
390 }
391 }
392 _ => {}
393 }
394}
395pub(super) fn find_self_recursive_tail_calls(
397 expr: &LcnfExpr,
398 fn_name: &str,
399 var_to_name: &HashMap<LcnfVarId, String>,
400) -> Vec<LcnfVarId> {
401 let mut self_calls = Vec::new();
402 match expr {
403 LcnfExpr::Let { body, .. } => {
404 self_calls.extend(find_self_recursive_tail_calls(body, fn_name, var_to_name));
405 }
406 LcnfExpr::Case { alts, default, .. } => {
407 for alt in alts {
408 self_calls.extend(find_self_recursive_tail_calls(
409 &alt.body,
410 fn_name,
411 var_to_name,
412 ));
413 }
414 if let Some(def) = default {
415 self_calls.extend(find_self_recursive_tail_calls(def, fn_name, var_to_name));
416 }
417 }
418 LcnfExpr::TailCall(LcnfArg::Var(v), _) => {
419 if let Some(name) = var_to_name.get(v) {
420 if name == fn_name {
421 self_calls.push(*v);
422 }
423 }
424 }
425 _ => {}
426 }
427 self_calls
428}
429pub(super) fn is_join_point_candidate(callee_id: LcnfVarId, call_sites: &[CallSiteInfo]) -> bool {
432 let relevant: Vec<&CallSiteInfo> = call_sites
433 .iter()
434 .filter(|cs| cs.callee_var == Some(callee_id))
435 .collect();
436 if relevant.is_empty() {
437 return false;
438 }
439 relevant.iter().all(|cs| cs.is_tail)
440}
441pub fn optimize_join_points(module: &mut LcnfModule, config: &JoinPointConfig) {
443 let mut optimizer = JoinPointOptimizer::new(config.clone());
444 for decl in &mut module.fun_decls {
445 optimizer.optimize_decl(decl);
446 }
447 if config.eliminate_dead_joins {
448 eliminate_dead_functions(module);
449 }
450}
451pub(super) fn eliminate_dead_functions(module: &mut LcnfModule) {
453 if module.fun_decls.len() <= 1 {
454 return;
455 }
456 let call_graph = compute_call_graph(&module.fun_decls);
457 let mut reachable: HashSet<String> = HashSet::new();
458 let mut worklist: Vec<String> = module
459 .fun_decls
460 .iter()
461 .filter(|d| !d.is_lifted)
462 .map(|d| d.name.clone())
463 .collect();
464 while let Some(fn_name) = worklist.pop() {
465 if reachable.insert(fn_name.clone()) {
466 if let Some(callees) = call_graph.get(&fn_name) {
467 for callee in callees {
468 if !reachable.contains(callee) {
469 worklist.push(callee.clone());
470 }
471 }
472 }
473 }
474 }
475 module.fun_decls.retain(|d| reachable.contains(&d.name));
476}
477pub(super) fn create_join_point(
479 join_id: LcnfVarId,
480 params: Vec<LcnfParam>,
481 body: LcnfExpr,
482 ret_type: LcnfType,
483) -> LcnfFunDecl {
484 let cost = count_instructions(&body);
485 LcnfFunDecl {
486 name: format!("_join_{}", join_id.0),
487 original_name: None,
488 params,
489 ret_type,
490 body,
491 is_recursive: false,
492 is_lifted: true,
493 inline_cost: cost,
494 }
495}
496pub(super) fn convert_to_loop(decl: &mut LcnfFunDecl) -> bool {
498 if !decl.is_recursive {
499 return false;
500 }
501 let var_to_name: HashMap<LcnfVarId, String> = HashMap::new();
502 let self_calls = find_self_recursive_tail_calls(&decl.body, &decl.name, &var_to_name);
503 !self_calls.is_empty()
504}
505#[cfg(test)]
506mod tests {
507 use super::*;
508 pub(super) fn make_var(n: u64) -> LcnfVarId {
509 LcnfVarId(n)
510 }
511 pub(super) fn make_param(n: u64, name: &str) -> LcnfParam {
512 LcnfParam {
513 id: LcnfVarId(n),
514 name: name.to_string(),
515 ty: LcnfType::Nat,
516 erased: false,
517 borrowed: false,
518 }
519 }
520 pub(super) fn make_simple_let(id: u64, value: LcnfLetValue, body: LcnfExpr) -> LcnfExpr {
521 LcnfExpr::Let {
522 id: LcnfVarId(id),
523 name: format!("x{}", id),
524 ty: LcnfType::Nat,
525 value,
526 body: Box::new(body),
527 }
528 }
529 pub(super) fn make_simple_decl(name: &str, body: LcnfExpr) -> LcnfFunDecl {
530 LcnfFunDecl {
531 name: name.to_string(),
532 original_name: None,
533 params: vec![make_param(0, "arg0")],
534 ret_type: LcnfType::Nat,
535 body,
536 is_recursive: false,
537 is_lifted: false,
538 inline_cost: 1,
539 }
540 }
541 #[test]
542 pub(super) fn test_config_default() {
543 let config = JoinPointConfig::default();
544 assert_eq!(config.max_join_size, 10);
545 assert!(config.inline_small_joins);
546 assert!(config.detect_tail_calls);
547 assert!(config.enable_contification);
548 }
549 #[test]
550 pub(super) fn test_stats_default() {
551 let stats = JoinPointStats::default();
552 assert_eq!(stats.total_changes(), 0);
553 }
554 #[test]
555 pub(super) fn test_tail_use_merge() {
556 assert_eq!(TailUse::Unused.merge(&TailUse::TailOnly), TailUse::TailOnly);
557 assert_eq!(
558 TailUse::TailOnly.merge(&TailUse::TailOnly),
559 TailUse::TailOnly
560 );
561 assert_eq!(TailUse::TailOnly.merge(&TailUse::NonTail), TailUse::Mixed);
562 assert_eq!(TailUse::NonTail.merge(&TailUse::NonTail), TailUse::NonTail);
563 }
564 #[test]
565 pub(super) fn test_is_pure_value() {
566 assert!(is_pure_value(&LcnfLetValue::Lit(LcnfLit::Nat(42))));
567 assert!(is_pure_value(&LcnfLetValue::Erased));
568 assert!(is_pure_value(&LcnfLetValue::FVar(make_var(0))));
569 assert!(is_pure_value(&LcnfLetValue::Proj(
570 "foo".into(),
571 0,
572 make_var(0)
573 )));
574 assert!(!is_pure_value(&LcnfLetValue::App(
575 LcnfArg::Var(make_var(0)),
576 vec![]
577 )));
578 }
579 #[test]
580 pub(super) fn test_count_instructions() {
581 let ret = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
582 assert_eq!(count_instructions(&ret), 1);
583 let let_expr = make_simple_let(
584 1,
585 LcnfLetValue::Lit(LcnfLit::Nat(42)),
586 LcnfExpr::Return(LcnfArg::Var(make_var(1))),
587 );
588 assert_eq!(count_instructions(&let_expr), 2);
589 }
590 #[test]
591 pub(super) fn test_collect_used_vars() {
592 let expr = make_simple_let(
593 1,
594 LcnfLetValue::FVar(make_var(0)),
595 LcnfExpr::Return(LcnfArg::Var(make_var(1))),
596 );
597 let used = collect_used_vars(&expr);
598 assert!(used.contains(&make_var(0)));
599 assert!(used.contains(&make_var(1)));
600 }
601 #[test]
602 pub(super) fn test_expr_uses_var() {
603 let expr = LcnfExpr::Return(LcnfArg::Var(make_var(5)));
604 assert!(expr_uses_var(&expr, make_var(5)));
605 assert!(!expr_uses_var(&expr, make_var(6)));
606 }
607 #[test]
608 pub(super) fn test_value_uses_var() {
609 let val = LcnfLetValue::App(LcnfArg::Var(make_var(1)), vec![LcnfArg::Var(make_var(2))]);
610 assert!(value_uses_var(&val, make_var(1)));
611 assert!(value_uses_var(&val, make_var(2)));
612 assert!(!value_uses_var(&val, make_var(3)));
613 }
614 #[test]
615 pub(super) fn test_extract_value_vars() {
616 let val = LcnfLetValue::App(
617 LcnfArg::Var(make_var(1)),
618 vec![LcnfArg::Var(make_var(2)), LcnfArg::Lit(LcnfLit::Nat(0))],
619 );
620 let vars = extract_value_vars(&val);
621 assert_eq!(vars.len(), 2);
622 assert!(vars.contains(&make_var(1)));
623 assert!(vars.contains(&make_var(2)));
624 }
625 #[test]
626 pub(super) fn test_detect_tail_calls() {
627 let mut expr = make_simple_let(
628 1,
629 LcnfLetValue::App(LcnfArg::Var(make_var(10)), vec![LcnfArg::Var(make_var(0))]),
630 LcnfExpr::Return(LcnfArg::Var(make_var(1))),
631 );
632 let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
633 optimizer.detect_tail_calls_in_expr(&mut expr, "test");
634 assert!(matches!(expr, LcnfExpr::TailCall(_, _)));
635 assert_eq!(optimizer.stats.tail_calls_detected, 1);
636 }
637 #[test]
638 pub(super) fn test_dead_join_elimination() {
639 let mut expr = make_simple_let(
640 1,
641 LcnfLetValue::Lit(LcnfLit::Nat(42)),
642 make_simple_let(
643 2,
644 LcnfLetValue::Lit(LcnfLit::Nat(100)),
645 LcnfExpr::Return(LcnfArg::Var(make_var(2))),
646 ),
647 );
648 let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
649 optimizer.eliminate_dead_joins(&mut expr);
650 assert!(matches!(expr, LcnfExpr::Let { id, .. } if id == make_var(2)));
651 }
652 #[test]
653 pub(super) fn test_optimize_join_points_full() {
654 let body = make_simple_let(
655 1,
656 LcnfLetValue::Lit(LcnfLit::Nat(42)),
657 LcnfExpr::Return(LcnfArg::Var(make_var(1))),
658 );
659 let decl = make_simple_decl("test_fn", body);
660 let mut module = LcnfModule {
661 fun_decls: vec![decl],
662 extern_decls: vec![],
663 name: "test_mod".to_string(),
664 metadata: LcnfModuleMetadata::default(),
665 };
666 let config = JoinPointConfig::default();
667 optimize_join_points(&mut module, &config);
668 assert_eq!(module.fun_decls.len(), 1);
669 }
670 #[test]
671 pub(super) fn test_value_size() {
672 let optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
673 assert_eq!(optimizer.value_size(&LcnfLetValue::Lit(LcnfLit::Nat(0))), 1);
674 assert_eq!(optimizer.value_size(&LcnfLetValue::Erased), 1);
675 assert_eq!(
676 optimizer.value_size(&LcnfLetValue::App(
677 LcnfArg::Var(make_var(0)),
678 vec![LcnfArg::Var(make_var(1)), LcnfArg::Var(make_var(2))]
679 )),
680 3
681 );
682 }
683 #[test]
684 pub(super) fn test_analyze_tail_uses_return() {
685 let expr = LcnfExpr::Return(LcnfArg::Var(make_var(5)));
686 let uses = analyze_tail_uses(&expr, true);
687 assert_eq!(uses.get(&make_var(5)), Some(&TailUse::TailOnly));
688 }
689 #[test]
690 pub(super) fn test_analyze_tail_uses_non_tail() {
691 let expr = LcnfExpr::Return(LcnfArg::Var(make_var(5)));
692 let uses = analyze_tail_uses(&expr, false);
693 assert_eq!(uses.get(&make_var(5)), Some(&TailUse::NonTail));
694 }
695 #[test]
696 pub(super) fn test_call_site_analysis() {
697 let body = make_simple_let(
698 1,
699 LcnfLetValue::App(LcnfArg::Var(make_var(10)), vec![LcnfArg::Var(make_var(0))]),
700 LcnfExpr::Return(LcnfArg::Var(make_var(1))),
701 );
702 let sites = analyze_call_sites(&body, "test_fn", true);
703 assert_eq!(sites.len(), 1);
704 assert!(!sites[0].is_tail);
705 assert_eq!(sites[0].arg_count, 1);
706 }
707 #[test]
708 pub(super) fn test_compute_call_graph() {
709 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
710 let decl1 = make_simple_decl("foo", body.clone());
711 let decl2 = make_simple_decl("bar", body);
712 let graph = compute_call_graph(&[decl1, decl2]);
713 assert!(graph.contains_key("foo"));
714 assert!(graph.contains_key("bar"));
715 }
716 #[test]
717 pub(super) fn test_create_join_point() {
718 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
719 let jp = create_join_point(make_var(100), vec![make_param(1, "p")], body, LcnfType::Nat);
720 assert_eq!(jp.name, "_join_100");
721 assert!(jp.is_lifted);
722 }
723 #[test]
724 pub(super) fn test_convert_to_loop_non_recursive() {
725 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
726 let mut decl = make_simple_decl("test", body);
727 decl.is_recursive = false;
728 assert!(!convert_to_loop(&mut decl));
729 }
730 #[test]
731 pub(super) fn test_join_point_candidate() {
732 let sites = vec![
733 CallSiteInfo {
734 caller: "f".to_string(),
735 is_tail: true,
736 arg_count: 1,
737 callee_var: Some(make_var(5)),
738 },
739 CallSiteInfo {
740 caller: "g".to_string(),
741 is_tail: true,
742 arg_count: 1,
743 callee_var: Some(make_var(5)),
744 },
745 ];
746 assert!(is_join_point_candidate(make_var(5), &sites));
747 let mixed_sites = vec![
748 CallSiteInfo {
749 caller: "f".to_string(),
750 is_tail: true,
751 arg_count: 1,
752 callee_var: Some(make_var(5)),
753 },
754 CallSiteInfo {
755 caller: "g".to_string(),
756 is_tail: false,
757 arg_count: 1,
758 callee_var: Some(make_var(5)),
759 },
760 ];
761 assert!(!is_join_point_candidate(make_var(5), &mixed_sites));
762 }
763 #[test]
764 pub(super) fn test_optimizer_fresh_id() {
765 let mut opt = JoinPointOptimizer::new(JoinPointConfig::default());
766 let id1 = opt.fresh_id();
767 let id2 = opt.fresh_id();
768 assert_ne!(id1, id2);
769 }
770 #[test]
771 pub(super) fn test_case_tail_call_detection() {
772 let mut expr = LcnfExpr::Case {
773 scrutinee: make_var(0),
774 scrutinee_ty: LcnfType::Nat,
775 alts: vec![
776 LcnfAlt {
777 ctor_name: "True".to_string(),
778 ctor_tag: 0,
779 params: vec![],
780 body: make_simple_let(
781 5,
782 LcnfLetValue::App(
783 LcnfArg::Var(make_var(10)),
784 vec![LcnfArg::Var(make_var(1))],
785 ),
786 LcnfExpr::Return(LcnfArg::Var(make_var(5))),
787 ),
788 },
789 LcnfAlt {
790 ctor_name: "False".to_string(),
791 ctor_tag: 1,
792 params: vec![],
793 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
794 },
795 ],
796 default: None,
797 };
798 let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
799 optimizer.detect_tail_calls_in_expr(&mut expr, "test");
800 assert_eq!(optimizer.stats.tail_calls_detected, 1);
801 if let LcnfExpr::Case { alts, .. } = &expr {
802 assert!(matches!(alts[0].body, LcnfExpr::TailCall(_, _)));
803 }
804 }
805 #[test]
806 pub(super) fn test_nested_dead_elimination() {
807 let mut expr = make_simple_let(
808 1,
809 LcnfLetValue::Lit(LcnfLit::Nat(1)),
810 make_simple_let(
811 2,
812 LcnfLetValue::Lit(LcnfLit::Nat(2)),
813 make_simple_let(
814 3,
815 LcnfLetValue::Lit(LcnfLit::Nat(3)),
816 LcnfExpr::Return(LcnfArg::Var(make_var(3))),
817 ),
818 ),
819 );
820 let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
821 optimizer.eliminate_dead_joins(&mut expr);
822 assert!(matches!(& expr, LcnfExpr::Let { id, .. } if * id == make_var(3)));
823 }
824 #[test]
825 pub(super) fn test_unreachable_count() {
826 let expr = LcnfExpr::Unreachable;
827 assert_eq!(count_instructions(&expr), 1);
828 }
829 #[test]
830 pub(super) fn test_tail_call_count() {
831 let expr = LcnfExpr::TailCall(LcnfArg::Var(make_var(0)), vec![LcnfArg::Var(make_var(1))]);
832 assert_eq!(count_instructions(&expr), 1);
833 }
834 #[test]
835 pub(super) fn test_find_small_joins() {
836 let expr = make_simple_let(
837 1,
838 LcnfLetValue::Lit(LcnfLit::Nat(42)),
839 make_simple_let(
840 2,
841 LcnfLetValue::FVar(make_var(1)),
842 LcnfExpr::Return(LcnfArg::Var(make_var(2))),
843 ),
844 );
845 let optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
846 let joins = optimizer.find_small_joins(&expr);
847 assert!(joins.contains_key(&make_var(1)));
848 assert!(joins.contains_key(&make_var(2)));
849 }
850 #[test]
851 pub(super) fn test_inline_small_joins() {
852 let mut expr = make_simple_let(
853 1,
854 LcnfLetValue::Lit(LcnfLit::Nat(42)),
855 make_simple_let(
856 2,
857 LcnfLetValue::FVar(make_var(1)),
858 LcnfExpr::Return(LcnfArg::Var(make_var(2))),
859 ),
860 );
861 let mut optimizer = JoinPointOptimizer::new(JoinPointConfig::default());
862 optimizer.inline_small_joins(&mut expr);
863 }
864 #[test]
865 pub(super) fn test_case_instruction_count() {
866 let expr = LcnfExpr::Case {
867 scrutinee: make_var(0),
868 scrutinee_ty: LcnfType::Nat,
869 alts: vec![
870 LcnfAlt {
871 ctor_name: "A".to_string(),
872 ctor_tag: 0,
873 params: vec![],
874 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))),
875 },
876 LcnfAlt {
877 ctor_name: "B".to_string(),
878 ctor_tag: 1,
879 params: vec![],
880 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(2))),
881 },
882 ],
883 default: None,
884 };
885 assert_eq!(count_instructions(&expr), 3);
886 }
887 #[test]
888 pub(super) fn test_full_pipeline_with_case() {
889 let body = LcnfExpr::Case {
890 scrutinee: make_var(0),
891 scrutinee_ty: LcnfType::Nat,
892 alts: vec![LcnfAlt {
893 ctor_name: "Zero".to_string(),
894 ctor_tag: 0,
895 params: vec![],
896 body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
897 }],
898 default: Some(Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))))),
899 };
900 let decl = make_simple_decl("test_case", body);
901 let mut module = LcnfModule {
902 fun_decls: vec![decl],
903 extern_decls: vec![],
904 name: "test_mod".to_string(),
905 metadata: LcnfModuleMetadata::default(),
906 };
907 let config = JoinPointConfig::default();
908 optimize_join_points(&mut module, &config);
909 assert_eq!(module.fun_decls.len(), 1);
910 }
911 #[test]
912 pub(super) fn test_find_self_recursive_tail_calls() {
913 let mut var_map = HashMap::new();
914 var_map.insert(make_var(10), "my_fn".to_string());
915 let expr = LcnfExpr::TailCall(LcnfArg::Var(make_var(10)), vec![LcnfArg::Var(make_var(0))]);
916 let calls = find_self_recursive_tail_calls(&expr, "my_fn", &var_map);
917 assert_eq!(calls.len(), 1);
918 assert_eq!(calls[0], make_var(10));
919 }
920 #[test]
921 pub(super) fn test_collect_callees_empty() {
922 let expr = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
923 let known: HashSet<&str> = HashSet::new();
924 let mut callees = HashSet::new();
925 collect_callees(&expr, &known, &mut callees);
926 assert!(callees.is_empty());
927 }
928 #[test]
929 pub(super) fn test_multiple_iterations() {
930 let body = make_simple_let(
931 1,
932 LcnfLetValue::Lit(LcnfLit::Nat(1)),
933 make_simple_let(
934 2,
935 LcnfLetValue::Lit(LcnfLit::Nat(2)),
936 make_simple_let(
937 3,
938 LcnfLetValue::App(LcnfArg::Var(make_var(10)), vec![LcnfArg::Var(make_var(2))]),
939 LcnfExpr::Return(LcnfArg::Var(make_var(3))),
940 ),
941 ),
942 );
943 let mut decl = make_simple_decl("multi_iter", body);
944 let mut optimizer = JoinPointOptimizer::new(JoinPointConfig {
945 max_iterations: 10,
946 ..JoinPointConfig::default()
947 });
948 optimizer.optimize_decl(&mut decl);
949 assert!(optimizer.stats.iterations > 0);
950 }
951}
952#[cfg(test)]
953mod tests_ojoin_extra {
954 use super::*;
955 #[test]
956 pub(super) fn test_ojoin_config() {
957 let mut cfg = OJoinConfig::new();
958 cfg.set("mode", "release");
959 cfg.set("verbose", "true");
960 assert_eq!(cfg.get("mode"), Some("release"));
961 assert!(cfg.get_bool("verbose"));
962 assert!(cfg.get_int("mode").is_none());
963 assert_eq!(cfg.len(), 2);
964 }
965 #[test]
966 pub(super) fn test_ojoin_source_buffer() {
967 let mut buf = OJoinSourceBuffer::new();
968 buf.push_line("fn main() {");
969 buf.indent();
970 buf.push_line("println!(\"hello\");");
971 buf.dedent();
972 buf.push_line("}");
973 assert!(buf.as_str().contains("fn main()"));
974 assert!(buf.as_str().contains(" println!"));
975 assert_eq!(buf.line_count(), 3);
976 buf.reset();
977 assert!(buf.is_empty());
978 }
979 #[test]
980 pub(super) fn test_ojoin_name_scope() {
981 let mut scope = OJoinNameScope::new();
982 assert!(scope.declare("x"));
983 assert!(!scope.declare("x"));
984 assert!(scope.is_declared("x"));
985 let scope = scope.push_scope();
986 assert_eq!(scope.depth(), 1);
987 let mut scope = scope.pop_scope();
988 assert_eq!(scope.depth(), 0);
989 scope.declare("y");
990 assert_eq!(scope.len(), 2);
991 }
992 #[test]
993 pub(super) fn test_ojoin_diag_collector() {
994 let mut col = OJoinDiagCollector::new();
995 col.emit(OJoinDiagMsg::warning("pass_a", "slow"));
996 col.emit(OJoinDiagMsg::error("pass_b", "fatal"));
997 assert!(col.has_errors());
998 assert_eq!(col.errors().len(), 1);
999 assert_eq!(col.warnings().len(), 1);
1000 col.clear();
1001 assert!(col.is_empty());
1002 }
1003 #[test]
1004 pub(super) fn test_ojoin_id_gen() {
1005 let mut gen = OJoinIdGen::new();
1006 assert_eq!(gen.next_id(), 0);
1007 assert_eq!(gen.next_id(), 1);
1008 gen.skip(10);
1009 assert_eq!(gen.next_id(), 12);
1010 gen.reset();
1011 assert_eq!(gen.peek_next(), 0);
1012 }
1013 #[test]
1014 pub(super) fn test_ojoin_incr_key() {
1015 let k1 = OJoinIncrKey::new(100, 200);
1016 let k2 = OJoinIncrKey::new(100, 200);
1017 let k3 = OJoinIncrKey::new(999, 200);
1018 assert!(k1.matches(&k2));
1019 assert!(!k1.matches(&k3));
1020 }
1021 #[test]
1022 pub(super) fn test_ojoin_profiler() {
1023 let mut p = OJoinProfiler::new();
1024 p.record(OJoinPassTiming::new("pass_a", 1000, 50, 200, 100));
1025 p.record(OJoinPassTiming::new("pass_b", 500, 30, 100, 200));
1026 assert_eq!(p.total_elapsed_us(), 1500);
1027 assert_eq!(
1028 p.slowest_pass()
1029 .expect("slowest pass should exist")
1030 .pass_name,
1031 "pass_a"
1032 );
1033 assert_eq!(p.profitable_passes().len(), 1);
1034 }
1035 #[test]
1036 pub(super) fn test_ojoin_event_log() {
1037 let mut log = OJoinEventLog::new(3);
1038 log.push("event1");
1039 log.push("event2");
1040 log.push("event3");
1041 assert_eq!(log.len(), 3);
1042 log.push("event4");
1043 assert_eq!(log.len(), 3);
1044 assert_eq!(
1045 log.iter()
1046 .next()
1047 .expect("iterator should have next element"),
1048 "event2"
1049 );
1050 }
1051 #[test]
1052 pub(super) fn test_ojoin_version() {
1053 let v = OJoinVersion::new(1, 2, 3).with_pre("alpha");
1054 assert!(!v.is_stable());
1055 assert_eq!(format!("{}", v), "1.2.3-alpha");
1056 let stable = OJoinVersion::new(2, 0, 0);
1057 assert!(stable.is_stable());
1058 assert!(stable.is_compatible_with(&OJoinVersion::new(2, 0, 0)));
1059 assert!(!stable.is_compatible_with(&OJoinVersion::new(3, 0, 0)));
1060 }
1061 #[test]
1062 pub(super) fn test_ojoin_features() {
1063 let mut f = OJoinFeatures::new();
1064 f.enable("sse2");
1065 f.enable("avx2");
1066 assert!(f.is_enabled("sse2"));
1067 assert!(!f.is_enabled("avx512"));
1068 f.disable("avx2");
1069 assert!(!f.is_enabled("avx2"));
1070 let mut g = OJoinFeatures::new();
1071 g.enable("sse2");
1072 g.enable("neon");
1073 let union = f.union(&g);
1074 assert!(union.is_enabled("sse2") && union.is_enabled("neon"));
1075 let inter = f.intersection(&g);
1076 assert!(inter.is_enabled("sse2"));
1077 }
1078 #[test]
1079 pub(super) fn test_ojoin_emit_stats() {
1080 let mut s = OJoinEmitStats::new();
1081 s.bytes_emitted = 50_000;
1082 s.items_emitted = 500;
1083 s.elapsed_ms = 100;
1084 assert!(s.is_clean());
1085 assert!((s.throughput_bps() - 500_000.0).abs() < 1.0);
1086 let disp = format!("{}", s);
1087 assert!(disp.contains("bytes=50000"));
1088 }
1089}
1090#[cfg(test)]
1091mod OJ_infra_tests {
1092 use super::*;
1093 #[test]
1094 pub(super) fn test_pass_config() {
1095 let config = OJPassConfig::new("test_pass", OJPassPhase::Transformation);
1096 assert!(config.enabled);
1097 assert!(config.phase.is_modifying());
1098 assert_eq!(config.phase.name(), "transformation");
1099 }
1100 #[test]
1101 pub(super) fn test_pass_stats() {
1102 let mut stats = OJPassStats::new();
1103 stats.record_run(10, 100, 3);
1104 stats.record_run(20, 200, 5);
1105 assert_eq!(stats.total_runs, 2);
1106 assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
1107 assert!((stats.success_rate() - 1.0).abs() < 0.01);
1108 let s = stats.format_summary();
1109 assert!(s.contains("Runs: 2/2"));
1110 }
1111 #[test]
1112 pub(super) fn test_pass_registry() {
1113 let mut reg = OJPassRegistry::new();
1114 reg.register(OJPassConfig::new("pass_a", OJPassPhase::Analysis));
1115 reg.register(OJPassConfig::new("pass_b", OJPassPhase::Transformation).disabled());
1116 assert_eq!(reg.total_passes(), 2);
1117 assert_eq!(reg.enabled_count(), 1);
1118 reg.update_stats("pass_a", 5, 50, 2);
1119 let stats = reg.get_stats("pass_a").expect("stats should exist");
1120 assert_eq!(stats.total_changes, 5);
1121 }
1122 #[test]
1123 pub(super) fn test_analysis_cache() {
1124 let mut cache = OJAnalysisCache::new(10);
1125 cache.insert("key1".to_string(), vec![1, 2, 3]);
1126 assert!(cache.get("key1").is_some());
1127 assert!(cache.get("key2").is_none());
1128 assert!((cache.hit_rate() - 0.5).abs() < 0.01);
1129 cache.invalidate("key1");
1130 assert!(!cache.entries["key1"].valid);
1131 assert_eq!(cache.size(), 1);
1132 }
1133 #[test]
1134 pub(super) fn test_worklist() {
1135 let mut wl = OJWorklist::new();
1136 assert!(wl.push(1));
1137 assert!(wl.push(2));
1138 assert!(!wl.push(1));
1139 assert_eq!(wl.len(), 2);
1140 assert_eq!(wl.pop(), Some(1));
1141 assert!(!wl.contains(1));
1142 assert!(wl.contains(2));
1143 }
1144 #[test]
1145 pub(super) fn test_dominator_tree() {
1146 let mut dt = OJDominatorTree::new(5);
1147 dt.set_idom(1, 0);
1148 dt.set_idom(2, 0);
1149 dt.set_idom(3, 1);
1150 assert!(dt.dominates(0, 3));
1151 assert!(dt.dominates(1, 3));
1152 assert!(!dt.dominates(2, 3));
1153 assert!(dt.dominates(3, 3));
1154 }
1155 #[test]
1156 pub(super) fn test_liveness() {
1157 let mut liveness = OJLivenessInfo::new(3);
1158 liveness.add_def(0, 1);
1159 liveness.add_use(1, 1);
1160 assert!(liveness.defs[0].contains(&1));
1161 assert!(liveness.uses[1].contains(&1));
1162 }
1163 #[test]
1164 pub(super) fn test_constant_folding() {
1165 assert_eq!(OJConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
1166 assert_eq!(OJConstantFoldingHelper::fold_div_i64(10, 0), None);
1167 assert_eq!(OJConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
1168 assert_eq!(
1169 OJConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
1170 0b1000
1171 );
1172 assert_eq!(OJConstantFoldingHelper::fold_bitnot_i64(0), -1);
1173 }
1174 #[test]
1175 pub(super) fn test_dep_graph() {
1176 let mut g = OJDepGraph::new();
1177 g.add_dep(1, 2);
1178 g.add_dep(2, 3);
1179 g.add_dep(1, 3);
1180 assert_eq!(g.dependencies_of(2), vec![1]);
1181 let topo = g.topological_sort();
1182 assert_eq!(topo.len(), 3);
1183 assert!(!g.has_cycle());
1184 let pos: std::collections::HashMap<u32, usize> =
1185 topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
1186 assert!(pos[&1] < pos[&2]);
1187 assert!(pos[&1] < pos[&3]);
1188 assert!(pos[&2] < pos[&3]);
1189 }
1190}