1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Spanned, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5use crate::verify_law::canonical_spec_ref;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NonTailRecursionWarning {
9 pub fn_name: String,
10 pub line: usize,
11 pub recursive_calls: usize,
12 pub callsite_lines: Vec<usize>,
14 pub message: String,
15}
16
17pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
18 collect_non_tail_recursion_warnings_in(items, None)
19}
20
21pub fn collect_non_tail_recursion_warnings_with_sigs(
22 items: &[TopLevel],
23 fn_sigs: &crate::verify_law::FnSigMap,
24) -> Vec<NonTailRecursionWarning> {
25 collect_non_tail_recursion_warnings_in(items, Some(fn_sigs))
26}
27
28fn collect_non_tail_recursion_warnings_in(
29 items: &[TopLevel],
30 fn_sigs: Option<&crate::verify_law::FnSigMap>,
31) -> Vec<NonTailRecursionWarning> {
32 let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
33 for scc in call_graph::find_tco_groups(items) {
34 for name in &scc {
35 fn_to_scc.insert(name.clone(), scc.clone());
36 }
37 }
38 let spec_fns = collect_canonical_spec_functions(items, fn_sigs);
39
40 let mut warnings = Vec::new();
41 for item in items {
42 let TopLevel::FnDef(fd) = item else {
43 continue;
44 };
45 if spec_fns.contains(&fd.name) {
46 continue;
47 }
48 let Some(scc_members) = fn_to_scc.get(&fd.name) else {
49 continue;
50 };
51 let callsite_lines: Vec<usize> =
52 collect_non_tail_recursive_call_lines_body(&fd.body, scc_members)
53 .into_iter()
54 .filter(|&ln| ln >= fd.line)
55 .collect();
56 if callsite_lines.is_empty() {
57 continue;
58 }
59 let recursive_calls = callsite_lines.len();
60 warnings.push(NonTailRecursionWarning {
61 fn_name: fd.name.clone(),
62 line: fd.line,
63 recursive_calls,
64 callsite_lines,
65 message: format!(
66 "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
67 fd.name, recursive_calls
68 ),
69 });
70 }
71 warnings
72}
73
74fn collect_canonical_spec_functions(
75 items: &[TopLevel],
76 fn_sigs: Option<&crate::verify_law::FnSigMap>,
77) -> HashSet<String> {
78 let Some(fn_sigs) = fn_sigs else {
79 return HashSet::new();
80 };
81
82 items
83 .iter()
84 .filter_map(|item| match item {
85 TopLevel::Verify(v) => match &v.kind {
86 crate::ast::VerifyKind::Law(law) => canonical_spec_ref(&v.fn_name, law, fn_sigs)
87 .map(|spec_ref| spec_ref.spec_fn_name),
88 crate::ast::VerifyKind::Cases => None,
89 },
90 _ => None,
91 })
92 .collect()
93}
94
95fn collect_non_tail_recursive_call_lines_body(
96 body: &FnBody,
97 recursive: &HashSet<String>,
98) -> Vec<usize> {
99 let mut lines = Vec::new();
100 for stmt in body.stmts() {
101 collect_non_tail_recursive_call_lines_stmt(stmt, recursive, &mut lines);
102 }
103 lines
104}
105
106fn collect_non_tail_recursive_call_lines_stmt(
107 stmt: &Stmt,
108 recursive: &HashSet<String>,
109 out: &mut Vec<usize>,
110) {
111 match stmt {
112 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
113 collect_non_tail_recursive_call_lines_expr(expr, recursive, out);
114 }
115 }
116}
117
118fn collect_non_tail_recursive_call_lines_expr(
119 expr: &Spanned<Expr>,
120 recursive: &HashSet<String>,
121 out: &mut Vec<usize>,
122) {
123 match &expr.node {
124 Expr::FnCall(func, args) => {
125 if let Some(callee) = dotted_name(func.as_ref())
126 && recursive.contains(&callee)
127 {
128 out.push(expr.line);
129 }
130 collect_non_tail_recursive_call_lines_expr(func, recursive, out);
131 for arg in args {
132 collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
133 }
134 }
135 Expr::TailCall(boxed) => {
136 for arg in &boxed.1 {
137 collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
138 }
139 }
140 Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
141 collect_non_tail_recursive_call_lines_expr(obj, recursive, out);
142 }
143 Expr::BinOp(_, left, right) => {
144 collect_non_tail_recursive_call_lines_expr(left, recursive, out);
145 collect_non_tail_recursive_call_lines_expr(right, recursive, out);
146 }
147 Expr::Match { subject, arms } => {
148 collect_non_tail_recursive_call_lines_expr(subject, recursive, out);
149 for arm in arms {
150 collect_non_tail_recursive_call_lines_expr(&arm.body, recursive, out);
151 }
152 }
153 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
154 for item in items {
155 collect_non_tail_recursive_call_lines_expr(item, recursive, out);
156 }
157 }
158 Expr::MapLiteral(entries) => {
159 for (key, value) in entries {
160 collect_non_tail_recursive_call_lines_expr(key, recursive, out);
161 collect_non_tail_recursive_call_lines_expr(value, recursive, out);
162 }
163 }
164 Expr::Constructor(_, maybe_arg) => {
165 if let Some(arg) = maybe_arg.as_deref() {
166 collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
167 }
168 }
169 Expr::InterpolatedStr(parts) => {
170 for part in parts {
171 if let StrPart::Parsed(expr) = part {
172 collect_non_tail_recursive_call_lines_expr(expr, recursive, out);
173 }
174 }
175 }
176 Expr::RecordCreate { fields, .. } => {
177 for (_, val) in fields {
178 collect_non_tail_recursive_call_lines_expr(val, recursive, out);
179 }
180 }
181 Expr::RecordUpdate { base, updates, .. } => {
182 collect_non_tail_recursive_call_lines_expr(base, recursive, out);
183 for (_, val) in updates {
184 collect_non_tail_recursive_call_lines_expr(val, recursive, out);
185 }
186 }
187 _ => {}
188 }
189}
190
191fn dotted_name(expr: &Spanned<Expr>) -> Option<String> {
192 match &expr.node {
193 Expr::Ident(name) => Some(name.clone()),
194 Expr::Attr(base, field) => {
195 let mut prefix = dotted_name(base)?;
196 prefix.push('.');
197 prefix.push_str(field);
198 Some(prefix)
199 }
200 _ => None,
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use crate::ast::TopLevel;
207 use crate::types::checker::run_type_check_full;
208 use crate::{parser::Parser, tco};
209
210 use super::*;
211
212 fn parse(src: &str) -> Vec<TopLevel> {
213 let mut lexer = crate::lexer::Lexer::new(src);
214 let tokens = lexer.tokenize().expect("lex failed");
215 let mut parser = Parser::new(tokens);
216 parser.parse().expect("parse failed")
217 }
218
219 #[test]
220 fn warns_for_recursive_calls_left_after_tco() {
221 let src = r#"
222fn fib(n: Int) -> Int
223 match n
224 0 -> 0
225 1 -> 1
226 _ -> fib(n - 1) + fib(n - 2)
227"#;
228 let mut items = parse(src);
229 tco::transform_program(&mut items);
230
231 let warnings = collect_non_tail_recursion_warnings(&items);
232 assert_eq!(warnings.len(), 1);
233 assert_eq!(warnings[0].fn_name, "fib");
234 assert_eq!(warnings[0].recursive_calls, 2);
235 assert_eq!(
236 warnings[0].message,
237 "non-tail recursion in 'fib' — 2 recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec"
238 );
239 }
240
241 #[test]
242 fn skips_pure_tail_recursion_after_tco() {
243 let src = r#"
244fn factorial(n: Int, acc: Int) -> Int
245 match n
246 0 -> acc
247 _ -> factorial(n - 1, acc * n)
248"#;
249 let mut items = parse(src);
250 tco::transform_program(&mut items);
251
252 let warnings = collect_non_tail_recursion_warnings(&items);
253 assert!(warnings.is_empty());
254 }
255
256 #[test]
257 fn skips_mutual_tail_recursion_after_tco() {
258 let src = r#"
259fn isEven(n: Int) -> Bool
260 match n
261 0 -> true
262 _ -> isOdd(n - 1)
263
264fn isOdd(n: Int) -> Bool
265 match n
266 0 -> false
267 _ -> isEven(n - 1)
268"#;
269 let mut items = parse(src);
270 tco::transform_program(&mut items);
271
272 let warnings = collect_non_tail_recursion_warnings(&items);
273 assert!(warnings.is_empty());
274 }
275
276 #[test]
277 fn skips_canonical_spec_functions() {
278 let src = r#"
279fn fib(n: Int) -> Int
280 fibSpec(n)
281
282fn fibSpec(n: Int) -> Int
283 match n
284 0 -> 0
285 1 -> 1
286 _ -> fibSpec(n - 1) + fibSpec(n - 2)
287
288verify fib law fibSpec
289 given n: Int = [0, 1, 2, 3]
290 fib(n) => fibSpec(n)
291"#;
292 let mut items = parse(src);
293 tco::transform_program(&mut items);
294 let tc = run_type_check_full(&items, None);
295
296 let warnings = collect_non_tail_recursion_warnings_with_sigs(&items, &tc.fn_sigs);
297 assert!(
298 warnings.is_empty(),
299 "expected spec function warning to be suppressed, got {warnings:?}"
300 );
301 }
302}