1use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
10
11mod codegen;
12mod scc;
13
14pub use codegen::ordered_fn_components;
15
16pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
23 let graph = build_call_graph(items);
24 let user_fns = user_fn_names(items);
25 recursive_sccs(&graph, &user_fns)
26 .into_iter()
27 .map(|scc| scc.into_iter().collect())
28 .collect()
29}
30
31pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
34 let graph = build_call_graph(items);
35 let user_fns = user_fn_names(items);
36 let mut recursive = HashSet::new();
37 for scc in recursive_sccs(&graph, &user_fns) {
38 for name in scc {
39 recursive.insert(name);
40 }
41 }
42 recursive
43}
44
45pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
47 let graph = build_call_graph(items);
48 let mut out = HashMap::new();
49 for item in items {
50 if let TopLevel::FnDef(fd) = item {
51 let mut callees = graph
52 .get(&fd.name)
53 .cloned()
54 .unwrap_or_default()
55 .into_iter()
56 .collect::<Vec<_>>();
57 callees.sort();
58 out.insert(fd.name.clone(), callees);
59 }
60 }
61 out
62}
63
64pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
74 let graph = build_call_graph(items);
75 let user_fns = user_fn_names(items);
76 let sccs = recursive_sccs(&graph, &user_fns);
77 let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
78 for scc in sccs {
79 let members: HashSet<String> = scc.iter().cloned().collect();
80 for name in scc {
81 scc_members.insert(name, members.clone());
82 }
83 }
84
85 let mut out = HashMap::new();
86 for item in items {
87 if let TopLevel::FnDef(fd) = item {
88 let mut count = 0usize;
89 if let Some(members) = scc_members.get(&fd.name) {
90 count_recursive_calls_body(&fd.body, members, &mut count);
91 }
92 out.insert(fd.name.clone(), count);
93 }
94 }
95 out
96}
97
98pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
101 let graph = build_call_graph(items);
102 let user_fns = user_fn_names(items);
103 let mut sccs = recursive_sccs(&graph, &user_fns);
104 for scc in &mut sccs {
105 scc.sort();
106 }
107 sccs.sort_by(|a, b| a.first().cmp(&b.first()));
108
109 let mut out = HashMap::new();
110 for (idx, scc) in sccs.into_iter().enumerate() {
111 let id = idx + 1;
112 for name in scc {
113 out.insert(name, id);
114 }
115 }
116 out
117}
118
119fn collect_codegen_deps_body(body: &FnBody, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
120 match body {
121 FnBody::Expr(e) => collect_codegen_deps_expr(e, fn_names, out),
122 FnBody::Block(stmts) => {
123 for s in stmts {
124 match s {
125 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
126 collect_codegen_deps_expr(e, fn_names, out)
127 }
128 }
129 }
130 }
131 }
132}
133
134fn collect_codegen_deps_expr(expr: &Expr, fn_names: &HashSet<String>, out: &mut HashSet<String>) {
135 walk_expr(expr, &mut |node| match node {
136 Expr::FnCall(func, args) => {
137 if let Some(callee) = expr_to_dotted_name(func.as_ref())
138 && fn_names.contains(&callee)
139 {
140 out.insert(callee);
141 }
142 for arg in args {
143 if let Some(qname) = expr_to_dotted_name(arg)
145 && fn_names.contains(&qname)
146 {
147 out.insert(qname);
148 }
149 }
150 }
151 Expr::TailCall(boxed) => {
152 if fn_names.contains(&boxed.0) {
153 out.insert(boxed.0.clone());
154 }
155 }
156 _ => {}
157 });
158}
159
160fn expr_to_dotted_name(expr: &Expr) -> Option<String> {
161 match expr {
162 Expr::Ident(name) => Some(name.clone()),
163 Expr::Attr(obj, field) => {
164 let head = expr_to_dotted_name(obj)?;
165 Some(format!("{}.{}", head, field))
166 }
167 _ => None,
168 }
169}
170
171fn walk_expr(expr: &Expr, visit: &mut impl FnMut(&Expr)) {
172 visit(expr);
173 match expr {
174 Expr::FnCall(func, args) => {
175 walk_expr(func, visit);
176 for arg in args {
177 walk_expr(arg, visit);
178 }
179 }
180 Expr::TailCall(boxed) => {
181 for arg in &boxed.1 {
182 walk_expr(arg, visit);
183 }
184 }
185 Expr::Attr(obj, _) => walk_expr(obj, visit),
186 Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
187 walk_expr(l, visit);
188 walk_expr(r, visit);
189 }
190 Expr::Match { subject, arms, .. } => {
191 walk_expr(subject, visit);
192 for arm in arms {
193 walk_expr(&arm.body, visit);
194 }
195 }
196 Expr::List(items) | Expr::Tuple(items) => {
197 for item in items {
198 walk_expr(item, visit);
199 }
200 }
201 Expr::MapLiteral(entries) => {
202 for (k, v) in entries {
203 walk_expr(k, visit);
204 walk_expr(v, visit);
205 }
206 }
207 Expr::Constructor(_, maybe) => {
208 if let Some(inner) = maybe {
209 walk_expr(inner, visit);
210 }
211 }
212 Expr::ErrorProp(inner) => walk_expr(inner, visit),
213 Expr::InterpolatedStr(parts) => {
214 for part in parts {
215 if let StrPart::Parsed(e) = part {
216 walk_expr(e, visit);
217 }
218 }
219 }
220 Expr::RecordCreate { fields, .. } => {
221 for (_, e) in fields {
222 walk_expr(e, visit);
223 }
224 }
225 Expr::RecordUpdate { base, updates, .. } => {
226 walk_expr(base, visit);
227 for (_, e) in updates {
228 walk_expr(e, visit);
229 }
230 }
231 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => {}
232 }
233}
234
235fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
240 let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
241 for item in items {
242 if let TopLevel::FnDef(fd) = item {
243 let mut callees = HashSet::new();
244 collect_callees_body(&fd.body, &mut callees);
245 graph.insert(fd.name.clone(), callees);
246 }
247 }
248 graph
249}
250
251fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
252 items
253 .iter()
254 .filter_map(|item| {
255 if let TopLevel::FnDef(fd) = item {
256 Some(fd.name.clone())
257 } else {
258 None
259 }
260 })
261 .collect()
262}
263
264fn recursive_sccs(
265 graph: &HashMap<String, HashSet<String>>,
266 user_fns: &HashSet<String>,
267) -> Vec<Vec<String>> {
268 let mut names = user_fns.iter().cloned().collect::<Vec<_>>();
269 names.sort();
270
271 let mut adj: HashMap<String, Vec<String>> = HashMap::new();
272 for name in &names {
273 let mut deps = graph
274 .get(name)
275 .cloned()
276 .unwrap_or_default()
277 .into_iter()
278 .filter(|callee| user_fns.contains(callee))
279 .collect::<Vec<_>>();
280 deps.sort();
281 adj.insert(name.clone(), deps);
282 }
283
284 scc::tarjan_sccs(&names, &adj)
285 .into_iter()
286 .filter(|scc| is_recursive_scc(scc, graph))
287 .collect()
288}
289
290fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
291 if scc.len() > 1 {
292 return true;
293 }
294 if let Some(name) = scc.first() {
295 return graph
296 .get(name)
297 .is_some_and(|callees| callees.contains(name));
298 }
299 false
300}
301
302pub(crate) fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
303 match body {
304 FnBody::Expr(e) => collect_callees_expr(e, callees),
305 FnBody::Block(stmts) => {
306 for s in stmts {
307 collect_callees_stmt(s, callees);
308 }
309 }
310 }
311}
312
313fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
314 match body {
315 FnBody::Expr(e) => count_recursive_calls_expr(e, recursive, out),
316 FnBody::Block(stmts) => {
317 for s in stmts {
318 count_recursive_calls_stmt(s, recursive, out);
319 }
320 }
321 }
322}
323
324fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
325 match stmt {
326 Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
327 }
328}
329
330fn count_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>, out: &mut usize) {
331 match expr {
332 Expr::FnCall(func, args) => {
333 match func.as_ref() {
334 Expr::Ident(name) => {
335 if recursive.contains(name) {
336 *out += 1;
337 }
338 }
339 Expr::Attr(obj, member) => {
340 if let Expr::Ident(ns) = obj.as_ref() {
341 let q = format!("{}.{}", ns, member);
342 if recursive.contains(&q) {
343 *out += 1;
344 }
345 } else {
346 count_recursive_calls_expr(obj, recursive, out);
347 }
348 }
349 other => count_recursive_calls_expr(other, recursive, out),
350 }
351 for arg in args {
352 count_recursive_calls_expr(arg, recursive, out);
353 }
354 }
355 Expr::TailCall(boxed) => {
356 if recursive.contains(&boxed.0) {
357 *out += 1;
358 }
359 for arg in &boxed.1 {
360 count_recursive_calls_expr(arg, recursive, out);
361 }
362 }
363 Expr::Literal(_) | Expr::Resolved(_) | Expr::Ident(_) => {}
364 Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
365 Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
366 count_recursive_calls_expr(l, recursive, out);
367 count_recursive_calls_expr(r, recursive, out);
368 }
369 Expr::Match {
370 subject: scrutinee,
371 arms,
372 ..
373 } => {
374 count_recursive_calls_expr(scrutinee, recursive, out);
375 for arm in arms {
376 count_recursive_calls_expr(&arm.body, recursive, out);
377 }
378 }
379 Expr::List(elems) | Expr::Tuple(elems) => {
380 for e in elems {
381 count_recursive_calls_expr(e, recursive, out);
382 }
383 }
384 Expr::MapLiteral(entries) => {
385 for (k, v) in entries {
386 count_recursive_calls_expr(k, recursive, out);
387 count_recursive_calls_expr(v, recursive, out);
388 }
389 }
390 Expr::Constructor(_, arg) => {
391 if let Some(a) = arg {
392 count_recursive_calls_expr(a, recursive, out);
393 }
394 }
395 Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
396 Expr::InterpolatedStr(parts) => {
397 for part in parts {
398 if let crate::ast::StrPart::Parsed(expr) = part {
399 count_recursive_calls_expr(expr, recursive, out);
400 }
401 }
402 }
403 Expr::RecordCreate { fields, .. } => {
404 for (_, e) in fields {
405 count_recursive_calls_expr(e, recursive, out);
406 }
407 }
408 Expr::RecordUpdate { base, updates, .. } => {
409 count_recursive_calls_expr(base, recursive, out);
410 for (_, e) in updates {
411 count_recursive_calls_expr(e, recursive, out);
412 }
413 }
414 }
415}
416
417fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
418 match stmt {
419 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
420 collect_callees_expr(e, callees);
421 }
422 }
423}
424
425fn collect_callees_expr(expr: &Expr, callees: &mut HashSet<String>) {
426 walk_expr(expr, &mut |node| match node {
427 Expr::FnCall(func, _) => {
428 if let Some(callee) = expr_to_dotted_name(func.as_ref()) {
429 callees.insert(callee);
430 }
431 }
432 Expr::TailCall(boxed) => {
433 callees.insert(boxed.0.clone());
434 }
435 _ => {}
436 });
437}
438
439#[cfg(test)]
440mod tests;