use crate::ast::{BinOp, Expr, Program, Stmt, UnaryOp};
pub struct Optimizer {
pub tail_recursion: bool,
pub constant_folding: bool,
pub dead_code_elimination: bool,
}
impl Optimizer {
pub fn new() -> Self {
Optimizer {
tail_recursion: true,
constant_folding: true,
dead_code_elimination: true,
}
}
pub fn optimize_program(&self, program: &Program) -> Program {
let mut optimized = program.clone();
if self.constant_folding {
optimized = self.fold_constants(optimized);
}
if self.dead_code_elimination {
optimized = self.eliminate_dead_code(optimized);
}
if self.tail_recursion {
optimized = self.optimize_tail_recursion(optimized);
}
optimized
}
fn fold_constants(&self, program: Program) -> Program {
program
.into_iter()
.map(|stmt| self.fold_stmt(stmt))
.collect()
}
fn fold_stmt(&self, stmt: Stmt) -> Stmt {
match stmt {
Stmt::Set { name, value } => Stmt::Set {
name,
value: self.fold_expr(value),
},
Stmt::FuncDef { name, params, body } => Stmt::FuncDef {
name,
params,
body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
},
Stmt::GeneratorDef { name, params, body } => Stmt::GeneratorDef {
name,
params,
body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
},
Stmt::Return(expr) => Stmt::Return(self.fold_expr(expr)),
Stmt::Yield(expr) => Stmt::Yield(self.fold_expr(expr)),
Stmt::While { condition, body } => Stmt::While {
condition: self.fold_expr(condition),
body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
},
Stmt::For {
var,
iterable,
body,
} => Stmt::For {
var,
iterable: self.fold_expr(iterable),
body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
},
Stmt::ForIndexed {
index_var,
value_var,
iterable,
body,
} => Stmt::ForIndexed {
index_var,
value_var,
iterable: self.fold_expr(iterable),
body: body.into_iter().map(|s| self.fold_stmt(s)).collect(),
},
Stmt::Expression(expr) => Stmt::Expression(self.fold_expr(expr)),
other => other,
}
}
#[allow(clippy::only_used_in_recursion)]
fn fold_expr(&self, expr: Expr) -> Expr {
match expr {
Expr::Binary { left, op, right } => {
let left = self.fold_expr(*left);
let right = self.fold_expr(*right);
if let (Expr::Number(l), Expr::Number(r)) = (&left, &right)
&& let Some(result) = Self::eval_const_binary(*l, &op, *r)
{
return Expr::Number(result);
}
Expr::Binary {
left: Box::new(left),
op,
right: Box::new(right),
}
}
Expr::Unary { op, expr } => {
let expr = self.fold_expr(*expr);
if let Expr::Number(n) = expr {
match op {
UnaryOp::Minus => return Expr::Number(-n),
UnaryOp::Not => return Expr::Boolean(n == 0.0),
}
}
if let (UnaryOp::Not, Expr::Boolean(b)) = (&op, &expr) {
return Expr::Boolean(!b);
}
Expr::Unary {
op,
expr: Box::new(expr),
}
}
Expr::Call { func, args } => Expr::Call {
func: Box::new(self.fold_expr(*func)),
args: args.into_iter().map(|e| self.fold_expr(e)).collect(),
},
Expr::Array(elements) => {
Expr::Array(elements.into_iter().map(|e| self.fold_expr(e)).collect())
}
Expr::Index { object, index } => Expr::Index {
object: Box::new(self.fold_expr(*object)),
index: Box::new(self.fold_expr(*index)),
},
other => other,
}
}
fn eval_const_binary(left: f64, op: &BinOp, right: f64) -> Option<f64> {
match op {
BinOp::Add => Some(left + right),
BinOp::Subtract => Some(left - right),
BinOp::Multiply => Some(left * right),
BinOp::Divide if right != 0.0 => Some(left / right),
BinOp::Modulo if right != 0.0 => Some(left % right),
_ => None,
}
}
fn eliminate_dead_code(&self, program: Program) -> Program {
program
.into_iter()
.filter_map(|stmt| self.eliminate_dead_stmt(stmt))
.collect()
}
fn eliminate_dead_stmt(&self, stmt: Stmt) -> Option<Stmt> {
match stmt {
Stmt::While { condition, body } => {
if let Expr::Boolean(false) = condition {
return None;
}
Some(Stmt::While {
condition,
body: body
.into_iter()
.filter_map(|s| self.eliminate_dead_stmt(s))
.collect(),
})
}
Stmt::FuncDef { name, params, body } => Some(Stmt::FuncDef {
name,
params,
body: body
.into_iter()
.filter_map(|s| self.eliminate_dead_stmt(s))
.collect(),
}),
Stmt::GeneratorDef { name, params, body } => Some(Stmt::GeneratorDef {
name,
params,
body: body
.into_iter()
.filter_map(|s| self.eliminate_dead_stmt(s))
.collect(),
}),
Stmt::Expression(expr) => Some(Stmt::Expression(self.eliminate_dead_expr(expr))),
other => Some(other),
}
}
fn eliminate_dead_expr(&self, expr: Expr) -> Expr {
match expr {
Expr::If {
condition,
then_branch,
elif_branches,
else_branch,
} => {
if let Expr::Boolean(true) = *condition {
return Expr::If {
condition: Box::new(Expr::Boolean(true)),
then_branch,
elif_branches: vec![],
else_branch: None,
};
}
if let Expr::Boolean(false) = *condition {
if let Some(else_body) = else_branch {
return Expr::If {
condition: Box::new(Expr::Boolean(true)),
then_branch: else_body,
elif_branches: vec![],
else_branch: None,
};
}
return Expr::Null;
}
Expr::If {
condition,
then_branch: then_branch
.into_iter()
.filter_map(|s| self.eliminate_dead_stmt(s))
.collect(),
elif_branches: elif_branches
.into_iter()
.map(|(c, b)| {
(
self.eliminate_dead_expr(c),
b.into_iter()
.filter_map(|s| self.eliminate_dead_stmt(s))
.collect(),
)
})
.collect(),
else_branch: else_branch.map(|b| {
b.into_iter()
.filter_map(|s| self.eliminate_dead_stmt(s))
.collect()
}),
}
}
other => other,
}
}
fn optimize_tail_recursion(&self, program: Program) -> Program {
program
.into_iter()
.map(|stmt| self.optimize_tail_recursive_stmt(stmt))
.collect()
}
fn optimize_tail_recursive_stmt(&self, stmt: Stmt) -> Stmt {
match stmt {
Stmt::FuncDef { name, params, body } => {
if self.is_tail_recursive(&name, &body) {
Stmt::FuncDef {
name: name.clone(),
params: params.clone(),
body: self.convert_tail_recursion_to_loop(&name, ¶ms, body),
}
} else {
Stmt::FuncDef { name, params, body }
}
}
other => other,
}
}
fn is_tail_recursive(&self, func_name: &str, body: &[Stmt]) -> bool {
if body.is_empty() {
return false;
}
self.has_tail_recursion_in_body(func_name, body)
}
fn has_tail_recursion_in_body(&self, func_name: &str, body: &[Stmt]) -> bool {
body.iter()
.any(|stmt| self.stmt_has_tail_recursion(func_name, stmt))
}
fn stmt_has_tail_recursion(&self, func_name: &str, stmt: &Stmt) -> bool {
match stmt {
Stmt::Return(expr) => self.is_tail_call(func_name, expr),
Stmt::Expression(expr) => self.expr_has_tail_recursion(func_name, expr),
Stmt::While { body, .. } => self.has_tail_recursion_in_body(func_name, body),
Stmt::For { body, .. } => self.has_tail_recursion_in_body(func_name, body),
Stmt::ForIndexed { body, .. } => self.has_tail_recursion_in_body(func_name, body),
_ => false,
}
}
fn expr_has_tail_recursion(&self, func_name: &str, expr: &Expr) -> bool {
match expr {
Expr::If {
then_branch,
elif_branches,
else_branch,
..
} => {
let then_tail = self.has_tail_recursion_in_body(func_name, then_branch);
let elif_tail = elif_branches
.iter()
.any(|(_, body)| self.has_tail_recursion_in_body(func_name, body));
let else_tail = else_branch
.as_ref()
.map(|body| self.has_tail_recursion_in_body(func_name, body))
.unwrap_or(false);
then_tail || elif_tail || else_tail
}
_ => false,
}
}
fn is_tail_call(&self, func_name: &str, expr: &Expr) -> bool {
match expr {
Expr::Call { func, .. } => {
if let Expr::Identifier(name) = &**func {
name == func_name
} else {
false
}
}
Expr::If {
then_branch,
elif_branches,
else_branch,
..
} => {
let then_is_tail = self.branch_ends_with_tail_call(func_name, then_branch);
let elif_all_tail = elif_branches
.iter()
.all(|(_, body)| self.branch_ends_with_tail_call(func_name, body));
let else_is_tail = else_branch
.as_ref()
.map(|body| self.branch_ends_with_tail_call(func_name, body))
.unwrap_or(true);
then_is_tail && elif_all_tail && else_is_tail
}
_ => false,
}
}
fn branch_ends_with_tail_call(&self, func_name: &str, branch: &[Stmt]) -> bool {
if let Some(last_stmt) = branch.last() {
match last_stmt {
Stmt::Return(expr) => self.is_tail_call(func_name, expr),
Stmt::Expression(expr) => {
self.is_tail_call(func_name, expr)
}
_ => false,
}
} else {
false
}
}
fn convert_tail_recursion_to_loop(
&self,
func_name: &str,
params: &[String],
body: Vec<Stmt>,
) -> Vec<Stmt> {
let mut new_body = Vec::new();
for param in params {
new_body.push(Stmt::Set {
name: format!("_loop_{}", param),
value: Expr::Identifier(param.clone()),
});
}
new_body.push(Stmt::Set {
name: "_loop_continue".to_string(),
value: Expr::Boolean(true),
});
let loop_body = self.transform_body_to_loop(func_name, params, body);
new_body.push(Stmt::While {
condition: Expr::Identifier("_loop_continue".to_string()),
body: loop_body,
});
new_body
}
fn transform_body_to_loop(
&self,
func_name: &str,
params: &[String],
body: Vec<Stmt>,
) -> Vec<Stmt> {
let mut loop_body = Vec::new();
for stmt in body {
match stmt {
Stmt::Return(expr) => {
if let Some(new_args) = self.extract_tail_call_args(func_name, &expr) {
for (i, param) in params.iter().enumerate() {
if let Some(arg) = new_args.get(i) {
loop_body.push(Stmt::Set {
name: format!("_loop_{}", param),
value: arg.clone(),
});
}
}
for param in params {
loop_body.push(Stmt::Set {
name: param.clone(),
value: Expr::Identifier(format!("_loop_{}", param)),
});
}
} else {
loop_body.push(Stmt::Set {
name: "_loop_continue".to_string(),
value: Expr::Boolean(false),
});
loop_body.push(Stmt::Return(expr));
}
}
_ => {
loop_body.push(self.transform_stmt_for_loop(func_name, params, stmt));
}
}
}
loop_body
}
fn extract_tail_call_args(&self, func_name: &str, expr: &Expr) -> Option<Vec<Expr>> {
match expr {
Expr::Call { func, args } => {
if let Expr::Identifier(name) = &**func
&& name == func_name
{
return Some(args.clone());
}
None
}
_ => None,
}
}
fn transform_stmt_for_loop(&self, func_name: &str, params: &[String], stmt: Stmt) -> Stmt {
match stmt {
Stmt::Expression(expr) => {
Stmt::Expression(self.transform_expr_for_loop(func_name, params, expr))
}
Stmt::While { condition, body } => Stmt::While {
condition,
body: self.transform_body_to_loop(func_name, params, body),
},
Stmt::For {
var,
iterable,
body,
} => Stmt::For {
var,
iterable,
body: self.transform_body_to_loop(func_name, params, body),
},
Stmt::ForIndexed {
index_var,
value_var,
iterable,
body,
} => Stmt::ForIndexed {
index_var,
value_var,
iterable,
body: self.transform_body_to_loop(func_name, params, body),
},
other => other,
}
}
fn transform_expr_for_loop(&self, func_name: &str, params: &[String], expr: Expr) -> Expr {
match expr {
Expr::If {
condition,
then_branch,
elif_branches,
else_branch,
} => Expr::If {
condition,
then_branch: self.transform_body_to_loop(func_name, params, then_branch),
elif_branches: elif_branches
.into_iter()
.map(|(cond, body)| {
(cond, self.transform_body_to_loop(func_name, params, body))
})
.collect(),
else_branch: else_branch
.map(|body| self.transform_body_to_loop(func_name, params, body)),
},
other => other,
}
}
}
impl Default for Optimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_folding() {
let optimizer = Optimizer::new();
let expr = Expr::Binary {
left: Box::new(Expr::Number(2.0)),
op: BinOp::Add,
right: Box::new(Expr::Number(3.0)),
};
let folded = optimizer.fold_expr(expr);
assert_eq!(folded, Expr::Number(5.0));
}
#[test]
fn test_dead_code_elimination() {
let optimizer = Optimizer::new();
let stmt = Stmt::While {
condition: Expr::Boolean(false),
body: vec![Stmt::Set {
name: "x".to_string(),
value: Expr::Number(10.0),
}],
};
let result = optimizer.eliminate_dead_stmt(stmt);
assert!(result.is_none());
}
#[test]
fn test_tail_recursion_detection() {
let optimizer = Optimizer::new();
let body = vec![Stmt::Return(Expr::Call {
func: Box::new(Expr::Identifier("factorial".to_string())),
args: vec![
Expr::Binary {
left: Box::new(Expr::Identifier("n".to_string())),
op: BinOp::Subtract,
right: Box::new(Expr::Number(1.0)),
},
Expr::Binary {
left: Box::new(Expr::Identifier("acc".to_string())),
op: BinOp::Multiply,
right: Box::new(Expr::Identifier("n".to_string())),
},
],
})];
assert!(optimizer.is_tail_recursive("factorial", &body));
}
#[test]
fn test_non_tail_recursion_detection() {
let optimizer = Optimizer::new();
let body = vec![Stmt::Return(Expr::Binary {
left: Box::new(Expr::Identifier("n".to_string())),
op: BinOp::Multiply,
right: Box::new(Expr::Call {
func: Box::new(Expr::Identifier("factorial".to_string())),
args: vec![Expr::Binary {
left: Box::new(Expr::Identifier("n".to_string())),
op: BinOp::Subtract,
right: Box::new(Expr::Number(1.0)),
}],
}),
})];
assert!(!optimizer.is_tail_recursive("factorial", &body));
}
#[test]
fn test_tail_recursion_in_if() {
let optimizer = Optimizer::new();
let body = vec![Stmt::Expression(Expr::If {
condition: Box::new(Expr::Binary {
left: Box::new(Expr::Identifier("n".to_string())),
op: BinOp::LessEqual,
right: Box::new(Expr::Number(0.0)),
}),
then_branch: vec![Stmt::Return(Expr::Identifier("acc".to_string()))],
elif_branches: vec![],
else_branch: Some(vec![Stmt::Return(Expr::Call {
func: Box::new(Expr::Identifier("sum".to_string())),
args: vec![
Expr::Binary {
left: Box::new(Expr::Identifier("n".to_string())),
op: BinOp::Subtract,
right: Box::new(Expr::Number(1.0)),
},
Expr::Binary {
left: Box::new(Expr::Identifier("acc".to_string())),
op: BinOp::Add,
right: Box::new(Expr::Identifier("n".to_string())),
},
],
})]),
})];
assert!(optimizer.is_tail_recursive("sum", &body));
}
#[test]
fn test_tail_recursion_optimization_transform() {
let optimizer = Optimizer::new();
let func_def = Stmt::FuncDef {
name: "factorial".to_string(),
params: vec!["n".to_string(), "acc".to_string()],
body: vec![Stmt::Return(Expr::Call {
func: Box::new(Expr::Identifier("factorial".to_string())),
args: vec![
Expr::Binary {
left: Box::new(Expr::Identifier("n".to_string())),
op: BinOp::Subtract,
right: Box::new(Expr::Number(1.0)),
},
Expr::Binary {
left: Box::new(Expr::Identifier("acc".to_string())),
op: BinOp::Multiply,
right: Box::new(Expr::Identifier("n".to_string())),
},
],
})],
};
let optimized = optimizer.optimize_tail_recursive_stmt(func_def);
if let Stmt::FuncDef { body, .. } = optimized {
assert!(
body.len() >= 3,
"Expected at least 3 statements, got {}",
body.len()
);
if let Some(Stmt::While { .. }) = body.last() {
} else {
panic!("Expected While loop at the end of optimized function body");
}
} else {
panic!("Expected FuncDef");
}
}
}