1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5use crate::verify_law::canonical_spec_ref;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NonTailRecursionWarning {
9 pub fn_name: String,
10 pub line: usize,
11 pub recursive_calls: usize,
12 pub message: String,
13}
14
15pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
16 collect_non_tail_recursion_warnings_in(items, None)
17}
18
19pub fn collect_non_tail_recursion_warnings_with_sigs(
20 items: &[TopLevel],
21 fn_sigs: &crate::verify_law::FnSigMap,
22) -> Vec<NonTailRecursionWarning> {
23 collect_non_tail_recursion_warnings_in(items, Some(fn_sigs))
24}
25
26fn collect_non_tail_recursion_warnings_in(
27 items: &[TopLevel],
28 fn_sigs: Option<&crate::verify_law::FnSigMap>,
29) -> Vec<NonTailRecursionWarning> {
30 let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
31 for scc in call_graph::find_tco_groups(items) {
32 for name in &scc {
33 fn_to_scc.insert(name.clone(), scc.clone());
34 }
35 }
36 let spec_fns = collect_canonical_spec_functions(items, fn_sigs);
37
38 let mut warnings = Vec::new();
39 for item in items {
40 let TopLevel::FnDef(fd) = item else {
41 continue;
42 };
43 if spec_fns.contains(&fd.name) {
44 continue;
45 }
46 let Some(scc_members) = fn_to_scc.get(&fd.name) else {
47 continue;
48 };
49 let recursive_calls = count_non_tail_recursive_calls_body(&fd.body, scc_members);
50 if recursive_calls == 0 {
51 continue;
52 }
53 warnings.push(NonTailRecursionWarning {
54 fn_name: fd.name.clone(),
55 line: fd.line,
56 recursive_calls,
57 message: format!(
58 "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
59 fd.name, recursive_calls
60 ),
61 });
62 }
63 warnings
64}
65
66fn collect_canonical_spec_functions(
67 items: &[TopLevel],
68 fn_sigs: Option<&crate::verify_law::FnSigMap>,
69) -> HashSet<String> {
70 let Some(fn_sigs) = fn_sigs else {
71 return HashSet::new();
72 };
73
74 items
75 .iter()
76 .filter_map(|item| match item {
77 TopLevel::Verify(v) => match &v.kind {
78 crate::ast::VerifyKind::Law(law) => canonical_spec_ref(&v.fn_name, law, fn_sigs)
79 .map(|spec_ref| spec_ref.spec_fn_name),
80 crate::ast::VerifyKind::Cases => None,
81 },
82 _ => None,
83 })
84 .collect()
85}
86
87fn count_non_tail_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>) -> usize {
88 body.stmts()
89 .iter()
90 .map(|stmt| count_non_tail_recursive_calls_stmt(stmt, recursive))
91 .sum()
92}
93
94fn count_non_tail_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>) -> usize {
95 match stmt {
96 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
97 count_non_tail_recursive_calls_expr(expr, recursive)
98 }
99 }
100}
101
102fn count_non_tail_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>) -> usize {
103 match expr {
104 Expr::FnCall(func, args) => {
105 let mut count = 0;
106 if let Some(callee) = dotted_name(func.as_ref())
107 && recursive.contains(&callee)
108 {
109 count += 1;
110 }
111 count
112 + count_non_tail_recursive_calls_expr(func, recursive)
113 + args
114 .iter()
115 .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
116 .sum::<usize>()
117 }
118 Expr::TailCall(boxed) => boxed
119 .1
120 .iter()
121 .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
122 .sum(),
123 Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
124 count_non_tail_recursive_calls_expr(obj, recursive)
125 }
126 Expr::BinOp(_, left, right) => {
127 count_non_tail_recursive_calls_expr(left, recursive)
128 + count_non_tail_recursive_calls_expr(right, recursive)
129 }
130 Expr::Match { subject, arms, .. } => {
131 count_non_tail_recursive_calls_expr(subject, recursive)
132 + arms
133 .iter()
134 .map(|arm| count_non_tail_recursive_calls_expr(&arm.body, recursive))
135 .sum::<usize>()
136 }
137 Expr::List(items) | Expr::Tuple(items) => items
138 .iter()
139 .map(|item| count_non_tail_recursive_calls_expr(item, recursive))
140 .sum(),
141 Expr::MapLiteral(entries) => entries
142 .iter()
143 .map(|(key, value)| {
144 count_non_tail_recursive_calls_expr(key, recursive)
145 + count_non_tail_recursive_calls_expr(value, recursive)
146 })
147 .sum(),
148 Expr::Constructor(_, maybe_arg) => maybe_arg
149 .as_deref()
150 .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
151 .unwrap_or(0),
152 Expr::InterpolatedStr(parts) => parts
153 .iter()
154 .map(|part| match part {
155 StrPart::Literal(_) => 0,
156 StrPart::Parsed(expr) => count_non_tail_recursive_calls_expr(expr, recursive),
157 })
158 .sum(),
159 Expr::RecordCreate { fields, .. } => fields
160 .iter()
161 .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
162 .sum(),
163 Expr::RecordUpdate { base, updates, .. } => {
164 count_non_tail_recursive_calls_expr(base, recursive)
165 + updates
166 .iter()
167 .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
168 .sum::<usize>()
169 }
170 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => 0,
171 }
172}
173
174fn dotted_name(expr: &Expr) -> Option<String> {
175 match expr {
176 Expr::Ident(name) => Some(name.clone()),
177 Expr::Attr(base, field) => {
178 let mut prefix = dotted_name(base)?;
179 prefix.push('.');
180 prefix.push_str(field);
181 Some(prefix)
182 }
183 _ => None,
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use crate::ast::TopLevel;
190 use crate::types::checker::run_type_check_full;
191 use crate::{parser::Parser, tco};
192
193 use super::*;
194
195 fn parse(src: &str) -> Vec<TopLevel> {
196 let mut lexer = crate::lexer::Lexer::new(src);
197 let tokens = lexer.tokenize().expect("lex failed");
198 let mut parser = Parser::new(tokens);
199 parser.parse().expect("parse failed")
200 }
201
202 #[test]
203 fn warns_for_recursive_calls_left_after_tco() {
204 let src = r#"
205fn fib(n: Int) -> Int
206 match n
207 0 -> 0
208 1 -> 1
209 _ -> fib(n - 1) + fib(n - 2)
210"#;
211 let mut items = parse(src);
212 tco::transform_program(&mut items);
213
214 let warnings = collect_non_tail_recursion_warnings(&items);
215 assert_eq!(warnings.len(), 1);
216 assert_eq!(warnings[0].fn_name, "fib");
217 assert_eq!(warnings[0].recursive_calls, 2);
218 assert_eq!(
219 warnings[0].message,
220 "non-tail recursion in 'fib' — 2 recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec"
221 );
222 }
223
224 #[test]
225 fn skips_pure_tail_recursion_after_tco() {
226 let src = r#"
227fn factorial(n: Int, acc: Int) -> Int
228 match n
229 0 -> acc
230 _ -> factorial(n - 1, acc * n)
231"#;
232 let mut items = parse(src);
233 tco::transform_program(&mut items);
234
235 let warnings = collect_non_tail_recursion_warnings(&items);
236 assert!(warnings.is_empty());
237 }
238
239 #[test]
240 fn skips_mutual_tail_recursion_after_tco() {
241 let src = r#"
242fn isEven(n: Int) -> Bool
243 match n
244 0 -> true
245 _ -> isOdd(n - 1)
246
247fn isOdd(n: Int) -> Bool
248 match n
249 0 -> false
250 _ -> isEven(n - 1)
251"#;
252 let mut items = parse(src);
253 tco::transform_program(&mut items);
254
255 let warnings = collect_non_tail_recursion_warnings(&items);
256 assert!(warnings.is_empty());
257 }
258
259 #[test]
260 fn skips_canonical_spec_functions() {
261 let src = r#"
262fn fib(n: Int) -> Int
263 fibSpec(n)
264
265fn fibSpec(n: Int) -> Int
266 match n
267 0 -> 0
268 1 -> 1
269 _ -> fibSpec(n - 1) + fibSpec(n - 2)
270
271verify fib law fibSpec
272 given n: Int = [0, 1, 2, 3]
273 fib(n) => fibSpec(n)
274"#;
275 let mut items = parse(src);
276 tco::transform_program(&mut items);
277 let tc = run_type_check_full(&items, None);
278
279 let warnings = collect_non_tail_recursion_warnings_with_sigs(&items, &tc.fn_sigs);
280 assert!(
281 warnings.is_empty(),
282 "expected spec function warning to be suppressed, got {warnings:?}"
283 );
284 }
285}