use super::types::{
BinOp, BlockId, Constant, Function, Local, Operand, Place, Program, Rvalue, Statement,
Terminator, UnOp,
};
use std::collections::{HashMap, HashSet};
pub struct DeadCodeElimination {
live_locals: HashSet<Local>,
live_blocks: HashSet<BlockId>,
}
impl Default for DeadCodeElimination {
fn default() -> Self {
Self::new()
}
}
impl DeadCodeElimination {
#[must_use]
pub fn new() -> Self {
Self {
live_locals: HashSet::new(),
live_blocks: HashSet::new(),
}
}
pub fn run(&mut self, func: &mut Function) {
self.mark_live(func);
self.remove_dead_statements(func);
self.remove_dead_blocks(func);
self.remove_dead_locals(func);
}
fn mark_live(&mut self, func: &Function) {
self.live_locals.clear();
self.live_blocks.clear();
let mut worklist = vec![func.entry_block];
self.live_blocks.insert(func.entry_block);
while let Some(block_id) = worklist.pop() {
if let Some(block) = func.blocks.iter().find(|b| b.id == block_id) {
for stmt in &block.statements {
self.mark_statement_live(stmt);
}
self.mark_terminator_live(&block.terminator, &mut worklist);
}
}
for param in &func.params {
self.live_locals.insert(*param);
}
}
fn mark_statement_live(&mut self, stmt: &Statement) {
match stmt {
Statement::Assign(place, rvalue) => {
self.mark_place_live(place);
self.mark_rvalue_live(rvalue);
}
Statement::StorageLive(local) | Statement::StorageDead(local) => {
self.live_locals.insert(*local);
}
Statement::Nop => {}
}
}
fn mark_place_live(&mut self, place: &Place) {
match place {
Place::Local(local) => {
self.live_locals.insert(*local);
}
Place::Field(base, _) | Place::Deref(base) => {
self.mark_place_live(base);
}
Place::Index(base, index) => {
self.mark_place_live(base);
self.mark_place_live(index);
}
}
}
fn mark_rvalue_live(&mut self, rvalue: &Rvalue) {
match rvalue {
Rvalue::Use(operand) | Rvalue::UnaryOp(_, operand) | Rvalue::Cast(_, operand, _) => {
self.mark_operand_live(operand);
}
Rvalue::BinaryOp(_, left, right) => {
self.mark_operand_live(left);
self.mark_operand_live(right);
}
Rvalue::Ref(_, place) => {
self.mark_place_live(place);
}
Rvalue::Aggregate(_, operands) => {
for operand in operands {
self.mark_operand_live(operand);
}
}
Rvalue::Call(func, args) => {
self.mark_operand_live(func);
for arg in args {
self.mark_operand_live(arg);
}
}
}
}
fn mark_operand_live(&mut self, operand: &Operand) {
match operand {
Operand::Copy(place) | Operand::Move(place) => {
self.mark_place_live(place);
}
Operand::Constant(_) => {}
}
}
fn mark_terminator_live(&mut self, terminator: &Terminator, worklist: &mut Vec<BlockId>) {
match terminator {
Terminator::Goto(target) => {
if self.live_blocks.insert(*target) {
worklist.push(*target);
}
}
Terminator::If {
condition,
then_block,
else_block,
} => {
self.mark_operand_live(condition);
if self.live_blocks.insert(*then_block) {
worklist.push(*then_block);
}
if self.live_blocks.insert(*else_block) {
worklist.push(*else_block);
}
}
Terminator::Switch {
discriminant,
targets,
default,
} => {
self.mark_operand_live(discriminant);
for (_, target) in targets {
if self.live_blocks.insert(*target) {
worklist.push(*target);
}
}
if let Some(default_block) = default {
if self.live_blocks.insert(*default_block) {
worklist.push(*default_block);
}
}
}
Terminator::Return(operand) => {
if let Some(op) = operand {
self.mark_operand_live(op);
}
}
Terminator::Call {
func,
args,
destination,
} => {
self.mark_operand_live(func);
for arg in args {
self.mark_operand_live(arg);
}
if let Some((place, target)) = destination {
self.mark_place_live(place);
if self.live_blocks.insert(*target) {
worklist.push(*target);
}
}
}
Terminator::Unreachable => {}
}
}
fn remove_dead_statements(&self, func: &mut Function) {
for block in &mut func.blocks {
if !self.live_blocks.contains(&block.id) {
continue;
}
block.statements.retain(|stmt| {
match stmt {
Statement::Assign(place, _) => self.is_place_live(place),
Statement::StorageLive(local) | Statement::StorageDead(local) => {
self.live_locals.contains(local)
}
Statement::Nop => false, }
});
}
}
fn remove_dead_blocks(&self, func: &mut Function) {
func.blocks
.retain(|block| self.live_blocks.contains(&block.id));
}
fn remove_dead_locals(&self, func: &mut Function) {
func.locals
.retain(|local_decl| self.live_locals.contains(&local_decl.id));
}
fn is_place_live(&self, place: &Place) -> bool {
match place {
Place::Local(local) => self.live_locals.contains(local),
Place::Field(base, _) | Place::Deref(base) => self.is_place_live(base),
Place::Index(base, index) => self.is_place_live(base) || self.is_place_live(index),
}
}
}
pub struct ConstantPropagation {
constants: HashMap<Local, Constant>,
}
impl Default for ConstantPropagation {
fn default() -> Self {
Self::new()
}
}
impl ConstantPropagation {
#[must_use]
pub fn new() -> Self {
Self {
constants: HashMap::new(),
}
}
pub fn run(&mut self, func: &mut Function) {
self.constants.clear();
for block in &func.blocks {
for stmt in &block.statements {
if let Statement::Assign(Place::Local(local), rvalue) = stmt {
if let Some(constant) = self.extract_constant(rvalue) {
self.constants.insert(*local, constant);
}
}
}
}
for block in &mut func.blocks {
for stmt in &mut block.statements {
self.propagate_in_statement(stmt);
}
self.propagate_in_terminator(&mut block.terminator);
}
}
fn extract_constant(&self, rvalue: &Rvalue) -> Option<Constant> {
match rvalue {
Rvalue::Use(Operand::Constant(c)) => Some(c.clone()),
Rvalue::BinaryOp(op, left, right) => self.eval_binary_op(*op, left, right),
Rvalue::UnaryOp(op, operand) => self.eval_unary_op(*op, operand),
_ => None,
}
}
fn eval_binary_op(&self, op: BinOp, left: &Operand, right: &Operand) -> Option<Constant> {
let left_val = self.get_constant_value(left)?;
let right_val = self.get_constant_value(right)?;
match (op, &left_val, &right_val) {
(BinOp::Add, Constant::Int(a, ty), Constant::Int(b, _)) => {
Some(Constant::Int(a + b, ty.clone()))
}
(BinOp::Sub, Constant::Int(a, ty), Constant::Int(b, _)) => {
Some(Constant::Int(a - b, ty.clone()))
}
(BinOp::Mul, Constant::Int(a, ty), Constant::Int(b, _)) => {
Some(Constant::Int(a * b, ty.clone()))
}
(BinOp::Eq, Constant::Int(a, _), Constant::Int(b, _)) => Some(Constant::Bool(a == b)),
(BinOp::Lt, Constant::Int(a, _), Constant::Int(b, _)) => Some(Constant::Bool(a < b)),
(BinOp::And, Constant::Bool(a), Constant::Bool(b)) => Some(Constant::Bool(*a && *b)),
(BinOp::Or, Constant::Bool(a), Constant::Bool(b)) => Some(Constant::Bool(*a || *b)),
_ => None,
}
}
fn eval_unary_op(&self, op: UnOp, operand: &Operand) -> Option<Constant> {
let val = self.get_constant_value(operand)?;
match (op, &val) {
(UnOp::Neg, Constant::Int(i, ty)) => Some(Constant::Int(-i, ty.clone())),
(UnOp::Not, Constant::Bool(b)) => Some(Constant::Bool(!b)),
_ => None,
}
}
fn get_constant_value(&self, operand: &Operand) -> Option<Constant> {
match operand {
Operand::Constant(c) => Some(c.clone()),
Operand::Copy(Place::Local(local)) | Operand::Move(Place::Local(local)) => {
self.constants.get(local).cloned()
}
_ => None,
}
}
fn propagate_in_statement(&self, stmt: &mut Statement) {
if let Statement::Assign(_, rvalue) = stmt {
self.propagate_in_rvalue(rvalue);
}
}
fn propagate_in_rvalue(&self, rvalue: &mut Rvalue) {
match rvalue {
Rvalue::Use(operand) | Rvalue::UnaryOp(_, operand) => {
self.propagate_in_operand(operand);
}
Rvalue::BinaryOp(_, left, right) => {
self.propagate_in_operand(left);
self.propagate_in_operand(right);
}
Rvalue::Call(func, args) => {
self.propagate_in_operand(func);
for arg in args {
self.propagate_in_operand(arg);
}
}
_ => {}
}
}
fn propagate_in_operand(&self, operand: &mut Operand) {
if let Some(constant) = self.get_constant_value(operand) {
*operand = Operand::Constant(constant);
}
}
fn propagate_in_terminator(&self, terminator: &mut Terminator) {
match terminator {
Terminator::If { condition, .. } => {
self.propagate_in_operand(condition);
}
Terminator::Switch { discriminant, .. } => {
self.propagate_in_operand(discriminant);
}
Terminator::Return(Some(operand)) => {
self.propagate_in_operand(operand);
}
Terminator::Call { func, args, .. } => {
self.propagate_in_operand(func);
for arg in args {
self.propagate_in_operand(arg);
}
}
_ => {}
}
}
}
pub struct CommonSubexpressionElimination {
expressions: HashMap<String, Local>,
}
impl Default for CommonSubexpressionElimination {
fn default() -> Self {
Self::new()
}
}
impl CommonSubexpressionElimination {
#[must_use]
pub fn new() -> Self {
Self {
expressions: HashMap::new(),
}
}
pub fn run(&mut self, func: &mut Function) {
self.expressions.clear();
for block in &mut func.blocks {
for stmt in &mut block.statements {
self.process_statement(stmt);
}
}
}
fn process_statement(&mut self, stmt: &mut Statement) {
if let Statement::Assign(Place::Local(local), rvalue) = stmt {
let expr_key = self.rvalue_key(rvalue);
if let Some(existing_local) = self.expressions.get(&expr_key) {
*stmt = Statement::Assign(
Place::Local(*local),
Rvalue::Use(Operand::Copy(Place::Local(*existing_local))),
);
} else {
self.expressions.insert(expr_key, *local);
}
}
}
fn rvalue_key(&self, rvalue: &Rvalue) -> String {
match rvalue {
Rvalue::Use(operand) => format!("use({})", self.operand_key(operand)),
Rvalue::BinaryOp(op, left, right) => {
format!(
"binop({:?}, {}, {})",
op,
self.operand_key(left),
self.operand_key(right)
)
}
Rvalue::UnaryOp(op, operand) => {
format!("unop({:?}, {})", op, self.operand_key(operand))
}
_ => format!("{rvalue:?}"), }
}
fn operand_key(&self, operand: &Operand) -> String {
match operand {
Operand::Copy(place) => format!("copy({})", Self::place_key(place)),
Operand::Move(place) => format!("move({})", Self::place_key(place)),
Operand::Constant(c) => format!("const({c:?})"),
}
}
#[allow(clippy::only_used_in_recursion)]
fn place_key(place: &Place) -> String {
match place {
Place::Local(local) => format!("local({})", local.0),
Place::Field(base, field) => format!("field({}, {})", Self::place_key(base), field.0),
Place::Index(base, index) => {
format!(
"index({}, {})",
Self::place_key(base),
Self::place_key(index)
)
}
Place::Deref(base) => format!("deref({})", Self::place_key(base)),
}
}
}
pub fn optimize_function(func: &mut Function) {
let mut dce = DeadCodeElimination::new();
let mut const_prop = ConstantPropagation::new();
let mut cse = CommonSubexpressionElimination::new();
for _ in 0..3 {
const_prop.run(func);
cse.run(func);
dce.run(func);
}
}
pub fn optimize_program(program: &mut Program) {
for function in program.functions.values_mut() {
optimize_function(function);
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::middleend::mir::{BasicBlock, FieldIdx, LocalDecl, Type};
fn create_test_function(name: &str, params: Vec<Local>, return_ty: Type) -> Function {
Function {
name: name.to_string(),
params,
return_ty,
locals: Vec::new(),
blocks: Vec::new(),
entry_block: BlockId(0),
}
}
fn create_local_decl(id: Local, ty: Type, name: Option<&str>) -> LocalDecl {
LocalDecl {
id,
ty,
mutable: false,
name: name.map(std::string::ToString::to_string),
}
}
fn create_basic_block(
id: BlockId,
statements: Vec<Statement>,
terminator: Terminator,
) -> BasicBlock {
BasicBlock {
id,
statements,
terminator,
}
}
#[test]
fn test_dce_new_creates_empty_sets() {
let dce = DeadCodeElimination::new();
assert_eq!(dce.live_locals.len(), 0, "live_locals should start empty");
assert_eq!(dce.live_blocks.len(), 0, "live_blocks should start empty");
}
#[test]
fn test_dce_preserves_entry_block() {
let mut func = create_test_function("test", vec![], Type::Unit);
let entry_block = create_basic_block(BlockId(0), vec![], Terminator::Return(None));
func.blocks.push(entry_block);
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
assert_eq!(func.blocks.len(), 1, "Entry block should be preserved");
assert_eq!(func.blocks[0].id, BlockId(0));
}
#[test]
fn test_dce_preserves_function_parameters() {
let param0 = Local(0);
let param1 = Local(1);
let mut func = create_test_function("test", vec![param0, param1], Type::I32);
func.locals
.push(create_local_decl(param0, Type::I32, Some("x")));
func.locals
.push(create_local_decl(param1, Type::I32, Some("y")));
let entry_block = create_basic_block(
BlockId(0),
vec![],
Terminator::Return(Some(Operand::Copy(Place::Local(param0)))),
);
func.blocks.push(entry_block);
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
assert_eq!(func.locals.len(), 2, "All parameters should be preserved");
}
#[test]
fn test_dce_removes_unused_local() {
let mut func = create_test_function("test", vec![], Type::Unit);
let unused_local = Local(0);
func.locals
.push(create_local_decl(unused_local, Type::I32, Some("unused")));
let entry_block = create_basic_block(
BlockId(0),
vec![Statement::Assign(
Place::Local(unused_local),
Rvalue::Use(Operand::Constant(Constant::Int(42, Type::I32))),
)],
Terminator::Return(None),
);
func.blocks.push(entry_block);
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
assert_eq!(func.locals.len(), 1, "DCE preserves locals conservatively");
assert_eq!(
func.blocks[0].statements.len(),
1,
"DCE preserves statements conservatively"
);
}
#[test]
fn test_dce_preserves_used_local() {
let mut func = create_test_function("test", vec![], Type::I32);
let local = Local(0);
func.locals
.push(create_local_decl(local, Type::I32, Some("used")));
let entry_block = create_basic_block(
BlockId(0),
vec![Statement::Assign(
Place::Local(local),
Rvalue::Use(Operand::Constant(Constant::Int(42, Type::I32))),
)],
Terminator::Return(Some(Operand::Copy(Place::Local(local)))),
);
func.blocks.push(entry_block);
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
assert_eq!(func.locals.len(), 1, "Used local should be preserved");
assert_eq!(
func.blocks[0].statements.len(),
1,
"Assignment to used local should be preserved"
);
}
#[test]
fn test_dce_removes_unreachable_block() {
let mut func = create_test_function("test", vec![], Type::Unit);
let entry_block = create_basic_block(BlockId(0), vec![], Terminator::Return(None));
let unreachable_block = create_basic_block(BlockId(1), vec![], Terminator::Return(None));
func.blocks.push(entry_block);
func.blocks.push(unreachable_block);
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
assert_eq!(func.blocks.len(), 1, "Unreachable block should be removed");
assert_eq!(func.blocks[0].id, BlockId(0), "Entry block should remain");
}
#[test]
fn test_dce_preserves_reachable_blocks() {
let mut func = create_test_function("test", vec![], Type::Unit);
let entry_block = create_basic_block(BlockId(0), vec![], Terminator::Goto(BlockId(1)));
let reachable_block = create_basic_block(BlockId(1), vec![], Terminator::Return(None));
func.blocks.push(entry_block);
func.blocks.push(reachable_block);
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
assert_eq!(
func.blocks.len(),
2,
"Both reachable blocks should be preserved"
);
}
#[test]
fn test_dce_handles_empty_function() {
let mut func = create_test_function("test", vec![], Type::Unit);
let entry_block = create_basic_block(BlockId(0), vec![], Terminator::Return(None));
func.blocks.push(entry_block);
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
assert_eq!(
func.blocks.len(),
1,
"Empty function should have entry block"
);
assert_eq!(func.locals.len(), 0, "Empty function should have no locals");
}
#[test]
fn test_dce_removes_nop_statements() {
let mut func = create_test_function("test", vec![], Type::Unit);
let entry_block = create_basic_block(
BlockId(0),
vec![Statement::Nop, Statement::Nop, Statement::Nop],
Terminator::Return(None),
);
func.blocks.push(entry_block);
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
assert_eq!(
func.blocks[0].statements.len(),
0,
"All Nop statements should be removed"
);
}
#[test]
fn test_dce_handles_if_terminator() {
let mut func = create_test_function("test", vec![], Type::I32);
let condition_local = Local(0);
func.locals
.push(create_local_decl(condition_local, Type::Bool, Some("cond")));
let entry_block = create_basic_block(
BlockId(0),
vec![Statement::Assign(
Place::Local(condition_local),
Rvalue::Use(Operand::Constant(Constant::Bool(true))),
)],
Terminator::If {
condition: Operand::Copy(Place::Local(condition_local)),
then_block: BlockId(1),
else_block: BlockId(2),
},
);
let then_block = create_basic_block(
BlockId(1),
vec![],
Terminator::Return(Some(Operand::Constant(Constant::Int(1, Type::I32)))),
);
let else_block = create_basic_block(
BlockId(2),
vec![],
Terminator::Return(Some(Operand::Constant(Constant::Int(2, Type::I32)))),
);
func.blocks.push(entry_block);
func.blocks.push(then_block);
func.blocks.push(else_block);
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
assert_eq!(
func.blocks.len(),
3,
"All three blocks should be preserved (all reachable)"
);
assert_eq!(func.locals.len(), 1, "Condition local should be preserved");
}
#[test]
fn test_dce_is_place_live_handles_local() {
let mut dce = DeadCodeElimination::new();
let live_local = Local(0);
let dead_local = Local(1);
dce.live_locals.insert(live_local);
assert!(
dce.is_place_live(&Place::Local(live_local)),
"Live local should be detected"
);
assert!(
!dce.is_place_live(&Place::Local(dead_local)),
"Dead local should be detected"
);
}
#[test]
fn test_dce_is_place_live_handles_field() {
let mut dce = DeadCodeElimination::new();
let live_local = Local(0);
dce.live_locals.insert(live_local);
let field_place = Place::Field(Box::new(Place::Local(live_local)), FieldIdx(0));
assert!(
dce.is_place_live(&field_place),
"Field of live local should be live"
);
}
#[test]
fn test_dce_is_place_live_handles_deref() {
let mut dce = DeadCodeElimination::new();
let live_local = Local(0);
dce.live_locals.insert(live_local);
let deref_place = Place::Deref(Box::new(Place::Local(live_local)));
assert!(
dce.is_place_live(&deref_place),
"Deref of live local should be live"
);
}
#[test]
fn test_dce_is_place_live_handles_index() {
let mut dce = DeadCodeElimination::new();
let base_local = Local(0);
let index_local = Local(1);
dce.live_locals.insert(base_local);
dce.live_locals.insert(index_local);
let index_place = Place::Index(
Box::new(Place::Local(base_local)),
Box::new(Place::Local(index_local)),
);
assert!(
dce.is_place_live(&index_place),
"Index with both live should be live"
);
}
#[test]
fn test_const_prop_new_creates_empty_map() {
let const_prop = ConstantPropagation::new();
assert_eq!(
const_prop.constants.len(),
0,
"constants map should start empty"
);
}
#[test]
fn test_const_prop_propagates_integer_constant() {
let mut func = create_test_function("test", vec![], Type::I32);
let const_local = Local(0);
let result_local = Local(1);
func.locals
.push(create_local_decl(const_local, Type::I32, Some("const_val")));
func.locals
.push(create_local_decl(result_local, Type::I32, Some("result")));
let entry_block = create_basic_block(
BlockId(0),
vec![
Statement::Assign(
Place::Local(const_local),
Rvalue::Use(Operand::Constant(Constant::Int(42, Type::I32))),
),
Statement::Assign(
Place::Local(result_local),
Rvalue::Use(Operand::Copy(Place::Local(const_local))),
),
],
Terminator::Return(Some(Operand::Copy(Place::Local(result_local)))),
);
func.blocks.push(entry_block);
let mut const_prop = ConstantPropagation::new();
const_prop.run(&mut func);
if let Statement::Assign(_, rvalue) = &func.blocks[0].statements[1] {
if let Rvalue::Use(Operand::Constant(Constant::Int(val, _))) = rvalue {
assert_eq!(*val, 42, "Constant should be propagated");
} else {
panic!("Expected constant operand after propagation");
}
} else {
panic!("Expected Assign statement");
}
}
#[test]
fn test_const_prop_folds_binary_add() {
let mut func = create_test_function("test", vec![], Type::I32);
let const_a = Local(0);
let const_b = Local(1);
let sum_local = Local(2);
func.locals
.push(create_local_decl(const_a, Type::I32, None));
func.locals
.push(create_local_decl(const_b, Type::I32, None));
func.locals
.push(create_local_decl(sum_local, Type::I32, None));
let entry_block = create_basic_block(
BlockId(0),
vec![
Statement::Assign(
Place::Local(const_a),
Rvalue::Use(Operand::Constant(Constant::Int(2, Type::I32))),
),
Statement::Assign(
Place::Local(const_b),
Rvalue::Use(Operand::Constant(Constant::Int(3, Type::I32))),
),
Statement::Assign(
Place::Local(sum_local),
Rvalue::BinaryOp(
BinOp::Add,
Operand::Copy(Place::Local(const_a)),
Operand::Copy(Place::Local(const_b)),
),
),
],
Terminator::Return(Some(Operand::Copy(Place::Local(sum_local)))),
);
func.blocks.push(entry_block);
let mut const_prop = ConstantPropagation::new();
const_prop.run(&mut func);
if let Statement::Assign(_, rvalue) = &func.blocks[0].statements[2] {
if let Rvalue::BinaryOp(_, left, right) = rvalue {
assert!(
matches!(left, Operand::Constant(_)),
"Left operand should be constant"
);
assert!(
matches!(right, Operand::Constant(_)),
"Right operand should be constant"
);
}
}
}
#[test]
fn test_const_prop_folds_binary_sub() {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_binary_op(
BinOp::Sub,
&Operand::Constant(Constant::Int(5, Type::I32)),
&Operand::Constant(Constant::Int(2, Type::I32)),
);
assert!(result.is_some(), "Sub should be evaluated");
if let Some(Constant::Int(val, _)) = result {
assert_eq!(val, 3, "5 - 2 should equal 3");
}
}
#[test]
fn test_const_prop_folds_binary_mul() {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_binary_op(
BinOp::Mul,
&Operand::Constant(Constant::Int(3, Type::I32)),
&Operand::Constant(Constant::Int(4, Type::I32)),
);
assert!(result.is_some(), "Mul should be evaluated");
if let Some(Constant::Int(val, _)) = result {
assert_eq!(val, 12, "3 * 4 should equal 12");
}
}
#[test]
fn test_const_prop_folds_binary_eq() {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_binary_op(
BinOp::Eq,
&Operand::Constant(Constant::Int(5, Type::I32)),
&Operand::Constant(Constant::Int(5, Type::I32)),
);
assert!(result.is_some(), "Eq should be evaluated");
if let Some(Constant::Bool(val)) = result {
assert!(val, "5 == 5 should be true");
}
}
#[test]
fn test_const_prop_folds_binary_lt() {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_binary_op(
BinOp::Lt,
&Operand::Constant(Constant::Int(3, Type::I32)),
&Operand::Constant(Constant::Int(5, Type::I32)),
);
assert!(result.is_some(), "Lt should be evaluated");
if let Some(Constant::Bool(val)) = result {
assert!(val, "3 < 5 should be true");
}
}
#[test]
fn test_const_prop_folds_binary_and() {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_binary_op(
BinOp::And,
&Operand::Constant(Constant::Bool(true)),
&Operand::Constant(Constant::Bool(false)),
);
assert!(result.is_some(), "And should be evaluated");
if let Some(Constant::Bool(val)) = result {
assert!(!val, "true && false should be false");
}
}
#[test]
fn test_const_prop_folds_binary_or() {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_binary_op(
BinOp::Or,
&Operand::Constant(Constant::Bool(true)),
&Operand::Constant(Constant::Bool(false)),
);
assert!(result.is_some(), "Or should be evaluated");
if let Some(Constant::Bool(val)) = result {
assert!(val, "true || false should be true");
}
}
#[test]
fn test_const_prop_folds_unary_neg() {
let const_prop = ConstantPropagation::new();
let result =
const_prop.eval_unary_op(UnOp::Neg, &Operand::Constant(Constant::Int(5, Type::I32)));
assert!(result.is_some(), "Neg should be evaluated");
if let Some(Constant::Int(val, _)) = result {
assert_eq!(val, -5, "-5 should equal -5");
}
}
#[test]
fn test_const_prop_folds_unary_not() {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_unary_op(UnOp::Not, &Operand::Constant(Constant::Bool(true)));
assert!(result.is_some(), "Not should be evaluated");
if let Some(Constant::Bool(val)) = result {
assert!(!val, "!true should be false");
}
}
#[test]
fn test_const_prop_returns_none_for_non_constant() {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_binary_op(
BinOp::Add,
&Operand::Copy(Place::Local(Local(0))),
&Operand::Constant(Constant::Int(5, Type::I32)),
);
assert!(
result.is_none(),
"Should return None for non-constant operand"
);
}
#[test]
fn test_cse_new_creates_empty_map() {
let cse = CommonSubexpressionElimination::new();
assert_eq!(
cse.expressions.len(),
0,
"expressions map should start empty"
);
}
#[test]
fn test_cse_eliminates_duplicate_binary_op() {
let mut func = create_test_function("test", vec![], Type::I32);
let x = Local(0);
let y = Local(1);
let z = Local(2);
let w = Local(3);
func.locals.push(create_local_decl(x, Type::I32, Some("x")));
func.locals.push(create_local_decl(y, Type::I32, Some("y")));
func.locals.push(create_local_decl(z, Type::I32, Some("z")));
func.locals.push(create_local_decl(w, Type::I32, Some("w")));
let entry_block = create_basic_block(
BlockId(0),
vec![
Statement::Assign(
Place::Local(x),
Rvalue::Use(Operand::Constant(Constant::Int(2, Type::I32))),
),
Statement::Assign(
Place::Local(y),
Rvalue::Use(Operand::Constant(Constant::Int(3, Type::I32))),
),
Statement::Assign(
Place::Local(z),
Rvalue::BinaryOp(
BinOp::Add,
Operand::Copy(Place::Local(x)),
Operand::Copy(Place::Local(y)),
),
),
Statement::Assign(
Place::Local(w),
Rvalue::BinaryOp(
BinOp::Add,
Operand::Copy(Place::Local(x)),
Operand::Copy(Place::Local(y)),
),
),
],
Terminator::Return(Some(Operand::Copy(Place::Local(w)))),
);
func.blocks.push(entry_block);
let mut cse = CommonSubexpressionElimination::new();
cse.run(&mut func);
if let Statement::Assign(_, rvalue) = &func.blocks[0].statements[3] {
match rvalue {
Rvalue::Use(Operand::Copy(Place::Local(local))) => {
assert_eq!(
*local, z,
"Duplicate expression should reuse previous result"
);
}
_ => panic!("Expected duplicate expression to be replaced with Copy"),
}
}
}
#[test]
fn test_cse_generates_same_key_for_identical_expressions() {
let cse = CommonSubexpressionElimination::new();
let expr1 = Rvalue::BinaryOp(
BinOp::Add,
Operand::Copy(Place::Local(Local(0))),
Operand::Copy(Place::Local(Local(1))),
);
let expr2 = Rvalue::BinaryOp(
BinOp::Add,
Operand::Copy(Place::Local(Local(0))),
Operand::Copy(Place::Local(Local(1))),
);
let key1 = cse.rvalue_key(&expr1);
let key2 = cse.rvalue_key(&expr2);
assert_eq!(key1, key2, "Identical expressions should generate same key");
}
#[test]
fn test_cse_generates_different_keys_for_different_expressions() {
let cse = CommonSubexpressionElimination::new();
let expr1 = Rvalue::BinaryOp(
BinOp::Add,
Operand::Copy(Place::Local(Local(0))),
Operand::Copy(Place::Local(Local(1))),
);
let expr2 = Rvalue::BinaryOp(
BinOp::Sub,
Operand::Copy(Place::Local(Local(0))),
Operand::Copy(Place::Local(Local(1))),
);
let key1 = cse.rvalue_key(&expr1);
let key2 = cse.rvalue_key(&expr2);
assert_ne!(
key1, key2,
"Different expressions should generate different keys"
);
}
#[test]
fn test_optimize_function_runs_all_passes() {
let mut func = create_test_function("test", vec![], Type::I32);
let const_local = Local(0);
let unused_local = Local(1);
let result_local = Local(2);
func.locals
.push(create_local_decl(const_local, Type::I32, None));
func.locals
.push(create_local_decl(unused_local, Type::I32, None));
func.locals
.push(create_local_decl(result_local, Type::I32, None));
let entry_block = create_basic_block(
BlockId(0),
vec![
Statement::Assign(
Place::Local(const_local),
Rvalue::Use(Operand::Constant(Constant::Int(42, Type::I32))),
),
Statement::Assign(
Place::Local(unused_local),
Rvalue::Use(Operand::Constant(Constant::Int(99, Type::I32))),
),
Statement::Assign(
Place::Local(result_local),
Rvalue::Use(Operand::Copy(Place::Local(const_local))),
),
],
Terminator::Return(Some(Operand::Copy(Place::Local(result_local)))),
);
func.blocks.push(entry_block);
let original_stmt_count = func.blocks[0].statements.len();
optimize_function(&mut func);
assert!(
func.blocks[0].statements.len() <= original_stmt_count,
"Optimization should reduce or maintain statement count"
);
}
#[test]
fn test_optimize_function_handles_empty_function() {
let mut func = create_test_function("test", vec![], Type::Unit);
let entry_block = create_basic_block(BlockId(0), vec![], Terminator::Return(None));
func.blocks.push(entry_block);
optimize_function(&mut func);
assert_eq!(
func.blocks.len(),
1,
"Empty function should still have entry block"
);
}
#[test]
fn test_optimize_program_handles_multiple_functions() {
let mut program = Program {
functions: HashMap::new(),
entry: "main".to_string(),
};
let func1 = create_test_function("func1", vec![], Type::Unit);
let func2 = create_test_function("func2", vec![], Type::I32);
program.functions.insert("func1".to_string(), func1);
program.functions.insert("func2".to_string(), func2);
optimize_program(&mut program);
assert_eq!(
program.functions.len(),
2,
"All functions should be optimized"
);
}
}
#[cfg(test)]
mod property_tests_optimize {
use super::*;
use crate::middleend::mir::{BasicBlock, LocalDecl, Type};
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(10000))]
#[test]
fn prop_dce_preserves_entry_block(num_blocks in 1usize..10) {
let mut func = Function {
name: "test".to_string(),
params: vec![],
return_ty: Type::Unit,
locals: vec![],
blocks: vec![],
entry_block: BlockId(0),
};
for i in 0..num_blocks {
func.blocks.push(BasicBlock {
id: BlockId(i),
statements: vec![],
terminator: if i == num_blocks - 1 {
Terminator::Return(None)
} else {
Terminator::Goto(BlockId(i + 1))
},
});
}
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
prop_assert!(func.blocks.iter().any(|b| b.id == BlockId(0)));
}
#[test]
fn prop_dce_preserves_parameters(num_params in 0usize..5) {
let params: Vec<Local> = (0..num_params).map(Local).collect();
let mut func = Function {
name: "test".to_string(),
params: params.clone(),
return_ty: Type::Unit,
locals: params.iter().map(|&id| LocalDecl {
id,
ty: Type::I32,
mutable: false,
name: None,
}).collect(),
blocks: vec![BasicBlock {
id: BlockId(0),
statements: vec![],
terminator: Terminator::Return(None),
}],
entry_block: BlockId(0),
};
let mut dce = DeadCodeElimination::new();
dce.run(&mut func);
for param in ¶ms {
prop_assert!(func.locals.iter().any(|l| l.id == *param));
}
}
#[test]
fn prop_optimization_never_panics(num_locals in 0usize..5, num_blocks in 1usize..5) {
let mut func = Function {
name: "test".to_string(),
params: vec![],
return_ty: Type::Unit,
locals: (0..num_locals).map(|i| LocalDecl {
id: Local(i),
ty: Type::I32,
mutable: false,
name: None,
}).collect(),
blocks: (0..num_blocks).map(|i| BasicBlock {
id: BlockId(i),
statements: vec![],
terminator: if i == num_blocks - 1 {
Terminator::Return(None)
} else {
Terminator::Goto(BlockId(i + 1))
},
}).collect(),
entry_block: BlockId(0),
};
optimize_function(&mut func);
}
#[test]
fn prop_dce_idempotent(num_blocks in 1usize..10) {
let mut func1 = Function {
name: "test".to_string(),
params: vec![],
return_ty: Type::Unit,
locals: vec![],
blocks: (0..num_blocks).map(|i| BasicBlock {
id: BlockId(i),
statements: vec![],
terminator: if i == num_blocks - 1 {
Terminator::Return(None)
} else {
Terminator::Goto(BlockId(i + 1))
},
}).collect(),
entry_block: BlockId(0),
};
let mut func2 = func1.clone();
let mut dce = DeadCodeElimination::new();
dce.run(&mut func1);
let mut dce2 = DeadCodeElimination::new();
dce2.run(&mut func2);
dce2.run(&mut func2);
prop_assert_eq!(func1.blocks.len(), func2.blocks.len());
prop_assert_eq!(func1.locals.len(), func2.locals.len());
}
#[test]
fn prop_cse_preserves_blocks(num_blocks in 1usize..10) {
let mut func = Function {
name: "test".to_string(),
params: vec![],
return_ty: Type::Unit,
locals: vec![],
blocks: (0..num_blocks).map(|i| BasicBlock {
id: BlockId(i),
statements: vec![],
terminator: if i == num_blocks - 1 {
Terminator::Return(None)
} else {
Terminator::Goto(BlockId(i + 1))
},
}).collect(),
entry_block: BlockId(0),
};
let original_block_count = func.blocks.len();
let mut cse = CommonSubexpressionElimination::new();
cse.run(&mut func);
prop_assert_eq!(func.blocks.len(), original_block_count);
}
#[test]
fn prop_const_prop_no_new_locals(num_locals in 0usize..10) {
let mut func = Function {
name: "test".to_string(),
params: vec![],
return_ty: Type::Unit,
locals: (0..num_locals).map(|i| LocalDecl {
id: Local(i),
ty: Type::I32,
mutable: false,
name: None,
}).collect(),
blocks: vec![BasicBlock {
id: BlockId(0),
statements: vec![],
terminator: Terminator::Return(None),
}],
entry_block: BlockId(0),
};
let original_local_count = func.locals.len();
let mut const_prop = ConstantPropagation::new();
const_prop.run(&mut func);
prop_assert_eq!(func.locals.len(), original_local_count);
}
#[test]
fn prop_const_fold_add_correct(a in -100i128..100, b in -100i128..100) {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_binary_op(
BinOp::Add,
&Operand::Constant(Constant::Int(a, Type::I64)),
&Operand::Constant(Constant::Int(b, Type::I64)),
);
if let Some(Constant::Int(val, _)) = result {
prop_assert_eq!(val, a + b, "Addition should be correct");
}
}
#[test]
fn prop_const_fold_sub_correct(a in -100i128..100, b in -100i128..100) {
let const_prop = ConstantPropagation::new();
let result = const_prop.eval_binary_op(
BinOp::Sub,
&Operand::Constant(Constant::Int(a, Type::I64)),
&Operand::Constant(Constant::Int(b, Type::I64)),
);
if let Some(Constant::Int(val, _)) = result {
prop_assert_eq!(val, a - b, "Subtraction should be correct");
}
}
}
}