1use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
10
11mod codegen;
12mod scc;
13
14pub use codegen::ordered_fn_components;
15
16pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
23 let graph = build_call_graph(items);
24 let user_fns = user_fn_names(items);
25 recursive_sccs(&graph, &user_fns)
26 .into_iter()
27 .map(|scc| scc.into_iter().collect())
28 .collect()
29}
30
31pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
34 let graph = build_call_graph(items);
35 let user_fns = user_fn_names(items);
36 let mut recursive = HashSet::new();
37 for scc in recursive_sccs(&graph, &user_fns) {
38 for name in scc {
39 recursive.insert(name);
40 }
41 }
42 recursive
43}
44
45pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
47 let graph = build_call_graph(items);
48 let mut out = HashMap::new();
49 for item in items {
50 if let TopLevel::FnDef(fd) = item {
51 let mut callees = graph
52 .get(&fd.name)
53 .cloned()
54 .unwrap_or_default()
55 .into_iter()
56 .collect::<Vec<_>>();
57 callees.sort();
58 out.insert(fd.name.clone(), callees);
59 }
60 }
61 out
62}
63
64pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
74 let graph = build_call_graph(items);
75 let user_fns = user_fn_names(items);
76 let sccs = recursive_sccs(&graph, &user_fns);
77 let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
78 for scc in sccs {
79 let members: HashSet<String> = scc.iter().cloned().collect();
80 for name in scc {
81 scc_members.insert(name, members.clone());
82 }
83 }
84
85 let mut out = HashMap::new();
86 for item in items {
87 if let TopLevel::FnDef(fd) = item {
88 let mut count = 0usize;
89 if let Some(members) = scc_members.get(&fd.name) {
90 count_recursive_calls_body(&fd.body, members, &mut count);
91 }
92 out.insert(fd.name.clone(), count);
93 }
94 }
95 out
96}
97
98pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
101 let graph = build_call_graph(items);
102 let user_fns = user_fn_names(items);
103 let mut sccs = recursive_sccs(&graph, &user_fns);
104 for scc in &mut sccs {
105 scc.sort();
106 }
107 sccs.sort_by(|a, b| a.first().cmp(&b.first()));
108
109 let mut out = HashMap::new();
110 for (idx, scc) in sccs.into_iter().enumerate() {
111 let id = idx + 1;
112 for name in scc {
113 out.insert(name, id);
114 }
115 }
116 out
117}
118
119fn canonical_codegen_dep(
120 name: &str,
121 fn_names: &HashSet<String>,
122 module_prefixes: &HashSet<String>,
123) -> Option<String> {
124 if fn_names.contains(name) {
125 return Some(name.to_string());
126 }
127
128 let mut best_prefix: Option<&str> = None;
129 for prefix in module_prefixes {
130 let dotted_prefix = format!("{}.", prefix);
131 if name.starts_with(&dotted_prefix)
132 && best_prefix.is_none_or(|best| prefix.len() > best.len())
133 {
134 best_prefix = Some(prefix.as_str());
135 }
136 }
137
138 let prefix = best_prefix?;
139 let bare = &name[prefix.len() + 1..];
140 fn_names.contains(bare).then(|| bare.to_string())
141}
142
143fn collect_codegen_deps_body(
144 body: &FnBody,
145 fn_names: &HashSet<String>,
146 module_prefixes: &HashSet<String>,
147 out: &mut HashSet<String>,
148) {
149 for s in body.stmts() {
150 match s {
151 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
152 collect_codegen_deps_expr(e, fn_names, module_prefixes, out)
153 }
154 }
155 }
156}
157
158fn collect_codegen_deps_expr(
159 expr: &Expr,
160 fn_names: &HashSet<String>,
161 module_prefixes: &HashSet<String>,
162 out: &mut HashSet<String>,
163) {
164 walk_expr(expr, &mut |node| match node {
165 Expr::FnCall(func, args) => {
166 if let Some(callee) = expr_to_dotted_name(func.as_ref())
167 && let Some(canonical) = canonical_codegen_dep(&callee, fn_names, module_prefixes)
168 {
169 out.insert(canonical);
170 }
171 for arg in args {
172 if let Some(qname) = expr_to_dotted_name(arg)
174 && let Some(canonical) =
175 canonical_codegen_dep(&qname, fn_names, module_prefixes)
176 {
177 out.insert(canonical);
178 }
179 }
180 }
181 Expr::TailCall(boxed) => {
182 if fn_names.contains(&boxed.0) {
183 out.insert(boxed.0.clone());
184 }
185 }
186 _ => {}
187 });
188}
189
190fn expr_to_dotted_name(expr: &Expr) -> Option<String> {
191 match expr {
192 Expr::Ident(name) => Some(name.clone()),
193 Expr::Attr(obj, field) => {
194 let head = expr_to_dotted_name(obj)?;
195 Some(format!("{}.{}", head, field))
196 }
197 _ => None,
198 }
199}
200
201fn walk_expr(expr: &Expr, visit: &mut impl FnMut(&Expr)) {
202 visit(expr);
203 match expr {
204 Expr::FnCall(func, args) => {
205 walk_expr(func, visit);
206 for arg in args {
207 walk_expr(arg, visit);
208 }
209 }
210 Expr::TailCall(boxed) => {
211 for arg in &boxed.1 {
212 walk_expr(arg, visit);
213 }
214 }
215 Expr::Attr(obj, _) => walk_expr(obj, visit),
216 Expr::BinOp(_, l, r) => {
217 walk_expr(l, visit);
218 walk_expr(r, visit);
219 }
220 Expr::Match { subject, arms, .. } => {
221 walk_expr(subject, visit);
222 for arm in arms {
223 walk_expr(&arm.body, visit);
224 }
225 }
226 Expr::List(items) | Expr::Tuple(items) => {
227 for item in items {
228 walk_expr(item, visit);
229 }
230 }
231 Expr::MapLiteral(entries) => {
232 for (k, v) in entries {
233 walk_expr(k, visit);
234 walk_expr(v, visit);
235 }
236 }
237 Expr::Constructor(_, maybe) => {
238 if let Some(inner) = maybe {
239 walk_expr(inner, visit);
240 }
241 }
242 Expr::ErrorProp(inner) => walk_expr(inner, visit),
243 Expr::InterpolatedStr(parts) => {
244 for part in parts {
245 if let StrPart::Parsed(e) = part {
246 walk_expr(e, visit);
247 }
248 }
249 }
250 Expr::RecordCreate { fields, .. } => {
251 for (_, e) in fields {
252 walk_expr(e, visit);
253 }
254 }
255 Expr::RecordUpdate { base, updates, .. } => {
256 walk_expr(base, visit);
257 for (_, e) in updates {
258 walk_expr(e, visit);
259 }
260 }
261 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => {}
262 }
263}
264
265fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
270 let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
271 for item in items {
272 if let TopLevel::FnDef(fd) = item {
273 let mut callees = HashSet::new();
274 collect_callees_body(&fd.body, &mut callees);
275 graph.insert(fd.name.clone(), callees);
276 }
277 }
278 graph
279}
280
281fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
282 items
283 .iter()
284 .filter_map(|item| {
285 if let TopLevel::FnDef(fd) = item {
286 Some(fd.name.clone())
287 } else {
288 None
289 }
290 })
291 .collect()
292}
293
294fn recursive_sccs(
295 graph: &HashMap<String, HashSet<String>>,
296 user_fns: &HashSet<String>,
297) -> Vec<Vec<String>> {
298 let mut names = user_fns.iter().cloned().collect::<Vec<_>>();
299 names.sort();
300
301 let mut adj: HashMap<String, Vec<String>> = HashMap::new();
302 for name in &names {
303 let mut deps = graph
304 .get(name)
305 .cloned()
306 .unwrap_or_default()
307 .into_iter()
308 .filter(|callee| user_fns.contains(callee))
309 .collect::<Vec<_>>();
310 deps.sort();
311 adj.insert(name.clone(), deps);
312 }
313
314 scc::tarjan_sccs(&names, &adj)
315 .into_iter()
316 .filter(|scc| is_recursive_scc(scc, graph))
317 .collect()
318}
319
320fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
321 if scc.len() > 1 {
322 return true;
323 }
324 if let Some(name) = scc.first() {
325 return graph
326 .get(name)
327 .is_some_and(|callees| callees.contains(name));
328 }
329 false
330}
331
332pub(crate) fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
333 for s in body.stmts() {
334 collect_callees_stmt(s, callees);
335 }
336}
337
338fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
339 for s in body.stmts() {
340 count_recursive_calls_stmt(s, recursive, out);
341 }
342}
343
344fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
345 match stmt {
346 Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
347 }
348}
349
350fn count_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>, out: &mut usize) {
351 match expr {
352 Expr::FnCall(func, args) => {
353 match func.as_ref() {
354 Expr::Ident(name) => {
355 if recursive.contains(name) {
356 *out += 1;
357 }
358 }
359 Expr::Attr(obj, member) => {
360 if let Expr::Ident(ns) = obj.as_ref() {
361 let q = format!("{}.{}", ns, member);
362 if recursive.contains(&q) {
363 *out += 1;
364 }
365 } else {
366 count_recursive_calls_expr(obj, recursive, out);
367 }
368 }
369 other => count_recursive_calls_expr(other, recursive, out),
370 }
371 for arg in args {
372 count_recursive_calls_expr(arg, recursive, out);
373 }
374 }
375 Expr::TailCall(boxed) => {
376 if recursive.contains(&boxed.0) {
377 *out += 1;
378 }
379 for arg in &boxed.1 {
380 count_recursive_calls_expr(arg, recursive, out);
381 }
382 }
383 Expr::Literal(_) | Expr::Resolved(_) | Expr::Ident(_) => {}
384 Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
385 Expr::BinOp(_, l, r) => {
386 count_recursive_calls_expr(l, recursive, out);
387 count_recursive_calls_expr(r, recursive, out);
388 }
389 Expr::Match {
390 subject: scrutinee,
391 arms,
392 ..
393 } => {
394 count_recursive_calls_expr(scrutinee, recursive, out);
395 for arm in arms {
396 count_recursive_calls_expr(&arm.body, recursive, out);
397 }
398 }
399 Expr::List(elems) | Expr::Tuple(elems) => {
400 for e in elems {
401 count_recursive_calls_expr(e, recursive, out);
402 }
403 }
404 Expr::MapLiteral(entries) => {
405 for (k, v) in entries {
406 count_recursive_calls_expr(k, recursive, out);
407 count_recursive_calls_expr(v, recursive, out);
408 }
409 }
410 Expr::Constructor(_, arg) => {
411 if let Some(a) = arg {
412 count_recursive_calls_expr(a, recursive, out);
413 }
414 }
415 Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
416 Expr::InterpolatedStr(parts) => {
417 for part in parts {
418 if let crate::ast::StrPart::Parsed(expr) = part {
419 count_recursive_calls_expr(expr, recursive, out);
420 }
421 }
422 }
423 Expr::RecordCreate { fields, .. } => {
424 for (_, e) in fields {
425 count_recursive_calls_expr(e, recursive, out);
426 }
427 }
428 Expr::RecordUpdate { base, updates, .. } => {
429 count_recursive_calls_expr(base, recursive, out);
430 for (_, e) in updates {
431 count_recursive_calls_expr(e, recursive, out);
432 }
433 }
434 }
435}
436
437fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
438 match stmt {
439 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
440 collect_callees_expr(e, callees);
441 }
442 }
443}
444
445fn collect_callees_expr(expr: &Expr, callees: &mut HashSet<String>) {
446 walk_expr(expr, &mut |node| match node {
447 Expr::FnCall(func, _) => {
448 if let Some(callee) = expr_to_dotted_name(func.as_ref()) {
449 callees.insert(callee);
450 }
451 }
452 Expr::TailCall(boxed) => {
453 callees.insert(boxed.0.clone());
454 }
455 _ => {}
456 });
457}
458
459#[cfg(test)]
460mod tests;