1use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, FnDef, Stmt, StrPart, TopLevel};
10
11pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
18 let graph = build_call_graph(items);
19 let user_fns = user_fn_names(items);
20 recursive_sccs(&graph, &user_fns)
21 .into_iter()
22 .map(|scc| scc.into_iter().collect())
23 .collect()
24}
25
26pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
29 let graph = build_call_graph(items);
30 let user_fns = user_fn_names(items);
31 let mut recursive = HashSet::new();
32 for scc in recursive_sccs(&graph, &user_fns) {
33 for name in scc {
34 recursive.insert(name);
35 }
36 }
37 recursive
38}
39
40pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
42 let graph = build_call_graph(items);
43 let mut out = HashMap::new();
44 for item in items {
45 if let TopLevel::FnDef(fd) = item {
46 let mut callees = graph
47 .get(&fd.name)
48 .cloned()
49 .unwrap_or_default()
50 .into_iter()
51 .collect::<Vec<_>>();
52 callees.sort();
53 out.insert(fd.name.clone(), callees);
54 }
55 }
56 out
57}
58
59pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
69 let graph = build_call_graph(items);
70 let user_fns = user_fn_names(items);
71 let sccs = recursive_sccs(&graph, &user_fns);
72 let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
73 for scc in sccs {
74 let members: HashSet<String> = scc.iter().cloned().collect();
75 for name in scc {
76 scc_members.insert(name, members.clone());
77 }
78 }
79
80 let mut out = HashMap::new();
81 for item in items {
82 if let TopLevel::FnDef(fd) = item {
83 let mut count = 0usize;
84 if let Some(members) = scc_members.get(&fd.name) {
85 count_recursive_calls_body(&fd.body, members, &mut count);
86 }
87 out.insert(fd.name.clone(), count);
88 }
89 }
90 out
91}
92
93pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
96 let graph = build_call_graph(items);
97 let user_fns = user_fn_names(items);
98 let mut sccs = recursive_sccs(&graph, &user_fns);
99 for scc in &mut sccs {
100 scc.sort();
101 }
102 sccs.sort_by(|a, b| a.first().cmp(&b.first()));
103
104 let mut out = HashMap::new();
105 for (idx, scc) in sccs.into_iter().enumerate() {
106 let id = idx + 1;
107 for name in scc {
108 out.insert(name, id);
109 }
110 }
111 out
112}
113
114pub fn ordered_fn_components<'a>(fns: &[&'a FnDef]) -> Vec<Vec<&'a FnDef>> {
121 if fns.is_empty() {
122 return vec![];
123 }
124
125 let fn_map: HashMap<String, &FnDef> = fns.iter().map(|fd| (fd.name.clone(), *fd)).collect();
126 let names: Vec<String> = fn_map.keys().cloned().collect();
127 let name_set: HashSet<String> = names.iter().cloned().collect();
128
129 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
130 for fd in fns {
131 let mut deps = HashSet::new();
132 collect_codegen_deps_body(&fd.body, &name_set, &mut deps);
133 let mut sorted = deps.into_iter().collect::<Vec<_>>();
134 sorted.sort();
135 graph.insert(fd.name.clone(), sorted);
136 }
137
138 let sccs = tarjan_all_sccs(&names, &graph);
139 let mut comp_of: HashMap<String, usize> = HashMap::new();
140 for (idx, comp) in sccs.iter().enumerate() {
141 for name in comp {
142 comp_of.insert(name.clone(), idx);
143 }
144 }
145
146 let mut comp_graph: HashMap<usize, HashSet<usize>> = HashMap::new();
147 for (caller, deps) in &graph {
148 let from = comp_of[caller];
149 for callee in deps {
150 let to = comp_of[callee];
151 if from != to {
152 comp_graph.entry(from).or_default().insert(to);
153 }
154 }
155 }
156
157 let comp_order = topo_components(&sccs, &comp_graph);
158 comp_order
159 .into_iter()
160 .map(|idx| {
161 let mut group: Vec<&FnDef> = sccs[idx]
162 .iter()
163 .filter_map(|name| fn_map.get(name).copied())
164 .collect();
165 group.sort_by(|a, b| a.name.cmp(&b.name));
166 group
167 })
168 .collect()
169}
170
171fn collect_codegen_deps_body(body: &FnBody, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
172 match body {
173 FnBody::Expr(e) => collect_codegen_deps_expr(e, fn_names, out),
174 FnBody::Block(stmts) => {
175 for s in stmts {
176 match s {
177 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
178 collect_codegen_deps_expr(e, fn_names, out)
179 }
180 }
181 }
182 }
183 }
184}
185
186fn collect_codegen_deps_expr(expr: &Expr, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
187 match expr {
188 Expr::FnCall(func, args) => {
189 if let Expr::Ident(name) = func.as_ref() {
190 if fn_names.contains(name) {
191 out.insert(name.clone());
192 }
193 }
194 if let Some(qname) = expr_to_dotted_name(func.as_ref()) {
195 if fn_names.contains(&qname) {
196 out.insert(qname);
197 }
198 }
199
200 collect_codegen_deps_expr(func, fn_names, out);
201 for arg in args {
202 if let Expr::Ident(name) = arg {
204 if fn_names.contains(name) {
205 out.insert(name.clone());
206 }
207 }
208 if let Some(qname) = expr_to_dotted_name(arg) {
209 if fn_names.contains(&qname) {
210 out.insert(qname);
211 }
212 }
213 collect_codegen_deps_expr(arg, fn_names, out);
214 }
215 }
216 Expr::TailCall(boxed) => {
217 if fn_names.contains(&boxed.0) {
218 out.insert(boxed.0.clone());
219 }
220 for arg in &boxed.1 {
221 collect_codegen_deps_expr(arg, fn_names, out);
222 }
223 }
224 Expr::Attr(obj, _) => collect_codegen_deps_expr(obj, fn_names, out),
225 Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
226 collect_codegen_deps_expr(l, fn_names, out);
227 collect_codegen_deps_expr(r, fn_names, out);
228 }
229 Expr::Match { subject, arms, .. } => {
230 collect_codegen_deps_expr(subject, fn_names, out);
231 for arm in arms {
232 collect_codegen_deps_expr(&arm.body, fn_names, out);
233 }
234 }
235 Expr::List(items) | Expr::Tuple(items) => {
236 for it in items {
237 collect_codegen_deps_expr(it, fn_names, out);
238 }
239 }
240 Expr::MapLiteral(entries) => {
241 for (k, v) in entries {
242 collect_codegen_deps_expr(k, fn_names, out);
243 collect_codegen_deps_expr(v, fn_names, out);
244 }
245 }
246 Expr::Constructor(_, maybe) => {
247 if let Some(inner) = maybe {
248 collect_codegen_deps_expr(inner, fn_names, out);
249 }
250 }
251 Expr::ErrorProp(inner) => collect_codegen_deps_expr(inner, fn_names, out),
252 Expr::InterpolatedStr(parts) => {
253 for p in parts {
254 if let StrPart::Parsed(e) = p {
255 collect_codegen_deps_expr(e, fn_names, out);
256 }
257 }
258 }
259 Expr::RecordCreate { fields, .. } => {
260 for (_, e) in fields {
261 collect_codegen_deps_expr(e, fn_names, out);
262 }
263 }
264 Expr::RecordUpdate { base, updates, .. } => {
265 collect_codegen_deps_expr(base, fn_names, out);
266 for (_, e) in updates {
267 collect_codegen_deps_expr(e, fn_names, out);
268 }
269 }
270 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => {}
271 }
272}
273
274fn expr_to_dotted_name(expr: &Expr) -> Option<String> {
275 match expr {
276 Expr::Ident(name) => Some(name.clone()),
277 Expr::Attr(obj, field) => {
278 let head = expr_to_dotted_name(obj)?;
279 Some(format!("{}.{}", head, field))
280 }
281 _ => None,
282 }
283}
284
285fn tarjan_all_sccs(nodes: &[String], graph: &HashMap<String, Vec<String>>) -> Vec<Vec<String>> {
286 struct TarjanAllState {
287 index: usize,
288 indices: HashMap<String, usize>,
289 lowlink: HashMap<String, usize>,
290 stack: Vec<String>,
291 on_stack: HashSet<String>,
292 components: Vec<Vec<String>>,
293 }
294
295 fn strong_connect(v: String, graph: &HashMap<String, Vec<String>>, st: &mut TarjanAllState) {
296 st.indices.insert(v.clone(), st.index);
297 st.lowlink.insert(v.clone(), st.index);
298 st.index += 1;
299 st.stack.push(v.clone());
300 st.on_stack.insert(v.clone());
301
302 if let Some(neighbors) = graph.get(&v) {
303 for w in neighbors {
304 if !st.indices.contains_key(w) {
305 strong_connect(w.clone(), graph, st);
306 let low_v = st.lowlink[&v];
307 let low_w = st.lowlink[w];
308 st.lowlink.insert(v.clone(), low_v.min(low_w));
309 } else if st.on_stack.contains(w) {
310 let low_v = st.lowlink[&v];
311 let idx_w = st.indices[w];
312 st.lowlink.insert(v.clone(), low_v.min(idx_w));
313 }
314 }
315 }
316
317 if st.lowlink[&v] == st.indices[&v] {
318 let mut comp = Vec::new();
319 while let Some(w) = st.stack.pop() {
320 st.on_stack.remove(&w);
321 let done = w == v;
322 comp.push(w);
323 if done {
324 break;
325 }
326 }
327 comp.sort();
328 st.components.push(comp);
329 }
330 }
331
332 let mut sorted_nodes = nodes.to_vec();
333 sorted_nodes.sort();
334 let mut st = TarjanAllState {
335 index: 0,
336 indices: HashMap::new(),
337 lowlink: HashMap::new(),
338 stack: Vec::new(),
339 on_stack: HashSet::new(),
340 components: Vec::new(),
341 };
342 for node in sorted_nodes {
343 if !st.indices.contains_key(&node) {
344 strong_connect(node, graph, &mut st);
345 }
346 }
347 st.components.sort_by(|a, b| a[0].cmp(&b[0]));
348 st.components
349}
350
351fn topo_components(
352 sccs: &[Vec<String>],
353 comp_graph: &HashMap<usize, HashSet<usize>>,
354) -> Vec<usize> {
355 let mut ids: Vec<usize> = (0..sccs.len()).collect();
356 ids.sort_by(|a, b| sccs[*a][0].cmp(&sccs[*b][0]));
357
358 let mut visited = HashSet::new();
359 let mut order = Vec::new();
360 for id in ids {
361 if !visited.contains(&id) {
362 topo_components_dfs(id, sccs, comp_graph, &mut visited, &mut order);
363 }
364 }
365 order
366}
367
368fn topo_components_dfs(
369 id: usize,
370 sccs: &[Vec<String>],
371 comp_graph: &HashMap<usize, HashSet<usize>>,
372 visited: &mut HashSet<usize>,
373 order: &mut Vec<usize>,
374) {
375 visited.insert(id);
376 let mut neighbors: Vec<usize> = comp_graph
377 .get(&id)
378 .map(|s| s.iter().copied().collect())
379 .unwrap_or_default();
380 neighbors.sort_by(|a, b| sccs[*a][0].cmp(&sccs[*b][0]));
381 for n in neighbors {
382 if !visited.contains(&n) {
383 topo_components_dfs(n, sccs, comp_graph, visited, order);
384 }
385 }
386 order.push(id);
387}
388
389fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
394 let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
395 for item in items {
396 if let TopLevel::FnDef(fd) = item {
397 let mut callees = HashSet::new();
398 collect_callees_body(&fd.body, &mut callees);
399 graph.insert(fd.name.clone(), callees);
400 }
401 }
402 graph
403}
404
405fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
406 items
407 .iter()
408 .filter_map(|item| {
409 if let TopLevel::FnDef(fd) = item {
410 Some(fd.name.clone())
411 } else {
412 None
413 }
414 })
415 .collect()
416}
417
418fn recursive_sccs(
419 graph: &HashMap<String, HashSet<String>>,
420 user_fns: &HashSet<String>,
421) -> Vec<Vec<String>> {
422 tarjan_scc(graph, user_fns)
423 .into_iter()
424 .filter(|scc| is_recursive_scc(scc, graph))
425 .collect()
426}
427
428fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
429 if scc.len() > 1 {
430 return true;
431 }
432 if let Some(name) = scc.first() {
433 return graph
434 .get(name)
435 .is_some_and(|callees| callees.contains(name));
436 }
437 false
438}
439
440pub(crate) fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
441 match body {
442 FnBody::Expr(e) => collect_callees_expr(e, callees),
443 FnBody::Block(stmts) => {
444 for s in stmts {
445 collect_callees_stmt(s, callees);
446 }
447 }
448 }
449}
450
451fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
452 match body {
453 FnBody::Expr(e) => count_recursive_calls_expr(e, recursive, out),
454 FnBody::Block(stmts) => {
455 for s in stmts {
456 count_recursive_calls_stmt(s, recursive, out);
457 }
458 }
459 }
460}
461
462fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
463 match stmt {
464 Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
465 }
466}
467
468fn count_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>, out: &mut usize) {
469 match expr {
470 Expr::FnCall(func, args) => {
471 match func.as_ref() {
472 Expr::Ident(name) => {
473 if recursive.contains(name) {
474 *out += 1;
475 }
476 }
477 Expr::Attr(obj, member) => {
478 if let Expr::Ident(ns) = obj.as_ref() {
479 let q = format!("{}.{}", ns, member);
480 if recursive.contains(&q) {
481 *out += 1;
482 }
483 } else {
484 count_recursive_calls_expr(obj, recursive, out);
485 }
486 }
487 other => count_recursive_calls_expr(other, recursive, out),
488 }
489 for arg in args {
490 count_recursive_calls_expr(arg, recursive, out);
491 }
492 }
493 Expr::TailCall(boxed) => {
494 if recursive.contains(&boxed.0) {
495 *out += 1;
496 }
497 for arg in &boxed.1 {
498 count_recursive_calls_expr(arg, recursive, out);
499 }
500 }
501 Expr::Literal(_) | Expr::Resolved(_) | Expr::Ident(_) => {}
502 Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
503 Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
504 count_recursive_calls_expr(l, recursive, out);
505 count_recursive_calls_expr(r, recursive, out);
506 }
507 Expr::Match {
508 subject: scrutinee,
509 arms,
510 ..
511 } => {
512 count_recursive_calls_expr(scrutinee, recursive, out);
513 for arm in arms {
514 count_recursive_calls_expr(&arm.body, recursive, out);
515 }
516 }
517 Expr::List(elems) | Expr::Tuple(elems) => {
518 for e in elems {
519 count_recursive_calls_expr(e, recursive, out);
520 }
521 }
522 Expr::MapLiteral(entries) => {
523 for (k, v) in entries {
524 count_recursive_calls_expr(k, recursive, out);
525 count_recursive_calls_expr(v, recursive, out);
526 }
527 }
528 Expr::Constructor(_, arg) => {
529 if let Some(a) = arg {
530 count_recursive_calls_expr(a, recursive, out);
531 }
532 }
533 Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
534 Expr::InterpolatedStr(parts) => {
535 for part in parts {
536 if let crate::ast::StrPart::Parsed(expr) = part {
537 count_recursive_calls_expr(expr, recursive, out);
538 }
539 }
540 }
541 Expr::RecordCreate { fields, .. } => {
542 for (_, e) in fields {
543 count_recursive_calls_expr(e, recursive, out);
544 }
545 }
546 Expr::RecordUpdate { base, updates, .. } => {
547 count_recursive_calls_expr(base, recursive, out);
548 for (_, e) in updates {
549 count_recursive_calls_expr(e, recursive, out);
550 }
551 }
552 }
553}
554
555fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
556 match stmt {
557 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
558 collect_callees_expr(e, callees);
559 }
560 }
561}
562
563fn collect_callees_expr(expr: &Expr, callees: &mut HashSet<String>) {
564 match expr {
565 Expr::FnCall(func, args) => {
566 match func.as_ref() {
568 Expr::Ident(name) => {
569 callees.insert(name.clone());
570 }
571 Expr::Attr(obj, member) => {
572 if let Expr::Ident(ns) = obj.as_ref() {
573 callees.insert(format!("{}.{}", ns, member));
574 }
575 }
576 _ => collect_callees_expr(func, callees),
577 }
578 for arg in args {
579 collect_callees_expr(arg, callees);
580 }
581 }
582 Expr::Literal(_) | Expr::Resolved(_) => {}
583 Expr::Ident(_) => {}
584 Expr::Attr(obj, _) => collect_callees_expr(obj, callees),
585 Expr::BinOp(_, l, r) => {
586 collect_callees_expr(l, callees);
587 collect_callees_expr(r, callees);
588 }
589 Expr::Pipe(l, r) => {
590 collect_callees_expr(l, callees);
591 collect_callees_expr(r, callees);
592 }
593 Expr::Match {
594 subject: scrutinee,
595 arms,
596 ..
597 } => {
598 collect_callees_expr(scrutinee, callees);
599 for arm in arms {
600 collect_callees_expr(&arm.body, callees);
601 }
602 }
603 Expr::List(elems) => {
604 for e in elems {
605 collect_callees_expr(e, callees);
606 }
607 }
608 Expr::Tuple(items) => {
609 for item in items {
610 collect_callees_expr(item, callees);
611 }
612 }
613 Expr::MapLiteral(entries) => {
614 for (key, value) in entries {
615 collect_callees_expr(key, callees);
616 collect_callees_expr(value, callees);
617 }
618 }
619 Expr::Constructor(_, arg) => {
620 if let Some(a) = arg {
621 collect_callees_expr(a, callees);
622 }
623 }
624 Expr::ErrorProp(inner) => collect_callees_expr(inner, callees),
625 Expr::InterpolatedStr(parts) => {
626 for part in parts {
627 if let crate::ast::StrPart::Parsed(expr) = part {
628 collect_callees_expr(expr, callees);
629 }
630 }
631 }
632 Expr::RecordCreate { fields, .. } => {
633 for (_, e) in fields {
634 collect_callees_expr(e, callees);
635 }
636 }
637 Expr::RecordUpdate { base, updates, .. } => {
638 collect_callees_expr(base, callees);
639 for (_, e) in updates {
640 collect_callees_expr(e, callees);
641 }
642 }
643 Expr::TailCall(boxed) => {
644 callees.insert(boxed.0.clone());
645 for arg in &boxed.1 {
646 collect_callees_expr(arg, callees);
647 }
648 }
649 }
650}
651
652struct TarjanState {
657 index_counter: usize,
658 stack: Vec<String>,
659 on_stack: HashSet<String>,
660 indices: HashMap<String, usize>,
661 lowlinks: HashMap<String, usize>,
662 sccs: Vec<Vec<String>>,
663}
664
665fn tarjan_scc(
666 graph: &HashMap<String, HashSet<String>>,
667 nodes: &HashSet<String>,
668) -> Vec<Vec<String>> {
669 let mut state = TarjanState {
670 index_counter: 0,
671 stack: Vec::new(),
672 on_stack: HashSet::new(),
673 indices: HashMap::new(),
674 lowlinks: HashMap::new(),
675 sccs: Vec::new(),
676 };
677
678 for node in nodes {
679 if !state.indices.contains_key(node) {
680 strongconnect(node, graph, &mut state);
681 }
682 }
683
684 state.sccs
685}
686
687fn strongconnect(v: &str, graph: &HashMap<String, HashSet<String>>, state: &mut TarjanState) {
688 let idx = state.index_counter;
689 state.index_counter += 1;
690 state.indices.insert(v.to_string(), idx);
691 state.lowlinks.insert(v.to_string(), idx);
692 state.stack.push(v.to_string());
693 state.on_stack.insert(v.to_string());
694
695 if let Some(callees) = graph.get(v) {
696 for w in callees {
697 if !state.indices.contains_key(w) {
698 if graph.contains_key(w) {
700 strongconnect(w, graph, state);
701 let w_low = state.lowlinks[w];
702 let v_low = state.lowlinks[v];
703 if w_low < v_low {
704 state.lowlinks.insert(v.to_string(), w_low);
705 }
706 }
707 } else if state.on_stack.contains(w) {
708 let w_idx = state.indices[w];
709 let v_low = state.lowlinks[v];
710 if w_idx < v_low {
711 state.lowlinks.insert(v.to_string(), w_idx);
712 }
713 }
714 }
715 }
716
717 if state.lowlinks[v] == state.indices[v] {
719 let mut scc = Vec::new();
720 loop {
721 let w = state.stack.pop().unwrap();
722 state.on_stack.remove(&w);
723 scc.push(w.clone());
724 if w == v {
725 break;
726 }
727 }
728 state.sccs.push(scc);
729 }
730}
731
732#[cfg(test)]
733mod tests {
734 use super::*;
735
736 #[test]
737 fn detects_self_recursion() {
738 let src = r#"
739fn fib(n: Int) -> Int
740 match n
741 0 -> 0
742 1 -> 1
743 _ -> fib(n - 1) + fib(n - 2)
744"#;
745 let items = parse(src);
746 let rec = find_recursive_fns(&items);
747 assert!(
748 rec.contains("fib"),
749 "fib should be recursive, got: {:?}",
750 rec
751 );
752 }
753
754 #[test]
755 fn non_recursive_fn() {
756 let src = "fn double(x: Int) -> Int\n = x + x\n";
757 let items = parse(src);
758 let rec = find_recursive_fns(&items);
759 assert!(
760 rec.is_empty(),
761 "double should not be recursive, got: {:?}",
762 rec
763 );
764 }
765
766 #[test]
767 fn mutual_recursion() {
768 let src = r#"
769fn isEven(n: Int) -> Bool
770 match n
771 0 -> true
772 _ -> isOdd(n - 1)
773
774fn isOdd(n: Int) -> Bool
775 match n
776 0 -> false
777 _ -> isEven(n - 1)
778"#;
779 let items = parse(src);
780 let rec = find_recursive_fns(&items);
781 assert!(rec.contains("isEven"), "isEven should be recursive");
782 assert!(rec.contains("isOdd"), "isOdd should be recursive");
783 }
784
785 #[test]
786 fn recursive_callsites_count_syntactic_occurrences() {
787 let src = r#"
788fn fib(n: Int) -> Int
789 match n
790 0 -> 0
791 1 -> 1
792 _ -> fib(n - 1) + fib(n - 2)
793"#;
794 let items = parse(src);
795 let counts = recursive_callsite_counts(&items);
796 assert_eq!(counts.get("fib").copied().unwrap_or(0), 2);
797 }
798
799 #[test]
800 fn recursive_callsites_are_scoped_to_scc() {
801 let src = r#"
802fn a(n: Int) -> Int
803 match n
804 0 -> 0
805 _ -> b(n - 1) + fib(n)
806
807fn b(n: Int) -> Int
808 match n
809 0 -> 0
810 _ -> a(n - 1)
811
812fn fib(n: Int) -> Int
813 match n
814 0 -> 0
815 1 -> 1
816 _ -> fib(n - 1) + fib(n - 2)
817"#;
818 let items = parse(src);
819 let counts = recursive_callsite_counts(&items);
820 assert_eq!(counts.get("a").copied().unwrap_or(0), 1);
821 assert_eq!(counts.get("b").copied().unwrap_or(0), 1);
822 assert_eq!(counts.get("fib").copied().unwrap_or(0), 2);
823 }
824
825 #[test]
826 fn recursive_scc_ids_are_deterministic_by_group_name() {
827 let src = r#"
828fn z(n: Int) -> Int
829 match n
830 0 -> 0
831 _ -> z(n - 1)
832
833fn a(n: Int) -> Int
834 match n
835 0 -> 0
836 _ -> b(n - 1)
837
838fn b(n: Int) -> Int
839 match n
840 0 -> 0
841 _ -> a(n - 1)
842"#;
843 let items = parse(src);
844 let ids = recursive_scc_ids(&items);
845 assert_eq!(ids.get("a").copied().unwrap_or(0), 1);
847 assert_eq!(ids.get("b").copied().unwrap_or(0), 1);
848 assert_eq!(ids.get("z").copied().unwrap_or(0), 2);
849 }
850
851 fn parse(src: &str) -> Vec<TopLevel> {
852 let mut lexer = crate::lexer::Lexer::new(src);
853 let tokens = lexer.tokenize().expect("lex failed");
854 let mut parser = crate::parser::Parser::new(tokens);
855 parser.parse().expect("parse failed")
856 }
857}