Skip to main content

aver/
tail_check.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5use crate::verify_law::canonical_spec_ref;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NonTailRecursionWarning {
9    pub fn_name: String,
10    pub line: usize,
11    pub recursive_calls: usize,
12    pub message: String,
13}
14
15pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
16    collect_non_tail_recursion_warnings_in(items, None)
17}
18
19pub fn collect_non_tail_recursion_warnings_with_sigs(
20    items: &[TopLevel],
21    fn_sigs: &crate::verify_law::FnSigMap,
22) -> Vec<NonTailRecursionWarning> {
23    collect_non_tail_recursion_warnings_in(items, Some(fn_sigs))
24}
25
26fn collect_non_tail_recursion_warnings_in(
27    items: &[TopLevel],
28    fn_sigs: Option<&crate::verify_law::FnSigMap>,
29) -> Vec<NonTailRecursionWarning> {
30    let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
31    for scc in call_graph::find_tco_groups(items) {
32        for name in &scc {
33            fn_to_scc.insert(name.clone(), scc.clone());
34        }
35    }
36    let spec_fns = collect_canonical_spec_functions(items, fn_sigs);
37
38    let mut warnings = Vec::new();
39    for item in items {
40        let TopLevel::FnDef(fd) = item else {
41            continue;
42        };
43        if spec_fns.contains(&fd.name) {
44            continue;
45        }
46        let Some(scc_members) = fn_to_scc.get(&fd.name) else {
47            continue;
48        };
49        let recursive_calls = count_non_tail_recursive_calls_body(&fd.body, scc_members);
50        if recursive_calls == 0 {
51            continue;
52        }
53        warnings.push(NonTailRecursionWarning {
54            fn_name: fd.name.clone(),
55            line: fd.line,
56            recursive_calls,
57            message: format!(
58                "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
59                fd.name, recursive_calls
60            ),
61        });
62    }
63    warnings
64}
65
66fn collect_canonical_spec_functions(
67    items: &[TopLevel],
68    fn_sigs: Option<&crate::verify_law::FnSigMap>,
69) -> HashSet<String> {
70    let Some(fn_sigs) = fn_sigs else {
71        return HashSet::new();
72    };
73
74    items
75        .iter()
76        .filter_map(|item| match item {
77            TopLevel::Verify(v) => match &v.kind {
78                crate::ast::VerifyKind::Law(law) => canonical_spec_ref(&v.fn_name, law, fn_sigs)
79                    .map(|spec_ref| spec_ref.spec_fn_name),
80                crate::ast::VerifyKind::Cases => None,
81            },
82            _ => None,
83        })
84        .collect()
85}
86
87fn count_non_tail_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>) -> usize {
88    body.stmts()
89        .iter()
90        .map(|stmt| count_non_tail_recursive_calls_stmt(stmt, recursive))
91        .sum()
92}
93
94fn count_non_tail_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>) -> usize {
95    match stmt {
96        Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
97            count_non_tail_recursive_calls_expr(expr, recursive)
98        }
99    }
100}
101
102fn count_non_tail_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>) -> usize {
103    match expr {
104        Expr::FnCall(func, args) => {
105            let mut count = 0;
106            if let Some(callee) = dotted_name(func.as_ref())
107                && recursive.contains(&callee)
108            {
109                count += 1;
110            }
111            count
112                + count_non_tail_recursive_calls_expr(func, recursive)
113                + args
114                    .iter()
115                    .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
116                    .sum::<usize>()
117        }
118        Expr::TailCall(boxed) => boxed
119            .1
120            .iter()
121            .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
122            .sum(),
123        Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
124            count_non_tail_recursive_calls_expr(obj, recursive)
125        }
126        Expr::BinOp(_, left, right) => {
127            count_non_tail_recursive_calls_expr(left, recursive)
128                + count_non_tail_recursive_calls_expr(right, recursive)
129        }
130        Expr::Match { subject, arms, .. } => {
131            count_non_tail_recursive_calls_expr(subject, recursive)
132                + arms
133                    .iter()
134                    .map(|arm| count_non_tail_recursive_calls_expr(&arm.body, recursive))
135                    .sum::<usize>()
136        }
137        Expr::List(items) | Expr::Tuple(items) => items
138            .iter()
139            .map(|item| count_non_tail_recursive_calls_expr(item, recursive))
140            .sum(),
141        Expr::MapLiteral(entries) => entries
142            .iter()
143            .map(|(key, value)| {
144                count_non_tail_recursive_calls_expr(key, recursive)
145                    + count_non_tail_recursive_calls_expr(value, recursive)
146            })
147            .sum(),
148        Expr::Constructor(_, maybe_arg) => maybe_arg
149            .as_deref()
150            .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
151            .unwrap_or(0),
152        Expr::InterpolatedStr(parts) => parts
153            .iter()
154            .map(|part| match part {
155                StrPart::Literal(_) => 0,
156                StrPart::Parsed(expr) => count_non_tail_recursive_calls_expr(expr, recursive),
157            })
158            .sum(),
159        Expr::RecordCreate { fields, .. } => fields
160            .iter()
161            .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
162            .sum(),
163        Expr::RecordUpdate { base, updates, .. } => {
164            count_non_tail_recursive_calls_expr(base, recursive)
165                + updates
166                    .iter()
167                    .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
168                    .sum::<usize>()
169        }
170        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => 0,
171    }
172}
173
174fn dotted_name(expr: &Expr) -> Option<String> {
175    match expr {
176        Expr::Ident(name) => Some(name.clone()),
177        Expr::Attr(base, field) => {
178            let mut prefix = dotted_name(base)?;
179            prefix.push('.');
180            prefix.push_str(field);
181            Some(prefix)
182        }
183        _ => None,
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use crate::ast::TopLevel;
190    use crate::types::checker::run_type_check_full;
191    use crate::{parser::Parser, tco};
192
193    use super::*;
194
195    fn parse(src: &str) -> Vec<TopLevel> {
196        let mut lexer = crate::lexer::Lexer::new(src);
197        let tokens = lexer.tokenize().expect("lex failed");
198        let mut parser = Parser::new(tokens);
199        parser.parse().expect("parse failed")
200    }
201
202    #[test]
203    fn warns_for_recursive_calls_left_after_tco() {
204        let src = r#"
205fn fib(n: Int) -> Int
206    match n
207        0 -> 0
208        1 -> 1
209        _ -> fib(n - 1) + fib(n - 2)
210"#;
211        let mut items = parse(src);
212        tco::transform_program(&mut items);
213
214        let warnings = collect_non_tail_recursion_warnings(&items);
215        assert_eq!(warnings.len(), 1);
216        assert_eq!(warnings[0].fn_name, "fib");
217        assert_eq!(warnings[0].recursive_calls, 2);
218        assert_eq!(
219            warnings[0].message,
220            "non-tail recursion in 'fib' — 2 recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec"
221        );
222    }
223
224    #[test]
225    fn skips_pure_tail_recursion_after_tco() {
226        let src = r#"
227fn factorial(n: Int, acc: Int) -> Int
228    match n
229        0 -> acc
230        _ -> factorial(n - 1, acc * n)
231"#;
232        let mut items = parse(src);
233        tco::transform_program(&mut items);
234
235        let warnings = collect_non_tail_recursion_warnings(&items);
236        assert!(warnings.is_empty());
237    }
238
239    #[test]
240    fn skips_mutual_tail_recursion_after_tco() {
241        let src = r#"
242fn isEven(n: Int) -> Bool
243    match n
244        0 -> true
245        _ -> isOdd(n - 1)
246
247fn isOdd(n: Int) -> Bool
248    match n
249        0 -> false
250        _ -> isEven(n - 1)
251"#;
252        let mut items = parse(src);
253        tco::transform_program(&mut items);
254
255        let warnings = collect_non_tail_recursion_warnings(&items);
256        assert!(warnings.is_empty());
257    }
258
259    #[test]
260    fn skips_canonical_spec_functions() {
261        let src = r#"
262fn fib(n: Int) -> Int
263    fibSpec(n)
264
265fn fibSpec(n: Int) -> Int
266    match n
267        0 -> 0
268        1 -> 1
269        _ -> fibSpec(n - 1) + fibSpec(n - 2)
270
271verify fib law fibSpec
272    given n: Int = [0, 1, 2, 3]
273    fib(n) => fibSpec(n)
274"#;
275        let mut items = parse(src);
276        tco::transform_program(&mut items);
277        let tc = run_type_check_full(&items, None);
278
279        let warnings = collect_non_tail_recursion_warnings_with_sigs(&items, &tc.fn_sigs);
280        assert!(
281            warnings.is_empty(),
282            "expected spec function warning to be suppressed, got {warnings:?}"
283        );
284    }
285}