1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5use crate::verify_law::canonical_spec_ref;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NonTailRecursionWarning {
9 pub fn_name: String,
10 pub line: usize,
11 pub recursive_calls: usize,
12 pub message: String,
13}
14
15pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
16 collect_non_tail_recursion_warnings_in(items, None)
17}
18
19pub fn collect_non_tail_recursion_warnings_with_sigs(
20 items: &[TopLevel],
21 fn_sigs: &crate::verify_law::FnSigMap,
22) -> Vec<NonTailRecursionWarning> {
23 collect_non_tail_recursion_warnings_in(items, Some(fn_sigs))
24}
25
26fn collect_non_tail_recursion_warnings_in(
27 items: &[TopLevel],
28 fn_sigs: Option<&crate::verify_law::FnSigMap>,
29) -> Vec<NonTailRecursionWarning> {
30 let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
31 for scc in call_graph::find_tco_groups(items) {
32 for name in &scc {
33 fn_to_scc.insert(name.clone(), scc.clone());
34 }
35 }
36 let spec_fns = collect_canonical_spec_functions(items, fn_sigs);
37
38 let mut warnings = Vec::new();
39 for item in items {
40 let TopLevel::FnDef(fd) = item else {
41 continue;
42 };
43 if spec_fns.contains(&fd.name) {
44 continue;
45 }
46 let Some(scc_members) = fn_to_scc.get(&fd.name) else {
47 continue;
48 };
49 let recursive_calls = count_non_tail_recursive_calls_body(&fd.body, scc_members);
50 if recursive_calls == 0 {
51 continue;
52 }
53 warnings.push(NonTailRecursionWarning {
54 fn_name: fd.name.clone(),
55 line: fd.line,
56 recursive_calls,
57 message: format!(
58 "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
59 fd.name, recursive_calls
60 ),
61 });
62 }
63 warnings
64}
65
66fn collect_canonical_spec_functions(
67 items: &[TopLevel],
68 fn_sigs: Option<&crate::verify_law::FnSigMap>,
69) -> HashSet<String> {
70 let Some(fn_sigs) = fn_sigs else {
71 return HashSet::new();
72 };
73
74 items.iter()
75 .filter_map(|item| match item {
76 TopLevel::Verify(v) => match &v.kind {
77 crate::ast::VerifyKind::Law(law) => canonical_spec_ref(&v.fn_name, law, fn_sigs)
78 .map(|spec_ref| spec_ref.spec_fn_name),
79 crate::ast::VerifyKind::Cases => None,
80 },
81 _ => None,
82 })
83 .collect()
84}
85
86fn count_non_tail_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>) -> usize {
87 body.stmts()
88 .iter()
89 .map(|stmt| count_non_tail_recursive_calls_stmt(stmt, recursive))
90 .sum()
91}
92
93fn count_non_tail_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>) -> usize {
94 match stmt {
95 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
96 count_non_tail_recursive_calls_expr(expr, recursive)
97 }
98 }
99}
100
101fn count_non_tail_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>) -> usize {
102 match expr {
103 Expr::FnCall(func, args) => {
104 let mut count = 0;
105 if let Some(callee) = dotted_name(func.as_ref())
106 && recursive.contains(&callee)
107 {
108 count += 1;
109 }
110 count
111 + count_non_tail_recursive_calls_expr(func, recursive)
112 + args
113 .iter()
114 .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
115 .sum::<usize>()
116 }
117 Expr::TailCall(boxed) => boxed
118 .1
119 .iter()
120 .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
121 .sum(),
122 Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
123 count_non_tail_recursive_calls_expr(obj, recursive)
124 }
125 Expr::BinOp(_, left, right) => {
126 count_non_tail_recursive_calls_expr(left, recursive)
127 + count_non_tail_recursive_calls_expr(right, recursive)
128 }
129 Expr::Match { subject, arms, .. } => {
130 count_non_tail_recursive_calls_expr(subject, recursive)
131 + arms
132 .iter()
133 .map(|arm| count_non_tail_recursive_calls_expr(&arm.body, recursive))
134 .sum::<usize>()
135 }
136 Expr::List(items) | Expr::Tuple(items) => items
137 .iter()
138 .map(|item| count_non_tail_recursive_calls_expr(item, recursive))
139 .sum(),
140 Expr::MapLiteral(entries) => entries
141 .iter()
142 .map(|(key, value)| {
143 count_non_tail_recursive_calls_expr(key, recursive)
144 + count_non_tail_recursive_calls_expr(value, recursive)
145 })
146 .sum(),
147 Expr::Constructor(_, maybe_arg) => maybe_arg
148 .as_deref()
149 .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
150 .unwrap_or(0),
151 Expr::InterpolatedStr(parts) => parts
152 .iter()
153 .map(|part| match part {
154 StrPart::Literal(_) => 0,
155 StrPart::Parsed(expr) => count_non_tail_recursive_calls_expr(expr, recursive),
156 })
157 .sum(),
158 Expr::RecordCreate { fields, .. } => fields
159 .iter()
160 .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
161 .sum(),
162 Expr::RecordUpdate { base, updates, .. } => {
163 count_non_tail_recursive_calls_expr(base, recursive)
164 + updates
165 .iter()
166 .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
167 .sum::<usize>()
168 }
169 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => 0,
170 }
171}
172
173fn dotted_name(expr: &Expr) -> Option<String> {
174 match expr {
175 Expr::Ident(name) => Some(name.clone()),
176 Expr::Attr(base, field) => {
177 let mut prefix = dotted_name(base)?;
178 prefix.push('.');
179 prefix.push_str(field);
180 Some(prefix)
181 }
182 _ => None,
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use crate::ast::TopLevel;
189 use crate::{parser::Parser, tco};
190 use crate::types::checker::run_type_check_full;
191
192 use super::*;
193
194 fn parse(src: &str) -> Vec<TopLevel> {
195 let mut lexer = crate::lexer::Lexer::new(src);
196 let tokens = lexer.tokenize().expect("lex failed");
197 let mut parser = Parser::new(tokens);
198 parser.parse().expect("parse failed")
199 }
200
201 #[test]
202 fn warns_for_recursive_calls_left_after_tco() {
203 let src = r#"
204fn fib(n: Int) -> Int
205 match n
206 0 -> 0
207 1 -> 1
208 _ -> fib(n - 1) + fib(n - 2)
209"#;
210 let mut items = parse(src);
211 tco::transform_program(&mut items);
212
213 let warnings = collect_non_tail_recursion_warnings(&items);
214 assert_eq!(warnings.len(), 1);
215 assert_eq!(warnings[0].fn_name, "fib");
216 assert_eq!(warnings[0].recursive_calls, 2);
217 assert_eq!(
218 warnings[0].message,
219 "non-tail recursion in 'fib' — 2 recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec"
220 );
221 }
222
223 #[test]
224 fn skips_pure_tail_recursion_after_tco() {
225 let src = r#"
226fn factorial(n: Int, acc: Int) -> Int
227 match n
228 0 -> acc
229 _ -> factorial(n - 1, acc * n)
230"#;
231 let mut items = parse(src);
232 tco::transform_program(&mut items);
233
234 let warnings = collect_non_tail_recursion_warnings(&items);
235 assert!(warnings.is_empty());
236 }
237
238 #[test]
239 fn skips_mutual_tail_recursion_after_tco() {
240 let src = r#"
241fn isEven(n: Int) -> Bool
242 match n
243 0 -> true
244 _ -> isOdd(n - 1)
245
246fn isOdd(n: Int) -> Bool
247 match n
248 0 -> false
249 _ -> isEven(n - 1)
250"#;
251 let mut items = parse(src);
252 tco::transform_program(&mut items);
253
254 let warnings = collect_non_tail_recursion_warnings(&items);
255 assert!(warnings.is_empty());
256 }
257
258 #[test]
259 fn skips_canonical_spec_functions() {
260 let src = r#"
261fn fib(n: Int) -> Int
262 fibSpec(n)
263
264fn fibSpec(n: Int) -> Int
265 match n
266 0 -> 0
267 1 -> 1
268 _ -> fibSpec(n - 1) + fibSpec(n - 2)
269
270verify fib law fibSpec
271 given n: Int = [0, 1, 2, 3]
272 fib(n) => fibSpec(n)
273"#;
274 let mut items = parse(src);
275 tco::transform_program(&mut items);
276 let tc = run_type_check_full(&items, None);
277
278 let warnings = collect_non_tail_recursion_warnings_with_sigs(&items, &tc.fn_sigs);
279 assert!(warnings.is_empty(), "expected spec function warning to be suppressed, got {warnings:?}");
280 }
281}