Skip to main content

aver/
tail_check.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Spanned, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5use crate::verify_law::canonical_spec_ref;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NonTailRecursionWarning {
9    pub fn_name: String,
10    pub line: usize,
11    pub recursive_calls: usize,
12    /// Source lines of the non-tail recursive callsites.
13    pub callsite_lines: Vec<usize>,
14    pub message: String,
15}
16
17pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
18    collect_non_tail_recursion_warnings_in(items, None)
19}
20
21pub fn collect_non_tail_recursion_warnings_with_sigs(
22    items: &[TopLevel],
23    fn_sigs: &crate::verify_law::FnSigMap,
24) -> Vec<NonTailRecursionWarning> {
25    collect_non_tail_recursion_warnings_in(items, Some(fn_sigs))
26}
27
28fn collect_non_tail_recursion_warnings_in(
29    items: &[TopLevel],
30    fn_sigs: Option<&crate::verify_law::FnSigMap>,
31) -> Vec<NonTailRecursionWarning> {
32    let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
33    for scc in call_graph::find_tco_groups(items) {
34        for name in &scc {
35            fn_to_scc.insert(name.clone(), scc.clone());
36        }
37    }
38    let spec_fns = collect_canonical_spec_functions(items, fn_sigs);
39
40    let mut warnings = Vec::new();
41    for item in items {
42        let TopLevel::FnDef(fd) = item else {
43            continue;
44        };
45        if spec_fns.contains(&fd.name) {
46            continue;
47        }
48        let Some(scc_members) = fn_to_scc.get(&fd.name) else {
49            continue;
50        };
51        let callsite_lines: Vec<usize> =
52            collect_non_tail_recursive_call_lines_body(&fd.body, scc_members)
53                .into_iter()
54                .filter(|&ln| ln >= fd.line)
55                .collect();
56        if callsite_lines.is_empty() {
57            continue;
58        }
59        let recursive_calls = callsite_lines.len();
60        warnings.push(NonTailRecursionWarning {
61            fn_name: fd.name.clone(),
62            line: fd.line,
63            recursive_calls,
64            callsite_lines,
65            message: format!(
66                "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
67                fd.name, recursive_calls
68            ),
69        });
70    }
71    warnings
72}
73
74fn collect_canonical_spec_functions(
75    items: &[TopLevel],
76    fn_sigs: Option<&crate::verify_law::FnSigMap>,
77) -> HashSet<String> {
78    let Some(fn_sigs) = fn_sigs else {
79        return HashSet::new();
80    };
81
82    items
83        .iter()
84        .filter_map(|item| match item {
85            TopLevel::Verify(v) => match &v.kind {
86                crate::ast::VerifyKind::Law(law) => canonical_spec_ref(&v.fn_name, law, fn_sigs)
87                    .map(|spec_ref| spec_ref.spec_fn_name),
88                crate::ast::VerifyKind::Cases => None,
89            },
90            _ => None,
91        })
92        .collect()
93}
94
95fn collect_non_tail_recursive_call_lines_body(
96    body: &FnBody,
97    recursive: &HashSet<String>,
98) -> Vec<usize> {
99    let mut lines = Vec::new();
100    for stmt in body.stmts() {
101        collect_non_tail_recursive_call_lines_stmt(stmt, recursive, &mut lines);
102    }
103    lines
104}
105
106fn collect_non_tail_recursive_call_lines_stmt(
107    stmt: &Stmt,
108    recursive: &HashSet<String>,
109    out: &mut Vec<usize>,
110) {
111    match stmt {
112        Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
113            collect_non_tail_recursive_call_lines_expr(expr, recursive, out);
114        }
115    }
116}
117
118fn collect_non_tail_recursive_call_lines_expr(
119    expr: &Spanned<Expr>,
120    recursive: &HashSet<String>,
121    out: &mut Vec<usize>,
122) {
123    match &expr.node {
124        Expr::FnCall(func, args) => {
125            if let Some(callee) = dotted_name(func.as_ref())
126                && recursive.contains(&callee)
127            {
128                out.push(expr.line);
129            }
130            collect_non_tail_recursive_call_lines_expr(func, recursive, out);
131            for arg in args {
132                collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
133            }
134        }
135        Expr::TailCall(boxed) => {
136            for arg in &boxed.1 {
137                collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
138            }
139        }
140        Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
141            collect_non_tail_recursive_call_lines_expr(obj, recursive, out);
142        }
143        Expr::BinOp(_, left, right) => {
144            collect_non_tail_recursive_call_lines_expr(left, recursive, out);
145            collect_non_tail_recursive_call_lines_expr(right, recursive, out);
146        }
147        Expr::Match { subject, arms } => {
148            collect_non_tail_recursive_call_lines_expr(subject, recursive, out);
149            for arm in arms {
150                collect_non_tail_recursive_call_lines_expr(&arm.body, recursive, out);
151            }
152        }
153        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
154            for item in items {
155                collect_non_tail_recursive_call_lines_expr(item, recursive, out);
156            }
157        }
158        Expr::MapLiteral(entries) => {
159            for (key, value) in entries {
160                collect_non_tail_recursive_call_lines_expr(key, recursive, out);
161                collect_non_tail_recursive_call_lines_expr(value, recursive, out);
162            }
163        }
164        Expr::Constructor(_, maybe_arg) => {
165            if let Some(arg) = maybe_arg.as_deref() {
166                collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
167            }
168        }
169        Expr::InterpolatedStr(parts) => {
170            for part in parts {
171                if let StrPart::Parsed(expr) = part {
172                    collect_non_tail_recursive_call_lines_expr(expr, recursive, out);
173                }
174            }
175        }
176        Expr::RecordCreate { fields, .. } => {
177            for (_, val) in fields {
178                collect_non_tail_recursive_call_lines_expr(val, recursive, out);
179            }
180        }
181        Expr::RecordUpdate { base, updates, .. } => {
182            collect_non_tail_recursive_call_lines_expr(base, recursive, out);
183            for (_, val) in updates {
184                collect_non_tail_recursive_call_lines_expr(val, recursive, out);
185            }
186        }
187        _ => {}
188    }
189}
190
191fn dotted_name(expr: &Spanned<Expr>) -> Option<String> {
192    match &expr.node {
193        Expr::Ident(name) => Some(name.clone()),
194        Expr::Attr(base, field) => {
195            let mut prefix = dotted_name(base)?;
196            prefix.push('.');
197            prefix.push_str(field);
198            Some(prefix)
199        }
200        _ => None,
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use crate::ast::TopLevel;
207    use crate::types::checker::run_type_check_full;
208    use crate::{parser::Parser, tco};
209
210    use super::*;
211
212    fn parse(src: &str) -> Vec<TopLevel> {
213        let mut lexer = crate::lexer::Lexer::new(src);
214        let tokens = lexer.tokenize().expect("lex failed");
215        let mut parser = Parser::new(tokens);
216        parser.parse().expect("parse failed")
217    }
218
219    #[test]
220    fn warns_for_recursive_calls_left_after_tco() {
221        let src = r#"
222fn fib(n: Int) -> Int
223    match n
224        0 -> 0
225        1 -> 1
226        _ -> fib(n - 1) + fib(n - 2)
227"#;
228        let mut items = parse(src);
229        tco::transform_program(&mut items);
230
231        let warnings = collect_non_tail_recursion_warnings(&items);
232        assert_eq!(warnings.len(), 1);
233        assert_eq!(warnings[0].fn_name, "fib");
234        assert_eq!(warnings[0].recursive_calls, 2);
235        assert_eq!(
236            warnings[0].message,
237            "non-tail recursion in 'fib' — 2 recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec"
238        );
239    }
240
241    #[test]
242    fn skips_pure_tail_recursion_after_tco() {
243        let src = r#"
244fn factorial(n: Int, acc: Int) -> Int
245    match n
246        0 -> acc
247        _ -> factorial(n - 1, acc * n)
248"#;
249        let mut items = parse(src);
250        tco::transform_program(&mut items);
251
252        let warnings = collect_non_tail_recursion_warnings(&items);
253        assert!(warnings.is_empty());
254    }
255
256    #[test]
257    fn skips_mutual_tail_recursion_after_tco() {
258        let src = r#"
259fn isEven(n: Int) -> Bool
260    match n
261        0 -> true
262        _ -> isOdd(n - 1)
263
264fn isOdd(n: Int) -> Bool
265    match n
266        0 -> false
267        _ -> isEven(n - 1)
268"#;
269        let mut items = parse(src);
270        tco::transform_program(&mut items);
271
272        let warnings = collect_non_tail_recursion_warnings(&items);
273        assert!(warnings.is_empty());
274    }
275
276    #[test]
277    fn skips_canonical_spec_functions() {
278        let src = r#"
279fn fib(n: Int) -> Int
280    fibSpec(n)
281
282fn fibSpec(n: Int) -> Int
283    match n
284        0 -> 0
285        1 -> 1
286        _ -> fibSpec(n - 1) + fibSpec(n - 2)
287
288verify fib law fibSpec
289    given n: Int = [0, 1, 2, 3]
290    fib(n) => fibSpec(n)
291"#;
292        let mut items = parse(src);
293        tco::transform_program(&mut items);
294        let tc = run_type_check_full(&items, None);
295
296        let warnings = collect_non_tail_recursion_warnings_with_sigs(&items, &tc.fn_sigs);
297        assert!(
298            warnings.is_empty(),
299            "expected spec function warning to be suppressed, got {warnings:?}"
300        );
301    }
302}