use std::{collections::HashSet, sync::Arc};
use crate::{
analysis::{ConstValue, DefUseIndex, SsaFunction, SsaOp, SsaVarId, ValueRange},
compiler::{pass::SsaPass, CompilerContext, EventKind, EventLog},
metadata::token::Token,
utils::is_power_of_two,
CilObject, Result,
};
pub struct StrengthReductionPass;
impl Default for StrengthReductionPass {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
struct InstrLocation {
block_idx: usize,
instr_idx: usize,
}
struct ReductionChecker<'a> {
index: &'a DefUseIndex,
used_constants: &'a HashSet<SsaVarId>,
}
impl<'a> ReductionChecker<'a> {
fn new(index: &'a DefUseIndex, used_constants: &'a HashSet<SsaVarId>) -> Self {
Self {
index,
used_constants,
}
}
fn try_mul_reduction(
&self,
dest: SsaVarId,
value_var: SsaVarId,
const_var: SsaVarId,
location: InstrLocation,
) -> Option<ReductionCandidate> {
let (const_block, const_instr, const_op) = self.index.full_definition(const_var)?;
let SsaOp::Const {
value: const_value, ..
} = const_op
else {
return None;
};
let value = const_value.as_i64()?;
let exponent = is_power_of_two(value)?;
let uses = self.index.use_count(const_var);
if uses != 1 || self.used_constants.contains(&const_var) {
return None;
}
Some(ReductionCandidate {
location,
const_var,
const_block,
const_instr,
new_const_value: ConstValue::I32(i32::from(exponent)),
new_op: SsaOp::Shl {
dest,
value: value_var,
amount: const_var,
},
description: format!("mul x, {value} → shl x, {exponent}"),
})
}
fn try_div_reduction(
&self,
dest: SsaVarId,
dividend: SsaVarId,
divisor_var: SsaVarId,
unsigned: bool,
location: InstrLocation,
) -> Option<ReductionCandidate> {
let (const_block, const_instr, const_op) = self.index.full_definition(divisor_var)?;
let SsaOp::Const {
value: const_value, ..
} = const_op
else {
return None;
};
let value = const_value.as_i64()?;
let exponent = is_power_of_two(value)?;
let uses = self.index.use_count(divisor_var);
if uses != 1 || self.used_constants.contains(&divisor_var) {
return None;
}
let desc = if unsigned {
format!("div.un x, {value} → shr.un x, {exponent}")
} else {
format!("div x, {value} → shr x, {exponent} (x >= 0)")
};
Some(ReductionCandidate {
location,
const_var: divisor_var,
const_block,
const_instr,
new_const_value: ConstValue::I32(i32::from(exponent)),
new_op: SsaOp::Shr {
dest,
value: dividend,
amount: divisor_var,
unsigned,
},
description: desc,
})
}
#[allow(clippy::cast_possible_truncation)] fn try_rem_reduction(
&self,
dest: SsaVarId,
dividend: SsaVarId,
divisor_var: SsaVarId,
unsigned: bool,
location: InstrLocation,
) -> Option<ReductionCandidate> {
let (const_block, const_instr, const_op) = self.index.full_definition(divisor_var)?;
let SsaOp::Const {
value: const_value, ..
} = const_op
else {
return None;
};
let value = const_value.as_i64()?;
let _exponent = is_power_of_two(value)?;
let mask = value - 1;
let uses = self.index.use_count(divisor_var);
if uses != 1 || self.used_constants.contains(&divisor_var) {
return None;
}
let desc = if unsigned {
format!("rem.un x, {value} → and x, {mask}")
} else {
format!("rem x, {value} → and x, {mask} (x >= 0)")
};
Some(ReductionCandidate {
location,
const_var: divisor_var,
const_block,
const_instr,
new_const_value: ConstValue::I32(mask as i32),
new_op: SsaOp::And {
dest,
left: dividend,
right: divisor_var,
},
description: desc,
})
}
}
#[derive(Debug)]
struct ReductionCandidate {
location: InstrLocation,
const_var: SsaVarId,
const_block: usize,
const_instr: usize,
new_const_value: ConstValue,
new_op: SsaOp,
description: String,
}
impl StrengthReductionPass {
#[must_use]
pub fn new() -> Self {
Self
}
fn find_candidates(
ssa: &SsaFunction,
index: &DefUseIndex,
ctx: &CompilerContext,
method_token: Token,
) -> Vec<ReductionCandidate> {
let mut candidates = Vec::new();
let mut used_constants: HashSet<SsaVarId> = HashSet::new();
for (block_idx, instr_idx, instr) in ssa.iter_instructions() {
let checker = ReductionChecker::new(index, &used_constants);
let location = InstrLocation {
block_idx,
instr_idx,
};
if let Some(candidate) =
Self::check_reduction(instr.op(), location, &checker, ctx, method_token)
{
used_constants.insert(candidate.const_var);
candidates.push(candidate);
}
}
candidates
}
fn check_reduction(
op: &SsaOp,
location: InstrLocation,
checker: &ReductionChecker<'_>,
ctx: &CompilerContext,
method_token: Token,
) -> Option<ReductionCandidate> {
match op {
SsaOp::Mul { dest, left, right } => {
if let Some(candidate) = checker.try_mul_reduction(*dest, *left, *right, location) {
return Some(candidate);
}
checker.try_mul_reduction(*dest, *right, *left, location)
}
SsaOp::Div {
dest,
left,
right,
unsigned: true,
} => checker.try_div_reduction(*dest, *left, *right, true, location),
SsaOp::Div {
dest,
left,
right,
unsigned: false,
} => {
if Self::is_provably_non_negative(*left, ctx, method_token) {
checker.try_div_reduction(*dest, *left, *right, false, location)
} else {
None
}
}
SsaOp::Rem {
dest,
left,
right,
unsigned: true,
} => checker.try_rem_reduction(*dest, *left, *right, true, location),
SsaOp::Rem {
dest,
left,
right,
unsigned: false,
} => {
if Self::is_provably_non_negative(*left, ctx, method_token) {
checker.try_rem_reduction(*dest, *left, *right, false, location)
} else {
None
}
}
_ => None,
}
}
fn is_provably_non_negative(var: SsaVarId, ctx: &CompilerContext, method_token: Token) -> bool {
ctx.with_known_range(method_token, var, ValueRange::is_always_non_negative)
.unwrap_or(false)
}
fn apply_reductions(
ssa: &mut SsaFunction,
candidates: Vec<ReductionCandidate>,
method_token: Token,
changes: &mut EventLog,
) {
for candidate in candidates {
if let Some(block) = ssa.block_mut(candidate.const_block) {
let const_instr = &mut block.instructions_mut()[candidate.const_instr];
const_instr.set_op(SsaOp::Const {
dest: candidate.const_var,
value: candidate.new_const_value,
});
}
if let Some(block) = ssa.block_mut(candidate.location.block_idx) {
let instr = &mut block.instructions_mut()[candidate.location.instr_idx];
instr.set_op(candidate.new_op);
changes
.record(EventKind::StrengthReduced)
.at(method_token, candidate.location.instr_idx)
.message(&candidate.description);
}
}
}
}
impl SsaPass for StrengthReductionPass {
fn name(&self) -> &'static str {
"strength-reduction"
}
fn description(&self) -> &'static str {
"Transform expensive operations (mul/div/rem) to cheaper equivalents (shl/shr/and)"
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
_assembly: &Arc<CilObject>,
) -> Result<bool> {
let mut changes = EventLog::new();
let index = DefUseIndex::build_with_ops(ssa);
let candidates = Self::find_candidates(ssa, &index, ctx, method_token);
Self::apply_reductions(ssa, candidates, method_token, &mut changes);
let changed = !changes.is_empty();
if changed {
ctx.events.merge(&changes);
}
Ok(changed)
}
}