use super::{Transform, TransformCategory, TransformLevel};
use crate::mir::{Block, Function, Immediate, Instruction, Operand, Register};
use std::collections::HashMap;
#[derive(Default)]
pub struct CopyPropagation;
impl Transform for CopyPropagation {
fn name(&self) -> &'static str {
"copy_propagation"
}
fn description(&self) -> &'static str {
"Replaces variable uses with their source values"
}
fn category(&self) -> TransformCategory {
TransformCategory::CopyPropagation
}
fn level(&self) -> TransformLevel {
TransformLevel::Stable
}
fn apply(&self, func: &mut Function) -> Result<bool, String> {
self.apply_internal(func)
}
}
impl CopyPropagation {
fn apply_internal(&self, func: &mut Function) -> Result<bool, String> {
const MAX_BLOCKS: usize = 500;
const MAX_INSTRUCTIONS_PER_BLOCK: usize = 1_000;
if func.blocks.len() > MAX_BLOCKS {
return Err(format!(
"Function too large for copy propagation ({} blocks, max {})",
func.blocks.len(),
MAX_BLOCKS
));
}
for block in &func.blocks {
if block.instructions.len() > MAX_INSTRUCTIONS_PER_BLOCK {
return Err(format!(
"Block '{}' too large for copy propagation ({} instructions, max {})",
block.label,
block.instructions.len(),
MAX_INSTRUCTIONS_PER_BLOCK
));
}
}
let mut changed = false;
for block in &mut func.blocks {
if self.propagate_copies_in_block(block) {
changed = true;
}
}
Ok(changed)
}
fn propagate_copies_in_block(&self, block: &mut Block) -> bool {
let mut changed = false;
let mut value_map = HashMap::new();
let mut propagations_this_block = 0;
let max_propagations_per_block = 50;
let mut new_instructions = Vec::new();
for instr in &block.instructions {
let mut new_instr = instr.clone();
if propagations_this_block < max_propagations_per_block
&& self.propagate_copies(&mut new_instr, &value_map)
{
changed = true;
propagations_this_block += 1;
}
if let Some(def_reg) = new_instr.def_reg() {
value_map.remove(def_reg);
value_map.retain(|_, v| !matches!(v, Operand::Register(r) if r == def_reg));
}
if let Instruction::IntBinary {
op: crate::mir::IntBinOp::Add,
dst,
lhs,
rhs,
..
} = &new_instr
&& let Operand::Immediate(Immediate::I64(0)) = rhs
&& let Operand::Register(src_reg) = lhs
{
value_map.insert(dst.clone(), Operand::Register(src_reg.clone()));
}
match &new_instr {
Instruction::IntBinary {
op: crate::mir::IntBinOp::Sub,
dst,
lhs,
rhs,
..
} => {
if let Operand::Immediate(Immediate::I64(0)) = rhs
&& let Operand::Register(src_reg) = lhs
{
value_map.insert(dst.clone(), Operand::Register(src_reg.clone()));
}
}
Instruction::IntBinary {
op: crate::mir::IntBinOp::Mul,
dst,
lhs,
rhs,
..
} => {
if let Operand::Immediate(Immediate::I64(1)) = rhs
&& let Operand::Register(src_reg) = lhs
{
value_map.insert(dst.clone(), Operand::Register(src_reg.clone()));
}
if let Operand::Immediate(Immediate::I64(1)) = lhs
&& let Operand::Register(src_reg) = rhs
{
value_map.insert(dst.clone(), Operand::Register(src_reg.clone()));
}
}
Instruction::IntBinary {
op: crate::mir::IntBinOp::Or,
dst,
lhs,
rhs,
..
} => {
if let Operand::Immediate(Immediate::I64(0)) = rhs
&& let Operand::Register(src_reg) = lhs
{
value_map.insert(dst.clone(), Operand::Register(src_reg.clone()));
}
if let Operand::Immediate(Immediate::I64(0)) = lhs
&& let Operand::Register(src_reg) = rhs
{
value_map.insert(dst.clone(), Operand::Register(src_reg.clone()));
}
}
Instruction::IntBinary {
op: crate::mir::IntBinOp::And,
dst,
lhs,
rhs,
..
} => {
if let Operand::Immediate(Immediate::I64(-1)) = rhs
&& let Operand::Register(src_reg) = lhs
{
value_map.insert(dst.clone(), Operand::Register(src_reg.clone()));
}
}
Instruction::IntBinary {
op: crate::mir::IntBinOp::Xor,
dst,
lhs,
rhs,
..
} => {
if let Operand::Immediate(Immediate::I64(0)) = rhs
&& let Operand::Register(src_reg) = lhs
{
value_map.insert(dst.clone(), Operand::Register(src_reg.clone()));
}
}
_ => {}
}
new_instructions.push(new_instr);
}
block.instructions = new_instructions;
changed
}
fn propagate_copies(
&self,
instr: &mut Instruction,
value_map: &HashMap<Register, Operand>,
) -> bool {
let mut changed = false;
match instr {
Instruction::IntBinary { lhs, rhs, .. }
| Instruction::FloatBinary { lhs, rhs, .. }
| Instruction::IntCmp { lhs, rhs, .. }
| Instruction::FloatCmp { lhs, rhs, .. } => {
changed |= self.replace_operand(lhs, value_map);
changed |= self.replace_operand(rhs, value_map);
}
Instruction::FloatUnary { src, .. } => {
changed |= self.replace_operand(src, value_map);
}
Instruction::Select {
cond: _,
true_val,
false_val,
..
} => {
changed |= self.replace_operand(true_val, value_map);
changed |= self.replace_operand(false_val, value_map);
}
Instruction::Load { addr, .. } => {
if let crate::mir::AddressMode::BaseOffset { base, offset: _ } = addr
&& let Some(new_base) = value_map.get(base)
&& let Operand::Register(new_reg) = new_base
{
*base = new_reg.clone();
changed = true;
}
}
Instruction::Store { src, addr, .. } => {
changed |= self.replace_operand(src, value_map);
if let crate::mir::AddressMode::BaseOffset { base, offset: _ } = addr
&& let Some(new_base) = value_map.get(base)
&& let Operand::Register(new_reg) = new_base
{
*base = new_reg.clone();
changed = true;
}
}
Instruction::Call { args, .. } => {
for arg in args {
changed |= self.replace_operand(arg, value_map);
}
}
Instruction::Br { cond, .. } => {
if let Some(new_cond) = value_map.get(cond)
&& let Operand::Register(new_reg) = new_cond
{
*cond = new_reg.clone();
changed = true;
}
}
Instruction::Switch { value, .. } => {
if let Some(new_value) = value_map.get(value)
&& let Operand::Register(new_reg) = new_value
{
*value = new_reg.clone();
changed = true;
}
}
Instruction::Ret { value: Some(val) } => {
changed |= self.replace_operand(val, value_map);
}
_ => {}
}
changed
}
fn replace_operand(
&self,
operand: &mut Operand,
value_map: &HashMap<Register, Operand>,
) -> bool {
if let Operand::Register(reg) = operand
&& let Some(replacement) = value_map.get(reg)
{
*operand = replacement.clone();
return true;
}
false
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
use crate::mir::{
FunctionBuilder, Immediate, IntBinOp, MirType, Operand, ScalarType, VirtualReg,
};
#[test]
fn test_copy_propagation_basic() {
let func = FunctionBuilder::new("test")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(2).into(),
lhs: Operand::Register(VirtualReg::gpr(1).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(2).into())),
})
.build();
let mut func = func;
let cp = CopyPropagation;
let changed = cp
.apply(&mut func)
.expect("Copy propagation should succeed");
assert!(changed);
let entry = func.get_block("entry").expect("entry block exists");
assert_eq!(entry.instructions.len(), 3);
match &entry.instructions[1] {
Instruction::IntBinary { dst, lhs, rhs, .. } => {
assert_eq!(dst, &VirtualReg::gpr(2).into());
assert_eq!(lhs, &Operand::Register(VirtualReg::gpr(0).into()));
assert_eq!(rhs, &Operand::Immediate(Immediate::I64(0)));
}
_ => panic!("Expected IntBinary instruction"),
}
}
#[test]
fn test_copy_propagation_register_redefinition() {
let func = FunctionBuilder::new("test")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(3).into()),
rhs: Operand::Register(VirtualReg::gpr(4).into()),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(2).into(),
lhs: Operand::Register(VirtualReg::gpr(1).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(2).into())),
})
.build();
let mut func = func;
let cp = CopyPropagation;
let changed = cp
.apply(&mut func)
.expect("Copy propagation should succeed");
assert!(changed);
let entry = func.get_block("entry").expect("entry block exists");
assert_eq!(entry.instructions.len(), 4);
match &entry.instructions[2] {
Instruction::IntBinary { dst, lhs, rhs, .. } => {
assert_eq!(dst, &VirtualReg::gpr(2).into());
assert_eq!(lhs, &Operand::Register(VirtualReg::gpr(1).into())); assert_eq!(rhs, &Operand::Immediate(Immediate::I64(0)));
}
_ => panic!("Expected IntBinary instruction"),
}
}
#[test]
fn test_copy_propagation_empty_function() {
let mut func = FunctionBuilder::new("empty")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::Ret { value: None })
.build();
let cp = CopyPropagation;
let result = cp.apply(&mut func);
assert!(result.is_ok());
assert!(!result.unwrap());
}
#[test]
fn test_copy_propagation_no_infinite_loop() {
let mut func = FunctionBuilder::new("circular")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(2).into(),
lhs: Operand::Register(VirtualReg::gpr(1).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Register(VirtualReg::gpr(2).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(0).into())),
})
.build();
let cp = CopyPropagation;
for _ in 0..10 {
let _ = cp.apply(&mut func);
}
}
#[test]
fn test_copy_propagation_mul_by_one() {
let mut func = FunctionBuilder::new("mul_one")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Mul,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(1)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(2).into(),
lhs: Operand::Register(VirtualReg::gpr(1).into()),
rhs: Operand::Immediate(Immediate::I64(5)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(2).into())),
})
.build();
let cp = CopyPropagation;
let changed = cp.apply(&mut func).expect("should succeed");
assert!(changed);
let entry = func.get_block("entry").unwrap();
match &entry.instructions[1] {
Instruction::IntBinary { lhs, .. } => {
assert_eq!(lhs, &Operand::Register(VirtualReg::gpr(0).into()));
}
_ => panic!("Expected IntBinary"),
}
}
#[test]
fn test_copy_propagation_loop_counter_pattern() {
let mut func = FunctionBuilder::new("loop_counter")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Immediate(Immediate::I64(0)),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::Jmp {
target: "loop".to_string(),
})
.block("loop")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(1)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(0).into())),
})
.build();
let cp = CopyPropagation;
let result = cp.apply(&mut func);
assert!(result.is_ok());
}
#[test]
fn test_copy_propagation_block_size_pattern() {
let mut func = FunctionBuilder::new("block_size")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Immediate(Immediate::I64(8)),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(), lhs: Operand::Register(VirtualReg::gpr(2).into()),
rhs: Operand::Register(VirtualReg::gpr(0).into()),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(1).into())),
})
.build();
let cp = CopyPropagation;
let result = cp.apply(&mut func);
assert!(result.is_ok());
}
#[test]
fn test_copy_propagation_chained_adds_zero() {
let mut func = FunctionBuilder::new("chained_zero")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(), lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(2).into(), lhs: Operand::Register(VirtualReg::gpr(1).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(3).into(), lhs: Operand::Register(VirtualReg::gpr(2).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(4).into(), lhs: Operand::Register(VirtualReg::gpr(3).into()),
rhs: Operand::Immediate(Immediate::I64(100)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(4).into())),
})
.build();
let cp = CopyPropagation;
let changed = cp.apply(&mut func).expect("should succeed");
assert!(changed);
let entry = func.get_block("entry").unwrap();
match &entry.instructions[3] {
Instruction::IntBinary { lhs, .. } => {
assert_eq!(lhs, &Operand::Register(VirtualReg::gpr(0).into()));
}
_ => panic!("Expected IntBinary"),
}
}
#[test]
fn test_copy_propagation_same_reg_redefined_in_block() {
let mut func = FunctionBuilder::new("redefine_same")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Immediate(Immediate::I64(0)),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Mul,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Register(VirtualReg::gpr(1).into()),
rhs: Operand::Register(VirtualReg::gpr(2).into()),
})
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(1)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(0).into())),
})
.build();
let cp = CopyPropagation;
let result = cp.apply(&mut func);
assert!(result.is_ok());
}
#[test]
fn test_copy_propagation_unroll_pattern() {
let mut func = FunctionBuilder::new("unroll_pattern")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Register(VirtualReg::gpr(10).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.build();
let entry = func.blocks.get_mut(0).unwrap();
for i in 1..=16 {
entry.push(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(i).into(), lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(i as i64)),
});
}
let mut sum_reg = VirtualReg::gpr(1);
for i in 2..=16 {
entry.push(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(50 + i).into(),
lhs: Operand::Register(sum_reg.into()),
rhs: Operand::Register(VirtualReg::gpr(i).into()),
});
sum_reg = VirtualReg::gpr(50 + i);
}
entry.push(Instruction::Ret {
value: Some(Operand::Register(sum_reg.into())),
});
let cp = CopyPropagation;
let result = cp.apply(&mut func);
assert!(result.is_ok());
let changed = result.unwrap();
assert!(changed); }
#[test]
fn test_copy_propagation_accumulator_pattern() {
let mut func = FunctionBuilder::new("accumulator")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Immediate(Immediate::I64(0)),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.build();
let entry = func.blocks.get_mut(0).unwrap();
for i in 1..=8 {
entry.push(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(0).into(), lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Register(VirtualReg::gpr(100 + i).into()),
});
}
entry.push(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(0).into())),
});
let cp = CopyPropagation;
let result = cp.apply(&mut func);
assert!(result.is_ok());
}
#[test]
fn test_copy_propagation_many_blocks_with_copies() {
let mut func = FunctionBuilder::new("multi_block_copies")
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(1).into(),
lhs: Operand::Register(VirtualReg::gpr(0).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::Jmp {
target: "block1".to_string(),
})
.block("block1")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(2).into(),
lhs: Operand::Register(VirtualReg::gpr(1).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::Jmp {
target: "block2".to_string(),
})
.block("block2")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: VirtualReg::gpr(3).into(),
lhs: Operand::Register(VirtualReg::gpr(2).into()),
rhs: Operand::Immediate(Immediate::I64(0)),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(VirtualReg::gpr(3).into())),
})
.build();
let cp = CopyPropagation;
let result = cp.apply(&mut func);
assert!(result.is_ok());
}
}