1use std::collections::HashSet;
13
14use crate::ast::*;
15use crate::call_graph;
16
17pub fn transform_program(items: &mut [TopLevel]) {
19 let groups = call_graph::find_tco_groups(items);
20 if groups.is_empty() {
21 return;
22 }
23
24 let mut fn_to_scc: std::collections::HashMap<String, &HashSet<String>> =
26 std::collections::HashMap::new();
27 for group in &groups {
28 for name in group {
29 fn_to_scc.insert(name.clone(), group);
30 }
31 }
32
33 for item in items.iter_mut() {
34 if let TopLevel::FnDef(fd) = item {
35 if let Some(scc_members) = fn_to_scc.get(&fd.name) {
36 transform_fn(fd, scc_members);
37 }
38 }
39 }
40}
41
42fn transform_fn(fd: &mut FnDef, scc_members: &HashSet<String>) {
43 let mut body = fd.body.as_ref().clone();
44 match &mut body {
45 FnBody::Expr(expr) => {
46 transform_tail_expr(expr, scc_members);
47 }
48 FnBody::Block(stmts) => {
49 if let Some(last) = stmts.last_mut() {
51 if let Stmt::Expr(expr) = last {
52 transform_tail_expr(expr, scc_members);
53 }
54 }
55 }
56 }
57 fd.body = std::rc::Rc::new(body);
58}
59
60fn transform_tail_expr(expr: &mut Expr, scc_members: &HashSet<String>) {
62 match expr {
63 Expr::FnCall(fn_expr, args) => {
65 if let Expr::Ident(name) = fn_expr.as_ref() {
66 if scc_members.contains(name) {
67 let name = name.clone();
68 let args = std::mem::take(args);
69 *expr = Expr::TailCall(Box::new((name, args)));
70 }
71 }
72 }
73 Expr::Match { arms, .. } => {
75 for arm in arms {
76 transform_tail_expr(&mut arm.body, scc_members);
77 }
78 }
79 _ => {}
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 fn parse(src: &str) -> Vec<TopLevel> {
89 let mut lexer = crate::lexer::Lexer::new(src);
90 let tokens = lexer.tokenize().expect("lex failed");
91 let mut parser = crate::parser::Parser::new(tokens);
92 parser.parse().expect("parse failed")
93 }
94
95 fn extract_match_arms(fd: &FnDef) -> &[MatchArm] {
98 match fd.body.as_ref() {
99 FnBody::Expr(Expr::Match { arms, .. }) => arms,
100 FnBody::Block(stmts) => {
101 if let Some(Stmt::Expr(Expr::Match { arms, .. })) = stmts.last() {
102 arms
103 } else {
104 panic!("expected Match in block body, got {:?}", fd.body)
105 }
106 }
107 other => panic!("expected Match body, got {:?}", other),
108 }
109 }
110
111 #[test]
112 fn transforms_self_tail_call() {
113 let src = r#"
114fn factorial(n: Int, acc: Int) -> Int
115 match n
116 0 -> acc
117 _ -> factorial(n - 1, acc * n)
118"#;
119 let mut items = parse(src);
120 transform_program(&mut items);
121
122 let fd = match &items[0] {
123 TopLevel::FnDef(fd) => fd,
124 _ => panic!("expected FnDef"),
125 };
126
127 let arms = extract_match_arms(fd);
128 assert!(!matches!(*arms[0].body, Expr::TailCall(..)));
130 match &*arms[1].body {
132 Expr::TailCall(boxed) => {
133 let (name, args) = boxed.as_ref();
134 assert_eq!(name, "factorial");
135 assert_eq!(args.len(), 2);
136 }
137 other => panic!("expected TailCall, got {:?}", other),
138 }
139 }
140
141 #[test]
142 fn does_not_transform_non_tail_call() {
143 let src = r#"
144fn fib(n: Int) -> Int
145 match n
146 0 -> 0
147 1 -> 1
148 _ -> fib(n - 1) + fib(n - 2)
149"#;
150 let mut items = parse(src);
151 transform_program(&mut items);
152
153 let fd = match &items[0] {
154 TopLevel::FnDef(fd) => fd,
155 _ => panic!("expected FnDef"),
156 };
157
158 let arms = extract_match_arms(fd);
159 assert!(
161 !matches!(*arms[2].body, Expr::TailCall(..)),
162 "fib should NOT be tail-call transformed"
163 );
164 }
165
166 #[test]
167 fn transforms_mutual_recursion() {
168 let src = r#"
169fn isEven(n: Int) -> Bool
170 match n
171 0 -> true
172 _ -> isOdd(n - 1)
173
174fn isOdd(n: Int) -> Bool
175 match n
176 0 -> false
177 _ -> isEven(n - 1)
178"#;
179 let mut items = parse(src);
180 transform_program(&mut items);
181
182 let fd_even = match &items[0] {
184 TopLevel::FnDef(fd) => fd,
185 _ => panic!("expected FnDef"),
186 };
187 let arms_even = extract_match_arms(fd_even);
188 match &*arms_even[1].body {
189 Expr::TailCall(boxed) => assert_eq!(boxed.0, "isOdd"),
190 other => panic!("expected TailCall to isOdd, got {:?}", other),
191 }
192
193 let fd_odd = match &items[1] {
195 TopLevel::FnDef(fd) => fd,
196 _ => panic!("expected FnDef"),
197 };
198 let arms_odd = extract_match_arms(fd_odd);
199 match &*arms_odd[1].body {
200 Expr::TailCall(boxed) => assert_eq!(boxed.0, "isEven"),
201 other => panic!("expected TailCall to isEven, got {:?}", other),
202 }
203 }
204}