1use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, Spanned, Stmt, StrPart, TopLevel};
10
11mod codegen;
12mod scc;
13
14pub use codegen::{ordered_fn_components, tailcall_scc_components};
15
16pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
23 let graph = build_call_graph(items);
24 let user_fns = user_fn_names(items);
25 recursive_sccs(&graph, &user_fns)
26 .into_iter()
27 .map(|scc| scc.into_iter().collect())
28 .collect()
29}
30
31pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
34 let graph = build_call_graph(items);
35 let user_fns = user_fn_names(items);
36 let mut recursive = HashSet::new();
37 for scc in recursive_sccs(&graph, &user_fns) {
38 for name in scc {
39 recursive.insert(name);
40 }
41 }
42 recursive
43}
44
45pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
47 let graph = build_call_graph(items);
48 let mut out = HashMap::new();
49 for item in items {
50 if let TopLevel::FnDef(fd) = item {
51 let mut callees = graph
52 .get(&fd.name)
53 .cloned()
54 .unwrap_or_default()
55 .into_iter()
56 .collect::<Vec<_>>();
57 callees.sort();
58 out.insert(fd.name.clone(), callees);
59 }
60 }
61 out
62}
63
64pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
74 let graph = build_call_graph(items);
75 let user_fns = user_fn_names(items);
76 let sccs = recursive_sccs(&graph, &user_fns);
77 let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
78 for scc in sccs {
79 let members: HashSet<String> = scc.iter().cloned().collect();
80 for name in scc {
81 scc_members.insert(name, members.clone());
82 }
83 }
84
85 let mut out = HashMap::new();
86 for item in items {
87 if let TopLevel::FnDef(fd) = item {
88 let mut count = 0usize;
89 if let Some(members) = scc_members.get(&fd.name) {
90 count_recursive_calls_body(&fd.body, members, &mut count);
91 }
92 out.insert(fd.name.clone(), count);
93 }
94 }
95 out
96}
97
98pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
101 let graph = build_call_graph(items);
102 let user_fns = user_fn_names(items);
103 let mut sccs = recursive_sccs(&graph, &user_fns);
104 for scc in &mut sccs {
105 scc.sort();
106 }
107 sccs.sort_by(|a, b| a.first().cmp(&b.first()));
108
109 let mut out = HashMap::new();
110 for (idx, scc) in sccs.into_iter().enumerate() {
111 let id = idx + 1;
112 for name in scc {
113 out.insert(name, id);
114 }
115 }
116 out
117}
118
119fn canonical_codegen_dep(
120 name: &str,
121 fn_names: &HashSet<String>,
122 module_prefixes: &HashSet<String>,
123) -> Option<String> {
124 if fn_names.contains(name) {
125 return Some(name.to_string());
126 }
127
128 let mut best_prefix: Option<&str> = None;
129 for prefix in module_prefixes {
130 let dotted_prefix = format!("{}.", prefix);
131 if name.starts_with(&dotted_prefix)
132 && best_prefix.is_none_or(|best| prefix.len() > best.len())
133 {
134 best_prefix = Some(prefix.as_str());
135 }
136 }
137
138 let prefix = best_prefix?;
139 let bare = &name[prefix.len() + 1..];
140 fn_names.contains(bare).then(|| bare.to_string())
141}
142
143fn collect_codegen_deps_body(
144 body: &FnBody,
145 fn_names: &HashSet<String>,
146 module_prefixes: &HashSet<String>,
147 out: &mut HashSet<String>,
148) {
149 for s in body.stmts() {
150 match s {
151 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
152 collect_codegen_deps_expr(e, fn_names, module_prefixes, out)
153 }
154 }
155 }
156}
157
158fn collect_codegen_deps_expr(
159 expr: &Spanned<Expr>,
160 fn_names: &HashSet<String>,
161 module_prefixes: &HashSet<String>,
162 out: &mut HashSet<String>,
163) {
164 walk_expr(expr, &mut |node| match node {
165 Expr::FnCall(func, args) => {
166 if let Some(callee) = expr_to_dotted_name(func.as_ref())
167 && let Some(canonical) = canonical_codegen_dep(&callee, fn_names, module_prefixes)
168 {
169 out.insert(canonical);
170 }
171 for arg in args {
172 if let Some(qname) = expr_to_dotted_name(arg)
174 && let Some(canonical) =
175 canonical_codegen_dep(&qname, fn_names, module_prefixes)
176 {
177 out.insert(canonical);
178 }
179 }
180 }
181 Expr::TailCall(boxed) if fn_names.contains(&boxed.target) => {
182 out.insert(boxed.target.clone());
183 }
184 _ => {}
185 });
186}
187
188fn expr_to_dotted_name(expr: &Spanned<Expr>) -> Option<String> {
189 match &expr.node {
190 Expr::Ident(name) => Some(name.clone()),
191 Expr::Attr(obj, field) => {
192 let head = expr_to_dotted_name(obj)?;
193 Some(format!("{}.{}", head, field))
194 }
195 _ => None,
196 }
197}
198
199fn walk_expr(expr: &Spanned<Expr>, visit: &mut impl FnMut(&Expr)) {
200 visit(&expr.node);
201 match &expr.node {
202 Expr::FnCall(func, args) => {
203 walk_expr(func, visit);
204 for arg in args {
205 walk_expr(arg, visit);
206 }
207 }
208 Expr::TailCall(boxed) => {
209 for arg in &boxed.args {
210 walk_expr(arg, visit);
211 }
212 }
213 Expr::Attr(obj, _) => walk_expr(obj, visit),
214 Expr::BinOp(_, l, r) => {
215 walk_expr(l, visit);
216 walk_expr(r, visit);
217 }
218 Expr::Match { subject, arms, .. } => {
219 walk_expr(subject, visit);
220 for arm in arms {
221 walk_expr(&arm.body, visit);
222 }
223 }
224 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
225 for item in items {
226 walk_expr(item, visit);
227 }
228 }
229 Expr::MapLiteral(entries) => {
230 for (k, v) in entries {
231 walk_expr(k, visit);
232 walk_expr(v, visit);
233 }
234 }
235 Expr::Constructor(_, maybe) => {
236 if let Some(inner) = maybe {
237 walk_expr(inner, visit);
238 }
239 }
240 Expr::ErrorProp(inner) => walk_expr(inner, visit),
241 Expr::InterpolatedStr(parts) => {
242 for part in parts {
243 if let StrPart::Parsed(e) = part {
244 walk_expr(e, visit);
245 }
246 }
247 }
248 Expr::RecordCreate { fields, .. } => {
249 for (_, e) in fields {
250 walk_expr(e, visit);
251 }
252 }
253 Expr::RecordUpdate { base, updates, .. } => {
254 walk_expr(base, visit);
255 for (_, e) in updates {
256 walk_expr(e, visit);
257 }
258 }
259 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } => {}
260 }
261}
262
263fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
268 let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
269 for item in items {
270 if let TopLevel::FnDef(fd) = item {
271 let mut callees = HashSet::new();
272 collect_callees_body(&fd.body, &mut callees);
273 graph.insert(fd.name.clone(), callees);
274 }
275 }
276 graph
277}
278
279fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
280 items
281 .iter()
282 .filter_map(|item| {
283 if let TopLevel::FnDef(fd) = item {
284 Some(fd.name.clone())
285 } else {
286 None
287 }
288 })
289 .collect()
290}
291
292fn recursive_sccs(
293 graph: &HashMap<String, HashSet<String>>,
294 user_fns: &HashSet<String>,
295) -> Vec<Vec<String>> {
296 let mut names = user_fns.iter().cloned().collect::<Vec<_>>();
297 names.sort();
298
299 let mut adj: HashMap<String, Vec<String>> = HashMap::new();
300 for name in &names {
301 let mut deps = graph
302 .get(name)
303 .cloned()
304 .unwrap_or_default()
305 .into_iter()
306 .filter(|callee| user_fns.contains(callee))
307 .collect::<Vec<_>>();
308 deps.sort();
309 adj.insert(name.clone(), deps);
310 }
311
312 scc::tarjan_sccs(&names, &adj)
313 .into_iter()
314 .filter(|scc| is_recursive_scc(scc, graph))
315 .collect()
316}
317
318fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
319 if scc.len() > 1 {
320 return true;
321 }
322 if let Some(name) = scc.first() {
323 return graph
324 .get(name)
325 .is_some_and(|callees| callees.contains(name));
326 }
327 false
328}
329
330pub(crate) fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
331 for s in body.stmts() {
332 collect_callees_stmt(s, callees);
333 }
334}
335
336fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
337 for s in body.stmts() {
338 count_recursive_calls_stmt(s, recursive, out);
339 }
340}
341
342fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
343 match stmt {
344 Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
345 }
346}
347
348fn count_recursive_calls_expr(expr: &Spanned<Expr>, recursive: &HashSet<String>, out: &mut usize) {
349 match &expr.node {
350 Expr::FnCall(func, args) => {
351 match &func.node {
352 Expr::Ident(name) => {
353 if recursive.contains(name) {
354 *out += 1;
355 }
356 }
357 Expr::Attr(obj, member) => {
358 if let Expr::Ident(ns) = &obj.node {
359 let q = format!("{}.{}", ns, member);
360 if recursive.contains(&q) {
361 *out += 1;
362 }
363 } else {
364 count_recursive_calls_expr(obj, recursive, out);
365 }
366 }
367 _ => count_recursive_calls_expr(func, recursive, out),
368 }
369 for arg in args {
370 count_recursive_calls_expr(arg, recursive, out);
371 }
372 }
373 Expr::TailCall(boxed) => {
374 if recursive.contains(&boxed.target) {
375 *out += 1;
376 }
377 for arg in &boxed.args {
378 count_recursive_calls_expr(arg, recursive, out);
379 }
380 }
381 Expr::Literal(_) | Expr::Resolved { .. } | Expr::Ident(_) => {}
382 Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
383 Expr::BinOp(_, l, r) => {
384 count_recursive_calls_expr(l, recursive, out);
385 count_recursive_calls_expr(r, recursive, out);
386 }
387 Expr::Match {
388 subject: scrutinee,
389 arms,
390 ..
391 } => {
392 count_recursive_calls_expr(scrutinee, recursive, out);
393 for arm in arms {
394 count_recursive_calls_expr(&arm.body, recursive, out);
395 }
396 }
397 Expr::List(elems) | Expr::Tuple(elems) | Expr::IndependentProduct(elems, _) => {
398 for e in elems {
399 count_recursive_calls_expr(e, recursive, out);
400 }
401 }
402 Expr::MapLiteral(entries) => {
403 for (k, v) in entries {
404 count_recursive_calls_expr(k, recursive, out);
405 count_recursive_calls_expr(v, recursive, out);
406 }
407 }
408 Expr::Constructor(_, arg) => {
409 if let Some(a) = arg {
410 count_recursive_calls_expr(a, recursive, out);
411 }
412 }
413 Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
414 Expr::InterpolatedStr(parts) => {
415 for part in parts {
416 if let crate::ast::StrPart::Parsed(expr) = part {
417 count_recursive_calls_expr(expr, recursive, out);
418 }
419 }
420 }
421 Expr::RecordCreate { fields, .. } => {
422 for (_, e) in fields {
423 count_recursive_calls_expr(e, recursive, out);
424 }
425 }
426 Expr::RecordUpdate { base, updates, .. } => {
427 count_recursive_calls_expr(base, recursive, out);
428 for (_, e) in updates {
429 count_recursive_calls_expr(e, recursive, out);
430 }
431 }
432 }
433}
434
435fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
436 match stmt {
437 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
438 collect_callees_expr(e, callees);
439 }
440 }
441}
442
443fn collect_callees_expr(expr: &Spanned<Expr>, callees: &mut HashSet<String>) {
444 walk_expr(expr, &mut |node| match node {
445 Expr::FnCall(func, _) => {
446 if let Some(callee) = expr_to_dotted_name(func.as_ref()) {
447 callees.insert(callee);
448 }
449 }
450 Expr::TailCall(boxed) => {
451 callees.insert(boxed.target.clone());
452 }
453 _ => {}
454 });
455}
456
457#[cfg(test)]
458mod tests;