Skip to main content

aver/
tco.rs

1/// Tail-call optimization transform pass.
2///
3/// Runs after parsing, before type-checking. Uses the call graph SCC analysis
4/// to find groups of mutually-recursive functions, then rewrites tail-position
5/// calls within each SCC from `FnCall` to `TailCall`.
6///
7/// A call is in tail position if its result is the direct return value of the
8/// function — no further computation wraps it. Specifically:
9///   - The expression body of `FnBody::Expr`
10///   - The last `Stmt::Expr` in `FnBody::Block`
11///   - Each arm body of a `match` in tail position
12use std::collections::HashSet;
13
14use crate::ast::*;
15use crate::call_graph;
16
17/// Transform all eligible tail calls in the program.
18pub fn transform_program(items: &mut [TopLevel]) {
19    let groups = call_graph::find_tco_groups(items);
20    if groups.is_empty() {
21        return;
22    }
23
24    // Build a map: fn_name → set of SCC peers (including self)
25    let mut fn_to_scc: std::collections::HashMap<String, &HashSet<String>> =
26        std::collections::HashMap::new();
27    for group in &groups {
28        for name in group {
29            fn_to_scc.insert(name.clone(), group);
30        }
31    }
32
33    for item in items.iter_mut() {
34        if let TopLevel::FnDef(fd) = item {
35            if let Some(scc_members) = fn_to_scc.get(&fd.name) {
36                transform_fn(fd, scc_members);
37            }
38        }
39    }
40}
41
42fn transform_fn(fd: &mut FnDef, scc_members: &HashSet<String>) {
43    let mut body = fd.body.as_ref().clone();
44    match &mut body {
45        FnBody::Expr(expr) => {
46            transform_tail_expr(expr, scc_members);
47        }
48        FnBody::Block(stmts) => {
49            // Only the last Stmt::Expr is in tail position
50            if let Some(last) = stmts.last_mut() {
51                if let Stmt::Expr(expr) = last {
52                    transform_tail_expr(expr, scc_members);
53                }
54            }
55        }
56    }
57    fd.body = std::rc::Rc::new(body);
58}
59
60/// Recursively transform an expression in tail position.
61fn transform_tail_expr(expr: &mut Expr, scc_members: &HashSet<String>) {
62    match expr {
63        // Direct call: `f(args)` where f is Ident in SCC
64        Expr::FnCall(fn_expr, args) => {
65            if let Expr::Ident(name) = fn_expr.as_ref() {
66                if scc_members.contains(name) {
67                    let name = name.clone();
68                    let args = std::mem::take(args);
69                    *expr = Expr::TailCall(Box::new((name, args)));
70                }
71            }
72        }
73        // Match: each arm body is in tail position
74        Expr::Match { arms, .. } => {
75            for arm in arms {
76                transform_tail_expr(&mut arm.body, scc_members);
77            }
78        }
79        // Everything else is not a tail call
80        _ => {}
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    fn parse(src: &str) -> Vec<TopLevel> {
89        let mut lexer = crate::lexer::Lexer::new(src);
90        let tokens = lexer.tokenize().expect("lex failed");
91        let mut parser = crate::parser::Parser::new(tokens);
92        parser.parse().expect("parse failed")
93    }
94
95    /// Helper: extract the match arms from a fn body.
96    /// The parser produces `Block([Expr(Match{subject, arms, ..})])` for indented match bodies.
97    fn extract_match_arms(fd: &FnDef) -> &[MatchArm] {
98        match fd.body.as_ref() {
99            FnBody::Expr(Expr::Match { arms, .. }) => arms,
100            FnBody::Block(stmts) => {
101                if let Some(Stmt::Expr(Expr::Match { arms, .. })) = stmts.last() {
102                    arms
103                } else {
104                    panic!("expected Match in block body, got {:?}", fd.body)
105                }
106            }
107            other => panic!("expected Match body, got {:?}", other),
108        }
109    }
110
111    #[test]
112    fn transforms_self_tail_call() {
113        let src = r#"
114fn factorial(n: Int, acc: Int) -> Int
115    match n
116        0 -> acc
117        _ -> factorial(n - 1, acc * n)
118"#;
119        let mut items = parse(src);
120        transform_program(&mut items);
121
122        let fd = match &items[0] {
123            TopLevel::FnDef(fd) => fd,
124            _ => panic!("expected FnDef"),
125        };
126
127        let arms = extract_match_arms(fd);
128        // arm 0: literal 0 -> acc (unchanged)
129        assert!(!matches!(*arms[0].body, Expr::TailCall(..)));
130        // arm 1: _ -> TailCall("factorial", ...)
131        match &*arms[1].body {
132            Expr::TailCall(boxed) => {
133                let (name, args) = boxed.as_ref();
134                assert_eq!(name, "factorial");
135                assert_eq!(args.len(), 2);
136            }
137            other => panic!("expected TailCall, got {:?}", other),
138        }
139    }
140
141    #[test]
142    fn does_not_transform_non_tail_call() {
143        let src = r#"
144fn fib(n: Int) -> Int
145    match n
146        0 -> 0
147        1 -> 1
148        _ -> fib(n - 1) + fib(n - 2)
149"#;
150        let mut items = parse(src);
151        transform_program(&mut items);
152
153        let fd = match &items[0] {
154            TopLevel::FnDef(fd) => fd,
155            _ => panic!("expected FnDef"),
156        };
157
158        let arms = extract_match_arms(fd);
159        // arm 2: _ -> fib(n-1) + fib(n-2) — BinOp, NOT TailCall
160        assert!(
161            !matches!(*arms[2].body, Expr::TailCall(..)),
162            "fib should NOT be tail-call transformed"
163        );
164    }
165
166    #[test]
167    fn transforms_mutual_recursion() {
168        let src = r#"
169fn isEven(n: Int) -> Bool
170    match n
171        0 -> true
172        _ -> isOdd(n - 1)
173
174fn isOdd(n: Int) -> Bool
175    match n
176        0 -> false
177        _ -> isEven(n - 1)
178"#;
179        let mut items = parse(src);
180        transform_program(&mut items);
181
182        // Check isEven
183        let fd_even = match &items[0] {
184            TopLevel::FnDef(fd) => fd,
185            _ => panic!("expected FnDef"),
186        };
187        let arms_even = extract_match_arms(fd_even);
188        match &*arms_even[1].body {
189            Expr::TailCall(boxed) => assert_eq!(boxed.0, "isOdd"),
190            other => panic!("expected TailCall to isOdd, got {:?}", other),
191        }
192
193        // Check isOdd
194        let fd_odd = match &items[1] {
195            TopLevel::FnDef(fd) => fd,
196            _ => panic!("expected FnDef"),
197        };
198        let arms_odd = extract_match_arms(fd_odd);
199        match &*arms_odd[1].body {
200            Expr::TailCall(boxed) => assert_eq!(boxed.0, "isEven"),
201            other => panic!("expected TailCall to isEven, got {:?}", other),
202        }
203    }
204}