Skip to main content

aver/call_graph/
codegen.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, FnDef, Stmt};
4
5use super::collect_codegen_deps_body;
6use super::scc::{tarjan_sccs, topo_components};
7
8/// Deterministic function emission order for codegen backends.
9///
10/// Returns SCC components in callee-before-caller topological order.
11/// Each inner vector is one SCC (single function or mutual-recursive group).
12/// Function references passed as call arguments (e.g. `apply(f, x)`)
13/// are treated as dependencies for ordering.
14pub fn ordered_fn_components<'a>(
15    fns: &[&'a FnDef],
16    module_prefixes: &HashSet<String>,
17) -> Vec<Vec<&'a FnDef>> {
18    if fns.is_empty() {
19        return vec![];
20    }
21
22    let fn_map: HashMap<String, &FnDef> = fns.iter().map(|fd| (fd.name.clone(), *fd)).collect();
23    let names: Vec<String> = fn_map.keys().cloned().collect();
24    let name_set: HashSet<String> = names.iter().cloned().collect();
25
26    let mut graph: HashMap<String, Vec<String>> = HashMap::new();
27    for fd in fns {
28        let mut deps = HashSet::new();
29        collect_codegen_deps_body(&fd.body, &name_set, module_prefixes, &mut deps);
30        let mut sorted = deps.into_iter().collect::<Vec<_>>();
31        sorted.sort();
32        graph.insert(fd.name.clone(), sorted);
33    }
34
35    let sccs = tarjan_sccs(&names, &graph);
36    let mut comp_of: HashMap<String, usize> = HashMap::new();
37    for (idx, comp) in sccs.iter().enumerate() {
38        for name in comp {
39            comp_of.insert(name.clone(), idx);
40        }
41    }
42
43    let mut comp_graph: HashMap<usize, HashSet<usize>> = HashMap::new();
44    for (caller, deps) in &graph {
45        let from = comp_of[caller];
46        for callee in deps {
47            let to = comp_of[callee];
48            if from != to {
49                comp_graph.entry(from).or_default().insert(to);
50            }
51        }
52    }
53
54    let comp_order = topo_components(&sccs, &comp_graph);
55    comp_order
56        .into_iter()
57        .map(|idx| {
58            let mut group: Vec<&FnDef> = sccs[idx]
59                .iter()
60                .filter_map(|name| fn_map.get(name).copied())
61                .collect();
62            group.sort_by(|a, b| a.name.cmp(&b.name));
63            group
64        })
65        .collect()
66}
67
68/// Tail-call SCCs for mutual-trampoline codegen.
69///
70/// Returns only components with more than one function, sorted deterministically
71/// by the first function name in each SCC.
72pub fn tailcall_scc_components<'a>(fns: &[&'a FnDef]) -> Vec<Vec<&'a FnDef>> {
73    if fns.is_empty() {
74        return vec![];
75    }
76
77    let fn_map: HashMap<String, &FnDef> = fns.iter().map(|fd| (fd.name.clone(), *fd)).collect();
78    let names: Vec<String> = fn_map.keys().cloned().collect();
79    let name_set: HashSet<String> = names.iter().cloned().collect();
80
81    let mut graph: HashMap<String, Vec<String>> = HashMap::new();
82    for fd in fns {
83        let mut deps = HashSet::new();
84        collect_tailcall_deps_body(&fd.body, &name_set, &mut deps);
85        let mut sorted = deps.into_iter().collect::<Vec<_>>();
86        sorted.sort();
87        graph.insert(fd.name.clone(), sorted);
88    }
89
90    tarjan_sccs(&names, &graph)
91        .into_iter()
92        .filter(|comp| comp.len() > 1)
93        .map(|comp| {
94            let mut group: Vec<&FnDef> = comp
95                .iter()
96                .filter_map(|name| fn_map.get(name).copied())
97                .collect();
98            group.sort_by(|a, b| a.name.cmp(&b.name));
99            group
100        })
101        .collect()
102}
103
104fn collect_tailcall_deps_body(
105    body: &FnBody,
106    fn_names: &HashSet<String>,
107    out: &mut HashSet<String>,
108) {
109    for stmt in body.stmts() {
110        match stmt {
111            Stmt::Expr(expr) | Stmt::Binding(_, _, expr) => {
112                collect_tailcall_deps_expr(expr, fn_names, out);
113            }
114        }
115    }
116}
117
118fn collect_tailcall_deps_expr(expr: &Expr, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
119    match expr {
120        Expr::TailCall(boxed) => {
121            let (target, _) = boxed.as_ref();
122            if fn_names.contains(target) {
123                out.insert(target.clone());
124            }
125        }
126        Expr::Match { arms, .. } => {
127            for arm in arms {
128                collect_tailcall_deps_expr(&arm.body, fn_names, out);
129            }
130        }
131        _ => {}
132    }
133}