aver/call_graph/
codegen.rs1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, FnDef, Stmt};
4
5use super::collect_codegen_deps_body;
6use super::scc::{tarjan_sccs, topo_components};
7
8pub fn ordered_fn_components<'a>(
15 fns: &[&'a FnDef],
16 module_prefixes: &HashSet<String>,
17) -> Vec<Vec<&'a FnDef>> {
18 if fns.is_empty() {
19 return vec![];
20 }
21
22 let fn_map: HashMap<String, &FnDef> = fns.iter().map(|fd| (fd.name.clone(), *fd)).collect();
23 let names: Vec<String> = fn_map.keys().cloned().collect();
24 let name_set: HashSet<String> = names.iter().cloned().collect();
25
26 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
27 for fd in fns {
28 let mut deps = HashSet::new();
29 collect_codegen_deps_body(&fd.body, &name_set, module_prefixes, &mut deps);
30 let mut sorted = deps.into_iter().collect::<Vec<_>>();
31 sorted.sort();
32 graph.insert(fd.name.clone(), sorted);
33 }
34
35 let sccs = tarjan_sccs(&names, &graph);
36 let mut comp_of: HashMap<String, usize> = HashMap::new();
37 for (idx, comp) in sccs.iter().enumerate() {
38 for name in comp {
39 comp_of.insert(name.clone(), idx);
40 }
41 }
42
43 let mut comp_graph: HashMap<usize, HashSet<usize>> = HashMap::new();
44 for (caller, deps) in &graph {
45 let from = comp_of[caller];
46 for callee in deps {
47 let to = comp_of[callee];
48 if from != to {
49 comp_graph.entry(from).or_default().insert(to);
50 }
51 }
52 }
53
54 let comp_order = topo_components(&sccs, &comp_graph);
55 comp_order
56 .into_iter()
57 .map(|idx| {
58 let mut group: Vec<&FnDef> = sccs[idx]
59 .iter()
60 .filter_map(|name| fn_map.get(name).copied())
61 .collect();
62 group.sort_by(|a, b| a.name.cmp(&b.name));
63 group
64 })
65 .collect()
66}
67
68pub fn tailcall_scc_components<'a>(fns: &[&'a FnDef]) -> Vec<Vec<&'a FnDef>> {
73 if fns.is_empty() {
74 return vec![];
75 }
76
77 let fn_map: HashMap<String, &FnDef> = fns.iter().map(|fd| (fd.name.clone(), *fd)).collect();
78 let names: Vec<String> = fn_map.keys().cloned().collect();
79 let name_set: HashSet<String> = names.iter().cloned().collect();
80
81 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
82 for fd in fns {
83 let mut deps = HashSet::new();
84 collect_tailcall_deps_body(&fd.body, &name_set, &mut deps);
85 let mut sorted = deps.into_iter().collect::<Vec<_>>();
86 sorted.sort();
87 graph.insert(fd.name.clone(), sorted);
88 }
89
90 tarjan_sccs(&names, &graph)
91 .into_iter()
92 .filter(|comp| comp.len() > 1)
93 .map(|comp| {
94 let mut group: Vec<&FnDef> = comp
95 .iter()
96 .filter_map(|name| fn_map.get(name).copied())
97 .collect();
98 group.sort_by(|a, b| a.name.cmp(&b.name));
99 group
100 })
101 .collect()
102}
103
104fn collect_tailcall_deps_body(
105 body: &FnBody,
106 fn_names: &HashSet<String>,
107 out: &mut HashSet<String>,
108) {
109 for stmt in body.stmts() {
110 match stmt {
111 Stmt::Expr(expr) | Stmt::Binding(_, _, expr) => {
112 collect_tailcall_deps_expr(expr, fn_names, out);
113 }
114 }
115 }
116}
117
118fn collect_tailcall_deps_expr(expr: &Expr, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
119 match expr {
120 Expr::TailCall(boxed) => {
121 let (target, _) = boxed.as_ref();
122 if fn_names.contains(target) {
123 out.insert(target.clone());
124 }
125 }
126 Expr::Match { arms, .. } => {
127 for arm in arms {
128 collect_tailcall_deps_expr(&arm.body, fn_names, out);
129 }
130 }
131 _ => {}
132 }
133}