use std::{collections::HashMap, sync::Arc};
use crate::{
analysis::{PhiAnalyzer, SsaFunction, SsaType, SsaVarId, VariableOrigin},
compiler::{pass::SsaPass, passes::utils::resolve_chain, CompilerContext, EventKind, EventLog},
metadata::token::Token,
CilObject, Result,
};
const MAX_ITERATIONS: usize = 100;
pub struct CopyPropagationPass;
impl Default for CopyPropagationPass {
fn default() -> Self {
Self::new()
}
}
impl CopyPropagationPass {
#[must_use]
pub fn new() -> Self {
Self
}
fn resolve_chains(copies: &HashMap<SsaVarId, SsaVarId>) -> HashMap<SsaVarId, SsaVarId> {
copies
.iter()
.map(|(&dest, &src)| (dest, resolve_chain(copies, src)))
.collect()
}
fn run_iteration(
ssa: &mut SsaFunction,
method_token: Token,
changes: &mut EventLog,
assembly: &CilObject,
) -> usize {
let copies = PhiAnalyzer::new(ssa).collect_all_copies();
if copies.is_empty() {
return 0;
}
let resolved = Self::resolve_chains(&copies);
Self::propagate_local_types(ssa, &resolved, assembly);
let mut total_replaced = 0;
for (dest, src) in &resolved {
if dest == src {
continue;
}
let replaced = ssa.replace_uses(*dest, *src);
if replaced > 0 {
changes
.record(EventKind::CopyPropagated)
.method(method_token)
.message(format!("{dest} → {src} ({replaced} uses)"));
total_replaced += replaced;
}
}
total_replaced
}
fn propagate_local_types(
ssa: &mut SsaFunction,
resolved: &HashMap<SsaVarId, SsaVarId>,
assembly: &CilObject,
) {
let original_types = match ssa.original_local_types() {
Some(types) => types.to_vec(),
None => return,
};
let mut type_assignments: Vec<(SsaVarId, SsaType)> = Vec::new();
for (dest, src) in resolved {
if dest == src {
continue;
}
let Some(dest_var) = ssa.variable(*dest) else {
continue;
};
let VariableOrigin::Local(local_idx) = dest_var.origin() else {
continue;
};
let local_type = match original_types.get(local_idx as usize) {
Some(sig) => &sig.base,
None => continue,
};
let ssa_type = SsaType::from_type_signature(local_type, assembly);
if ssa_type.is_unknown() || matches!(ssa_type, SsaType::I32) {
continue;
}
let should_propagate = match ssa.variable(*src) {
Some(src_var) => src_var.var_type().is_unknown(),
None => false,
};
if should_propagate {
type_assignments.push((*src, ssa_type));
}
}
for (var_id, ssa_type) in type_assignments {
if let Some(var) = ssa.variable_mut(var_id) {
var.set_type(ssa_type);
}
}
}
}
impl SsaPass for CopyPropagationPass {
fn name(&self) -> &'static str {
"copy-propagation"
}
fn description(&self) -> &'static str {
"Propagates copy operations, replacing uses with original sources"
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
assembly: &Arc<CilObject>,
) -> Result<bool> {
let mut changes = EventLog::new();
for _ in 0..MAX_ITERATIONS {
let replaced = Self::run_iteration(ssa, method_token, &mut changes, assembly);
if replaced == 0 {
break;
}
}
let changed = !changes.is_empty();
if changed {
ctx.events.merge(&changes);
}
Ok(changed)
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use crate::{
analysis::{
CallGraph, ConstValue, DefSite, PhiAnalyzer, PhiNode, PhiOperand, SsaBlock,
SsaFunction, SsaFunctionBuilder, SsaInstruction, SsaOp, SsaVarId, SsaVariable,
VariableOrigin,
},
compiler::CompilerContext,
compiler::{CopyPropagationPass, SsaPass},
metadata::token::Token,
test::helpers::test_assembly_arc,
};
fn test_context() -> CompilerContext {
let call_graph = Arc::new(CallGraph::new());
CompilerContext::new(call_graph)
}
#[test]
fn test_collect_empty_function() {
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.ret());
});
let copies = PhiAnalyzer::new(&ssa).collect_all_copies();
assert!(copies.is_empty());
}
#[test]
fn test_collect_single_copy() {
let (ssa, v0, v1) = {
let mut v0_out = SsaVarId::new();
let mut v1_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
let v1 = b.copy(v0);
v0_out = v0;
v1_out = v1;
b.ret();
});
});
(ssa, v0_out, v1_out)
};
let copies = PhiAnalyzer::new(&ssa).collect_all_copies();
assert_eq!(copies.len(), 1);
assert_eq!(copies.get(&v1), Some(&v0));
}
#[test]
fn test_collect_multiple_copies() {
let (ssa, v0, v1, v2) = {
let mut v0_out = SsaVarId::new();
let mut v1_out = SsaVarId::new();
let mut v2_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
let v1 = b.copy(v0);
let v2 = b.copy(v1);
v0_out = v0;
v1_out = v1;
v2_out = v2;
b.ret();
});
});
(ssa, v0_out, v1_out, v2_out)
};
let copies = PhiAnalyzer::new(&ssa).collect_all_copies();
assert_eq!(copies.len(), 2);
assert_eq!(copies.get(&v1), Some(&v0));
assert_eq!(copies.get(&v2), Some(&v1));
}
#[test]
fn test_collect_trivial_phi_all_same() {
let (ssa, v0, v_phi) = {
let mut v0_out = SsaVarId::new();
let mut v_phi_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
let cond = b.const_true();
v0_out = v0;
b.branch(cond, 1, 2);
});
f.block(1, |b| b.jump(3));
f.block(2, |b| b.jump(3));
f.block(3, |b| {
let phi_result = b.phi(&[(1, v0_out), (2, v0_out)]);
v_phi_out = phi_result;
b.ret_val(phi_result);
});
});
(ssa, v0_out, v_phi_out)
};
let copies = PhiAnalyzer::new(&ssa).collect_all_copies();
assert_eq!(copies.len(), 1);
assert_eq!(copies.get(&v_phi), Some(&v0));
}
#[test]
fn test_collect_trivial_phi_with_self_reference() {
let mut ssa = SsaFunction::new(0, 0);
let v0_var = SsaVariable::new(VariableOrigin::Stack(0), 0, DefSite::instruction(0, 0));
let v0 = v0_var.id();
ssa.add_variable(v0_var);
let phi_variable = SsaVariable::new(VariableOrigin::Stack(1), 0, DefSite::phi(1));
let phi_var = phi_variable.id();
ssa.add_variable(phi_variable);
let mut block0 = SsaBlock::new(0);
block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const {
dest: v0,
value: ConstValue::I32(42),
}));
block0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 }));
ssa.add_block(block0);
let mut block1 = SsaBlock::new(1);
let mut phi = PhiNode::new(phi_var, VariableOrigin::Stack(1));
phi.add_operand(PhiOperand::new(v0, 0)); phi.add_operand(PhiOperand::new(phi_var, 1)); block1.add_phi(phi);
block1.add_instruction(SsaInstruction::synthetic(SsaOp::Return {
value: Some(phi_var),
}));
ssa.add_block(block1);
let copies = PhiAnalyzer::new(&ssa).collect_all_copies();
assert_eq!(copies.len(), 1);
assert_eq!(copies.get(&phi_var), Some(&v0));
}
#[test]
fn test_collect_non_trivial_phi() {
let ssa = {
let mut v0_out = SsaVarId::new();
let mut v1_out = SsaVarId::new();
SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let cond = b.const_true();
b.branch(cond, 1, 2);
});
f.block(1, |b| {
let v0 = b.const_i32(10);
v0_out = v0;
b.jump(3);
});
f.block(2, |b| {
let v1 = b.const_i32(20);
v1_out = v1;
b.jump(3);
});
f.block(3, |b| {
let phi_result = b.phi(&[(1, v0_out), (2, v1_out)]);
b.ret_val(phi_result);
});
})
};
let copies = PhiAnalyzer::new(&ssa).collect_all_copies();
assert!(copies.is_empty());
}
#[test]
fn test_resolve_simple_chain() {
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
let v2 = SsaVarId::new();
let mut copies = HashMap::new();
copies.insert(v2, v1);
copies.insert(v1, v0);
let resolved = CopyPropagationPass::resolve_chains(&copies);
assert_eq!(resolved.get(&v1), Some(&v0));
assert_eq!(resolved.get(&v2), Some(&v0));
}
#[test]
fn test_resolve_long_chain() {
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
let v2 = SsaVarId::new();
let v3 = SsaVarId::new();
let v4 = SsaVarId::new();
let mut copies = HashMap::new();
copies.insert(v4, v3);
copies.insert(v3, v2);
copies.insert(v2, v1);
copies.insert(v1, v0);
let resolved = CopyPropagationPass::resolve_chains(&copies);
assert_eq!(resolved.get(&v1), Some(&v0));
assert_eq!(resolved.get(&v2), Some(&v0));
assert_eq!(resolved.get(&v3), Some(&v0));
assert_eq!(resolved.get(&v4), Some(&v0));
}
#[test]
fn test_resolve_cycle() {
let v1 = SsaVarId::new();
let v2 = SsaVarId::new();
let mut copies = HashMap::new();
copies.insert(v1, v2);
copies.insert(v2, v1);
let resolved = CopyPropagationPass::resolve_chains(&copies);
assert!(resolved.contains_key(&v1));
assert!(resolved.contains_key(&v2));
}
#[test]
fn test_resolve_multiple_independent_chains() {
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
let v2 = SsaVarId::new();
let v3 = SsaVarId::new();
let v4 = SsaVarId::new();
let v5 = SsaVarId::new();
let mut copies = HashMap::new();
copies.insert(v2, v1);
copies.insert(v1, v0);
copies.insert(v5, v4);
copies.insert(v4, v3);
let resolved = CopyPropagationPass::resolve_chains(&copies);
assert_eq!(resolved.get(&v1), Some(&v0));
assert_eq!(resolved.get(&v2), Some(&v0));
assert_eq!(resolved.get(&v4), Some(&v3));
assert_eq!(resolved.get(&v5), Some(&v3));
}
#[test]
fn test_trivial_phi_single_operand() {
let ssa = SsaFunction::new(0, 0);
let analyzer = PhiAnalyzer::new(&ssa);
let result = SsaVarId::new();
let v0 = SsaVarId::new();
let mut phi = PhiNode::new(result, VariableOrigin::Local(0));
phi.add_operand(PhiOperand::new(v0, 0));
let source = analyzer.is_trivial(&phi);
assert_eq!(source, Some(v0));
}
#[test]
fn test_trivial_phi_all_same_operands() {
let ssa = SsaFunction::new(0, 0);
let analyzer = PhiAnalyzer::new(&ssa);
let result = SsaVarId::new();
let v0 = SsaVarId::new();
let mut phi = PhiNode::new(result, VariableOrigin::Local(0));
phi.add_operand(PhiOperand::new(v0, 0));
phi.add_operand(PhiOperand::new(v0, 1));
phi.add_operand(PhiOperand::new(v0, 2));
let source = analyzer.is_trivial(&phi);
assert_eq!(source, Some(v0));
}
#[test]
fn test_trivial_phi_with_self_references() {
let ssa = SsaFunction::new(0, 0);
let analyzer = PhiAnalyzer::new(&ssa);
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
let mut phi = PhiNode::new(v1, VariableOrigin::Local(0));
phi.add_operand(PhiOperand::new(v0, 0)); phi.add_operand(PhiOperand::new(v1, 1)); phi.add_operand(PhiOperand::new(v1, 2));
let source = analyzer.is_trivial(&phi);
assert_eq!(source, Some(v0));
}
#[test]
fn test_non_trivial_phi_different_operands() {
let ssa = SsaFunction::new(0, 0);
let analyzer = PhiAnalyzer::new(&ssa);
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
let v5 = SsaVarId::new();
let mut phi = PhiNode::new(v5, VariableOrigin::Local(0));
phi.add_operand(PhiOperand::new(v0, 0));
phi.add_operand(PhiOperand::new(v1, 1));
let source = analyzer.is_trivial(&phi);
assert_eq!(source, None);
}
#[test]
fn test_trivial_phi_all_self_references() {
let ssa = SsaFunction::new(0, 0);
let analyzer = PhiAnalyzer::new(&ssa);
let v1 = SsaVarId::new();
let mut phi = PhiNode::new(v1, VariableOrigin::Local(0));
phi.add_operand(PhiOperand::new(v1, 0)); phi.add_operand(PhiOperand::new(v1, 1));
let source = analyzer.is_trivial(&phi);
assert_eq!(source, None);
}
#[test]
fn test_propagate_single_copy() {
let (mut ssa, v0, _v1) = {
let mut v0_out = SsaVarId::new();
let mut v1_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
let v1 = b.copy(v0);
let _v2 = b.add(v1, v1);
v0_out = v0;
v1_out = v1;
b.ret_val(v1);
});
});
(ssa, v0_out, v1_out)
};
let pass = CopyPropagationPass::new();
let ctx = test_context();
let changed = pass
.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc())
.unwrap();
assert!(changed);
let block = ssa.block(0).unwrap();
let add_instr = &block.instructions()[2];
if let SsaOp::Add { left, right, .. } = add_instr.op() {
assert_eq!(*left, v0);
assert_eq!(*right, v0);
} else {
panic!("Expected Add instruction");
}
let ret_instr = &block.instructions()[3];
if let SsaOp::Return { value } = ret_instr.op() {
assert_eq!(*value, Some(v0));
} else {
panic!("Expected Return instruction");
}
}
#[test]
fn test_propagate_copy_chain() {
let (mut ssa, v0) = {
let mut v0_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
let v1 = b.copy(v0);
let v2 = b.copy(v1);
let v3 = b.copy(v2);
v0_out = v0;
b.ret_val(v3);
});
});
(ssa, v0_out)
};
let pass = CopyPropagationPass::new();
let ctx = test_context();
let _changes = pass
.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc())
.unwrap();
let block = ssa.block(0).unwrap();
if let Some(SsaOp::Return { value }) = block.terminator_op() {
assert_eq!(*value, Some(v0));
} else {
panic!("Expected Return instruction");
}
}
#[test]
fn test_propagate_trivial_phi() {
let (mut ssa, v0) = {
let mut v0_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
v0_out = b.const_i32(42);
let cond = b.const_true();
b.branch(cond, 1, 2);
});
f.block(1, |b| b.jump(3));
f.block(2, |b| b.jump(3));
f.block(3, |b| {
let phi_result = b.phi(&[(1, v0_out), (2, v0_out)]);
let _ = b.add(phi_result, phi_result);
b.ret_val(phi_result);
});
});
(ssa, v0_out)
};
let pass = CopyPropagationPass::new();
let ctx = test_context();
let changed = pass
.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc())
.unwrap();
assert!(changed);
let block3 = ssa.block(3).unwrap();
let add_instr = &block3.instructions()[0];
if let SsaOp::Add { left, right, .. } = add_instr.op() {
assert_eq!(*left, v0);
assert_eq!(*right, v0);
}
}
#[test]
fn test_no_propagation_needed() {
let mut ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
b.ret_val(v0); });
});
let pass = CopyPropagationPass::new();
let ctx = test_context();
let changed = pass
.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc())
.unwrap();
assert!(!changed);
}
#[test]
fn test_iterative_convergence() {
let (mut ssa, v0) = {
let mut v0_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
v0_out = v0;
let v1 = b.copy(v0);
let v2 = b.copy(v1);
let v3 = b.copy(v2);
let v10 = b.add(v1, v2);
let v11 = b.add(v10, v3);
b.ret_val(v11);
});
});
(ssa, v0_out)
};
let pass = CopyPropagationPass::new();
let ctx = test_context();
let result =
pass.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc());
assert!(result.is_ok());
let block = ssa.block(0).unwrap();
let add1 = &block.instructions()[4];
if let SsaOp::Add { left, right, .. } = add1.op() {
assert_eq!(*left, v0);
assert_eq!(*right, v0);
}
let add2 = &block.instructions()[5];
if let SsaOp::Add { right, .. } = add2.op() {
assert_eq!(*right, v0);
}
}
#[test]
fn test_copy_not_propagated_to_definition() {
let (mut ssa, v0, v1) = {
let mut v0_out = SsaVarId::new();
let mut v1_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
v0_out = v0;
let v1 = b.copy(v0);
v1_out = v1;
b.ret_val(v1);
});
});
(ssa, v0_out, v1_out)
};
let pass = CopyPropagationPass::new();
let ctx = test_context();
let _changes = pass
.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc())
.unwrap();
let block = ssa.block(0).unwrap();
let copy_instr = &block.instructions()[1];
if let SsaOp::Copy { dest, src } = copy_instr.op() {
assert_eq!(*dest, v1);
assert_eq!(*src, v0);
}
}
#[test]
fn test_phi_operands_preserved() {
let (mut ssa, _v0, v1, v2) = {
let mut v0_out = SsaVarId::new();
let mut v1_out = SsaVarId::new();
let mut v2_out = SsaVarId::new();
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let v0 = b.const_i32(42);
v0_out = v0;
let v1 = b.copy(v0);
v1_out = v1;
let cond = b.const_true();
b.branch(cond, 1, 2);
});
f.block(1, |b| {
v2_out = b.const_i32(100);
b.jump(3);
});
f.block(2, |b| b.jump(3));
f.block(3, |b| {
let phi_result = b.phi(&[(1, v2_out), (2, v1_out)]);
b.ret_val(phi_result);
});
});
(ssa, v0_out, v1_out, v2_out)
};
let pass = CopyPropagationPass::new();
let ctx = test_context();
let _changes = pass
.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc())
.unwrap();
let block3 = ssa.block(3).unwrap();
let phi = &block3.phi_nodes()[0];
let operand_values: Vec<_> = phi.operands().iter().map(|op| op.value()).collect();
assert!(operand_values.contains(&v2));
assert!(operand_values.contains(&v1)); }
}