Skip to main content

aver/
tail_check.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::ast::{Expr, FnBody, Stmt, StrPart, TopLevel};
4use crate::call_graph;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct NonTailRecursionWarning {
8    pub fn_name: String,
9    pub line: usize,
10    pub recursive_calls: usize,
11    pub message: String,
12}
13
14pub fn collect_non_tail_recursion_warnings(items: &[TopLevel]) -> Vec<NonTailRecursionWarning> {
15    let mut fn_to_scc: HashMap<String, HashSet<String>> = HashMap::new();
16    for scc in call_graph::find_tco_groups(items) {
17        for name in &scc {
18            fn_to_scc.insert(name.clone(), scc.clone());
19        }
20    }
21
22    let mut warnings = Vec::new();
23    for item in items {
24        let TopLevel::FnDef(fd) = item else {
25            continue;
26        };
27        let Some(scc_members) = fn_to_scc.get(&fd.name) else {
28            continue;
29        };
30        let recursive_calls = count_non_tail_recursive_calls_body(&fd.body, scc_members);
31        if recursive_calls == 0 {
32            continue;
33        }
34        warnings.push(NonTailRecursionWarning {
35            fn_name: fd.name.clone(),
36            line: fd.line,
37            recursive_calls,
38            message: format!(
39                "non-tail recursion in '{}' — {} recursive callsite(s) remain after tail-call optimization; consider accumulator pattern",
40                fd.name, recursive_calls
41            ),
42        });
43    }
44    warnings
45}
46
47fn count_non_tail_recursive_calls_body(body: &FnBody, recursive: &HashSet<String>) -> usize {
48    body.stmts()
49        .iter()
50        .map(|stmt| count_non_tail_recursive_calls_stmt(stmt, recursive))
51        .sum()
52}
53
54fn count_non_tail_recursive_calls_stmt(stmt: &Stmt, recursive: &HashSet<String>) -> usize {
55    match stmt {
56        Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
57            count_non_tail_recursive_calls_expr(expr, recursive)
58        }
59    }
60}
61
62fn count_non_tail_recursive_calls_expr(expr: &Expr, recursive: &HashSet<String>) -> usize {
63    match expr {
64        Expr::FnCall(func, args) => {
65            let mut count = 0;
66            if let Some(callee) = dotted_name(func.as_ref())
67                && recursive.contains(&callee)
68            {
69                count += 1;
70            }
71            count
72                + count_non_tail_recursive_calls_expr(func, recursive)
73                + args
74                    .iter()
75                    .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
76                    .sum::<usize>()
77        }
78        Expr::TailCall(boxed) => boxed
79            .1
80            .iter()
81            .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
82            .sum(),
83        Expr::Attr(obj, _) | Expr::ErrorProp(obj) => {
84            count_non_tail_recursive_calls_expr(obj, recursive)
85        }
86        Expr::BinOp(_, left, right) => {
87            count_non_tail_recursive_calls_expr(left, recursive)
88                + count_non_tail_recursive_calls_expr(right, recursive)
89        }
90        Expr::Match { subject, arms, .. } => {
91            count_non_tail_recursive_calls_expr(subject, recursive)
92                + arms
93                    .iter()
94                    .map(|arm| count_non_tail_recursive_calls_expr(&arm.body, recursive))
95                    .sum::<usize>()
96        }
97        Expr::List(items) | Expr::Tuple(items) => items
98            .iter()
99            .map(|item| count_non_tail_recursive_calls_expr(item, recursive))
100            .sum(),
101        Expr::MapLiteral(entries) => entries
102            .iter()
103            .map(|(key, value)| {
104                count_non_tail_recursive_calls_expr(key, recursive)
105                    + count_non_tail_recursive_calls_expr(value, recursive)
106            })
107            .sum(),
108        Expr::Constructor(_, maybe_arg) => maybe_arg
109            .as_deref()
110            .map(|arg| count_non_tail_recursive_calls_expr(arg, recursive))
111            .unwrap_or(0),
112        Expr::InterpolatedStr(parts) => parts
113            .iter()
114            .map(|part| match part {
115                StrPart::Literal(_) => 0,
116                StrPart::Parsed(expr) => count_non_tail_recursive_calls_expr(expr, recursive),
117            })
118            .sum(),
119        Expr::RecordCreate { fields, .. } => fields
120            .iter()
121            .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
122            .sum(),
123        Expr::RecordUpdate { base, updates, .. } => {
124            count_non_tail_recursive_calls_expr(base, recursive)
125                + updates
126                    .iter()
127                    .map(|(_, expr)| count_non_tail_recursive_calls_expr(expr, recursive))
128                    .sum::<usize>()
129        }
130        Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved(_) => 0,
131    }
132}
133
134fn dotted_name(expr: &Expr) -> Option<String> {
135    match expr {
136        Expr::Ident(name) => Some(name.clone()),
137        Expr::Attr(base, field) => {
138            let mut prefix = dotted_name(base)?;
139            prefix.push('.');
140            prefix.push_str(field);
141            Some(prefix)
142        }
143        _ => None,
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use crate::ast::TopLevel;
150    use crate::{parser::Parser, tco};
151
152    use super::*;
153
154    fn parse(src: &str) -> Vec<TopLevel> {
155        let mut lexer = crate::lexer::Lexer::new(src);
156        let tokens = lexer.tokenize().expect("lex failed");
157        let mut parser = Parser::new(tokens);
158        parser.parse().expect("parse failed")
159    }
160
161    #[test]
162    fn warns_for_recursive_calls_left_after_tco() {
163        let src = r#"
164fn fib(n: Int) -> Int
165    match n
166        0 -> 0
167        1 -> 1
168        _ -> fib(n - 1) + fib(n - 2)
169"#;
170        let mut items = parse(src);
171        tco::transform_program(&mut items);
172
173        let warnings = collect_non_tail_recursion_warnings(&items);
174        assert_eq!(warnings.len(), 1);
175        assert_eq!(warnings[0].fn_name, "fib");
176        assert_eq!(warnings[0].recursive_calls, 2);
177    }
178
179    #[test]
180    fn skips_pure_tail_recursion_after_tco() {
181        let src = r#"
182fn factorial(n: Int, acc: Int) -> Int
183    match n
184        0 -> acc
185        _ -> factorial(n - 1, acc * n)
186"#;
187        let mut items = parse(src);
188        tco::transform_program(&mut items);
189
190        let warnings = collect_non_tail_recursion_warnings(&items);
191        assert!(warnings.is_empty());
192    }
193
194    #[test]
195    fn skips_mutual_tail_recursion_after_tco() {
196        let src = r#"
197fn isEven(n: Int) -> Bool
198    match n
199        0 -> true
200        _ -> isOdd(n - 1)
201
202fn isOdd(n: Int) -> Bool
203    match n
204        0 -> false
205        _ -> isEven(n - 1)
206"#;
207        let mut items = parse(src);
208        tco::transform_program(&mut items);
209
210        let warnings = collect_non_tail_recursion_warnings(&items);
211        assert!(warnings.is_empty());
212    }
213}