Skip to main content

aver/
call_graph.rs

1/// Call-graph analysis and Tarjan's SCC algorithm.
2///
3/// Given a parsed program, builds a directed graph of function calls
4/// and finds strongly-connected components.  A function is *recursive*
5/// if it belongs to an SCC with a cycle (size > 1, or size 1 with a
6/// self-edge).
7use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
10
11mod codegen;
12mod scc;
13
14pub use codegen::ordered_fn_components;
15
16// ---------------------------------------------------------------------------
17// Public API
18// ---------------------------------------------------------------------------
19
20/// Returns the SCC groups that contain cycles (self or mutual recursion).
21/// Each group is a `HashSet<String>` of function names in the SCC.
22pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
23    let graph = build_call_graph(items);
24    let user_fns = user_fn_names(items);
25    recursive_sccs(&graph, &user_fns)
26        .into_iter()
27        .map(|scc| scc.into_iter().collect())
28        .collect()
29}
30
31/// Returns the set of user-defined function names that are recursive
32/// (directly or mutually).
33pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
34    let graph = build_call_graph(items);
35    let user_fns = user_fn_names(items);
36    let mut recursive = HashSet::new();
37    for scc in recursive_sccs(&graph, &user_fns) {
38        for name in scc {
39            recursive.insert(name);
40        }
41    }
42    recursive
43}
44
45/// Direct call summary per user-defined function (unique + sorted).
46pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
47    let graph = build_call_graph(items);
48    let mut out = HashMap::new();
49    for item in items {
50        if let TopLevel::FnDef(fd) = item {
51            let mut callees = graph
52                .get(&fd.name)
53                .cloned()
54                .unwrap_or_default()
55                .into_iter()
56                .collect::<Vec<_>>();
57            callees.sort();
58            out.insert(fd.name.clone(), callees);
59        }
60    }
61    out
62}
63
64/// Count recursive callsites per user-defined function, scoped to caller SCC.
65///
66/// Callsite definition:
67/// - one syntactic `FnCall` or `TailCall` node in the function body,
68/// - whose callee is a user-defined function in the same recursive SCC
69///   as the caller.
70///
71/// This is a syntactic metric over AST nodes (not dynamic execution count,
72/// not CFG edges), so it stays stable across control-flow rewrites.
73pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
74    let graph = build_call_graph(items);
75    let user_fns = user_fn_names(items);
76    let sccs = recursive_sccs(&graph, &user_fns);
77    let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
78    for scc in sccs {
79        let members: HashSet<String> = scc.iter().cloned().collect();
80        for name in scc {
81            scc_members.insert(name, members.clone());
82        }
83    }
84
85    let mut out = HashMap::new();
86    for item in items {
87        if let TopLevel::FnDef(fd) = item {
88            let mut count = 0usize;
89            if let Some(members) = scc_members.get(&fd.name) {
90                count_recursive_calls_body(&fd.body, members, &mut count);
91            }
92            out.insert(fd.name.clone(), count);
93        }
94    }
95    out
96}
97
98/// Deterministic recursive SCC id per function (1-based).
99/// Non-recursive functions are absent from the returned map.
100pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
101    let graph = build_call_graph(items);
102    let user_fns = user_fn_names(items);
103    let mut sccs = recursive_sccs(&graph, &user_fns);
104    for scc in &mut sccs {
105        scc.sort();
106    }
107    sccs.sort_by(|a, b| a.first().cmp(&b.first()));
108
109    let mut out = HashMap::new();
110    for (idx, scc) in sccs.into_iter().enumerate() {
111        let id = idx + 1;
112        for name in scc {
113            out.insert(name, id);
114        }
115    }
116    out
117}
118
119fn canonical_codegen_dep(
120    name: &str,
121    fn_names: &HashSet<String>,
122    module_prefixes: &HashSet<String>,
123) -> Option<String> {
124    if fn_names.contains(name) {
125        return Some(name.to_string());
126    }
127
128    let mut best_prefix: Option<&str> = None;
129    for prefix in module_prefixes {
130        let dotted_prefix = format!("{}.", prefix);
131        if name.starts_with(&dotted_prefix)
132            && best_prefix.is_none_or(|best| prefix.len() > best.len())
133        {
134            best_prefix = Some(prefix.as_str());
135        }
136    }
137
138    let prefix = best_prefix?;
139    let bare = &name[prefix.len() + 1..];
140    fn_names.contains(bare).then(|| bare.to_string())
141}
142
143fn collect_codegen_deps_body(
144    body: &FnBody,
145    fn_names: &HashSet<String>,
146    module_prefixes: &HashSet<String>,
147    out: &mut HashSet<String>,
148) {
149    for s in body.stmts() {
150        match s {
151            Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
152                collect_codegen_deps_expr(e, fn_names, module_prefixes, out)
153            }
154        }
155    }
156}
157
158fn collect_codegen_deps_expr(
159    expr: &Expr,
160    fn_names: &HashSet<String>,
161    module_prefixes: &HashSet<String>,
162    out: &mut HashSet<String>,
163) {
164    walk_expr(expr, &mut |node| match node {
165        Expr::FnCall(func, args) => {
166            if let Some(callee) = expr_to_dotted_name(func.as_ref())
167                && let Some(canonical) = canonical_codegen_dep(&callee, fn_names, module_prefixes)
168            {
169                out.insert(canonical);
170            }
171            for arg in args {
172                // function-as-value dependency, e.g. apply(f, x)
173                if let Some(qname) = expr_to_dotted_name(arg)
174                    && let Some(canonical) =
175                        canonical_codegen_dep(&qname, fn_names, module_prefixes)
176                {
177                    out.insert(canonical);
178                }
179            }
180        }
181        Expr::TailCall(boxed) => {
182            if fn_names.contains(&boxed.0) {
183                out.insert(boxed.0.clone());
184            }
185        }
186        _ => {}
187    });
188}
189
190fn expr_to_dotted_name(expr: &Expr) -> Option<String> {
191    match expr {
192        Expr::Ident(name) => Some(name.clone()),
193        Expr::Attr(obj, field) => {
194            let head = expr_to_dotted_name(obj)?;
195            Some(format!("{}.{}", head, field))
196        }
197        _ => None,
198    }
199}
200
201fn walk_expr(expr: &Expr, visit: &mut impl FnMut(&Expr)) {
202    visit(expr);
203    match expr {
204        Expr::FnCall(func, args) => {
205            walk_expr(func, visit);
206            for arg in args {
207                walk_expr(arg, visit);
208            }
209        }
210        Expr::TailCall(boxed) => {
211            for arg in &boxed.1 {
212                walk_expr(arg, visit);
213            }
214        }
215        Expr::Attr(obj, _) => walk_expr(obj, visit),
216        Expr::BinOp(_, l, r) => {
217            walk_expr(l, visit);
218            walk_expr(r, visit);
219        }
220        Expr::Match { subject, arms, .. } => {
221            walk_expr(subject, visit);
222            for arm in arms {
223                walk_expr(&arm.body, visit);
224            }
225        }
226        Expr::List(items) | Expr::Tuple(items) => {
227            for item in items {
228                walk_expr(item, visit);
229            }
230        }
231        Expr::MapLiteral(entries) => {
232            for (k, v) in entries {
233                walk_expr(k, visit);
234                walk_expr(v, visit);
235            }
236        }
237        Expr::Constructor(_, maybe) => {
238            if let Some(inner) = maybe {
239                walk_expr(inner, visit);
240            }
241        }
242        Expr::ErrorProp(inner) => walk_expr(inner, visit),
243        Expr::InterpolatedStr(parts) => {
244            for part in parts {
245                if let StrPart::Parsed(e) = part {
246                    walk_expr(e, visit);
247                }
248            }
249        }
250        Expr::RecordCreate { fields, .. } => {
251            for (_, e) in fields {
252                walk_expr(e, visit);
253            }
254        }
255        Expr::RecordUpdate { base, updates, .. } => {
256            walk_expr(base, visit);
257            for (_, e) in updates {
258                walk_expr(e, visit);
259            }
260        }
261        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => {}
262    }
263}
264
265// ---------------------------------------------------------------------------
266// Call graph construction
267// ---------------------------------------------------------------------------
268
269fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
270    let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
271    for item in items {
272        if let TopLevel::FnDef(fd) = item {
273            let mut callees = HashSet::new();
274            collect_callees_body(&fd.body, &mut callees);
275            graph.insert(fd.name.clone(), callees);
276        }
277    }
278    graph
279}
280
281fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
282    items
283        .iter()
284        .filter_map(|item| {
285            if let TopLevel::FnDef(fd) = item {
286                Some(fd.name.clone())
287            } else {
288                None
289            }
290        })
291        .collect()
292}
293
294fn recursive_sccs(
295    graph: &HashMap<String, HashSet<String>>,
296    user_fns: &HashSet<String>,
297) -> Vec<Vec<String>> {
298    let mut names = user_fns.iter().cloned().collect::<Vec<_>>();
299    names.sort();
300
301    let mut adj: HashMap<String, Vec<String>> = HashMap::new();
302    for name in &names {
303        let mut deps = graph
304            .get(name)
305            .cloned()
306            .unwrap_or_default()
307            .into_iter()
308            .filter(|callee| user_fns.contains(callee))
309            .collect::<Vec<_>>();
310        deps.sort();
311        adj.insert(name.clone(), deps);
312    }
313
314    scc::tarjan_sccs(&names, &adj)
315        .into_iter()
316        .filter(|scc| is_recursive_scc(scc, graph))
317        .collect()
318}
319
320fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
321    if scc.len() > 1 {
322        return true;
323    }
324    if let Some(name) = scc.first() {
325        return graph
326            .get(name)
327            .is_some_and(|callees| callees.contains(name));
328    }
329    false
330}
331
332pub(crate) fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
333    for s in body.stmts() {
334        collect_callees_stmt(s, callees);
335    }
336}
337
338fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
339    for s in body.stmts() {
340        count_recursive_calls_stmt(s, recursive, out);
341    }
342}
343
344fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
345    match stmt {
346        Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
347    }
348}
349
350fn count_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>, out: &mut usize) {
351    match expr {
352        Expr::FnCall(func, args) => {
353            match func.as_ref() {
354                Expr::Ident(name) => {
355                    if recursive.contains(name) {
356                        *out += 1;
357                    }
358                }
359                Expr::Attr(obj, member) => {
360                    if let Expr::Ident(ns) = obj.as_ref() {
361                        let q = format!("{}.{}", ns, member);
362                        if recursive.contains(&q) {
363                            *out += 1;
364                        }
365                    } else {
366                        count_recursive_calls_expr(obj, recursive, out);
367                    }
368                }
369                other => count_recursive_calls_expr(other, recursive, out),
370            }
371            for arg in args {
372                count_recursive_calls_expr(arg, recursive, out);
373            }
374        }
375        Expr::TailCall(boxed) => {
376            if recursive.contains(&boxed.0) {
377                *out += 1;
378            }
379            for arg in &boxed.1 {
380                count_recursive_calls_expr(arg, recursive, out);
381            }
382        }
383        Expr::Literal(_) | Expr::Resolved(_) | Expr::Ident(_) => {}
384        Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
385        Expr::BinOp(_, l, r) => {
386            count_recursive_calls_expr(l, recursive, out);
387            count_recursive_calls_expr(r, recursive, out);
388        }
389        Expr::Match {
390            subject: scrutinee,
391            arms,
392            ..
393        } => {
394            count_recursive_calls_expr(scrutinee, recursive, out);
395            for arm in arms {
396                count_recursive_calls_expr(&arm.body, recursive, out);
397            }
398        }
399        Expr::List(elems) | Expr::Tuple(elems) => {
400            for e in elems {
401                count_recursive_calls_expr(e, recursive, out);
402            }
403        }
404        Expr::MapLiteral(entries) => {
405            for (k, v) in entries {
406                count_recursive_calls_expr(k, recursive, out);
407                count_recursive_calls_expr(v, recursive, out);
408            }
409        }
410        Expr::Constructor(_, arg) => {
411            if let Some(a) = arg {
412                count_recursive_calls_expr(a, recursive, out);
413            }
414        }
415        Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
416        Expr::InterpolatedStr(parts) => {
417            for part in parts {
418                if let crate::ast::StrPart::Parsed(expr) = part {
419                    count_recursive_calls_expr(expr, recursive, out);
420                }
421            }
422        }
423        Expr::RecordCreate { fields, .. } => {
424            for (_, e) in fields {
425                count_recursive_calls_expr(e, recursive, out);
426            }
427        }
428        Expr::RecordUpdate { base, updates, .. } => {
429            count_recursive_calls_expr(base, recursive, out);
430            for (_, e) in updates {
431                count_recursive_calls_expr(e, recursive, out);
432            }
433        }
434    }
435}
436
437fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
438    match stmt {
439        Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
440            collect_callees_expr(e, callees);
441        }
442    }
443}
444
445fn collect_callees_expr(expr: &Expr, callees: &mut HashSet<String>) {
446    walk_expr(expr, &mut |node| match node {
447        Expr::FnCall(func, _) => {
448            if let Some(callee) = expr_to_dotted_name(func.as_ref()) {
449                callees.insert(callee);
450            }
451        }
452        Expr::TailCall(boxed) => {
453            callees.insert(boxed.0.clone());
454        }
455        _ => {}
456    });
457}
458
459#[cfg(test)]
460mod tests;