1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Spanned, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5#[cfg(feature = "runtime")]
6use crate::verify_law::canonical_spec_ref;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct NonTailRecursionWarning {
10 pub fn_name: String,
11 pub line: usize,
12 pub recursive_calls: usize,
13 pub callsite_lines: Vec<usize>,
15 pub message: String,
16}
17
18pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
19 let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
20 for scc in call_graph::find_tco_groups(items) {
21 for name in &scc {
22 fn_to_scc.insert(name.clone(), scc.clone());
23 }
24 }
25
26 let mut warnings = Vec::new();
27 for item in items {
28 let TopLevel::FnDef(fd) = item else {
29 continue;
30 };
31 let Some(scc_members) = fn_to_scc.get(&fd.name) else {
32 continue;
33 };
34 let callsite_lines: Vec<usize> =
35 collect_non_tail_recursive_call_lines_body(&fd.body, scc_members)
36 .into_iter()
37 .filter(|&ln| ln >= fd.line)
38 .collect();
39 if callsite_lines.is_empty() {
40 continue;
41 }
42 let recursive_calls = callsite_lines.len();
43 warnings.push(NonTailRecursionWarning {
44 fn_name: fd.name.clone(),
45 line: fd.line,
46 recursive_calls,
47 callsite_lines,
48 message: format!(
49 "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
50 fd.name, recursive_calls
51 ),
52 });
53 }
54 warnings
55}
56
57#[cfg(feature = "runtime")]
58pub fn collect_non_tail_recursion_warnings_with_sigs(
59 items: &[TopLevel],
60 fn_sigs: &crate::verify_law::FnSigMap,
61) -> Vec<NonTailRecursionWarning> {
62 collect_non_tail_recursion_warnings_in(items, Some(fn_sigs))
63}
64
65#[cfg(feature = "runtime")]
66fn collect_non_tail_recursion_warnings_in(
67 items: &[TopLevel],
68 fn_sigs: Option<&crate::verify_law::FnSigMap>,
69) -> Vec<NonTailRecursionWarning> {
70 let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
71 for scc in call_graph::find_tco_groups(items) {
72 for name in &scc {
73 fn_to_scc.insert(name.clone(), scc.clone());
74 }
75 }
76 let spec_fns = collect_canonical_spec_functions(items, fn_sigs);
77
78 let mut warnings = Vec::new();
79 for item in items {
80 let TopLevel::FnDef(fd) = item else {
81 continue;
82 };
83 if spec_fns.contains(&fd.name) {
84 continue;
85 }
86 let Some(scc_members) = fn_to_scc.get(&fd.name) else {
87 continue;
88 };
89 let callsite_lines: Vec<usize> =
90 collect_non_tail_recursive_call_lines_body(&fd.body, scc_members)
91 .into_iter()
92 .filter(|&ln| ln >= fd.line)
93 .collect();
94 if callsite_lines.is_empty() {
95 continue;
96 }
97 let recursive_calls = callsite_lines.len();
98 warnings.push(NonTailRecursionWarning {
99 fn_name: fd.name.clone(),
100 line: fd.line,
101 recursive_calls,
102 callsite_lines,
103 message: format!(
104 "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
105 fd.name, recursive_calls
106 ),
107 });
108 }
109 warnings
110}
111
112#[cfg(feature = "runtime")]
113fn collect_canonical_spec_functions(
114 items: &[TopLevel],
115 fn_sigs: Option<&crate::verify_law::FnSigMap>,
116) -> HashSet<String> {
117 let Some(fn_sigs) = fn_sigs else {
118 return HashSet::new();
119 };
120
121 items
122 .iter()
123 .filter_map(|item| match item {
124 TopLevel::Verify(v) => match &v.kind {
125 crate::ast::VerifyKind::Law(law) => canonical_spec_ref(&v.fn_name, law, fn_sigs)
126 .map(|spec_ref| spec_ref.spec_fn_name),
127 crate::ast::VerifyKind::Cases => None,
128 },
129 _ => None,
130 })
131 .collect()
132}
133
134fn collect_non_tail_recursive_call_lines_body(
135 body: &FnBody,
136 recursive: &HashSet<String>,
137) -> Vec<usize> {
138 let mut lines = Vec::new();
139 for stmt in body.stmts() {
140 collect_non_tail_recursive_call_lines_stmt(stmt, recursive, &mut lines);
141 }
142 lines
143}
144
145fn collect_non_tail_recursive_call_lines_stmt(
146 stmt: &Stmt,
147 recursive: &HashSet<String>,
148 out: &mut Vec<usize>,
149) {
150 match stmt {
151 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
152 collect_non_tail_recursive_call_lines_expr(expr, recursive, out);
153 }
154 }
155}
156
157fn collect_non_tail_recursive_call_lines_expr(
158 expr: &Spanned<Expr>,
159 recursive: &HashSet<String>,
160 out: &mut Vec<usize>,
161) {
162 match &expr.node {
163 Expr::FnCall(func, args) => {
164 if let Some(callee) = dotted_name(func.as_ref())
165 && recursive.contains(&callee)
166 {
167 out.push(expr.line);
168 }
169 collect_non_tail_recursive_call_lines_expr(func, recursive, out);
170 for arg in args {
171 collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
172 }
173 }
174 Expr::TailCall(boxed) => {
175 for arg in &boxed.args {
176 collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
177 }
178 }
179 Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
180 collect_non_tail_recursive_call_lines_expr(obj, recursive, out);
181 }
182 Expr::BinOp(_, left, right) => {
183 collect_non_tail_recursive_call_lines_expr(left, recursive, out);
184 collect_non_tail_recursive_call_lines_expr(right, recursive, out);
185 }
186 Expr::Match { subject, arms } => {
187 collect_non_tail_recursive_call_lines_expr(subject, recursive, out);
188 for arm in arms {
189 collect_non_tail_recursive_call_lines_expr(&arm.body, recursive, out);
190 }
191 }
192 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
193 for item in items {
194 collect_non_tail_recursive_call_lines_expr(item, recursive, out);
195 }
196 }
197 Expr::MapLiteral(entries) => {
198 for (key, value) in entries {
199 collect_non_tail_recursive_call_lines_expr(key, recursive, out);
200 collect_non_tail_recursive_call_lines_expr(value, recursive, out);
201 }
202 }
203 Expr::Constructor(_, maybe_arg) => {
204 if let Some(arg) = maybe_arg.as_deref() {
205 collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
206 }
207 }
208 Expr::InterpolatedStr(parts) => {
209 for part in parts {
210 if let StrPart::Parsed(expr) = part {
211 collect_non_tail_recursive_call_lines_expr(expr, recursive, out);
212 }
213 }
214 }
215 Expr::RecordCreate { fields, .. } => {
216 for (_, val) in fields {
217 collect_non_tail_recursive_call_lines_expr(val, recursive, out);
218 }
219 }
220 Expr::RecordUpdate { base, updates, .. } => {
221 collect_non_tail_recursive_call_lines_expr(base, recursive, out);
222 for (_, val) in updates {
223 collect_non_tail_recursive_call_lines_expr(val, recursive, out);
224 }
225 }
226 _ => {}
227 }
228}
229
230fn dotted_name(expr: &Spanned<Expr>) -> Option<String> {
231 match &expr.node {
232 Expr::Ident(name) => Some(name.clone()),
233 Expr::Attr(base, field) => {
234 let mut prefix = dotted_name(base)?;
235 prefix.push('.');
236 prefix.push_str(field);
237 Some(prefix)
238 }
239 _ => None,
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use crate::ast::TopLevel;
246 use crate::ir::TypecheckMode;
247 use crate::parser::Parser;
248
249 use super::*;
250
251 fn parse(src: &str) -> Vec<TopLevel> {
252 let mut lexer = crate::lexer::Lexer::new(src);
253 let tokens = lexer.tokenize().expect("lex failed");
254 let mut parser = Parser::new(tokens);
255 parser.parse().expect("parse failed")
256 }
257
258 #[test]
259 fn warns_for_recursive_calls_left_after_tco() {
260 let src = r#"
261fn fib(n: Int) -> Int
262 match n
263 0 -> 0
264 1 -> 1
265 _ -> fib(n - 1) + fib(n - 2)
266"#;
267 let mut items = parse(src);
268 crate::ir::pipeline::tco(&mut items);
269
270 let warnings = collect_non_tail_recursion_warnings(&items);
271 assert_eq!(warnings.len(), 1);
272 assert_eq!(warnings[0].fn_name, "fib");
273 assert_eq!(warnings[0].recursive_calls, 2);
274 assert_eq!(
275 warnings[0].message,
276 "non-tail recursion in 'fib' — 2 recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec"
277 );
278 }
279
280 #[test]
281 fn skips_pure_tail_recursion_after_tco() {
282 let src = r#"
283fn factorial(n: Int, acc: Int) -> Int
284 match n
285 0 -> acc
286 _ -> factorial(n - 1, acc * n)
287"#;
288 let mut items = parse(src);
289 crate::ir::pipeline::tco(&mut items);
290
291 let warnings = collect_non_tail_recursion_warnings(&items);
292 assert!(warnings.is_empty());
293 }
294
295 #[test]
296 fn skips_mutual_tail_recursion_after_tco() {
297 let src = r#"
298fn isEven(n: Int) -> Bool
299 match n
300 0 -> true
301 _ -> isOdd(n - 1)
302
303fn isOdd(n: Int) -> Bool
304 match n
305 0 -> false
306 _ -> isEven(n - 1)
307"#;
308 let mut items = parse(src);
309 crate::ir::pipeline::tco(&mut items);
310
311 let warnings = collect_non_tail_recursion_warnings(&items);
312 assert!(warnings.is_empty());
313 }
314
315 #[test]
316 fn skips_canonical_spec_functions() {
317 let src = r#"
318fn fib(n: Int) -> Int
319 fibSpec(n)
320
321fn fibSpec(n: Int) -> Int
322 match n
323 0 -> 0
324 1 -> 1
325 _ -> fibSpec(n - 1) + fibSpec(n - 2)
326
327verify fib law fibSpec
328 given n: Int = [0, 1, 2, 3]
329 fib(n) => fibSpec(n)
330"#;
331 let mut items = parse(src);
332 crate::ir::pipeline::tco(&mut items);
333 let tc = crate::ir::pipeline::typecheck(&items, &TypecheckMode::Full { base_dir: None });
334
335 let warnings = collect_non_tail_recursion_warnings_with_sigs(&items, &tc.fn_sigs);
336 assert!(
337 warnings.is_empty(),
338 "expected spec function warning to be suppressed, got {warnings:?}"
339 );
340 }
341}