Skip to main content

aver/
tail_check.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5use crate::verify_law::canonical_spec_ref;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NonTailRecursionWarning {
9    pub fn_name: String,
10    pub line: usize,
11    pub recursive_calls: usize,
12    pub message: String,
13}
14
15pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
16    collect_non_tail_recursion_warnings_in(items, None)
17}
18
19pub fn collect_non_tail_recursion_warnings_with_sigs(
20    items: &[TopLevel],
21    fn_sigs: &crate::verify_law::FnSigMap,
22) -> Vec<NonTailRecursionWarning> {
23    collect_non_tail_recursion_warnings_in(items, Some(fn_sigs))
24}
25
26fn collect_non_tail_recursion_warnings_in(
27    items: &[TopLevel],
28    fn_sigs: Option<&crate::verify_law::FnSigMap>,
29) -> Vec<NonTailRecursionWarning> {
30    let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
31    for scc in call_graph::find_tco_groups(items) {
32        for name in &scc {
33            fn_to_scc.insert(name.clone(), scc.clone());
34        }
35    }
36    let spec_fns = collect_canonical_spec_functions(items, fn_sigs);
37
38    let mut warnings = Vec::new();
39    for item in items {
40        let TopLevel::FnDef(fd) = item else {
41            continue;
42        };
43        if spec_fns.contains(&fd.name) {
44            continue;
45        }
46        let Some(scc_members) = fn_to_scc.get(&fd.name) else {
47            continue;
48        };
49        let recursive_calls = count_non_tail_recursive_calls_body(&fd.body, scc_members);
50        if recursive_calls == 0 {
51            continue;
52        }
53        warnings.push(NonTailRecursionWarning {
54            fn_name: fd.name.clone(),
55            line: fd.line,
56            recursive_calls,
57            message: format!(
58                "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec",
59                fd.name, recursive_calls
60            ),
61        });
62    }
63    warnings
64}
65
66fn collect_canonical_spec_functions(
67    items: &[TopLevel],
68    fn_sigs: Option<&crate::verify_law::FnSigMap>,
69) -> HashSet<String> {
70    let Some(fn_sigs) = fn_sigs else {
71        return HashSet::new();
72    };
73
74    items.iter()
75        .filter_map(|item| match item {
76            TopLevel::Verify(v) => match &v.kind {
77                crate::ast::VerifyKind::Law(law) => canonical_spec_ref(&v.fn_name, law, fn_sigs)
78                    .map(|spec_ref| spec_ref.spec_fn_name),
79                crate::ast::VerifyKind::Cases => None,
80            },
81            _ => None,
82        })
83        .collect()
84}
85
86fn count_non_tail_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>) -> usize {
87    body.stmts()
88        .iter()
89        .map(|stmt| count_non_tail_recursive_calls_stmt(stmt, recursive))
90        .sum()
91}
92
93fn count_non_tail_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>) -> usize {
94    match stmt {
95        Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
96            count_non_tail_recursive_calls_expr(expr, recursive)
97        }
98    }
99}
100
101fn count_non_tail_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>) -> usize {
102    match expr {
103        Expr::FnCall(func, args) => {
104            let mut count = 0;
105            if let Some(callee) = dotted_name(func.as_ref())
106                && recursive.contains(&callee)
107            {
108                count += 1;
109            }
110            count
111                + count_non_tail_recursive_calls_expr(func, recursive)
112                + args
113                    .iter()
114                    .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
115                    .sum::<usize>()
116        }
117        Expr::TailCall(boxed) => boxed
118            .1
119            .iter()
120            .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
121            .sum(),
122        Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
123            count_non_tail_recursive_calls_expr(obj, recursive)
124        }
125        Expr::BinOp(_, left, right) => {
126            count_non_tail_recursive_calls_expr(left, recursive)
127                + count_non_tail_recursive_calls_expr(right, recursive)
128        }
129        Expr::Match { subject, arms, .. } => {
130            count_non_tail_recursive_calls_expr(subject, recursive)
131                + arms
132                    .iter()
133                    .map(|arm| count_non_tail_recursive_calls_expr(&arm.body, recursive))
134                    .sum::<usize>()
135        }
136        Expr::List(items) | Expr::Tuple(items) => items
137            .iter()
138            .map(|item| count_non_tail_recursive_calls_expr(item, recursive))
139            .sum(),
140        Expr::MapLiteral(entries) => entries
141            .iter()
142            .map(|(key, value)| {
143                count_non_tail_recursive_calls_expr(key, recursive)
144                    + count_non_tail_recursive_calls_expr(value, recursive)
145            })
146            .sum(),
147        Expr::Constructor(_, maybe_arg) => maybe_arg
148            .as_deref()
149            .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
150            .unwrap_or(0),
151        Expr::InterpolatedStr(parts) => parts
152            .iter()
153            .map(|part| match part {
154                StrPart::Literal(_) => 0,
155                StrPart::Parsed(expr) => count_non_tail_recursive_calls_expr(expr, recursive),
156            })
157            .sum(),
158        Expr::RecordCreate { fields, .. } => fields
159            .iter()
160            .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
161            .sum(),
162        Expr::RecordUpdate { base, updates, .. } => {
163            count_non_tail_recursive_calls_expr(base, recursive)
164                + updates
165                    .iter()
166                    .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
167                    .sum::<usize>()
168        }
169        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => 0,
170    }
171}
172
173fn dotted_name(expr: &Expr) -> Option<String> {
174    match expr {
175        Expr::Ident(name) => Some(name.clone()),
176        Expr::Attr(base, field) => {
177            let mut prefix = dotted_name(base)?;
178            prefix.push('.');
179            prefix.push_str(field);
180            Some(prefix)
181        }
182        _ => None,
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use crate::ast::TopLevel;
189    use crate::{parser::Parser, tco};
190    use crate::types::checker::run_type_check_full;
191
192    use super::*;
193
194    fn parse(src: &str) -> Vec<TopLevel> {
195        let mut lexer = crate::lexer::Lexer::new(src);
196        let tokens = lexer.tokenize().expect("lex failed");
197        let mut parser = Parser::new(tokens);
198        parser.parse().expect("parse failed")
199    }
200
201    #[test]
202    fn warns_for_recursive_calls_left_after_tco() {
203        let src = r#"
204fn fib(n: Int) -> Int
205    match n
206        0 -> 0
207        1 -> 1
208        _ -> fib(n - 1) + fib(n - 2)
209"#;
210        let mut items = parse(src);
211        tco::transform_program(&mut items);
212
213        let warnings = collect_non_tail_recursion_warnings(&items);
214        assert_eq!(warnings.len(), 1);
215        assert_eq!(warnings[0].fn_name, "fib");
216        assert_eq!(warnings[0].recursive_calls, 2);
217        assert_eq!(
218            warnings[0].message,
219            "non-tail recursion in 'fib' — 2 recursive callsite(s) remain after tail-call optimization; rewrite it to tail recursion or make it a spec"
220        );
221    }
222
223    #[test]
224    fn skips_pure_tail_recursion_after_tco() {
225        let src = r#"
226fn factorial(n: Int, acc: Int) -> Int
227    match n
228        0 -> acc
229        _ -> factorial(n - 1, acc * n)
230"#;
231        let mut items = parse(src);
232        tco::transform_program(&mut items);
233
234        let warnings = collect_non_tail_recursion_warnings(&items);
235        assert!(warnings.is_empty());
236    }
237
238    #[test]
239    fn skips_mutual_tail_recursion_after_tco() {
240        let src = r#"
241fn isEven(n: Int) -> Bool
242    match n
243        0 -> true
244        _ -> isOdd(n - 1)
245
246fn isOdd(n: Int) -> Bool
247    match n
248        0 -> false
249        _ -> isEven(n - 1)
250"#;
251        let mut items = parse(src);
252        tco::transform_program(&mut items);
253
254        let warnings = collect_non_tail_recursion_warnings(&items);
255        assert!(warnings.is_empty());
256    }
257
258    #[test]
259    fn skips_canonical_spec_functions() {
260        let src = r#"
261fn fib(n: Int) -> Int
262    fibSpec(n)
263
264fn fibSpec(n: Int) -> Int
265    match n
266        0 -> 0
267        1 -> 1
268        _ -> fibSpec(n - 1) + fibSpec(n - 2)
269
270verify fib law fibSpec
271    given n: Int = [0, 1, 2, 3]
272    fib(n) => fibSpec(n)
273"#;
274        let mut items = parse(src);
275        tco::transform_program(&mut items);
276        let tc = run_type_check_full(&items, None);
277
278        let warnings = collect_non_tail_recursion_warnings_with_sigs(&items, &tc.fn_sigs);
279        assert!(warnings.is_empty(), "expected spec function warning to be suppressed, got {warnings:?}");
280    }
281}