use std::collections::HashSet;
use crate::ast::*;
use crate::call_graph;
pub fn transform_program(items: &mut [TopLevel]) {
let groups = call_graph::find_tco_groups(items);
if groups.is_empty() {
return;
}
let mut fn_to_scc: std::collections::HashMap<String, &HashSet<String>> =
std::collections::HashMap::new();
for group in &groups {
for name in group {
fn_to_scc.insert(name.clone(), group);
}
}
for item in items.iter_mut() {
if let TopLevel::FnDef(fd) = item
&& let Some(scc_members) = fn_to_scc.get(&fd.name)
{
transform_fn(fd, scc_members);
}
}
}
fn transform_fn(fd: &mut FnDef, scc_members: &HashSet<String>) {
let mut body = fd.body.as_ref().clone();
if let Some(expr) = body.tail_expr_mut() {
transform_tail_expr(expr, scc_members);
}
fd.body = std::sync::Arc::new(body);
}
fn transform_tail_expr(spanned: &mut Spanned<Expr>, scc_members: &HashSet<String>) {
match &mut spanned.node {
Expr::FnCall(fn_expr, args) => {
if let Expr::Ident(name) = &fn_expr.node
&& scc_members.contains(name)
{
let name = name.clone();
let args = std::mem::take(args);
spanned.node = Expr::TailCall(Box::new((name, args)));
}
}
Expr::Match { arms, .. } => {
for arm in arms {
transform_tail_expr(&mut arm.body, scc_members);
}
}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn parse(src: &str) -> Vec<TopLevel> {
let mut lexer = crate::lexer::Lexer::new(src);
let tokens = lexer.tokenize().expect("lex failed");
let mut parser = crate::parser::Parser::new(tokens);
parser.parse().expect("parse failed")
}
fn extract_match_arms(fd: &FnDef) -> &[MatchArm] {
if let Some(Stmt::Expr(spanned)) = fd.body.stmts().last()
&& let Expr::Match { arms, .. } = &spanned.node
{
arms
} else {
panic!("expected Match in block body, got {:?}", fd.body)
}
}
#[test]
fn transforms_self_tail_call() {
let src = r#"
fn factorial(n: Int, acc: Int) -> Int
match n
0 -> acc
_ -> factorial(n - 1, acc * n)
"#;
let mut items = parse(src);
transform_program(&mut items);
let fd = match &items[0] {
TopLevel::FnDef(fd) => fd,
_ => panic!("expected FnDef"),
};
let arms = extract_match_arms(fd);
assert!(!matches!(arms[0].body.node, Expr::TailCall(..)));
match &arms[1].body.node {
Expr::TailCall(boxed) => {
let (name, args) = boxed.as_ref();
assert_eq!(name, "factorial");
assert_eq!(args.len(), 2);
}
other => panic!("expected TailCall, got {:?}", other),
}
}
#[test]
fn does_not_transform_non_tail_call() {
let src = r#"
fn fib(n: Int) -> Int
match n
0 -> 0
1 -> 1
_ -> fib(n - 1) + fib(n - 2)
"#;
let mut items = parse(src);
transform_program(&mut items);
let fd = match &items[0] {
TopLevel::FnDef(fd) => fd,
_ => panic!("expected FnDef"),
};
let arms = extract_match_arms(fd);
assert!(
!matches!(arms[2].body.node, Expr::TailCall(..)),
"fib should NOT be tail-call transformed"
);
}
#[test]
fn transforms_mutual_recursion() {
let src = r#"
fn isEven(n: Int) -> Bool
match n
0 -> true
_ -> isOdd(n - 1)
fn isOdd(n: Int) -> Bool
match n
0 -> false
_ -> isEven(n - 1)
"#;
let mut items = parse(src);
transform_program(&mut items);
let fd_even = match &items[0] {
TopLevel::FnDef(fd) => fd,
_ => panic!("expected FnDef"),
};
let arms_even = extract_match_arms(fd_even);
match &arms_even[1].body.node {
Expr::TailCall(boxed) => assert_eq!(boxed.0, "isOdd"),
other => panic!("expected TailCall to isOdd, got {:?}", other),
}
let fd_odd = match &items[1] {
TopLevel::FnDef(fd) => fd,
_ => panic!("expected FnDef"),
};
let arms_odd = extract_match_arms(fd_odd);
match &arms_odd[1].body.node {
Expr::TailCall(boxed) => assert_eq!(boxed.0, "isEven"),
other => panic!("expected TailCall to isEven, got {:?}", other),
}
}
}