Skip to main content

aver/
tco.rs

1/// Tail-call optimization transform pass.
2///
3/// Runs after parsing, before type-checking. Uses the call graph SCC analysis
4/// to find groups of mutually-recursive functions, then rewrites tail-position
5/// calls within each SCC from `FnCall` to `TailCall`.
6///
7/// A call is in tail position if its result is the direct return value of the
8/// function — no further computation wraps it. Specifically:
9///   - The last `Stmt::Expr` in `FnBody::Block`
10///   - Each arm body of a `match` in tail position
11use std::collections::HashSet;
12
13use crate::ast::*;
14use crate::call_graph;
15
16/// Transform all eligible tail calls in the program.
17pub fn transform_program(items: &mut [TopLevel]) {
18    let groups = call_graph::find_tco_groups(items);
19    if groups.is_empty() {
20        return;
21    }
22
23    // Build a map: fn_name → set of SCC peers (including self)
24    let mut fn_to_scc: std::collections::HashMap<String, &HashSet<String>> =
25        std::collections::HashMap::new();
26    for group in &groups {
27        for name in group {
28            fn_to_scc.insert(name.clone(), group);
29        }
30    }
31
32    for item in items.iter_mut() {
33        if let TopLevel::FnDef(fd) = item
34            && let Some(scc_members) = fn_to_scc.get(&fd.name)
35        {
36            transform_fn(fd, scc_members);
37        }
38    }
39}
40
41fn transform_fn(fd: &mut FnDef, scc_members: &HashSet<String>) {
42    let mut body = fd.body.as_ref().clone();
43    // Only the last Stmt::Expr is in tail position
44    if let Some(expr) = body.tail_expr_mut() {
45        transform_tail_expr(expr, scc_members);
46    }
47    fd.body = std::sync::Arc::new(body);
48}
49
50/// Recursively transform an expression in tail position.
51fn transform_tail_expr(spanned: &mut Spanned<Expr>, scc_members: &HashSet<String>) {
52    match &mut spanned.node {
53        // Direct call: `f(args)` where f is Ident in SCC
54        Expr::FnCall(fn_expr, args) => {
55            if let Expr::Ident(name) = &fn_expr.node
56                && scc_members.contains(name)
57            {
58                let name = name.clone();
59                let args = std::mem::take(args);
60                spanned.node = Expr::TailCall(Box::new(TailCallData::new(name, args)));
61            }
62        }
63        // Match: each arm body is in tail position
64        Expr::Match { arms, .. } => {
65            for arm in arms {
66                transform_tail_expr(&mut arm.body, scc_members);
67            }
68        }
69        // Everything else is not a tail call
70        _ => {}
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    fn parse(src: &str) -> Vec<TopLevel> {
79        let mut lexer = crate::lexer::Lexer::new(src);
80        let tokens = lexer.tokenize().expect("lex failed");
81        let mut parser = crate::parser::Parser::new(tokens);
82        parser.parse().expect("parse failed")
83    }
84
85    /// Helper: extract the match arms from a fn body.
86    /// The parser produces `Block([Expr(Match{subject, arms, ..})])` for indented match bodies.
87    fn extract_match_arms(fd: &FnDef) -> &[MatchArm] {
88        if let Some(Stmt::Expr(spanned)) = fd.body.stmts().last()
89            && let Expr::Match { arms, .. } = &spanned.node
90        {
91            arms
92        } else {
93            panic!("expected Match in block body, got {:?}", fd.body)
94        }
95    }
96
97    #[test]
98    fn transforms_self_tail_call() {
99        let src = r#"
100fn factorial(n: Int, acc: Int) -> Int
101    match n
102        0 -> acc
103        _ -> factorial(n - 1, acc * n)
104"#;
105        let mut items = parse(src);
106        transform_program(&mut items);
107
108        let fd = match &items[0] {
109            TopLevel::FnDef(fd) => fd,
110            _ => panic!("expected FnDef"),
111        };
112
113        let arms = extract_match_arms(fd);
114        // arm 0: literal 0 -> acc (unchanged)
115        assert!(!matches!(arms[0].body.node, Expr::TailCall(..)));
116        // arm 1: _ -> TailCall("factorial", ...)
117        match &arms[1].body.node {
118            Expr::TailCall(boxed) => {
119                let TailCallData {
120                    target: name, args, ..
121                } = boxed.as_ref();
122                assert_eq!(name, "factorial");
123                assert_eq!(args.len(), 2);
124            }
125            other => panic!("expected TailCall, got {:?}", other),
126        }
127    }
128
129    #[test]
130    fn does_not_transform_non_tail_call() {
131        let src = r#"
132fn fib(n: Int) -> Int
133    match n
134        0 -> 0
135        1 -> 1
136        _ -> fib(n - 1) + fib(n - 2)
137"#;
138        let mut items = parse(src);
139        transform_program(&mut items);
140
141        let fd = match &items[0] {
142            TopLevel::FnDef(fd) => fd,
143            _ => panic!("expected FnDef"),
144        };
145
146        let arms = extract_match_arms(fd);
147        // arm 2: _ -> fib(n-1) + fib(n-2) — BinOp, NOT TailCall
148        assert!(
149            !matches!(arms[2].body.node, Expr::TailCall(..)),
150            "fib should NOT be tail-call transformed"
151        );
152    }
153
154    #[test]
155    fn transforms_mutual_recursion() {
156        let src = r#"
157fn isEven(n: Int) -> Bool
158    match n
159        0 -> true
160        _ -> isOdd(n - 1)
161
162fn isOdd(n: Int) -> Bool
163    match n
164        0 -> false
165        _ -> isEven(n - 1)
166"#;
167        let mut items = parse(src);
168        transform_program(&mut items);
169
170        // Check isEven
171        let fd_even = match &items[0] {
172            TopLevel::FnDef(fd) => fd,
173            _ => panic!("expected FnDef"),
174        };
175        let arms_even = extract_match_arms(fd_even);
176        match &arms_even[1].body.node {
177            Expr::TailCall(boxed) => assert_eq!(boxed.target, "isOdd"),
178            other => panic!("expected TailCall to isOdd, got {:?}", other),
179        }
180
181        // Check isOdd
182        let fd_odd = match &items[1] {
183            TopLevel::FnDef(fd) => fd,
184            _ => panic!("expected FnDef"),
185        };
186        let arms_odd = extract_match_arms(fd_odd);
187        match &arms_odd[1].body.node {
188            Expr::TailCall(boxed) => assert_eq!(boxed.target, "isEven"),
189            other => panic!("expected TailCall to isEven, got {:?}", other),
190        }
191    }
192}