1use std::collections::HashSet;
12
13use crate::ast::*;
14use crate::call_graph;
15
16pub fn transform_program(items: &mut [TopLevel]) {
18 let groups = call_graph::find_tco_groups(items);
19 if groups.is_empty() {
20 return;
21 }
22
23 let mut fn_to_scc: std::collections::HashMap<String, &HashSet<String>> =
25 std::collections::HashMap::new();
26 for group in &groups {
27 for name in group {
28 fn_to_scc.insert(name.clone(), group);
29 }
30 }
31
32 for item in items.iter_mut() {
33 if let TopLevel::FnDef(fd) = item
34 && let Some(scc_members) = fn_to_scc.get(&fd.name)
35 {
36 transform_fn(fd, scc_members);
37 }
38 }
39}
40
41fn transform_fn(fd: &mut FnDef, scc_members: &HashSet<String>) {
42 let mut body = fd.body.as_ref().clone();
43 if let Some(expr) = body.tail_expr_mut() {
45 transform_tail_expr(expr, scc_members);
46 }
47 fd.body = std::sync::Arc::new(body);
48}
49
50fn transform_tail_expr(spanned: &mut Spanned<Expr>, scc_members: &HashSet<String>) {
52 match &mut spanned.node {
53 Expr::FnCall(fn_expr, args) => {
55 if let Expr::Ident(name) = &fn_expr.node
56 && scc_members.contains(name)
57 {
58 let name = name.clone();
59 let args = std::mem::take(args);
60 spanned.node = Expr::TailCall(Box::new(TailCallData::new(name, args)));
61 }
62 }
63 Expr::Match { arms, .. } => {
65 for arm in arms {
66 transform_tail_expr(&mut arm.body, scc_members);
67 }
68 }
69 _ => {}
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77
78 fn parse(src: &str) -> Vec<TopLevel> {
79 let mut lexer = crate::lexer::Lexer::new(src);
80 let tokens = lexer.tokenize().expect("lex failed");
81 let mut parser = crate::parser::Parser::new(tokens);
82 parser.parse().expect("parse failed")
83 }
84
85 fn extract_match_arms(fd: &FnDef) -> &[MatchArm] {
88 if let Some(Stmt::Expr(spanned)) = fd.body.stmts().last()
89 && let Expr::Match { arms, .. } = &spanned.node
90 {
91 arms
92 } else {
93 panic!("expected Match in block body, got {:?}", fd.body)
94 }
95 }
96
97 #[test]
98 fn transforms_self_tail_call() {
99 let src = r#"
100fn factorial(n: Int, acc: Int) -> Int
101 match n
102 0 -> acc
103 _ -> factorial(n - 1, acc * n)
104"#;
105 let mut items = parse(src);
106 transform_program(&mut items);
107
108 let fd = match &items[0] {
109 TopLevel::FnDef(fd) => fd,
110 _ => panic!("expected FnDef"),
111 };
112
113 let arms = extract_match_arms(fd);
114 assert!(!matches!(arms[0].body.node, Expr::TailCall(..)));
116 match &arms[1].body.node {
118 Expr::TailCall(boxed) => {
119 let TailCallData {
120 target: name, args, ..
121 } = boxed.as_ref();
122 assert_eq!(name, "factorial");
123 assert_eq!(args.len(), 2);
124 }
125 other => panic!("expected TailCall, got {:?}", other),
126 }
127 }
128
129 #[test]
130 fn does_not_transform_non_tail_call() {
131 let src = r#"
132fn fib(n: Int) -> Int
133 match n
134 0 -> 0
135 1 -> 1
136 _ -> fib(n - 1) + fib(n - 2)
137"#;
138 let mut items = parse(src);
139 transform_program(&mut items);
140
141 let fd = match &items[0] {
142 TopLevel::FnDef(fd) => fd,
143 _ => panic!("expected FnDef"),
144 };
145
146 let arms = extract_match_arms(fd);
147 assert!(
149 !matches!(arms[2].body.node, Expr::TailCall(..)),
150 "fib should NOT be tail-call transformed"
151 );
152 }
153
154 #[test]
155 fn transforms_mutual_recursion() {
156 let src = r#"
157fn isEven(n: Int) -> Bool
158 match n
159 0 -> true
160 _ -> isOdd(n - 1)
161
162fn isOdd(n: Int) -> Bool
163 match n
164 0 -> false
165 _ -> isEven(n - 1)
166"#;
167 let mut items = parse(src);
168 transform_program(&mut items);
169
170 let fd_even = match &items[0] {
172 TopLevel::FnDef(fd) => fd,
173 _ => panic!("expected FnDef"),
174 };
175 let arms_even = extract_match_arms(fd_even);
176 match &arms_even[1].body.node {
177 Expr::TailCall(boxed) => assert_eq!(boxed.target, "isOdd"),
178 other => panic!("expected TailCall to isOdd, got {:?}", other),
179 }
180
181 let fd_odd = match &items[1] {
183 TopLevel::FnDef(fd) => fd,
184 _ => panic!("expected FnDef"),
185 };
186 let arms_odd = extract_match_arms(fd_odd);
187 match &arms_odd[1].body.node {
188 Expr::TailCall(boxed) => assert_eq!(boxed.target, "isEven"),
189 other => panic!("expected TailCall to isEven, got {:?}", other),
190 }
191 }
192}