use std::{collections::HashMap, sync::Arc};
use crate::{
analysis::{ConstValue, DefUseIndex, SsaFunction, SsaOp, SsaVarId},
compiler::{pass::SsaPass, CompilerContext, EventKind, EventLog},
metadata::{token::Token, typesystem::PointerSize},
CilObject, Result,
};
pub struct ReassociationPass;
impl Default for ReassociationPass {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct ReassociationCandidate {
block_idx: usize,
instr_idx: usize,
dest: SsaVarId,
base_var: SsaVarId,
const1_var: SsaVarId,
const2_var: SsaVarId,
const1_value: ConstValue,
const2_value: ConstValue,
inner_block: usize,
inner_instr: usize,
inner_dest: SsaVarId,
op_kind: OpKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum OpKind {
Add,
Sub,
Mul,
And,
Or,
Xor,
Shl,
Shr { unsigned: bool },
}
impl OpKind {
fn combine(
self,
c1: &ConstValue,
c2: &ConstValue,
ptr_size: PointerSize,
) -> Option<ConstValue> {
match self {
OpKind::Add | OpKind::Sub | OpKind::Shl | OpKind::Shr { .. } => c1.add(c2, ptr_size),
OpKind::Mul => c1.mul(c2, ptr_size),
OpKind::And => c1.bitwise_and(c2, ptr_size),
OpKind::Or => c1.bitwise_or(c2, ptr_size),
OpKind::Xor => c1.bitwise_xor(c2, ptr_size),
}
}
fn name(self) -> &'static str {
match self {
OpKind::Add => "add",
OpKind::Sub => "sub",
OpKind::Mul => "mul",
OpKind::And => "and",
OpKind::Or => "or",
OpKind::Xor => "xor",
OpKind::Shl => "shl",
OpKind::Shr { unsigned: false } => "shr",
OpKind::Shr { unsigned: true } => "shr.un",
}
}
fn combine_name(self) -> &'static str {
match self {
OpKind::Add | OpKind::Mul | OpKind::And | OpKind::Or | OpKind::Xor => self.name(),
OpKind::Sub | OpKind::Shl | OpKind::Shr { .. } => "add",
}
}
const fn is_commutative(self) -> bool {
match self {
OpKind::Add | OpKind::Mul | OpKind::And | OpKind::Or | OpKind::Xor => true,
OpKind::Sub | OpKind::Shl | OpKind::Shr { .. } => false,
}
}
}
impl ReassociationPass {
#[must_use]
pub fn new() -> Self {
Self
}
fn get_op_kind(op: &SsaOp) -> Option<(OpKind, SsaVarId, SsaVarId, SsaVarId)> {
match op {
SsaOp::Add { dest, left, right } => Some((OpKind::Add, *dest, *left, *right)),
SsaOp::Sub { dest, left, right } => Some((OpKind::Sub, *dest, *left, *right)),
SsaOp::Mul { dest, left, right } => Some((OpKind::Mul, *dest, *left, *right)),
SsaOp::And { dest, left, right } => Some((OpKind::And, *dest, *left, *right)),
SsaOp::Or { dest, left, right } => Some((OpKind::Or, *dest, *left, *right)),
SsaOp::Xor { dest, left, right } => Some((OpKind::Xor, *dest, *left, *right)),
SsaOp::Shl {
dest,
value,
amount,
} => Some((OpKind::Shl, *dest, *value, *amount)),
SsaOp::Shr {
dest,
value,
amount,
unsigned,
} => Some((
OpKind::Shr {
unsigned: *unsigned,
},
*dest,
*value,
*amount,
)),
_ => None,
}
}
fn make_op(kind: OpKind, dest: SsaVarId, left: SsaVarId, right: SsaVarId) -> SsaOp {
match kind {
OpKind::Add => SsaOp::Add { dest, left, right },
OpKind::Sub => SsaOp::Sub { dest, left, right },
OpKind::Mul => SsaOp::Mul { dest, left, right },
OpKind::And => SsaOp::And { dest, left, right },
OpKind::Or => SsaOp::Or { dest, left, right },
OpKind::Xor => SsaOp::Xor { dest, left, right },
OpKind::Shl => SsaOp::Shl {
dest,
value: left,
amount: right,
},
OpKind::Shr { unsigned } => SsaOp::Shr {
dest,
value: left,
amount: right,
unsigned,
},
}
}
fn find_candidates(
ssa: &SsaFunction,
constants: &HashMap<SsaVarId, ConstValue>,
index: &DefUseIndex,
uses: &HashMap<SsaVarId, usize>,
) -> Vec<ReassociationCandidate> {
let mut candidates = Vec::new();
for (block_idx, instr_idx, instr) in ssa.iter_instructions() {
if let Some(candidate) =
Self::check_reassociation(instr.op(), block_idx, instr_idx, constants, index, uses)
{
candidates.push(candidate);
}
}
candidates
}
fn check_reassociation(
op: &SsaOp,
block_idx: usize,
instr_idx: usize,
constants: &HashMap<SsaVarId, ConstValue>,
index: &DefUseIndex,
uses: &HashMap<SsaVarId, usize>,
) -> Option<ReassociationCandidate> {
let (outer_kind, dest, outer_left, outer_right) = Self::get_op_kind(op)?;
let c2_value = constants.get(&outer_right)?;
let (inner_block, inner_instr, inner_op) = index.full_definition(outer_left)?;
let (inner_kind, inner_dest, inner_left, inner_right) = Self::get_op_kind(inner_op)?;
if inner_kind != outer_kind {
return None;
}
let inner_uses = uses.get(&inner_dest).copied().unwrap_or(0);
if inner_uses > 1 {
return None;
}
if let Some(c1_value) = constants.get(&inner_right) {
return Some(ReassociationCandidate {
block_idx,
instr_idx,
dest,
base_var: inner_left,
const1_var: inner_right,
const2_var: outer_right,
const1_value: c1_value.clone(),
const2_value: c2_value.clone(),
inner_block,
inner_instr,
inner_dest,
op_kind: outer_kind,
});
}
if outer_kind.is_commutative() {
if let Some(c1_value) = constants.get(&inner_left) {
return Some(ReassociationCandidate {
block_idx,
instr_idx,
dest,
base_var: inner_right,
const1_var: inner_left,
const2_var: outer_right,
const1_value: c1_value.clone(),
const2_value: c2_value.clone(),
inner_block,
inner_instr,
inner_dest,
op_kind: outer_kind,
});
}
}
None
}
fn apply_reassociations(
ssa: &mut SsaFunction,
candidates: Vec<ReassociationCandidate>,
method_token: Token,
changes: &mut EventLog,
ptr_size: PointerSize,
) {
for candidate in candidates {
let Some(combined) = candidate.op_kind.combine(
&candidate.const1_value,
&candidate.const2_value,
ptr_size,
) else {
continue;
};
if let Some(block) = ssa.block_mut(candidate.inner_block) {
for instr in block.instructions_mut() {
if let SsaOp::Const { dest, value: _ } = instr.op() {
if *dest == candidate.const1_var {
instr.set_op(SsaOp::Const {
dest: *dest,
value: combined.clone(),
});
break;
}
}
}
let inner_instr = &mut block.instructions_mut()[candidate.inner_instr];
inner_instr.set_op(Self::make_op(
candidate.op_kind,
candidate.inner_dest,
candidate.base_var,
candidate.const1_var,
));
}
if let Some(block) = ssa.block_mut(candidate.block_idx) {
let outer_instr = &mut block.instructions_mut()[candidate.instr_idx];
outer_instr.set_op(SsaOp::Copy {
dest: candidate.dest,
src: candidate.inner_dest,
});
}
changes
.record(EventKind::ConstantFolded)
.at(method_token, candidate.instr_idx)
.message(format!(
"reassociate: (x {} c1) {} c2 → x {} (c1 {} c2)",
candidate.op_kind.name(),
candidate.op_kind.name(),
candidate.op_kind.name(),
candidate.op_kind.combine_name()
));
}
}
}
impl SsaPass for ReassociationPass {
fn name(&self) -> &'static str {
"reassociation"
}
fn description(&self) -> &'static str {
"Reorder operations to enable constant folding (add, sub, mul, and, or, xor, shl, shr)"
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
assembly: &Arc<CilObject>,
) -> Result<bool> {
let ptr_size = PointerSize::from_pe(assembly.file().pe().is_64bit);
let mut changes = EventLog::new();
let constants = ssa.find_constants();
let index = DefUseIndex::build_with_ops(ssa);
let uses = ssa.count_uses();
let candidates = Self::find_candidates(ssa, &constants, &index, &uses);
Self::apply_reassociations(ssa, candidates, method_token, &mut changes, ptr_size);
let changed = !changes.is_empty();
if changed {
ctx.events.merge(&changes);
}
Ok(changed)
}
}
#[cfg(test)]
mod tests {
use crate::{analysis::ConstValue, metadata::typesystem::PointerSize};
use super::OpKind;
#[test]
fn test_op_kind_combine_add() {
let c1 = ConstValue::I32(5);
let c2 = ConstValue::I32(3);
let result = OpKind::Add.combine(&c1, &c2, PointerSize::Bit64);
assert_eq!(result, Some(ConstValue::I32(8)));
}
#[test]
fn test_op_kind_combine_xor() {
let c1 = ConstValue::I32(0xF0);
let c2 = ConstValue::I32(0x0F);
let result = OpKind::Xor.combine(&c1, &c2, PointerSize::Bit64);
assert_eq!(result, Some(ConstValue::I32(0xFF)));
}
#[test]
fn test_op_kind_combine_mul() {
let c1 = ConstValue::I32(7);
let c2 = ConstValue::I32(11);
let result = OpKind::Mul.combine(&c1, &c2, PointerSize::Bit64);
assert_eq!(result, Some(ConstValue::I32(77)));
}
#[test]
fn test_op_kind_combine_and() {
let c1 = ConstValue::I32(0xFF);
let c2 = ConstValue::I32(0x0F);
let result = OpKind::And.combine(&c1, &c2, PointerSize::Bit64);
assert_eq!(result, Some(ConstValue::I32(0x0F)));
}
#[test]
fn test_op_kind_combine_or() {
let c1 = ConstValue::I32(0xF0);
let c2 = ConstValue::I32(0x0F);
let result = OpKind::Or.combine(&c1, &c2, PointerSize::Bit64);
assert_eq!(result, Some(ConstValue::I32(0xFF)));
}
#[test]
fn test_op_kind_combine_sub() {
let c1 = ConstValue::I32(5);
let c2 = ConstValue::I32(3);
let result = OpKind::Sub.combine(&c1, &c2, PointerSize::Bit64);
assert_eq!(result, Some(ConstValue::I32(8)));
}
#[test]
fn test_op_kind_combine_shl() {
let c1 = ConstValue::I32(2);
let c2 = ConstValue::I32(3);
let result = OpKind::Shl.combine(&c1, &c2, PointerSize::Bit64);
assert_eq!(result, Some(ConstValue::I32(5)));
}
#[test]
fn test_op_kind_combine_shr() {
let c1 = ConstValue::I32(4);
let c2 = ConstValue::I32(2);
let result = OpKind::Shr { unsigned: false }.combine(&c1, &c2, PointerSize::Bit64);
assert_eq!(result, Some(ConstValue::I32(6)));
}
#[test]
fn test_op_kind_combine_shr_unsigned() {
let c1 = ConstValue::I32(4);
let c2 = ConstValue::I32(2);
let result = OpKind::Shr { unsigned: true }.combine(&c1, &c2, PointerSize::Bit64);
assert_eq!(result, Some(ConstValue::I32(6)));
}
#[test]
fn test_op_kind_is_commutative() {
assert!(OpKind::Add.is_commutative());
assert!(OpKind::Mul.is_commutative());
assert!(OpKind::And.is_commutative());
assert!(OpKind::Or.is_commutative());
assert!(OpKind::Xor.is_commutative());
assert!(!OpKind::Sub.is_commutative());
assert!(!OpKind::Shl.is_commutative());
assert!(!OpKind::Shr { unsigned: false }.is_commutative());
assert!(!OpKind::Shr { unsigned: true }.is_commutative());
}
#[test]
fn test_op_kind_combine_name() {
assert_eq!(OpKind::Add.combine_name(), "add");
assert_eq!(OpKind::Mul.combine_name(), "mul");
assert_eq!(OpKind::And.combine_name(), "and");
assert_eq!(OpKind::Or.combine_name(), "or");
assert_eq!(OpKind::Xor.combine_name(), "xor");
assert_eq!(OpKind::Sub.combine_name(), "add");
assert_eq!(OpKind::Shl.combine_name(), "add");
assert_eq!(OpKind::Shr { unsigned: false }.combine_name(), "add");
assert_eq!(OpKind::Shr { unsigned: true }.combine_name(), "add");
}
#[test]
fn test_op_kind_name() {
assert_eq!(OpKind::Add.name(), "add");
assert_eq!(OpKind::Sub.name(), "sub");
assert_eq!(OpKind::Mul.name(), "mul");
assert_eq!(OpKind::And.name(), "and");
assert_eq!(OpKind::Or.name(), "or");
assert_eq!(OpKind::Xor.name(), "xor");
assert_eq!(OpKind::Shl.name(), "shl");
assert_eq!(OpKind::Shr { unsigned: false }.name(), "shr");
assert_eq!(OpKind::Shr { unsigned: true }.name(), "shr.un");
}
}