Skip to main content

aver/
call_graph.rs

1/// Call-graph analysis and Tarjan's SCC algorithm.
2///
3/// Given a parsed program, builds a directed graph of function calls
4/// and finds strongly-connected components.  A function is *recursive*
5/// if it belongs to an SCC with a cycle (size > 1, or size 1 with a
6/// self-edge).
7use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, FnDef, Stmt, StrPart, TopLevel};
10
11// ---------------------------------------------------------------------------
12// Public API
13// ---------------------------------------------------------------------------
14
15/// Returns the SCC groups that contain cycles (self or mutual recursion).
16/// Each group is a `HashSet<String>` of function names in the SCC.
17pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
18    let graph = build_call_graph(items);
19    let user_fns = user_fn_names(items);
20    recursive_sccs(&graph, &user_fns)
21        .into_iter()
22        .map(|scc| scc.into_iter().collect())
23        .collect()
24}
25
26/// Returns the set of user-defined function names that are recursive
27/// (directly or mutually).
28pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
29    let graph = build_call_graph(items);
30    let user_fns = user_fn_names(items);
31    let mut recursive = HashSet::new();
32    for scc in recursive_sccs(&graph, &user_fns) {
33        for name in scc {
34            recursive.insert(name);
35        }
36    }
37    recursive
38}
39
40/// Direct call summary per user-defined function (unique + sorted).
41pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
42    let graph = build_call_graph(items);
43    let mut out = HashMap::new();
44    for item in items {
45        if let TopLevel::FnDef(fd) = item {
46            let mut callees = graph
47                .get(&fd.name)
48                .cloned()
49                .unwrap_or_default()
50                .into_iter()
51                .collect::<Vec<_>>();
52            callees.sort();
53            out.insert(fd.name.clone(), callees);
54        }
55    }
56    out
57}
58
59/// Count recursive callsites per user-defined function, scoped to caller SCC.
60///
61/// Callsite definition:
62/// - one syntactic `FnCall` or `TailCall` node in the function body,
63/// - whose callee is a user-defined function in the same recursive SCC
64///   as the caller.
65///
66/// This is a syntactic metric over AST nodes (not dynamic execution count,
67/// not CFG edges), so it stays stable across control-flow rewrites.
68pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
69    let graph = build_call_graph(items);
70    let user_fns = user_fn_names(items);
71    let sccs = recursive_sccs(&graph, &user_fns);
72    let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
73    for scc in sccs {
74        let members: HashSet<String> = scc.iter().cloned().collect();
75        for name in scc {
76            scc_members.insert(name, members.clone());
77        }
78    }
79
80    let mut out = HashMap::new();
81    for item in items {
82        if let TopLevel::FnDef(fd) = item {
83            let mut count = 0usize;
84            if let Some(members) = scc_members.get(&fd.name) {
85                count_recursive_calls_body(&fd.body, members, &mut count);
86            }
87            out.insert(fd.name.clone(), count);
88        }
89    }
90    out
91}
92
93/// Deterministic recursive SCC id per function (1-based).
94/// Non-recursive functions are absent from the returned map.
95pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
96    let graph = build_call_graph(items);
97    let user_fns = user_fn_names(items);
98    let mut sccs = recursive_sccs(&graph, &user_fns);
99    for scc in &mut sccs {
100        scc.sort();
101    }
102    sccs.sort_by(|a, b| a.first().cmp(&b.first()));
103
104    let mut out = HashMap::new();
105    for (idx, scc) in sccs.into_iter().enumerate() {
106        let id = idx + 1;
107        for name in scc {
108            out.insert(name, id);
109        }
110    }
111    out
112}
113
114/// Deterministic function emission order for codegen backends.
115///
116/// Returns SCC components in callee-before-caller topological order.
117/// Each inner vector is one SCC (single function or mutual-recursive group).
118/// Function references passed as call arguments (e.g. `List.fold(xs, init, f)`)
119/// are treated as dependencies for ordering.
120pub fn ordered_fn_components<'a>(fns: &[&'a FnDef]) -> Vec<Vec<&'a FnDef>> {
121    if fns.is_empty() {
122        return vec![];
123    }
124
125    let fn_map: HashMap<String, &FnDef> = fns.iter().map(|fd| (fd.name.clone(), *fd)).collect();
126    let names: Vec<String> = fn_map.keys().cloned().collect();
127    let name_set: HashSet<String> = names.iter().cloned().collect();
128
129    let mut graph: HashMap<String, Vec<String>> = HashMap::new();
130    for fd in fns {
131        let mut deps = HashSet::new();
132        collect_codegen_deps_body(&fd.body, &name_set, &mut deps);
133        let mut sorted = deps.into_iter().collect::<Vec<_>>();
134        sorted.sort();
135        graph.insert(fd.name.clone(), sorted);
136    }
137
138    let sccs = tarjan_all_sccs(&names, &graph);
139    let mut comp_of: HashMap<String, usize> = HashMap::new();
140    for (idx, comp) in sccs.iter().enumerate() {
141        for name in comp {
142            comp_of.insert(name.clone(), idx);
143        }
144    }
145
146    let mut comp_graph: HashMap<usize, HashSet<usize>> = HashMap::new();
147    for (caller, deps) in &graph {
148        let from = comp_of[caller];
149        for callee in deps {
150            let to = comp_of[callee];
151            if from != to {
152                comp_graph.entry(from).or_default().insert(to);
153            }
154        }
155    }
156
157    let comp_order = topo_components(&sccs, &comp_graph);
158    comp_order
159        .into_iter()
160        .map(|idx| {
161            let mut group: Vec<&FnDef> = sccs[idx]
162                .iter()
163                .filter_map(|name| fn_map.get(name).copied())
164                .collect();
165            group.sort_by(|a, b| a.name.cmp(&b.name));
166            group
167        })
168        .collect()
169}
170
171fn collect_codegen_deps_body(body: &FnBody, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
172    match body {
173        FnBody::Expr(e) => collect_codegen_deps_expr(e, fn_names, out),
174        FnBody::Block(stmts) => {
175            for s in stmts {
176                match s {
177                    Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
178                        collect_codegen_deps_expr(e, fn_names, out)
179                    }
180                }
181            }
182        }
183    }
184}
185
186fn collect_codegen_deps_expr(expr: &Expr, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
187    match expr {
188        Expr::FnCall(func, args) => {
189            if let Expr::Ident(name) = func.as_ref() {
190                if fn_names.contains(name) {
191                    out.insert(name.clone());
192                }
193            }
194            if let Some(qname) = expr_to_dotted_name(func.as_ref()) {
195                if fn_names.contains(&qname) {
196                    out.insert(qname);
197                }
198            }
199
200            collect_codegen_deps_expr(func, fn_names, out);
201            for arg in args {
202                // function-as-value dependency, e.g. List.fold(xs, init, f)
203                if let Expr::Ident(name) = arg {
204                    if fn_names.contains(name) {
205                        out.insert(name.clone());
206                    }
207                }
208                if let Some(qname) = expr_to_dotted_name(arg) {
209                    if fn_names.contains(&qname) {
210                        out.insert(qname);
211                    }
212                }
213                collect_codegen_deps_expr(arg, fn_names, out);
214            }
215        }
216        Expr::TailCall(boxed) => {
217            if fn_names.contains(&boxed.0) {
218                out.insert(boxed.0.clone());
219            }
220            for arg in &boxed.1 {
221                collect_codegen_deps_expr(arg, fn_names, out);
222            }
223        }
224        Expr::Attr(obj, _) => collect_codegen_deps_expr(obj, fn_names, out),
225        Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
226            collect_codegen_deps_expr(l, fn_names, out);
227            collect_codegen_deps_expr(r, fn_names, out);
228        }
229        Expr::Match { subject, arms, .. } => {
230            collect_codegen_deps_expr(subject, fn_names, out);
231            for arm in arms {
232                collect_codegen_deps_expr(&arm.body, fn_names, out);
233            }
234        }
235        Expr::List(items) | Expr::Tuple(items) => {
236            for it in items {
237                collect_codegen_deps_expr(it, fn_names, out);
238            }
239        }
240        Expr::MapLiteral(entries) => {
241            for (k, v) in entries {
242                collect_codegen_deps_expr(k, fn_names, out);
243                collect_codegen_deps_expr(v, fn_names, out);
244            }
245        }
246        Expr::Constructor(_, maybe) => {
247            if let Some(inner) = maybe {
248                collect_codegen_deps_expr(inner, fn_names, out);
249            }
250        }
251        Expr::ErrorProp(inner) => collect_codegen_deps_expr(inner, fn_names, out),
252        Expr::InterpolatedStr(parts) => {
253            for p in parts {
254                if let StrPart::Parsed(e) = p {
255                    collect_codegen_deps_expr(e, fn_names, out);
256                }
257            }
258        }
259        Expr::RecordCreate { fields, .. } => {
260            for (_, e) in fields {
261                collect_codegen_deps_expr(e, fn_names, out);
262            }
263        }
264        Expr::RecordUpdate { base, updates, .. } => {
265            collect_codegen_deps_expr(base, fn_names, out);
266            for (_, e) in updates {
267                collect_codegen_deps_expr(e, fn_names, out);
268            }
269        }
270        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => {}
271    }
272}
273
274fn expr_to_dotted_name(expr: &Expr) -> Option<String> {
275    match expr {
276        Expr::Ident(name) => Some(name.clone()),
277        Expr::Attr(obj, field) => {
278            let head = expr_to_dotted_name(obj)?;
279            Some(format!("{}.{}", head, field))
280        }
281        _ => None,
282    }
283}
284
285fn tarjan_all_sccs(nodes: &[String], graph: &HashMap<String, Vec<String>>) -> Vec<Vec<String>> {
286    struct TarjanAllState {
287        index: usize,
288        indices: HashMap<String, usize>,
289        lowlink: HashMap<String, usize>,
290        stack: Vec<String>,
291        on_stack: HashSet<String>,
292        components: Vec<Vec<String>>,
293    }
294
295    fn strong_connect(v: String, graph: &HashMap<String, Vec<String>>, st: &mut TarjanAllState) {
296        st.indices.insert(v.clone(), st.index);
297        st.lowlink.insert(v.clone(), st.index);
298        st.index += 1;
299        st.stack.push(v.clone());
300        st.on_stack.insert(v.clone());
301
302        if let Some(neighbors) = graph.get(&v) {
303            for w in neighbors {
304                if !st.indices.contains_key(w) {
305                    strong_connect(w.clone(), graph, st);
306                    let low_v = st.lowlink[&v];
307                    let low_w = st.lowlink[w];
308                    st.lowlink.insert(v.clone(), low_v.min(low_w));
309                } else if st.on_stack.contains(w) {
310                    let low_v = st.lowlink[&v];
311                    let idx_w = st.indices[w];
312                    st.lowlink.insert(v.clone(), low_v.min(idx_w));
313                }
314            }
315        }
316
317        if st.lowlink[&v] == st.indices[&v] {
318            let mut comp = Vec::new();
319            while let Some(w) = st.stack.pop() {
320                st.on_stack.remove(&w);
321                let done = w == v;
322                comp.push(w);
323                if done {
324                    break;
325                }
326            }
327            comp.sort();
328            st.components.push(comp);
329        }
330    }
331
332    let mut sorted_nodes = nodes.to_vec();
333    sorted_nodes.sort();
334    let mut st = TarjanAllState {
335        index: 0,
336        indices: HashMap::new(),
337        lowlink: HashMap::new(),
338        stack: Vec::new(),
339        on_stack: HashSet::new(),
340        components: Vec::new(),
341    };
342    for node in sorted_nodes {
343        if !st.indices.contains_key(&node) {
344            strong_connect(node, graph, &mut st);
345        }
346    }
347    st.components.sort_by(|a, b| a[0].cmp(&b[0]));
348    st.components
349}
350
351fn topo_components(
352    sccs: &[Vec<String>],
353    comp_graph: &HashMap<usize, HashSet<usize>>,
354) -> Vec<usize> {
355    let mut ids: Vec<usize> = (0..sccs.len()).collect();
356    ids.sort_by(|a, b| sccs[*a][0].cmp(&sccs[*b][0]));
357
358    let mut visited = HashSet::new();
359    let mut order = Vec::new();
360    for id in ids {
361        if !visited.contains(&id) {
362            topo_components_dfs(id, sccs, comp_graph, &mut visited, &mut order);
363        }
364    }
365    order
366}
367
368fn topo_components_dfs(
369    id: usize,
370    sccs: &[Vec<String>],
371    comp_graph: &HashMap<usize, HashSet<usize>>,
372    visited: &mut HashSet<usize>,
373    order: &mut Vec<usize>,
374) {
375    visited.insert(id);
376    let mut neighbors: Vec<usize> = comp_graph
377        .get(&id)
378        .map(|s| s.iter().copied().collect())
379        .unwrap_or_default();
380    neighbors.sort_by(|a, b| sccs[*a][0].cmp(&sccs[*b][0]));
381    for n in neighbors {
382        if !visited.contains(&n) {
383            topo_components_dfs(n, sccs, comp_graph, visited, order);
384        }
385    }
386    order.push(id);
387}
388
389// ---------------------------------------------------------------------------
390// Call graph construction
391// ---------------------------------------------------------------------------
392
393fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
394    let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
395    for item in items {
396        if let TopLevel::FnDef(fd) = item {
397            let mut callees = HashSet::new();
398            collect_callees_body(&fd.body, &mut callees);
399            graph.insert(fd.name.clone(), callees);
400        }
401    }
402    graph
403}
404
405fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
406    items
407        .iter()
408        .filter_map(|item| {
409            if let TopLevel::FnDef(fd) = item {
410                Some(fd.name.clone())
411            } else {
412                None
413            }
414        })
415        .collect()
416}
417
418fn recursive_sccs(
419    graph: &HashMap<String, HashSet<String>>,
420    user_fns: &HashSet<String>,
421) -> Vec<Vec<String>> {
422    tarjan_scc(graph, user_fns)
423        .into_iter()
424        .filter(|scc| is_recursive_scc(scc, graph))
425        .collect()
426}
427
428fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
429    if scc.len() > 1 {
430        return true;
431    }
432    if let Some(name) = scc.first() {
433        return graph
434            .get(name)
435            .is_some_and(|callees| callees.contains(name));
436    }
437    false
438}
439
440pub(crate) fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
441    match body {
442        FnBody::Expr(e) => collect_callees_expr(e, callees),
443        FnBody::Block(stmts) => {
444            for s in stmts {
445                collect_callees_stmt(s, callees);
446            }
447        }
448    }
449}
450
451fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
452    match body {
453        FnBody::Expr(e) => count_recursive_calls_expr(e, recursive, out),
454        FnBody::Block(stmts) => {
455            for s in stmts {
456                count_recursive_calls_stmt(s, recursive, out);
457            }
458        }
459    }
460}
461
462fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
463    match stmt {
464        Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
465    }
466}
467
468fn count_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>, out: &mut usize) {
469    match expr {
470        Expr::FnCall(func, args) => {
471            match func.as_ref() {
472                Expr::Ident(name) => {
473                    if recursive.contains(name) {
474                        *out += 1;
475                    }
476                }
477                Expr::Attr(obj, member) => {
478                    if let Expr::Ident(ns) = obj.as_ref() {
479                        let q = format!("{}.{}", ns, member);
480                        if recursive.contains(&q) {
481                            *out += 1;
482                        }
483                    } else {
484                        count_recursive_calls_expr(obj, recursive, out);
485                    }
486                }
487                other => count_recursive_calls_expr(other, recursive, out),
488            }
489            for arg in args {
490                count_recursive_calls_expr(arg, recursive, out);
491            }
492        }
493        Expr::TailCall(boxed) => {
494            if recursive.contains(&boxed.0) {
495                *out += 1;
496            }
497            for arg in &boxed.1 {
498                count_recursive_calls_expr(arg, recursive, out);
499            }
500        }
501        Expr::Literal(_) | Expr::Resolved(_) | Expr::Ident(_) => {}
502        Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
503        Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
504            count_recursive_calls_expr(l, recursive, out);
505            count_recursive_calls_expr(r, recursive, out);
506        }
507        Expr::Match {
508            subject: scrutinee,
509            arms,
510            ..
511        } => {
512            count_recursive_calls_expr(scrutinee, recursive, out);
513            for arm in arms {
514                count_recursive_calls_expr(&arm.body, recursive, out);
515            }
516        }
517        Expr::List(elems) | Expr::Tuple(elems) => {
518            for e in elems {
519                count_recursive_calls_expr(e, recursive, out);
520            }
521        }
522        Expr::MapLiteral(entries) => {
523            for (k, v) in entries {
524                count_recursive_calls_expr(k, recursive, out);
525                count_recursive_calls_expr(v, recursive, out);
526            }
527        }
528        Expr::Constructor(_, arg) => {
529            if let Some(a) = arg {
530                count_recursive_calls_expr(a, recursive, out);
531            }
532        }
533        Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
534        Expr::InterpolatedStr(parts) => {
535            for part in parts {
536                if let crate::ast::StrPart::Parsed(expr) = part {
537                    count_recursive_calls_expr(expr, recursive, out);
538                }
539            }
540        }
541        Expr::RecordCreate { fields, .. } => {
542            for (_, e) in fields {
543                count_recursive_calls_expr(e, recursive, out);
544            }
545        }
546        Expr::RecordUpdate { base, updates, .. } => {
547            count_recursive_calls_expr(base, recursive, out);
548            for (_, e) in updates {
549                count_recursive_calls_expr(e, recursive, out);
550            }
551        }
552    }
553}
554
555fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
556    match stmt {
557        Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
558            collect_callees_expr(e, callees);
559        }
560    }
561}
562
563fn collect_callees_expr(expr: &Expr, callees: &mut HashSet<String>) {
564    match expr {
565        Expr::FnCall(func, args) => {
566            // Extract callee name
567            match func.as_ref() {
568                Expr::Ident(name) => {
569                    callees.insert(name.clone());
570                }
571                Expr::Attr(obj, member) => {
572                    if let Expr::Ident(ns) = obj.as_ref() {
573                        callees.insert(format!("{}.{}", ns, member));
574                    }
575                }
576                _ => collect_callees_expr(func, callees),
577            }
578            for arg in args {
579                collect_callees_expr(arg, callees);
580            }
581        }
582        Expr::Literal(_) | Expr::Resolved(_) => {}
583        Expr::Ident(_) => {}
584        Expr::Attr(obj, _) => collect_callees_expr(obj, callees),
585        Expr::BinOp(_, l, r) => {
586            collect_callees_expr(l, callees);
587            collect_callees_expr(r, callees);
588        }
589        Expr::Pipe(l, r) => {
590            collect_callees_expr(l, callees);
591            collect_callees_expr(r, callees);
592        }
593        Expr::Match {
594            subject: scrutinee,
595            arms,
596            ..
597        } => {
598            collect_callees_expr(scrutinee, callees);
599            for arm in arms {
600                collect_callees_expr(&arm.body, callees);
601            }
602        }
603        Expr::List(elems) => {
604            for e in elems {
605                collect_callees_expr(e, callees);
606            }
607        }
608        Expr::Tuple(items) => {
609            for item in items {
610                collect_callees_expr(item, callees);
611            }
612        }
613        Expr::MapLiteral(entries) => {
614            for (key, value) in entries {
615                collect_callees_expr(key, callees);
616                collect_callees_expr(value, callees);
617            }
618        }
619        Expr::Constructor(_, arg) => {
620            if let Some(a) = arg {
621                collect_callees_expr(a, callees);
622            }
623        }
624        Expr::ErrorProp(inner) => collect_callees_expr(inner, callees),
625        Expr::InterpolatedStr(parts) => {
626            for part in parts {
627                if let crate::ast::StrPart::Parsed(expr) = part {
628                    collect_callees_expr(expr, callees);
629                }
630            }
631        }
632        Expr::RecordCreate { fields, .. } => {
633            for (_, e) in fields {
634                collect_callees_expr(e, callees);
635            }
636        }
637        Expr::RecordUpdate { base, updates, .. } => {
638            collect_callees_expr(base, callees);
639            for (_, e) in updates {
640                collect_callees_expr(e, callees);
641            }
642        }
643        Expr::TailCall(boxed) => {
644            callees.insert(boxed.0.clone());
645            for arg in &boxed.1 {
646                collect_callees_expr(arg, callees);
647            }
648        }
649    }
650}
651
652// ---------------------------------------------------------------------------
653// Tarjan's SCC algorithm
654// ---------------------------------------------------------------------------
655
656struct TarjanState {
657    index_counter: usize,
658    stack: Vec<String>,
659    on_stack: HashSet<String>,
660    indices: HashMap<String, usize>,
661    lowlinks: HashMap<String, usize>,
662    sccs: Vec<Vec<String>>,
663}
664
665fn tarjan_scc(
666    graph: &HashMap<String, HashSet<String>>,
667    nodes: &HashSet<String>,
668) -> Vec<Vec<String>> {
669    let mut state = TarjanState {
670        index_counter: 0,
671        stack: Vec::new(),
672        on_stack: HashSet::new(),
673        indices: HashMap::new(),
674        lowlinks: HashMap::new(),
675        sccs: Vec::new(),
676    };
677
678    for node in nodes {
679        if !state.indices.contains_key(node) {
680            strongconnect(node, graph, &mut state);
681        }
682    }
683
684    state.sccs
685}
686
687fn strongconnect(v: &str, graph: &HashMap<String, HashSet<String>>, state: &mut TarjanState) {
688    let idx = state.index_counter;
689    state.index_counter += 1;
690    state.indices.insert(v.to_string(), idx);
691    state.lowlinks.insert(v.to_string(), idx);
692    state.stack.push(v.to_string());
693    state.on_stack.insert(v.to_string());
694
695    if let Some(callees) = graph.get(v) {
696        for w in callees {
697            if !state.indices.contains_key(w) {
698                // Only recurse into nodes that are in our function set
699                if graph.contains_key(w) {
700                    strongconnect(w, graph, state);
701                    let w_low = state.lowlinks[w];
702                    let v_low = state.lowlinks[v];
703                    if w_low < v_low {
704                        state.lowlinks.insert(v.to_string(), w_low);
705                    }
706                }
707            } else if state.on_stack.contains(w) {
708                let w_idx = state.indices[w];
709                let v_low = state.lowlinks[v];
710                if w_idx < v_low {
711                    state.lowlinks.insert(v.to_string(), w_idx);
712                }
713            }
714        }
715    }
716
717    // If v is a root node, pop the SCC
718    if state.lowlinks[v] == state.indices[v] {
719        let mut scc = Vec::new();
720        loop {
721            let w = state.stack.pop().unwrap();
722            state.on_stack.remove(&w);
723            scc.push(w.clone());
724            if w == v {
725                break;
726            }
727        }
728        state.sccs.push(scc);
729    }
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735
736    #[test]
737    fn detects_self_recursion() {
738        let src = r#"
739fn fib(n: Int) -> Int
740    match n
741        0 -> 0
742        1 -> 1
743        _ -> fib(n - 1) + fib(n - 2)
744"#;
745        let items = parse(src);
746        let rec = find_recursive_fns(&items);
747        assert!(
748            rec.contains("fib"),
749            "fib should be recursive, got: {:?}",
750            rec
751        );
752    }
753
754    #[test]
755    fn non_recursive_fn() {
756        let src = "fn double(x: Int) -> Int\n    = x + x\n";
757        let items = parse(src);
758        let rec = find_recursive_fns(&items);
759        assert!(
760            rec.is_empty(),
761            "double should not be recursive, got: {:?}",
762            rec
763        );
764    }
765
766    #[test]
767    fn mutual_recursion() {
768        let src = r#"
769fn isEven(n: Int) -> Bool
770    match n
771        0 -> true
772        _ -> isOdd(n - 1)
773
774fn isOdd(n: Int) -> Bool
775    match n
776        0 -> false
777        _ -> isEven(n - 1)
778"#;
779        let items = parse(src);
780        let rec = find_recursive_fns(&items);
781        assert!(rec.contains("isEven"), "isEven should be recursive");
782        assert!(rec.contains("isOdd"), "isOdd should be recursive");
783    }
784
785    #[test]
786    fn recursive_callsites_count_syntactic_occurrences() {
787        let src = r#"
788fn fib(n: Int) -> Int
789    match n
790        0 -> 0
791        1 -> 1
792        _ -> fib(n - 1) + fib(n - 2)
793"#;
794        let items = parse(src);
795        let counts = recursive_callsite_counts(&items);
796        assert_eq!(counts.get("fib").copied().unwrap_or(0), 2);
797    }
798
799    #[test]
800    fn recursive_callsites_are_scoped_to_scc() {
801        let src = r#"
802fn a(n: Int) -> Int
803    match n
804        0 -> 0
805        _ -> b(n - 1) + fib(n)
806
807fn b(n: Int) -> Int
808    match n
809        0 -> 0
810        _ -> a(n - 1)
811
812fn fib(n: Int) -> Int
813    match n
814        0 -> 0
815        1 -> 1
816        _ -> fib(n - 1) + fib(n - 2)
817"#;
818        let items = parse(src);
819        let counts = recursive_callsite_counts(&items);
820        assert_eq!(counts.get("a").copied().unwrap_or(0), 1);
821        assert_eq!(counts.get("b").copied().unwrap_or(0), 1);
822        assert_eq!(counts.get("fib").copied().unwrap_or(0), 2);
823    }
824
825    #[test]
826    fn recursive_scc_ids_are_deterministic_by_group_name() {
827        let src = r#"
828fn z(n: Int) -> Int
829    match n
830        0 -> 0
831        _ -> z(n - 1)
832
833fn a(n: Int) -> Int
834    match n
835        0 -> 0
836        _ -> b(n - 1)
837
838fn b(n: Int) -> Int
839    match n
840        0 -> 0
841        _ -> a(n - 1)
842"#;
843        let items = parse(src);
844        let ids = recursive_scc_ids(&items);
845        // Group {a,b} gets id=1 (min name "a"), group {z} gets id=2.
846        assert_eq!(ids.get("a").copied().unwrap_or(0), 1);
847        assert_eq!(ids.get("b").copied().unwrap_or(0), 1);
848        assert_eq!(ids.get("z").copied().unwrap_or(0), 2);
849    }
850
851    fn parse(src: &str) -> Vec<TopLevel> {
852        let mut lexer = crate::lexer::Lexer::new(src);
853        let tokens = lexer.tokenize().expect("lex failed");
854        let mut parser = crate::parser::Parser::new(tokens);
855        parser.parse().expect("parse failed")
856    }
857}