Skip to main content

aver/
call_graph.rs

1/// Call-graph analysis and Tarjan's SCC algorithm.
2///
3/// Given a parsed program, builds a directed graph of function calls
4/// and finds strongly-connected components.  A function is *recursive*
5/// if it belongs to an SCC with a cycle (size > 1, or size 1 with a
6/// self-edge).
7use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, Stmt, TopLevel};
10
11// ---------------------------------------------------------------------------
12// Public API
13// ---------------------------------------------------------------------------
14
15/// Returns the SCC groups that contain cycles (self or mutual recursion).
16/// Each group is a `HashSet<String>` of function names in the SCC.
17pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
18    let graph = build_call_graph(items);
19    let user_fns = user_fn_names(items);
20    recursive_sccs(&graph, &user_fns)
21        .into_iter()
22        .map(|scc| scc.into_iter().collect())
23        .collect()
24}
25
26/// Returns the set of user-defined function names that are recursive
27/// (directly or mutually).
28pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
29    let graph = build_call_graph(items);
30    let user_fns = user_fn_names(items);
31    let mut recursive = HashSet::new();
32    for scc in recursive_sccs(&graph, &user_fns) {
33        for name in scc {
34            recursive.insert(name);
35        }
36    }
37    recursive
38}
39
40/// Direct call summary per user-defined function (unique + sorted).
41pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
42    let graph = build_call_graph(items);
43    let mut out = HashMap::new();
44    for item in items {
45        if let TopLevel::FnDef(fd) = item {
46            let mut callees = graph
47                .get(&fd.name)
48                .cloned()
49                .unwrap_or_default()
50                .into_iter()
51                .collect::<Vec<_>>();
52            callees.sort();
53            out.insert(fd.name.clone(), callees);
54        }
55    }
56    out
57}
58
59/// Count recursive callsites per user-defined function, scoped to caller SCC.
60///
61/// Callsite definition:
62/// - one syntactic `FnCall` or `TailCall` node in the function body,
63/// - whose callee is a user-defined function in the same recursive SCC
64///   as the caller.
65///
66/// This is a syntactic metric over AST nodes (not dynamic execution count,
67/// not CFG edges), so it stays stable across control-flow rewrites.
68pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
69    let graph = build_call_graph(items);
70    let user_fns = user_fn_names(items);
71    let sccs = recursive_sccs(&graph, &user_fns);
72    let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
73    for scc in sccs {
74        let members: HashSet<String> = scc.iter().cloned().collect();
75        for name in scc {
76            scc_members.insert(name, members.clone());
77        }
78    }
79
80    let mut out = HashMap::new();
81    for item in items {
82        if let TopLevel::FnDef(fd) = item {
83            let mut count = 0usize;
84            if let Some(members) = scc_members.get(&fd.name) {
85                count_recursive_calls_body(&fd.body, members, &mut count);
86            }
87            out.insert(fd.name.clone(), count);
88        }
89    }
90    out
91}
92
93/// Deterministic recursive SCC id per function (1-based).
94/// Non-recursive functions are absent from the returned map.
95pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
96    let graph = build_call_graph(items);
97    let user_fns = user_fn_names(items);
98    let mut sccs = recursive_sccs(&graph, &user_fns);
99    for scc in &mut sccs {
100        scc.sort();
101    }
102    sccs.sort_by(|a, b| a.first().cmp(&b.first()));
103
104    let mut out = HashMap::new();
105    for (idx, scc) in sccs.into_iter().enumerate() {
106        let id = idx + 1;
107        for name in scc {
108            out.insert(name, id);
109        }
110    }
111    out
112}
113
114// ---------------------------------------------------------------------------
115// Call graph construction
116// ---------------------------------------------------------------------------
117
118fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
119    let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
120    for item in items {
121        if let TopLevel::FnDef(fd) = item {
122            let mut callees = HashSet::new();
123            collect_callees_body(&fd.body, &mut callees);
124            graph.insert(fd.name.clone(), callees);
125        }
126    }
127    graph
128}
129
130fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
131    items
132        .iter()
133        .filter_map(|item| {
134            if let TopLevel::FnDef(fd) = item {
135                Some(fd.name.clone())
136            } else {
137                None
138            }
139        })
140        .collect()
141}
142
143fn recursive_sccs(
144    graph: &HashMap<String, HashSet<String>>,
145    user_fns: &HashSet<String>,
146) -> Vec<Vec<String>> {
147    tarjan_scc(graph, user_fns)
148        .into_iter()
149        .filter(|scc| is_recursive_scc(scc, graph))
150        .collect()
151}
152
153fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
154    if scc.len() > 1 {
155        return true;
156    }
157    if let Some(name) = scc.first() {
158        return graph
159            .get(name)
160            .is_some_and(|callees| callees.contains(name));
161    }
162    false
163}
164
165fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
166    match body {
167        FnBody::Expr(e) => collect_callees_expr(e, callees),
168        FnBody::Block(stmts) => {
169            for s in stmts {
170                collect_callees_stmt(s, callees);
171            }
172        }
173    }
174}
175
176fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
177    match body {
178        FnBody::Expr(e) => count_recursive_calls_expr(e, recursive, out),
179        FnBody::Block(stmts) => {
180            for s in stmts {
181                count_recursive_calls_stmt(s, recursive, out);
182            }
183        }
184    }
185}
186
187fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
188    match stmt {
189        Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
190    }
191}
192
193fn count_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>, out: &mut usize) {
194    match expr {
195        Expr::FnCall(func, args) => {
196            match func.as_ref() {
197                Expr::Ident(name) => {
198                    if recursive.contains(name) {
199                        *out += 1;
200                    }
201                }
202                Expr::Attr(obj, member) => {
203                    if let Expr::Ident(ns) = obj.as_ref() {
204                        let q = format!("{}.{}", ns, member);
205                        if recursive.contains(&q) {
206                            *out += 1;
207                        }
208                    } else {
209                        count_recursive_calls_expr(obj, recursive, out);
210                    }
211                }
212                other => count_recursive_calls_expr(other, recursive, out),
213            }
214            for arg in args {
215                count_recursive_calls_expr(arg, recursive, out);
216            }
217        }
218        Expr::TailCall(boxed) => {
219            if recursive.contains(&boxed.0) {
220                *out += 1;
221            }
222            for arg in &boxed.1 {
223                count_recursive_calls_expr(arg, recursive, out);
224            }
225        }
226        Expr::Literal(_) | Expr::Resolved(_) | Expr::Ident(_) => {}
227        Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
228        Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
229            count_recursive_calls_expr(l, recursive, out);
230            count_recursive_calls_expr(r, recursive, out);
231        }
232        Expr::Match {
233            subject: scrutinee,
234            arms,
235            ..
236        } => {
237            count_recursive_calls_expr(scrutinee, recursive, out);
238            for arm in arms {
239                count_recursive_calls_expr(&arm.body, recursive, out);
240            }
241        }
242        Expr::List(elems) | Expr::Tuple(elems) => {
243            for e in elems {
244                count_recursive_calls_expr(e, recursive, out);
245            }
246        }
247        Expr::MapLiteral(entries) => {
248            for (k, v) in entries {
249                count_recursive_calls_expr(k, recursive, out);
250                count_recursive_calls_expr(v, recursive, out);
251            }
252        }
253        Expr::Constructor(_, arg) => {
254            if let Some(a) = arg {
255                count_recursive_calls_expr(a, recursive, out);
256            }
257        }
258        Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
259        Expr::InterpolatedStr(parts) => {
260            for part in parts {
261                if let crate::ast::StrPart::Parsed(expr) = part {
262                    count_recursive_calls_expr(expr, recursive, out);
263                }
264            }
265        }
266        Expr::RecordCreate { fields, .. } => {
267            for (_, e) in fields {
268                count_recursive_calls_expr(e, recursive, out);
269            }
270        }
271        Expr::RecordUpdate { base, updates, .. } => {
272            count_recursive_calls_expr(base, recursive, out);
273            for (_, e) in updates {
274                count_recursive_calls_expr(e, recursive, out);
275            }
276        }
277    }
278}
279
280fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
281    match stmt {
282        Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
283            collect_callees_expr(e, callees);
284        }
285    }
286}
287
288fn collect_callees_expr(expr: &Expr, callees: &mut HashSet<String>) {
289    match expr {
290        Expr::FnCall(func, args) => {
291            // Extract callee name
292            match func.as_ref() {
293                Expr::Ident(name) => {
294                    callees.insert(name.clone());
295                }
296                Expr::Attr(obj, member) => {
297                    if let Expr::Ident(ns) = obj.as_ref() {
298                        callees.insert(format!("{}.{}", ns, member));
299                    }
300                }
301                _ => collect_callees_expr(func, callees),
302            }
303            for arg in args {
304                collect_callees_expr(arg, callees);
305            }
306        }
307        Expr::Literal(_) | Expr::Resolved(_) => {}
308        Expr::Ident(_) => {}
309        Expr::Attr(obj, _) => collect_callees_expr(obj, callees),
310        Expr::BinOp(_, l, r) => {
311            collect_callees_expr(l, callees);
312            collect_callees_expr(r, callees);
313        }
314        Expr::Pipe(l, r) => {
315            collect_callees_expr(l, callees);
316            collect_callees_expr(r, callees);
317        }
318        Expr::Match {
319            subject: scrutinee,
320            arms,
321            ..
322        } => {
323            collect_callees_expr(scrutinee, callees);
324            for arm in arms {
325                collect_callees_expr(&arm.body, callees);
326            }
327        }
328        Expr::List(elems) => {
329            for e in elems {
330                collect_callees_expr(e, callees);
331            }
332        }
333        Expr::Tuple(items) => {
334            for item in items {
335                collect_callees_expr(item, callees);
336            }
337        }
338        Expr::MapLiteral(entries) => {
339            for (key, value) in entries {
340                collect_callees_expr(key, callees);
341                collect_callees_expr(value, callees);
342            }
343        }
344        Expr::Constructor(_, arg) => {
345            if let Some(a) = arg {
346                collect_callees_expr(a, callees);
347            }
348        }
349        Expr::ErrorProp(inner) => collect_callees_expr(inner, callees),
350        Expr::InterpolatedStr(parts) => {
351            for part in parts {
352                if let crate::ast::StrPart::Parsed(expr) = part {
353                    collect_callees_expr(expr, callees);
354                }
355            }
356        }
357        Expr::RecordCreate { fields, .. } => {
358            for (_, e) in fields {
359                collect_callees_expr(e, callees);
360            }
361        }
362        Expr::RecordUpdate { base, updates, .. } => {
363            collect_callees_expr(base, callees);
364            for (_, e) in updates {
365                collect_callees_expr(e, callees);
366            }
367        }
368        Expr::TailCall(boxed) => {
369            callees.insert(boxed.0.clone());
370            for arg in &boxed.1 {
371                collect_callees_expr(arg, callees);
372            }
373        }
374    }
375}
376
377// ---------------------------------------------------------------------------
378// Tarjan's SCC algorithm
379// ---------------------------------------------------------------------------
380
381struct TarjanState {
382    index_counter: usize,
383    stack: Vec<String>,
384    on_stack: HashSet<String>,
385    indices: HashMap<String, usize>,
386    lowlinks: HashMap<String, usize>,
387    sccs: Vec<Vec<String>>,
388}
389
390fn tarjan_scc(
391    graph: &HashMap<String, HashSet<String>>,
392    nodes: &HashSet<String>,
393) -> Vec<Vec<String>> {
394    let mut state = TarjanState {
395        index_counter: 0,
396        stack: Vec::new(),
397        on_stack: HashSet::new(),
398        indices: HashMap::new(),
399        lowlinks: HashMap::new(),
400        sccs: Vec::new(),
401    };
402
403    for node in nodes {
404        if !state.indices.contains_key(node) {
405            strongconnect(node, graph, &mut state);
406        }
407    }
408
409    state.sccs
410}
411
412fn strongconnect(v: &str, graph: &HashMap<String, HashSet<String>>, state: &mut TarjanState) {
413    let idx = state.index_counter;
414    state.index_counter += 1;
415    state.indices.insert(v.to_string(), idx);
416    state.lowlinks.insert(v.to_string(), idx);
417    state.stack.push(v.to_string());
418    state.on_stack.insert(v.to_string());
419
420    if let Some(callees) = graph.get(v) {
421        for w in callees {
422            if !state.indices.contains_key(w) {
423                // Only recurse into nodes that are in our function set
424                if graph.contains_key(w) {
425                    strongconnect(w, graph, state);
426                    let w_low = state.lowlinks[w];
427                    let v_low = state.lowlinks[v];
428                    if w_low < v_low {
429                        state.lowlinks.insert(v.to_string(), w_low);
430                    }
431                }
432            } else if state.on_stack.contains(w) {
433                let w_idx = state.indices[w];
434                let v_low = state.lowlinks[v];
435                if w_idx < v_low {
436                    state.lowlinks.insert(v.to_string(), w_idx);
437                }
438            }
439        }
440    }
441
442    // If v is a root node, pop the SCC
443    if state.lowlinks[v] == state.indices[v] {
444        let mut scc = Vec::new();
445        loop {
446            let w = state.stack.pop().unwrap();
447            state.on_stack.remove(&w);
448            scc.push(w.clone());
449            if w == v {
450                break;
451            }
452        }
453        state.sccs.push(scc);
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn detects_self_recursion() {
463        let src = r#"
464fn fib(n: Int) -> Int
465    match n
466        0 -> 0
467        1 -> 1
468        _ -> fib(n - 1) + fib(n - 2)
469"#;
470        let items = parse(src);
471        let rec = find_recursive_fns(&items);
472        assert!(
473            rec.contains("fib"),
474            "fib should be recursive, got: {:?}",
475            rec
476        );
477    }
478
479    #[test]
480    fn non_recursive_fn() {
481        let src = "fn double(x: Int) -> Int\n    = x + x\n";
482        let items = parse(src);
483        let rec = find_recursive_fns(&items);
484        assert!(
485            rec.is_empty(),
486            "double should not be recursive, got: {:?}",
487            rec
488        );
489    }
490
491    #[test]
492    fn mutual_recursion() {
493        let src = r#"
494fn isEven(n: Int) -> Bool
495    match n
496        0 -> true
497        _ -> isOdd(n - 1)
498
499fn isOdd(n: Int) -> Bool
500    match n
501        0 -> false
502        _ -> isEven(n - 1)
503"#;
504        let items = parse(src);
505        let rec = find_recursive_fns(&items);
506        assert!(rec.contains("isEven"), "isEven should be recursive");
507        assert!(rec.contains("isOdd"), "isOdd should be recursive");
508    }
509
510    #[test]
511    fn recursive_callsites_count_syntactic_occurrences() {
512        let src = r#"
513fn fib(n: Int) -> Int
514    match n
515        0 -> 0
516        1 -> 1
517        _ -> fib(n - 1) + fib(n - 2)
518"#;
519        let items = parse(src);
520        let counts = recursive_callsite_counts(&items);
521        assert_eq!(counts.get("fib").copied().unwrap_or(0), 2);
522    }
523
524    #[test]
525    fn recursive_callsites_are_scoped_to_scc() {
526        let src = r#"
527fn a(n: Int) -> Int
528    match n
529        0 -> 0
530        _ -> b(n - 1) + fib(n)
531
532fn b(n: Int) -> Int
533    match n
534        0 -> 0
535        _ -> a(n - 1)
536
537fn fib(n: Int) -> Int
538    match n
539        0 -> 0
540        1 -> 1
541        _ -> fib(n - 1) + fib(n - 2)
542"#;
543        let items = parse(src);
544        let counts = recursive_callsite_counts(&items);
545        assert_eq!(counts.get("a").copied().unwrap_or(0), 1);
546        assert_eq!(counts.get("b").copied().unwrap_or(0), 1);
547        assert_eq!(counts.get("fib").copied().unwrap_or(0), 2);
548    }
549
550    #[test]
551    fn recursive_scc_ids_are_deterministic_by_group_name() {
552        let src = r#"
553fn z(n: Int) -> Int
554    match n
555        0 -> 0
556        _ -> z(n - 1)
557
558fn a(n: Int) -> Int
559    match n
560        0 -> 0
561        _ -> b(n - 1)
562
563fn b(n: Int) -> Int
564    match n
565        0 -> 0
566        _ -> a(n - 1)
567"#;
568        let items = parse(src);
569        let ids = recursive_scc_ids(&items);
570        // Group {a,b} gets id=1 (min name "a"), group {z} gets id=2.
571        assert_eq!(ids.get("a").copied().unwrap_or(0), 1);
572        assert_eq!(ids.get("b").copied().unwrap_or(0), 1);
573        assert_eq!(ids.get("z").copied().unwrap_or(0), 2);
574    }
575
576    fn parse(src: &str) -> Vec<TopLevel> {
577        let mut lexer = crate::lexer::Lexer::new(src);
578        let tokens = lexer.tokenize().expect("lex failed");
579        let mut parser = crate::parser::Parser::new(tokens);
580        parser.parse().expect("parse failed")
581    }
582}