Skip to main content

aver/
call_graph.rs

1/// Call-graph analysis and Tarjan's SCC algorithm.
2///
3/// Given a parsed program, builds a directed graph of function calls
4/// and finds strongly-connected components.  A function is *recursive*
5/// if it belongs to an SCC with a cycle (size > 1, or size 1 with a
6/// self-edge).
7use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, Spanned, Stmt, StrPart, TopLevel};
10
11mod codegen;
12mod scc;
13
14pub use codegen::{ordered_fn_components, tailcall_scc_components};
15
16// ---------------------------------------------------------------------------
17// Public API
18// ---------------------------------------------------------------------------
19
20/// Returns the SCC groups that contain cycles (self or mutual recursion).
21/// Each group is a `HashSet<String>` of function names in the SCC.
22pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
23    let graph = build_call_graph(items);
24    let user_fns = user_fn_names(items);
25    recursive_sccs(&graph, &user_fns)
26        .into_iter()
27        .map(|scc| scc.into_iter().collect())
28        .collect()
29}
30
31/// Returns the set of user-defined function names that are recursive
32/// (directly or mutually).
33pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
34    let graph = build_call_graph(items);
35    let user_fns = user_fn_names(items);
36    let mut recursive = HashSet::new();
37    for scc in recursive_sccs(&graph, &user_fns) {
38        for name in scc {
39            recursive.insert(name);
40        }
41    }
42    recursive
43}
44
45/// Direct call summary per user-defined function (unique + sorted).
46pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
47    let graph = build_call_graph(items);
48    let mut out = HashMap::new();
49    for item in items {
50        if let TopLevel::FnDef(fd) = item {
51            let mut callees = graph
52                .get(&fd.name)
53                .cloned()
54                .unwrap_or_default()
55                .into_iter()
56                .collect::<Vec<_>>();
57            callees.sort();
58            out.insert(fd.name.clone(), callees);
59        }
60    }
61    out
62}
63
64/// Count recursive callsites per user-defined function, scoped to caller SCC.
65///
66/// Callsite definition:
67/// - one syntactic `FnCall` or `TailCall` node in the function body,
68/// - whose callee is a user-defined function in the same recursive SCC
69///   as the caller.
70///
71/// This is a syntactic metric over AST nodes (not dynamic execution count,
72/// not CFG edges), so it stays stable across control-flow rewrites.
73pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
74    let graph = build_call_graph(items);
75    let user_fns = user_fn_names(items);
76    let sccs = recursive_sccs(&graph, &user_fns);
77    let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
78    for scc in sccs {
79        let members: HashSet<String> = scc.iter().cloned().collect();
80        for name in scc {
81            scc_members.insert(name, members.clone());
82        }
83    }
84
85    let mut out = HashMap::new();
86    for item in items {
87        if let TopLevel::FnDef(fd) = item {
88            let mut count = 0usize;
89            if let Some(members) = scc_members.get(&fd.name) {
90                count_recursive_calls_body(&fd.body, members, &mut count);
91            }
92            out.insert(fd.name.clone(), count);
93        }
94    }
95    out
96}
97
98/// Deterministic recursive SCC id per function (1-based).
99/// Non-recursive functions are absent from the returned map.
100pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
101    let graph = build_call_graph(items);
102    let user_fns = user_fn_names(items);
103    let mut sccs = recursive_sccs(&graph, &user_fns);
104    for scc in &mut sccs {
105        scc.sort();
106    }
107    sccs.sort_by(|a, b| a.first().cmp(&b.first()));
108
109    let mut out = HashMap::new();
110    for (idx, scc) in sccs.into_iter().enumerate() {
111        let id = idx + 1;
112        for name in scc {
113            out.insert(name, id);
114        }
115    }
116    out
117}
118
119fn canonical_codegen_dep(
120    name: &str,
121    fn_names: &HashSet<String>,
122    module_prefixes: &HashSet<String>,
123) -> Option<String> {
124    if fn_names.contains(name) {
125        return Some(name.to_string());
126    }
127
128    let mut best_prefix: Option<&str> = None;
129    for prefix in module_prefixes {
130        let dotted_prefix = format!("{}.", prefix);
131        if name.starts_with(&dotted_prefix)
132            && best_prefix.is_none_or(|best| prefix.len() > best.len())
133        {
134            best_prefix = Some(prefix.as_str());
135        }
136    }
137
138    let prefix = best_prefix?;
139    let bare = &name[prefix.len() + 1..];
140    fn_names.contains(bare).then(|| bare.to_string())
141}
142
143fn collect_codegen_deps_body(
144    body: &FnBody,
145    fn_names: &HashSet<String>,
146    module_prefixes: &HashSet<String>,
147    out: &mut HashSet<String>,
148) {
149    for s in body.stmts() {
150        match s {
151            Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
152                collect_codegen_deps_expr(e, fn_names, module_prefixes, out)
153            }
154        }
155    }
156}
157
158fn collect_codegen_deps_expr(
159    expr: &Spanned<Expr>,
160    fn_names: &HashSet<String>,
161    module_prefixes: &HashSet<String>,
162    out: &mut HashSet<String>,
163) {
164    walk_expr(expr, &mut |node| match node {
165        Expr::FnCall(func, args) => {
166            if let Some(callee) = expr_to_dotted_name(func.as_ref())
167                && let Some(canonical) = canonical_codegen_dep(&callee, fn_names, module_prefixes)
168            {
169                out.insert(canonical);
170            }
171            for arg in args {
172                // function-as-value dependency, e.g. apply(f, x)
173                if let Some(qname) = expr_to_dotted_name(arg)
174                    && let Some(canonical) =
175                        canonical_codegen_dep(&qname, fn_names, module_prefixes)
176                {
177                    out.insert(canonical);
178                }
179            }
180        }
181        Expr::TailCall(boxed) if fn_names.contains(&boxed.target) => {
182            out.insert(boxed.target.clone());
183        }
184        _ => {}
185    });
186}
187
188fn expr_to_dotted_name(expr: &Spanned<Expr>) -> Option<String> {
189    match &expr.node {
190        Expr::Ident(name) => Some(name.clone()),
191        Expr::Attr(obj, field) => {
192            let head = expr_to_dotted_name(obj)?;
193            Some(format!("{}.{}", head, field))
194        }
195        _ => None,
196    }
197}
198
199fn walk_expr(expr: &Spanned<Expr>, visit: &mut impl FnMut(&Expr)) {
200    visit(&expr.node);
201    match &expr.node {
202        Expr::FnCall(func, args) => {
203            walk_expr(func, visit);
204            for arg in args {
205                walk_expr(arg, visit);
206            }
207        }
208        Expr::TailCall(boxed) => {
209            for arg in &boxed.args {
210                walk_expr(arg, visit);
211            }
212        }
213        Expr::Attr(obj, _) => walk_expr(obj, visit),
214        Expr::BinOp(_, l, r) => {
215            walk_expr(l, visit);
216            walk_expr(r, visit);
217        }
218        Expr::Match { subject, arms, .. } => {
219            walk_expr(subject, visit);
220            for arm in arms {
221                walk_expr(&arm.body, visit);
222            }
223        }
224        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
225            for item in items {
226                walk_expr(item, visit);
227            }
228        }
229        Expr::MapLiteral(entries) => {
230            for (k, v) in entries {
231                walk_expr(k, visit);
232                walk_expr(v, visit);
233            }
234        }
235        Expr::Constructor(_, maybe) => {
236            if let Some(inner) = maybe {
237                walk_expr(inner, visit);
238            }
239        }
240        Expr::ErrorProp(inner) => walk_expr(inner, visit),
241        Expr::InterpolatedStr(parts) => {
242            for part in parts {
243                if let StrPart::Parsed(e) = part {
244                    walk_expr(e, visit);
245                }
246            }
247        }
248        Expr::RecordCreate { fields, .. } => {
249            for (_, e) in fields {
250                walk_expr(e, visit);
251            }
252        }
253        Expr::RecordUpdate { base, updates, .. } => {
254            walk_expr(base, visit);
255            for (_, e) in updates {
256                walk_expr(e, visit);
257            }
258        }
259        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
260    }
261}
262
263// ---------------------------------------------------------------------------
264// Call graph construction
265// ---------------------------------------------------------------------------
266
267fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
268    let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
269    for item in items {
270        if let TopLevel::FnDef(fd) = item {
271            let mut callees = HashSet::new();
272            collect_callees_body(&fd.body, &mut callees);
273            graph.insert(fd.name.clone(), callees);
274        }
275    }
276    graph
277}
278
279fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
280    items
281        .iter()
282        .filter_map(|item| {
283            if let TopLevel::FnDef(fd) = item {
284                Some(fd.name.clone())
285            } else {
286                None
287            }
288        })
289        .collect()
290}
291
292fn recursive_sccs(
293    graph: &HashMap<String, HashSet<String>>,
294    user_fns: &HashSet<String>,
295) -> Vec<Vec<String>> {
296    let mut names = user_fns.iter().cloned().collect::<Vec<_>>();
297    names.sort();
298
299    let mut adj: HashMap<String, Vec<String>> = HashMap::new();
300    for name in &names {
301        let mut deps = graph
302            .get(name)
303            .cloned()
304            .unwrap_or_default()
305            .into_iter()
306            .filter(|callee| user_fns.contains(callee))
307            .collect::<Vec<_>>();
308        deps.sort();
309        adj.insert(name.clone(), deps);
310    }
311
312    scc::tarjan_sccs(&names, &adj)
313        .into_iter()
314        .filter(|scc| is_recursive_scc(scc, graph))
315        .collect()
316}
317
318fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
319    if scc.len() > 1 {
320        return true;
321    }
322    if let Some(name) = scc.first() {
323        return graph
324            .get(name)
325            .is_some_and(|callees| callees.contains(name));
326    }
327    false
328}
329
330pub(crate) fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
331    for s in body.stmts() {
332        collect_callees_stmt(s, callees);
333    }
334}
335
336fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
337    for s in body.stmts() {
338        count_recursive_calls_stmt(s, recursive, out);
339    }
340}
341
342fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
343    match stmt {
344        Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
345    }
346}
347
348fn count_recursive_calls_expr(expr: &Spanned<Expr>, recursive: &HashSet<String>, out: &mut usize) {
349    match &expr.node {
350        Expr::FnCall(func, args) => {
351            match &func.node {
352                Expr::Ident(name) => {
353                    if recursive.contains(name) {
354                        *out += 1;
355                    }
356                }
357                Expr::Attr(obj, member) => {
358                    if let Expr::Ident(ns) = &obj.node {
359                        let q = format!("{}.{}", ns, member);
360                        if recursive.contains(&q) {
361                            *out += 1;
362                        }
363                    } else {
364                        count_recursive_calls_expr(obj, recursive, out);
365                    }
366                }
367                _ => count_recursive_calls_expr(func, recursive, out),
368            }
369            for arg in args {
370                count_recursive_calls_expr(arg, recursive, out);
371            }
372        }
373        Expr::TailCall(boxed) => {
374            if recursive.contains(&boxed.target) {
375                *out += 1;
376            }
377            for arg in &boxed.args {
378                count_recursive_calls_expr(arg, recursive, out);
379            }
380        }
381        Expr::Literal(_) | Expr::Resolved { .. } | Expr::Ident(_) => {}
382        Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
383        Expr::BinOp(_, l, r) => {
384            count_recursive_calls_expr(l, recursive, out);
385            count_recursive_calls_expr(r, recursive, out);
386        }
387        Expr::Match {
388            subject: scrutinee,
389            arms,
390            ..
391        } => {
392            count_recursive_calls_expr(scrutinee, recursive, out);
393            for arm in arms {
394                count_recursive_calls_expr(&arm.body, recursive, out);
395            }
396        }
397        Expr::List(elems) | Expr::Tuple(elems) | Expr::IndependentProduct(elems, _) => {
398            for e in elems {
399                count_recursive_calls_expr(e, recursive, out);
400            }
401        }
402        Expr::MapLiteral(entries) => {
403            for (k, v) in entries {
404                count_recursive_calls_expr(k, recursive, out);
405                count_recursive_calls_expr(v, recursive, out);
406            }
407        }
408        Expr::Constructor(_, arg) => {
409            if let Some(a) = arg {
410                count_recursive_calls_expr(a, recursive, out);
411            }
412        }
413        Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
414        Expr::InterpolatedStr(parts) => {
415            for part in parts {
416                if let crate::ast::StrPart::Parsed(expr) = part {
417                    count_recursive_calls_expr(expr, recursive, out);
418                }
419            }
420        }
421        Expr::RecordCreate { fields, .. } => {
422            for (_, e) in fields {
423                count_recursive_calls_expr(e, recursive, out);
424            }
425        }
426        Expr::RecordUpdate { base, updates, .. } => {
427            count_recursive_calls_expr(base, recursive, out);
428            for (_, e) in updates {
429                count_recursive_calls_expr(e, recursive, out);
430            }
431        }
432    }
433}
434
435fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
436    match stmt {
437        Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
438            collect_callees_expr(e, callees);
439        }
440    }
441}
442
443fn collect_callees_expr(expr: &Spanned<Expr>, callees: &mut HashSet<String>) {
444    walk_expr(expr, &mut |node| match node {
445        Expr::FnCall(func, _) => {
446            if let Some(callee) = expr_to_dotted_name(func.as_ref()) {
447                callees.insert(callee);
448            }
449        }
450        Expr::TailCall(boxed) => {
451            callees.insert(boxed.target.clone());
452        }
453        _ => {}
454    });
455}
456
457#[cfg(test)]
458mod tests;