Skip to main content

aver/
tail_check.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Spanned, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5#[cfg(feature = "runtime")]
6use crate::verify_law::canonical_spec_ref;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct NonTailRecursionWarning {
10    pub fn_name: String,
11    pub line: usize,
12    pub recursive_calls: usize,
13    /// Source lines of the non-tail recursive callsites.
14    pub callsite_lines: Vec<usize>,
15    pub message: String,
16}
17
18pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
19    let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
20    for scc in call_graph::find_tco_groups(items) {
21        for name in &scc {
22            fn_to_scc.insert(name.clone(), scc.clone());
23        }
24    }
25
26    let mut warnings = Vec::new();
27    for item in items {
28        let TopLevel::FnDef(fd) = item else {
29            continue;
30        };
31        let Some(scc_members) = fn_to_scc.get(&fd.name) else {
32            continue;
33        };
34        let callsite_lines: Vec<usize> =
35            collect_non_tail_recursive_call_lines_body(&fd.body, scc_members)
36                .into_iter()
37                .filter(|&ln| ln >= fd.line)
38                .collect();
39        if callsite_lines.is_empty() {
40            continue;
41        }
42        let recursive_calls = callsite_lines.len();
43        warnings.push(NonTailRecursionWarning {
44            fn_name: fd.name.clone(),
45            line: fd.line,
46            recursive_calls,
47            callsite_lines,
48            message: format!(
49                "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
50                fd.name, recursive_calls
51            ),
52        });
53    }
54    warnings
55}
56
57#[cfg(feature = "runtime")]
58pub fn collect_non_tail_recursion_warnings_with_sigs(
59    items: &[TopLevel],
60    fn_sigs: &crate::verify_law::FnSigMap,
61) -> Vec<NonTailRecursionWarning> {
62    collect_non_tail_recursion_warnings_in(items, Some(fn_sigs))
63}
64
65#[cfg(feature = "runtime")]
66fn collect_non_tail_recursion_warnings_in(
67    items: &[TopLevel],
68    fn_sigs: Option<&crate::verify_law::FnSigMap>,
69) -> Vec<NonTailRecursionWarning> {
70    let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
71    for scc in call_graph::find_tco_groups(items) {
72        for name in &scc {
73            fn_to_scc.insert(name.clone(), scc.clone());
74        }
75    }
76    let spec_fns = collect_canonical_spec_functions(items, fn_sigs);
77
78    let mut warnings = Vec::new();
79    for item in items {
80        let TopLevel::FnDef(fd) = item else {
81            continue;
82        };
83        if spec_fns.contains(&fd.name) {
84            continue;
85        }
86        let Some(scc_members) = fn_to_scc.get(&fd.name) else {
87            continue;
88        };
89        let callsite_lines: Vec<usize> =
90            collect_non_tail_recursive_call_lines_body(&fd.body, scc_members)
91                .into_iter()
92                .filter(|&ln| ln >= fd.line)
93                .collect();
94        if callsite_lines.is_empty() {
95            continue;
96        }
97        let recursive_calls = callsite_lines.len();
98        warnings.push(NonTailRecursionWarning {
99            fn_name: fd.name.clone(),
100            line: fd.line,
101            recursive_calls,
102            callsite_lines,
103            message: format!(
104                "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
105                fd.name, recursive_calls
106            ),
107        });
108    }
109    warnings
110}
111
112#[cfg(feature = "runtime")]
113fn collect_canonical_spec_functions(
114    items: &[TopLevel],
115    fn_sigs: Option<&crate::verify_law::FnSigMap>,
116) -> HashSet<String> {
117    let Some(fn_sigs) = fn_sigs else {
118        return HashSet::new();
119    };
120
121    items
122        .iter()
123        .filter_map(|item| match item {
124            TopLevel::Verify(v) => match &v.kind {
125                crate::ast::VerifyKind::Law(law) => canonical_spec_ref(&v.fn_name, law, fn_sigs)
126                    .map(|spec_ref| spec_ref.spec_fn_name),
127                crate::ast::VerifyKind::Cases => None,
128            },
129            _ => None,
130        })
131        .collect()
132}
133
134fn collect_non_tail_recursive_call_lines_body(
135    body: &FnBody,
136    recursive: &HashSet<String>,
137) -> Vec<usize> {
138    let mut lines = Vec::new();
139    for stmt in body.stmts() {
140        collect_non_tail_recursive_call_lines_stmt(stmt, recursive, &mut lines);
141    }
142    lines
143}
144
145fn collect_non_tail_recursive_call_lines_stmt(
146    stmt: &Stmt,
147    recursive: &HashSet<String>,
148    out: &mut Vec<usize>,
149) {
150    match stmt {
151        Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
152            collect_non_tail_recursive_call_lines_expr(expr, recursive, out);
153        }
154    }
155}
156
157fn collect_non_tail_recursive_call_lines_expr(
158    expr: &Spanned<Expr>,
159    recursive: &HashSet<String>,
160    out: &mut Vec<usize>,
161) {
162    match &expr.node {
163        Expr::FnCall(func, args) => {
164            if let Some(callee) = dotted_name(func.as_ref())
165                && recursive.contains(&callee)
166            {
167                out.push(expr.line);
168            }
169            collect_non_tail_recursive_call_lines_expr(func, recursive, out);
170            for arg in args {
171                collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
172            }
173        }
174        Expr::TailCall(boxed) => {
175            for arg in &boxed.args {
176                collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
177            }
178        }
179        Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
180            collect_non_tail_recursive_call_lines_expr(obj, recursive, out);
181        }
182        Expr::BinOp(_, left, right) => {
183            collect_non_tail_recursive_call_lines_expr(left, recursive, out);
184            collect_non_tail_recursive_call_lines_expr(right, recursive, out);
185        }
186        Expr::Match { subject, arms } => {
187            collect_non_tail_recursive_call_lines_expr(subject, recursive, out);
188            for arm in arms {
189                collect_non_tail_recursive_call_lines_expr(&arm.body, recursive, out);
190            }
191        }
192        Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
193            for item in items {
194                collect_non_tail_recursive_call_lines_expr(item, recursive, out);
195            }
196        }
197        Expr::MapLiteral(entries) => {
198            for (key, value) in entries {
199                collect_non_tail_recursive_call_lines_expr(key, recursive, out);
200                collect_non_tail_recursive_call_lines_expr(value, recursive, out);
201            }
202        }
203        Expr::Constructor(_, maybe_arg) => {
204            if let Some(arg) = maybe_arg.as_deref() {
205                collect_non_tail_recursive_call_lines_expr(arg, recursive, out);
206            }
207        }
208        Expr::InterpolatedStr(parts) => {
209            for part in parts {
210                if let StrPart::Parsed(expr) = part {
211                    collect_non_tail_recursive_call_lines_expr(expr, recursive, out);
212                }
213            }
214        }
215        Expr::RecordCreate { fields, .. } => {
216            for (_, val) in fields {
217                collect_non_tail_recursive_call_lines_expr(val, recursive, out);
218            }
219        }
220        Expr::RecordUpdate { base, updates, .. } => {
221            collect_non_tail_recursive_call_lines_expr(base, recursive, out);
222            for (_, val) in updates {
223                collect_non_tail_recursive_call_lines_expr(val, recursive, out);
224            }
225        }
226        _ => {}
227    }
228}
229
230fn dotted_name(expr: &Spanned<Expr>) -> Option<String> {
231    match &expr.node {
232        Expr::Ident(name) => Some(name.clone()),
233        Expr::Attr(base, field) => {
234            let mut prefix = dotted_name(base)?;
235            prefix.push('.');
236            prefix.push_str(field);
237            Some(prefix)
238        }
239        _ => None,
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use crate::ast::TopLevel;
246    use crate::ir::TypecheckMode;
247    use crate::parser::Parser;
248
249    use super::*;
250
251    fn parse(src: &str) -> Vec<TopLevel> {
252        let mut lexer = crate::lexer::Lexer::new(src);
253        let tokens = lexer.tokenize().expect("lex failed");
254        let mut parser = Parser::new(tokens);
255        parser.parse().expect("parse failed")
256    }
257
258    #[test]
259    fn warns_for_recursive_calls_left_after_tco() {
260        let src = r#"
261fn fib(n: Int) -> Int
262    match n
263        0 -> 0
264        1 -> 1
265        _ -> fib(n - 1) + fib(n - 2)
266"#;
267        let mut items = parse(src);
268        crate::ir::pipeline::tco(&mut items);
269
270        let warnings = collect_non_tail_recursion_warnings(&items);
271        assert_eq!(warnings.len(), 1);
272        assert_eq!(warnings[0].fn_name, "fib");
273        assert_eq!(warnings[0].recursive_calls, 2);
274        assert_eq!(
275            warnings[0].message,
276            "non-tail recursion in 'fib' — 2 recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec"
277        );
278    }
279
280    #[test]
281    fn skips_pure_tail_recursion_after_tco() {
282        let src = r#"
283fn factorial(n: Int, acc: Int) -> Int
284    match n
285        0 -> acc
286        _ -> factorial(n - 1, acc * n)
287"#;
288        let mut items = parse(src);
289        crate::ir::pipeline::tco(&mut items);
290
291        let warnings = collect_non_tail_recursion_warnings(&items);
292        assert!(warnings.is_empty());
293    }
294
295    #[test]
296    fn skips_mutual_tail_recursion_after_tco() {
297        let src = r#"
298fn isEven(n: Int) -> Bool
299    match n
300        0 -> true
301        _ -> isOdd(n - 1)
302
303fn isOdd(n: Int) -> Bool
304    match n
305        0 -> false
306        _ -> isEven(n - 1)
307"#;
308        let mut items = parse(src);
309        crate::ir::pipeline::tco(&mut items);
310
311        let warnings = collect_non_tail_recursion_warnings(&items);
312        assert!(warnings.is_empty());
313    }
314
315    #[test]
316    fn skips_canonical_spec_functions() {
317        let src = r#"
318fn fib(n: Int) -> Int
319    fibSpec(n)
320
321fn fibSpec(n: Int) -> Int
322    match n
323        0 -> 0
324        1 -> 1
325        _ -> fibSpec(n - 1) + fibSpec(n - 2)
326
327verify fib law fibSpec
328    given n: Int = [0, 1, 2, 3]
329    fib(n) => fibSpec(n)
330"#;
331        let mut items = parse(src);
332        crate::ir::pipeline::tco(&mut items);
333        let tc = crate::ir::pipeline::typecheck(&items, &TypecheckMode::Full { base_dir: None });
334
335        let warnings = collect_non_tail_recursion_warnings_with_sigs(&items, &tc.fn_sigs);
336        assert!(
337            warnings.is_empty(),
338            "expected spec function warning to be suppressed, got {warnings:?}"
339        );
340    }
341}