use std::{collections::HashMap, sync::Arc};
use crate::{
analysis::{simplify_op, ConstValue, SimplifyResult, SsaFunction, SsaOp, SsaVarId},
compiler::{pass::SsaPass, CompilerContext, EventKind, EventLog},
metadata::token::Token,
CilObject, Result,
};
pub struct AlgebraicSimplificationPass;
impl Default for AlgebraicSimplificationPass {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
enum Simplification {
Constant(ConstValue),
Copy(SsaVarId),
}
#[derive(Debug)]
struct SimplificationCandidate {
block_idx: usize,
instr_idx: usize,
dest: SsaVarId,
simplification: Simplification,
description: String,
}
impl AlgebraicSimplificationPass {
#[must_use]
pub fn new() -> Self {
Self
}
fn find_candidates(
ssa: &SsaFunction,
constants: &HashMap<SsaVarId, ConstValue>,
) -> Vec<SimplificationCandidate> {
let mut candidates = Vec::new();
for (block_idx, instr_idx, instr) in ssa.iter_instructions() {
let op = instr.op();
if let Some(candidate) = Self::check_simplification(op, block_idx, instr_idx, constants)
{
candidates.push(candidate);
}
}
candidates
}
fn check_simplification(
op: &SsaOp,
block_idx: usize,
instr_idx: usize,
constants: &HashMap<SsaVarId, ConstValue>,
) -> Option<SimplificationCandidate> {
let dest = op.dest()?;
match simplify_op(op, constants) {
SimplifyResult::Constant(value) => Some(SimplificationCandidate {
block_idx,
instr_idx,
dest,
simplification: Simplification::Constant(value),
description: "algebraic → const".to_string(),
}),
SimplifyResult::Copy(src) => Some(SimplificationCandidate {
block_idx,
instr_idx,
dest,
simplification: Simplification::Copy(src),
description: "algebraic → copy".to_string(),
}),
SimplifyResult::None => None,
}
}
fn apply_simplifications(
ssa: &mut SsaFunction,
candidates: Vec<SimplificationCandidate>,
method_token: Token,
changes: &mut EventLog,
) {
for candidate in candidates {
if let Some(block) = ssa.block_mut(candidate.block_idx) {
let instr = &mut block.instructions_mut()[candidate.instr_idx];
let new_op = match candidate.simplification {
Simplification::Constant(value) => SsaOp::Const {
dest: candidate.dest,
value,
},
Simplification::Copy(src) => SsaOp::Copy {
dest: candidate.dest,
src,
},
};
instr.set_op(new_op);
changes
.record(EventKind::ConstantFolded)
.at(method_token, candidate.instr_idx)
.message(&candidate.description);
}
}
}
}
impl SsaPass for AlgebraicSimplificationPass {
fn name(&self) -> &'static str {
"algebraic-simplification"
}
fn description(&self) -> &'static str {
"Simplify algebraic identities (x xor x = 0, x or x = x, etc.)"
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
_assembly: &Arc<CilObject>,
) -> Result<bool> {
let mut changes = EventLog::new();
let constants = ssa.find_constants();
let candidates = Self::find_candidates(ssa, &constants);
Self::apply_simplifications(ssa, candidates, method_token, &mut changes);
let changed = !changes.is_empty();
if changed {
ctx.events.merge(&changes);
}
Ok(changed)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_div_by_one() {
let left = SsaVarId::from_index(0);
let right = SsaVarId::from_index(1);
let dest = SsaVarId::from_index(2);
let constants: HashMap<SsaVarId, ConstValue> = [(right, ConstValue::I32(1))].into();
let op = SsaOp::Div {
dest,
left,
right,
unsigned: false,
};
let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants);
assert!(result.is_some());
let candidate = result.unwrap();
assert!(matches!(candidate.simplification, Simplification::Copy(v) if v == left));
}
#[test]
fn test_rem_by_one() {
let left = SsaVarId::from_index(0);
let right = SsaVarId::from_index(1);
let dest = SsaVarId::from_index(2);
let constants: HashMap<SsaVarId, ConstValue> = [(right, ConstValue::I32(1))].into();
let op = SsaOp::Rem {
dest,
left,
right,
unsigned: false,
};
let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants);
assert!(result.is_some());
let candidate = result.unwrap();
assert!(matches!(
candidate.simplification,
Simplification::Constant(ConstValue::I32(0))
));
}
#[test]
fn test_ceq_same_var() {
let x = SsaVarId::from_index(0);
let dest = SsaVarId::from_index(1);
let constants: HashMap<SsaVarId, ConstValue> = HashMap::new();
let op = SsaOp::Ceq {
dest,
left: x,
right: x,
};
let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants);
assert!(result.is_some());
let candidate = result.unwrap();
assert!(matches!(
candidate.simplification,
Simplification::Constant(ConstValue::I32(1))
));
}
#[test]
fn test_clt_same_var() {
let x = SsaVarId::from_index(0);
let dest = SsaVarId::from_index(1);
let constants: HashMap<SsaVarId, ConstValue> = HashMap::new();
let op = SsaOp::Clt {
dest,
left: x,
right: x,
unsigned: false,
};
let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants);
assert!(result.is_some());
let candidate = result.unwrap();
assert!(matches!(
candidate.simplification,
Simplification::Constant(ConstValue::I32(0))
));
}
#[test]
fn test_cgt_same_var() {
let x = SsaVarId::from_index(0);
let dest = SsaVarId::from_index(1);
let constants: HashMap<SsaVarId, ConstValue> = HashMap::new();
let op = SsaOp::Cgt {
dest,
left: x,
right: x,
unsigned: false,
};
let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants);
assert!(result.is_some());
let candidate = result.unwrap();
assert!(matches!(
candidate.simplification,
Simplification::Constant(ConstValue::I32(0))
));
}
}