use crate::ast::{Block, Expr, FunctionDecl, Literal, Stmt};
use std::collections::HashMap;
pub struct CSEOptimizer {
expressions: HashMap<String, String>,
temp_counter: u32,
}
impl CSEOptimizer {
pub fn new() -> Self {
Self {
expressions: HashMap::new(),
temp_counter: 0,
}
}
pub fn optimize_expr(&mut self, expr: &Expr) -> (Expr, Vec<(String, Expr)>) {
let mut temps = Vec::new();
let optimized = self.optimize_expr_internal(expr, &mut temps);
(optimized, temps)
}
fn optimize_expr_internal(&mut self, expr: &Expr, temps: &mut Vec<(String, Expr)>) -> Expr {
match expr {
Expr::BinaryOp {
operator,
left,
right,
} => {
let left_opt = self.optimize_expr_internal(left, temps);
let right_opt = self.optimize_expr_internal(right, temps);
let key = format!("{:?} {:?} {:?}", left_opt, operator, right_opt);
if let Some(temp_var) = self.expressions.get(&key) {
return Expr::Variable(temp_var.clone());
}
let temp_var = format!("_cse_{}", self.temp_counter);
self.temp_counter += 1;
let new_expr = Expr::BinaryOp {
operator: operator.clone(),
left: Box::new(left_opt),
right: Box::new(right_opt),
};
self.expressions.insert(key, temp_var.clone());
temps.push((temp_var.clone(), new_expr.clone()));
Expr::Variable(temp_var)
}
_ => expr.clone(),
}
}
}
impl Default for CSEOptimizer {
fn default() -> Self {
Self::new()
}
}
pub struct FunctionInliner {
inline_threshold: usize,
inlined_count: usize,
}
impl FunctionInliner {
pub fn new(threshold: usize) -> Self {
Self {
inline_threshold: threshold,
inlined_count: 0,
}
}
pub fn should_inline(&self, func: &FunctionDecl) -> bool {
let stmt_count = func.block.statements.len();
stmt_count <= self.inline_threshold
}
pub fn inline_call(&mut self, func: &FunctionDecl, args: &[Expr]) -> Vec<Stmt> {
self.inlined_count += 1;
let mut stmts = Vec::new();
for (i, param) in func.parameters.iter().enumerate() {
if i < args.len() {
stmts.push(Stmt::Assignment {
target: param.name.clone(),
value: args[i].clone(),
});
}
}
stmts.extend(func.block.statements.clone());
stmts
}
pub fn inlined_count(&self) -> usize {
self.inlined_count
}
}
pub struct LoopOptimizer {
unroll_factor: usize,
}
impl LoopOptimizer {
pub fn new(unroll_factor: usize) -> Self {
Self { unroll_factor }
}
pub fn optimize_loop(&self, stmt: &Stmt) -> Stmt {
match stmt {
Stmt::For {
var_name,
start,
end,
body,
..
} => {
if let (Expr::Literal(Literal::Integer(s)), Expr::Literal(Literal::Integer(e))) =
(start, end)
{
let iterations = (e - s + 1).unsigned_abs() as usize;
if iterations <= self.unroll_factor {
return self.unroll_for_loop(var_name, *s, *e, body);
}
}
stmt.clone()
}
Stmt::While { condition, body } => self.hoist_invariants(condition, body),
_ => stmt.clone(),
}
}
fn unroll_for_loop(&self, var_name: &str, start: i64, end: i64, body: &[Stmt]) -> Stmt {
let mut unrolled = Vec::new();
for i in start..=end {
for stmt in body {
unrolled.push(self.substitute_var(stmt, var_name, i));
}
}
Stmt::Block(Block::with_statements(unrolled))
}
fn substitute_var(&self, stmt: &Stmt, var_name: &str, value: i64) -> Stmt {
match stmt {
Stmt::Assignment {
target,
value: expr,
} => Stmt::Assignment {
target: target.clone(),
value: self.substitute_expr(expr, var_name, value),
},
_ => stmt.clone(),
}
}
fn substitute_expr(&self, expr: &Expr, var_name: &str, value: i64) -> Expr {
match expr {
Expr::Variable(name) if name == var_name => Expr::Literal(Literal::Integer(value)),
Expr::BinaryOp {
operator,
left,
right,
} => Expr::BinaryOp {
operator: operator.clone(),
left: Box::new(self.substitute_expr(left, var_name, value)),
right: Box::new(self.substitute_expr(right, var_name, value)),
},
_ => expr.clone(),
}
}
fn hoist_invariants(&self, condition: &Expr, body: &[Stmt]) -> Stmt {
Stmt::While {
condition: condition.clone(),
body: body.to_vec(),
}
}
}
impl Default for LoopOptimizer {
fn default() -> Self {
Self::new(4)
}
}
pub struct TailCallOptimizer {
optimized_count: usize,
}
impl TailCallOptimizer {
pub fn new() -> Self {
Self { optimized_count: 0 }
}
pub fn is_tail_call(&self, stmt: &Stmt, func_name: &str) -> bool {
matches!(stmt, Stmt::ProcedureCall { name, .. } if name == func_name)
}
pub fn optimize_tail_call(&mut self, func: &FunctionDecl) -> FunctionDecl {
let optimized_func = func.clone();
if let Some(last_stmt) = func.block.statements.last() {
if self.is_tail_call(last_stmt, &func.name) {
self.optimized_count += 1;
}
}
optimized_func
}
pub fn optimized_count(&self) -> usize {
self.optimized_count
}
}
impl Default for TailCallOptimizer {
fn default() -> Self {
Self::new()
}
}
pub struct StrengthReducer;
impl StrengthReducer {
pub fn optimize(&self, expr: &Expr) -> Expr {
match expr {
Expr::BinaryOp {
operator,
left,
right,
} => {
let left_opt = self.optimize(left);
let right_opt = self.optimize(right);
if operator == "*" {
if let Expr::Literal(Literal::Integer(n)) = &right_opt {
if *n > 0 && (*n & (*n - 1)) == 0 {
let shift = (*n as f64).log2() as i64;
return Expr::BinaryOp {
operator: "shl".to_string(),
left: Box::new(left_opt),
right: Box::new(Expr::Literal(Literal::Integer(shift))),
};
}
}
}
if operator == "div" {
if let Expr::Literal(Literal::Integer(n)) = &right_opt {
if *n > 0 && (*n & (*n - 1)) == 0 {
let shift = (*n as f64).log2() as i64;
return Expr::BinaryOp {
operator: "shr".to_string(),
left: Box::new(left_opt),
right: Box::new(Expr::Literal(Literal::Integer(shift))),
};
}
}
}
Expr::BinaryOp {
operator: operator.clone(),
left: Box::new(left_opt),
right: Box::new(right_opt),
}
}
_ => expr.clone(),
}
}
}
pub struct AdvancedOptimizer {
cse: CSEOptimizer,
inliner: FunctionInliner,
loop_opt: LoopOptimizer,
tail_call: TailCallOptimizer,
strength_reducer: StrengthReducer,
}
impl AdvancedOptimizer {
pub fn new() -> Self {
Self {
cse: CSEOptimizer::new(),
inliner: FunctionInliner::new(10),
loop_opt: LoopOptimizer::new(4),
tail_call: TailCallOptimizer::new(),
strength_reducer: StrengthReducer,
}
}
pub fn optimize_stmt(&mut self, stmt: &Stmt) -> Stmt {
let stmt = self.loop_opt.optimize_loop(stmt);
stmt
}
pub fn optimize_expr(&mut self, expr: &Expr) -> Expr {
let expr = self.strength_reducer.optimize(expr);
expr
}
pub fn stats(&self) -> OptimizationStats {
OptimizationStats {
cse_temps: self.cse.temp_counter,
inlined_functions: self.inliner.inlined_count(),
tail_calls_optimized: self.tail_call.optimized_count(),
}
}
}
impl Default for AdvancedOptimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct OptimizationStats {
pub cse_temps: u32,
pub inlined_functions: usize,
pub tail_calls_optimized: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cse() {
let mut cse = CSEOptimizer::new();
let expr1 = Expr::BinaryOp {
operator: "+".to_string(),
left: Box::new(Expr::Variable("a".to_string())),
right: Box::new(Expr::Variable("b".to_string())),
};
let (_opt1, temps1) = cse.optimize_expr(&expr1);
assert_eq!(temps1.len(), 1);
let (_opt2, temps2) = cse.optimize_expr(&expr1);
assert_eq!(temps2.len(), 0);
}
#[test]
fn test_loop_unrolling() {
let optimizer = LoopOptimizer::new(10);
let loop_stmt = Stmt::For {
var_name: "i".to_string(),
start: Expr::Literal(Literal::Integer(1)),
direction: crate::ast::ForDirection::To,
end: Expr::Literal(Literal::Integer(3)),
body: vec![Stmt::Empty],
};
let optimized = optimizer.optimize_loop(&loop_stmt);
assert!(matches!(optimized, Stmt::Block(_)));
}
#[test]
fn test_strength_reduction() {
let reducer = StrengthReducer;
let expr = Expr::BinaryOp {
operator: "*".to_string(),
left: Box::new(Expr::Variable("x".to_string())),
right: Box::new(Expr::Literal(Literal::Integer(8))),
};
let optimized = reducer.optimize(&expr);
if let Expr::BinaryOp { operator, .. } = optimized {
assert_eq!(operator, "shl");
} else {
panic!("Expected shift operation");
}
}
}