Skip to main content

aver/
call_graph.rs

1/// Call-graph analysis and Tarjan's SCC algorithm.
2///
3/// Given a parsed program, builds a directed graph of function calls
4/// and finds strongly-connected components.  A function is *recursive*
5/// if it belongs to an SCC with a cycle (size > 1, or size 1 with a
6/// self-edge).
7use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
10
11mod codegen;
12mod scc;
13
14pub use codegen::ordered_fn_components;
15
16// ---------------------------------------------------------------------------
17// Public API
18// ---------------------------------------------------------------------------
19
20/// Returns the SCC groups that contain cycles (self or mutual recursion).
21/// Each group is a `HashSet<String>` of function names in the SCC.
22pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
23    let graph = build_call_graph(items);
24    let user_fns = user_fn_names(items);
25    recursive_sccs(&graph, &user_fns)
26        .into_iter()
27        .map(|scc| scc.into_iter().collect())
28        .collect()
29}
30
31/// Returns the set of user-defined function names that are recursive
32/// (directly or mutually).
33pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
34    let graph = build_call_graph(items);
35    let user_fns = user_fn_names(items);
36    let mut recursive = HashSet::new();
37    for scc in recursive_sccs(&graph, &user_fns) {
38        for name in scc {
39            recursive.insert(name);
40        }
41    }
42    recursive
43}
44
45/// Direct call summary per user-defined function (unique + sorted).
46pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
47    let graph = build_call_graph(items);
48    let mut out = HashMap::new();
49    for item in items {
50        if let TopLevel::FnDef(fd) = item {
51            let mut callees = graph
52                .get(&fd.name)
53                .cloned()
54                .unwrap_or_default()
55                .into_iter()
56                .collect::<Vec<_>>();
57            callees.sort();
58            out.insert(fd.name.clone(), callees);
59        }
60    }
61    out
62}
63
64/// Count recursive callsites per user-defined function, scoped to caller SCC.
65///
66/// Callsite definition:
67/// - one syntactic `FnCall` or `TailCall` node in the function body,
68/// - whose callee is a user-defined function in the same recursive SCC
69///   as the caller.
70///
71/// This is a syntactic metric over AST nodes (not dynamic execution count,
72/// not CFG edges), so it stays stable across control-flow rewrites.
73pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
74    let graph = build_call_graph(items);
75    let user_fns = user_fn_names(items);
76    let sccs = recursive_sccs(&graph, &user_fns);
77    let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
78    for scc in sccs {
79        let members: HashSet<String> = scc.iter().cloned().collect();
80        for name in scc {
81            scc_members.insert(name, members.clone());
82        }
83    }
84
85    let mut out = HashMap::new();
86    for item in items {
87        if let TopLevel::FnDef(fd) = item {
88            let mut count = 0usize;
89            if let Some(members) = scc_members.get(&fd.name) {
90                count_recursive_calls_body(&fd.body, members, &mut count);
91            }
92            out.insert(fd.name.clone(), count);
93        }
94    }
95    out
96}
97
98/// Deterministic recursive SCC id per function (1-based).
99/// Non-recursive functions are absent from the returned map.
100pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
101    let graph = build_call_graph(items);
102    let user_fns = user_fn_names(items);
103    let mut sccs = recursive_sccs(&graph, &user_fns);
104    for scc in &mut sccs {
105        scc.sort();
106    }
107    sccs.sort_by(|a, b| a.first().cmp(&b.first()));
108
109    let mut out = HashMap::new();
110    for (idx, scc) in sccs.into_iter().enumerate() {
111        let id = idx + 1;
112        for name in scc {
113            out.insert(name, id);
114        }
115    }
116    out
117}
118
119fn collect_codegen_deps_body(body: &FnBody, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
120    match body {
121        FnBody::Expr(e) => collect_codegen_deps_expr(e, fn_names, out),
122        FnBody::Block(stmts) => {
123            for s in stmts {
124                match s {
125                    Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
126                        collect_codegen_deps_expr(e, fn_names, out)
127                    }
128                }
129            }
130        }
131    }
132}
133
134fn collect_codegen_deps_expr(expr: &Expr, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
135    walk_expr(expr, &mut |node| match node {
136        Expr::FnCall(func, args) => {
137            if let Some(callee) = expr_to_dotted_name(func.as_ref())
138                && fn_names.contains(&callee)
139            {
140                out.insert(callee);
141            }
142            for arg in args {
143                // function-as-value dependency, e.g. List.fold(xs, init, f)
144                if let Some(qname) = expr_to_dotted_name(arg)
145                    && fn_names.contains(&qname)
146                {
147                    out.insert(qname);
148                }
149            }
150        }
151        Expr::TailCall(boxed) => {
152            if fn_names.contains(&boxed.0) {
153                out.insert(boxed.0.clone());
154            }
155        }
156        _ => {}
157    });
158}
159
160fn expr_to_dotted_name(expr: &Expr) -> Option<String> {
161    match expr {
162        Expr::Ident(name) => Some(name.clone()),
163        Expr::Attr(obj, field) => {
164            let head = expr_to_dotted_name(obj)?;
165            Some(format!("{}.{}", head, field))
166        }
167        _ => None,
168    }
169}
170
171fn walk_expr(expr: &Expr, visit: &mut impl FnMut(&Expr)) {
172    visit(expr);
173    match expr {
174        Expr::FnCall(func, args) => {
175            walk_expr(func, visit);
176            for arg in args {
177                walk_expr(arg, visit);
178            }
179        }
180        Expr::TailCall(boxed) => {
181            for arg in &boxed.1 {
182                walk_expr(arg, visit);
183            }
184        }
185        Expr::Attr(obj, _) => walk_expr(obj, visit),
186        Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
187            walk_expr(l, visit);
188            walk_expr(r, visit);
189        }
190        Expr::Match { subject, arms, .. } => {
191            walk_expr(subject, visit);
192            for arm in arms {
193                walk_expr(&arm.body, visit);
194            }
195        }
196        Expr::List(items) | Expr::Tuple(items) => {
197            for item in items {
198                walk_expr(item, visit);
199            }
200        }
201        Expr::MapLiteral(entries) => {
202            for (k, v) in entries {
203                walk_expr(k, visit);
204                walk_expr(v, visit);
205            }
206        }
207        Expr::Constructor(_, maybe) => {
208            if let Some(inner) = maybe {
209                walk_expr(inner, visit);
210            }
211        }
212        Expr::ErrorProp(inner) => walk_expr(inner, visit),
213        Expr::InterpolatedStr(parts) => {
214            for part in parts {
215                if let StrPart::Parsed(e) = part {
216                    walk_expr(e, visit);
217                }
218            }
219        }
220        Expr::RecordCreate { fields, .. } => {
221            for (_, e) in fields {
222                walk_expr(e, visit);
223            }
224        }
225        Expr::RecordUpdate { base, updates, .. } => {
226            walk_expr(base, visit);
227            for (_, e) in updates {
228                walk_expr(e, visit);
229            }
230        }
231        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => {}
232    }
233}
234
235// ---------------------------------------------------------------------------
236// Call graph construction
237// ---------------------------------------------------------------------------
238
239fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
240    let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
241    for item in items {
242        if let TopLevel::FnDef(fd) = item {
243            let mut callees = HashSet::new();
244            collect_callees_body(&fd.body, &mut callees);
245            graph.insert(fd.name.clone(), callees);
246        }
247    }
248    graph
249}
250
251fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
252    items
253        .iter()
254        .filter_map(|item| {
255            if let TopLevel::FnDef(fd) = item {
256                Some(fd.name.clone())
257            } else {
258                None
259            }
260        })
261        .collect()
262}
263
264fn recursive_sccs(
265    graph: &HashMap<String, HashSet<String>>,
266    user_fns: &HashSet<String>,
267) -> Vec<Vec<String>> {
268    let mut names = user_fns.iter().cloned().collect::<Vec<_>>();
269    names.sort();
270
271    let mut adj: HashMap<String, Vec<String>> = HashMap::new();
272    for name in &names {
273        let mut deps = graph
274            .get(name)
275            .cloned()
276            .unwrap_or_default()
277            .into_iter()
278            .filter(|callee| user_fns.contains(callee))
279            .collect::<Vec<_>>();
280        deps.sort();
281        adj.insert(name.clone(), deps);
282    }
283
284    scc::tarjan_sccs(&names, &adj)
285        .into_iter()
286        .filter(|scc| is_recursive_scc(scc, graph))
287        .collect()
288}
289
290fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
291    if scc.len() > 1 {
292        return true;
293    }
294    if let Some(name) = scc.first() {
295        return graph
296            .get(name)
297            .is_some_and(|callees| callees.contains(name));
298    }
299    false
300}
301
302pub(crate) fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
303    match body {
304        FnBody::Expr(e) => collect_callees_expr(e, callees),
305        FnBody::Block(stmts) => {
306            for s in stmts {
307                collect_callees_stmt(s, callees);
308            }
309        }
310    }
311}
312
313fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
314    match body {
315        FnBody::Expr(e) => count_recursive_calls_expr(e, recursive, out),
316        FnBody::Block(stmts) => {
317            for s in stmts {
318                count_recursive_calls_stmt(s, recursive, out);
319            }
320        }
321    }
322}
323
324fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
325    match stmt {
326        Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
327    }
328}
329
330fn count_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>, out: &mut usize) {
331    match expr {
332        Expr::FnCall(func, args) => {
333            match func.as_ref() {
334                Expr::Ident(name) => {
335                    if recursive.contains(name) {
336                        *out += 1;
337                    }
338                }
339                Expr::Attr(obj, member) => {
340                    if let Expr::Ident(ns) = obj.as_ref() {
341                        let q = format!("{}.{}", ns, member);
342                        if recursive.contains(&q) {
343                            *out += 1;
344                        }
345                    } else {
346                        count_recursive_calls_expr(obj, recursive, out);
347                    }
348                }
349                other => count_recursive_calls_expr(other, recursive, out),
350            }
351            for arg in args {
352                count_recursive_calls_expr(arg, recursive, out);
353            }
354        }
355        Expr::TailCall(boxed) => {
356            if recursive.contains(&boxed.0) {
357                *out += 1;
358            }
359            for arg in &boxed.1 {
360                count_recursive_calls_expr(arg, recursive, out);
361            }
362        }
363        Expr::Literal(_) | Expr::Resolved(_) | Expr::Ident(_) => {}
364        Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
365        Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
366            count_recursive_calls_expr(l, recursive, out);
367            count_recursive_calls_expr(r, recursive, out);
368        }
369        Expr::Match {
370            subject: scrutinee,
371            arms,
372            ..
373        } => {
374            count_recursive_calls_expr(scrutinee, recursive, out);
375            for arm in arms {
376                count_recursive_calls_expr(&arm.body, recursive, out);
377            }
378        }
379        Expr::List(elems) | Expr::Tuple(elems) => {
380            for e in elems {
381                count_recursive_calls_expr(e, recursive, out);
382            }
383        }
384        Expr::MapLiteral(entries) => {
385            for (k, v) in entries {
386                count_recursive_calls_expr(k, recursive, out);
387                count_recursive_calls_expr(v, recursive, out);
388            }
389        }
390        Expr::Constructor(_, arg) => {
391            if let Some(a) = arg {
392                count_recursive_calls_expr(a, recursive, out);
393            }
394        }
395        Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
396        Expr::InterpolatedStr(parts) => {
397            for part in parts {
398                if let crate::ast::StrPart::Parsed(expr) = part {
399                    count_recursive_calls_expr(expr, recursive, out);
400                }
401            }
402        }
403        Expr::RecordCreate { fields, .. } => {
404            for (_, e) in fields {
405                count_recursive_calls_expr(e, recursive, out);
406            }
407        }
408        Expr::RecordUpdate { base, updates, .. } => {
409            count_recursive_calls_expr(base, recursive, out);
410            for (_, e) in updates {
411                count_recursive_calls_expr(e, recursive, out);
412            }
413        }
414    }
415}
416
417fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
418    match stmt {
419        Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
420            collect_callees_expr(e, callees);
421        }
422    }
423}
424
425fn collect_callees_expr(expr: &Expr, callees: &mut HashSet<String>) {
426    walk_expr(expr, &mut |node| match node {
427        Expr::FnCall(func, _) => {
428            if let Some(callee) = expr_to_dotted_name(func.as_ref()) {
429                callees.insert(callee);
430            }
431        }
432        Expr::TailCall(boxed) => {
433            callees.insert(boxed.0.clone());
434        }
435        _ => {}
436    });
437}
438
439#[cfg(test)]
440mod tests;