1use std::collections::{HashMap, HashSet};
8
9use crate::ast::{Expr, FnBody, Stmt, TopLevel};
10
11pub fn find_tco_groups(items: &[TopLevel]) -> Vec<HashSet<String>> {
18 let graph = build_call_graph(items);
19 let user_fns = user_fn_names(items);
20 recursive_sccs(&graph, &user_fns)
21 .into_iter()
22 .map(|scc| scc.into_iter().collect())
23 .collect()
24}
25
26pub fn find_recursive_fns(items: &[TopLevel]) -> HashSet<String> {
29 let graph = build_call_graph(items);
30 let user_fns = user_fn_names(items);
31 let mut recursive = HashSet::new();
32 for scc in recursive_sccs(&graph, &user_fns) {
33 for name in scc {
34 recursive.insert(name);
35 }
36 }
37 recursive
38}
39
40pub fn direct_calls(items: &[TopLevel]) -> HashMap<String, Vec<String>> {
42 let graph = build_call_graph(items);
43 let mut out = HashMap::new();
44 for item in items {
45 if let TopLevel::FnDef(fd) = item {
46 let mut callees = graph
47 .get(&fd.name)
48 .cloned()
49 .unwrap_or_default()
50 .into_iter()
51 .collect::<Vec<_>>();
52 callees.sort();
53 out.insert(fd.name.clone(), callees);
54 }
55 }
56 out
57}
58
59pub fn recursive_callsite_counts(items: &[TopLevel]) -> HashMap<String, usize> {
69 let graph = build_call_graph(items);
70 let user_fns = user_fn_names(items);
71 let sccs = recursive_sccs(&graph, &user_fns);
72 let mut scc_members: HashMap<String, HashSet<String>> = HashMap::new();
73 for scc in sccs {
74 let members: HashSet<String> = scc.iter().cloned().collect();
75 for name in scc {
76 scc_members.insert(name, members.clone());
77 }
78 }
79
80 let mut out = HashMap::new();
81 for item in items {
82 if let TopLevel::FnDef(fd) = item {
83 let mut count = 0usize;
84 if let Some(members) = scc_members.get(&fd.name) {
85 count_recursive_calls_body(&fd.body, members, &mut count);
86 }
87 out.insert(fd.name.clone(), count);
88 }
89 }
90 out
91}
92
93pub fn recursive_scc_ids(items: &[TopLevel]) -> HashMap<String, usize> {
96 let graph = build_call_graph(items);
97 let user_fns = user_fn_names(items);
98 let mut sccs = recursive_sccs(&graph, &user_fns);
99 for scc in &mut sccs {
100 scc.sort();
101 }
102 sccs.sort_by(|a, b| a.first().cmp(&b.first()));
103
104 let mut out = HashMap::new();
105 for (idx, scc) in sccs.into_iter().enumerate() {
106 let id = idx + 1;
107 for name in scc {
108 out.insert(name, id);
109 }
110 }
111 out
112}
113
114fn build_call_graph(items: &[TopLevel]) -> HashMap<String, HashSet<String>> {
119 let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
120 for item in items {
121 if let TopLevel::FnDef(fd) = item {
122 let mut callees = HashSet::new();
123 collect_callees_body(&fd.body, &mut callees);
124 graph.insert(fd.name.clone(), callees);
125 }
126 }
127 graph
128}
129
130fn user_fn_names(items: &[TopLevel]) -> HashSet<String> {
131 items
132 .iter()
133 .filter_map(|item| {
134 if let TopLevel::FnDef(fd) = item {
135 Some(fd.name.clone())
136 } else {
137 None
138 }
139 })
140 .collect()
141}
142
143fn recursive_sccs(
144 graph: &HashMap<String, HashSet<String>>,
145 user_fns: &HashSet<String>,
146) -> Vec<Vec<String>> {
147 tarjan_scc(graph, user_fns)
148 .into_iter()
149 .filter(|scc| is_recursive_scc(scc, graph))
150 .collect()
151}
152
153fn is_recursive_scc(scc: &[String], graph: &HashMap<String, HashSet<String>>) -> bool {
154 if scc.len() > 1 {
155 return true;
156 }
157 if let Some(name) = scc.first() {
158 return graph
159 .get(name)
160 .is_some_and(|callees| callees.contains(name));
161 }
162 false
163}
164
165fn collect_callees_body(body: &FnBody, callees: &mut HashSet<String>) {
166 match body {
167 FnBody::Expr(e) => collect_callees_expr(e, callees),
168 FnBody::Block(stmts) => {
169 for s in stmts {
170 collect_callees_stmt(s, callees);
171 }
172 }
173 }
174}
175
176fn count_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>, out: &mut usize) {
177 match body {
178 FnBody::Expr(e) => count_recursive_calls_expr(e, recursive, out),
179 FnBody::Block(stmts) => {
180 for s in stmts {
181 count_recursive_calls_stmt(s, recursive, out);
182 }
183 }
184 }
185}
186
187fn count_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>, out: &mut usize) {
188 match stmt {
189 Stmt::Binding(_, _, e) | Stmt::Expr(e) => count_recursive_calls_expr(e, recursive, out),
190 }
191}
192
193fn count_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>, out: &mut usize) {
194 match expr {
195 Expr::FnCall(func, args) => {
196 match func.as_ref() {
197 Expr::Ident(name) => {
198 if recursive.contains(name) {
199 *out += 1;
200 }
201 }
202 Expr::Attr(obj, member) => {
203 if let Expr::Ident(ns) = obj.as_ref() {
204 let q = format!("{}.{}", ns, member);
205 if recursive.contains(&q) {
206 *out += 1;
207 }
208 } else {
209 count_recursive_calls_expr(obj, recursive, out);
210 }
211 }
212 other => count_recursive_calls_expr(other, recursive, out),
213 }
214 for arg in args {
215 count_recursive_calls_expr(arg, recursive, out);
216 }
217 }
218 Expr::TailCall(boxed) => {
219 if recursive.contains(&boxed.0) {
220 *out += 1;
221 }
222 for arg in &boxed.1 {
223 count_recursive_calls_expr(arg, recursive, out);
224 }
225 }
226 Expr::Literal(_) | Expr::Resolved(_) | Expr::Ident(_) => {}
227 Expr::Attr(obj, _) => count_recursive_calls_expr(obj, recursive, out),
228 Expr::BinOp(_, l, r) | Expr::Pipe(l, r) => {
229 count_recursive_calls_expr(l, recursive, out);
230 count_recursive_calls_expr(r, recursive, out);
231 }
232 Expr::Match {
233 subject: scrutinee,
234 arms,
235 ..
236 } => {
237 count_recursive_calls_expr(scrutinee, recursive, out);
238 for arm in arms {
239 count_recursive_calls_expr(&arm.body, recursive, out);
240 }
241 }
242 Expr::List(elems) | Expr::Tuple(elems) => {
243 for e in elems {
244 count_recursive_calls_expr(e, recursive, out);
245 }
246 }
247 Expr::MapLiteral(entries) => {
248 for (k, v) in entries {
249 count_recursive_calls_expr(k, recursive, out);
250 count_recursive_calls_expr(v, recursive, out);
251 }
252 }
253 Expr::Constructor(_, arg) => {
254 if let Some(a) = arg {
255 count_recursive_calls_expr(a, recursive, out);
256 }
257 }
258 Expr::ErrorProp(inner) => count_recursive_calls_expr(inner, recursive, out),
259 Expr::InterpolatedStr(parts) => {
260 for part in parts {
261 if let crate::ast::StrPart::Parsed(expr) = part {
262 count_recursive_calls_expr(expr, recursive, out);
263 }
264 }
265 }
266 Expr::RecordCreate { fields, .. } => {
267 for (_, e) in fields {
268 count_recursive_calls_expr(e, recursive, out);
269 }
270 }
271 Expr::RecordUpdate { base, updates, .. } => {
272 count_recursive_calls_expr(base, recursive, out);
273 for (_, e) in updates {
274 count_recursive_calls_expr(e, recursive, out);
275 }
276 }
277 }
278}
279
280fn collect_callees_stmt(stmt: &Stmt, callees: &mut HashSet<String>) {
281 match stmt {
282 Stmt::Binding(_, _, e) | Stmt::Expr(e) => {
283 collect_callees_expr(e, callees);
284 }
285 }
286}
287
288fn collect_callees_expr(expr: &Expr, callees: &mut HashSet<String>) {
289 match expr {
290 Expr::FnCall(func, args) => {
291 match func.as_ref() {
293 Expr::Ident(name) => {
294 callees.insert(name.clone());
295 }
296 Expr::Attr(obj, member) => {
297 if let Expr::Ident(ns) = obj.as_ref() {
298 callees.insert(format!("{}.{}", ns, member));
299 }
300 }
301 _ => collect_callees_expr(func, callees),
302 }
303 for arg in args {
304 collect_callees_expr(arg, callees);
305 }
306 }
307 Expr::Literal(_) | Expr::Resolved(_) => {}
308 Expr::Ident(_) => {}
309 Expr::Attr(obj, _) => collect_callees_expr(obj, callees),
310 Expr::BinOp(_, l, r) => {
311 collect_callees_expr(l, callees);
312 collect_callees_expr(r, callees);
313 }
314 Expr::Pipe(l, r) => {
315 collect_callees_expr(l, callees);
316 collect_callees_expr(r, callees);
317 }
318 Expr::Match {
319 subject: scrutinee,
320 arms,
321 ..
322 } => {
323 collect_callees_expr(scrutinee, callees);
324 for arm in arms {
325 collect_callees_expr(&arm.body, callees);
326 }
327 }
328 Expr::List(elems) => {
329 for e in elems {
330 collect_callees_expr(e, callees);
331 }
332 }
333 Expr::Tuple(items) => {
334 for item in items {
335 collect_callees_expr(item, callees);
336 }
337 }
338 Expr::MapLiteral(entries) => {
339 for (key, value) in entries {
340 collect_callees_expr(key, callees);
341 collect_callees_expr(value, callees);
342 }
343 }
344 Expr::Constructor(_, arg) => {
345 if let Some(a) = arg {
346 collect_callees_expr(a, callees);
347 }
348 }
349 Expr::ErrorProp(inner) => collect_callees_expr(inner, callees),
350 Expr::InterpolatedStr(parts) => {
351 for part in parts {
352 if let crate::ast::StrPart::Parsed(expr) = part {
353 collect_callees_expr(expr, callees);
354 }
355 }
356 }
357 Expr::RecordCreate { fields, .. } => {
358 for (_, e) in fields {
359 collect_callees_expr(e, callees);
360 }
361 }
362 Expr::RecordUpdate { base, updates, .. } => {
363 collect_callees_expr(base, callees);
364 for (_, e) in updates {
365 collect_callees_expr(e, callees);
366 }
367 }
368 Expr::TailCall(boxed) => {
369 callees.insert(boxed.0.clone());
370 for arg in &boxed.1 {
371 collect_callees_expr(arg, callees);
372 }
373 }
374 }
375}
376
377struct TarjanState {
382 index_counter: usize,
383 stack: Vec<String>,
384 on_stack: HashSet<String>,
385 indices: HashMap<String, usize>,
386 lowlinks: HashMap<String, usize>,
387 sccs: Vec<Vec<String>>,
388}
389
390fn tarjan_scc(
391 graph: &HashMap<String, HashSet<String>>,
392 nodes: &HashSet<String>,
393) -> Vec<Vec<String>> {
394 let mut state = TarjanState {
395 index_counter: 0,
396 stack: Vec::new(),
397 on_stack: HashSet::new(),
398 indices: HashMap::new(),
399 lowlinks: HashMap::new(),
400 sccs: Vec::new(),
401 };
402
403 for node in nodes {
404 if !state.indices.contains_key(node) {
405 strongconnect(node, graph, &mut state);
406 }
407 }
408
409 state.sccs
410}
411
412fn strongconnect(v: &str, graph: &HashMap<String, HashSet<String>>, state: &mut TarjanState) {
413 let idx = state.index_counter;
414 state.index_counter += 1;
415 state.indices.insert(v.to_string(), idx);
416 state.lowlinks.insert(v.to_string(), idx);
417 state.stack.push(v.to_string());
418 state.on_stack.insert(v.to_string());
419
420 if let Some(callees) = graph.get(v) {
421 for w in callees {
422 if !state.indices.contains_key(w) {
423 if graph.contains_key(w) {
425 strongconnect(w, graph, state);
426 let w_low = state.lowlinks[w];
427 let v_low = state.lowlinks[v];
428 if w_low < v_low {
429 state.lowlinks.insert(v.to_string(), w_low);
430 }
431 }
432 } else if state.on_stack.contains(w) {
433 let w_idx = state.indices[w];
434 let v_low = state.lowlinks[v];
435 if w_idx < v_low {
436 state.lowlinks.insert(v.to_string(), w_idx);
437 }
438 }
439 }
440 }
441
442 if state.lowlinks[v] == state.indices[v] {
444 let mut scc = Vec::new();
445 loop {
446 let w = state.stack.pop().unwrap();
447 state.on_stack.remove(&w);
448 scc.push(w.clone());
449 if w == v {
450 break;
451 }
452 }
453 state.sccs.push(scc);
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn detects_self_recursion() {
463 let src = r#"
464fn fib(n: Int) -> Int
465 match n
466 0 -> 0
467 1 -> 1
468 _ -> fib(n - 1) + fib(n - 2)
469"#;
470 let items = parse(src);
471 let rec = find_recursive_fns(&items);
472 assert!(
473 rec.contains("fib"),
474 "fib should be recursive, got: {:?}",
475 rec
476 );
477 }
478
479 #[test]
480 fn non_recursive_fn() {
481 let src = "fn double(x: Int) -> Int\n = x + x\n";
482 let items = parse(src);
483 let rec = find_recursive_fns(&items);
484 assert!(
485 rec.is_empty(),
486 "double should not be recursive, got: {:?}",
487 rec
488 );
489 }
490
491 #[test]
492 fn mutual_recursion() {
493 let src = r#"
494fn isEven(n: Int) -> Bool
495 match n
496 0 -> true
497 _ -> isOdd(n - 1)
498
499fn isOdd(n: Int) -> Bool
500 match n
501 0 -> false
502 _ -> isEven(n - 1)
503"#;
504 let items = parse(src);
505 let rec = find_recursive_fns(&items);
506 assert!(rec.contains("isEven"), "isEven should be recursive");
507 assert!(rec.contains("isOdd"), "isOdd should be recursive");
508 }
509
510 #[test]
511 fn recursive_callsites_count_syntactic_occurrences() {
512 let src = r#"
513fn fib(n: Int) -> Int
514 match n
515 0 -> 0
516 1 -> 1
517 _ -> fib(n - 1) + fib(n - 2)
518"#;
519 let items = parse(src);
520 let counts = recursive_callsite_counts(&items);
521 assert_eq!(counts.get("fib").copied().unwrap_or(0), 2);
522 }
523
524 #[test]
525 fn recursive_callsites_are_scoped_to_scc() {
526 let src = r#"
527fn a(n: Int) -> Int
528 match n
529 0 -> 0
530 _ -> b(n - 1) + fib(n)
531
532fn b(n: Int) -> Int
533 match n
534 0 -> 0
535 _ -> a(n - 1)
536
537fn fib(n: Int) -> Int
538 match n
539 0 -> 0
540 1 -> 1
541 _ -> fib(n - 1) + fib(n - 2)
542"#;
543 let items = parse(src);
544 let counts = recursive_callsite_counts(&items);
545 assert_eq!(counts.get("a").copied().unwrap_or(0), 1);
546 assert_eq!(counts.get("b").copied().unwrap_or(0), 1);
547 assert_eq!(counts.get("fib").copied().unwrap_or(0), 2);
548 }
549
550 #[test]
551 fn recursive_scc_ids_are_deterministic_by_group_name() {
552 let src = r#"
553fn z(n: Int) -> Int
554 match n
555 0 -> 0
556 _ -> z(n - 1)
557
558fn a(n: Int) -> Int
559 match n
560 0 -> 0
561 _ -> b(n - 1)
562
563fn b(n: Int) -> Int
564 match n
565 0 -> 0
566 _ -> a(n - 1)
567"#;
568 let items = parse(src);
569 let ids = recursive_scc_ids(&items);
570 assert_eq!(ids.get("a").copied().unwrap_or(0), 1);
572 assert_eq!(ids.get("b").copied().unwrap_or(0), 1);
573 assert_eq!(ids.get("z").copied().unwrap_or(0), 2);
574 }
575
576 fn parse(src: &str) -> Vec<TopLevel> {
577 let mut lexer = crate::lexer::Lexer::new(src);
578 let tokens = lexer.tokenize().expect("lex failed");
579 let mut parser = crate::parser::Parser::new(tokens);
580 parser.parse().expect("parse failed")
581 }
582}