1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct NonTailRecursionWarning {
8 pub fn_name: String,
9 pub line: usize,
10 pub recursive_calls: usize,
11 pub message: String,
12}
13
14pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
15 let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
16 for scc in call_graph::find_tco_groups(items) {
17 for name in &scc {
18 fn_to_scc.insert(name.clone(), scc.clone());
19 }
20 }
21
22 let mut warnings = Vec::new();
23 for item in items {
24 let TopLevel::FnDef(fd) = item else {
25 continue;
26 };
27 let Some(scc_members) = fn_to_scc.get(&fd.name) else {
28 continue;
29 };
30 let recursive_calls = count_non_tail_recursive_calls_body(&fd.body, scc_members);
31 if recursive_calls == 0 {
32 continue;
33 }
34 warnings.push(NonTailRecursionWarning {
35 fn_name: fd.name.clone(),
36 line: fd.line,
37 recursive_calls,
38 message: format!(
39 "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; consider accumulator pattern",
40 fd.name, recursive_calls
41 ),
42 });
43 }
44 warnings
45}
46
47fn count_non_tail_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>) -> usize {
48 body.stmts()
49 .iter()
50 .map(|stmt| count_non_tail_recursive_calls_stmt(stmt, recursive))
51 .sum()
52}
53
54fn count_non_tail_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>) -> usize {
55 match stmt {
56 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
57 count_non_tail_recursive_calls_expr(expr, recursive)
58 }
59 }
60}
61
62fn count_non_tail_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>) -> usize {
63 match expr {
64 Expr::FnCall(func, args) => {
65 let mut count = 0;
66 if let Some(callee) = dotted_name(func.as_ref())
67 && recursive.contains(&callee)
68 {
69 count += 1;
70 }
71 count
72 + count_non_tail_recursive_calls_expr(func, recursive)
73 + args
74 .iter()
75 .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
76 .sum::<usize>()
77 }
78 Expr::TailCall(boxed) => boxed
79 .1
80 .iter()
81 .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
82 .sum(),
83 Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
84 count_non_tail_recursive_calls_expr(obj, recursive)
85 }
86 Expr::BinOp(_, left, right) => {
87 count_non_tail_recursive_calls_expr(left, recursive)
88 + count_non_tail_recursive_calls_expr(right, recursive)
89 }
90 Expr::Match { subject, arms, .. } => {
91 count_non_tail_recursive_calls_expr(subject, recursive)
92 + arms
93 .iter()
94 .map(|arm| count_non_tail_recursive_calls_expr(&arm.body, recursive))
95 .sum::<usize>()
96 }
97 Expr::List(items) | Expr::Tuple(items) => items
98 .iter()
99 .map(|item| count_non_tail_recursive_calls_expr(item, recursive))
100 .sum(),
101 Expr::MapLiteral(entries) => entries
102 .iter()
103 .map(|(key, value)| {
104 count_non_tail_recursive_calls_expr(key, recursive)
105 + count_non_tail_recursive_calls_expr(value, recursive)
106 })
107 .sum(),
108 Expr::Constructor(_, maybe_arg) => maybe_arg
109 .as_deref()
110 .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
111 .unwrap_or(0),
112 Expr::InterpolatedStr(parts) => parts
113 .iter()
114 .map(|part| match part {
115 StrPart::Literal(_) => 0,
116 StrPart::Parsed(expr) => count_non_tail_recursive_calls_expr(expr, recursive),
117 })
118 .sum(),
119 Expr::RecordCreate { fields, .. } => fields
120 .iter()
121 .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
122 .sum(),
123 Expr::RecordUpdate { base, updates, .. } => {
124 count_non_tail_recursive_calls_expr(base, recursive)
125 + updates
126 .iter()
127 .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
128 .sum::<usize>()
129 }
130 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => 0,
131 }
132}
133
134fn dotted_name(expr: &Expr) -> Option<String> {
135 match expr {
136 Expr::Ident(name) => Some(name.clone()),
137 Expr::Attr(base, field) => {
138 let mut prefix = dotted_name(base)?;
139 prefix.push('.');
140 prefix.push_str(field);
141 Some(prefix)
142 }
143 _ => None,
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use crate::ast::TopLevel;
150 use crate::{parser::Parser, tco};
151
152 use super::*;
153
154 fn parse(src: &str) -> Vec<TopLevel> {
155 let mut lexer = crate::lexer::Lexer::new(src);
156 let tokens = lexer.tokenize().expect("lex failed");
157 let mut parser = Parser::new(tokens);
158 parser.parse().expect("parse failed")
159 }
160
161 #[test]
162 fn warns_for_recursive_calls_left_after_tco() {
163 let src = r#"
164fn fib(n: Int) -> Int
165 match n
166 0 -> 0
167 1 -> 1
168 _ -> fib(n - 1) + fib(n - 2)
169"#;
170 let mut items = parse(src);
171 tco::transform_program(&mut items);
172
173 let warnings = collect_non_tail_recursion_warnings(&items);
174 assert_eq!(warnings.len(), 1);
175 assert_eq!(warnings[0].fn_name, "fib");
176 assert_eq!(warnings[0].recursive_calls, 2);
177 }
178
179 #[test]
180 fn skips_pure_tail_recursion_after_tco() {
181 let src = r#"
182fn factorial(n: Int, acc: Int) -> Int
183 match n
184 0 -> acc
185 _ -> factorial(n - 1, acc * n)
186"#;
187 let mut items = parse(src);
188 tco::transform_program(&mut items);
189
190 let warnings = collect_non_tail_recursion_warnings(&items);
191 assert!(warnings.is_empty());
192 }
193
194 #[test]
195 fn skips_mutual_tail_recursion_after_tco() {
196 let src = r#"
197fn isEven(n: Int) -> Bool
198 match n
199 0 -> true
200 _ -> isOdd(n - 1)
201
202fn isOdd(n: Int) -> Bool
203 match n
204 0 -> false
205 _ -> isEven(n - 1)
206"#;
207 let mut items = parse(src);
208 tco::transform_program(&mut items);
209
210 let warnings = collect_non_tail_recursion_warnings(&items);
211 assert!(warnings.is_empty());
212 }
213}