use std::collections::HashMap;
use crate::{
analysis::ssa::{ConstValue, SsaFunction, SsaOp, SsaVarId},
metadata::typesystem::PointerSize,
utils::BitSet,
};
pub struct ConstEvaluator<'a> {
ssa: &'a SsaFunction,
cache: HashMap<SsaVarId, Option<ConstValue>>,
visiting: BitSet,
max_depth: usize,
pointer_size: PointerSize,
}
impl<'a> ConstEvaluator<'a> {
const DEFAULT_MAX_DEPTH: usize = 20;
#[must_use]
pub fn new(ssa: &'a SsaFunction, ptr_size: PointerSize) -> Self {
Self::with_max_depth(ssa, Self::DEFAULT_MAX_DEPTH, ptr_size)
}
#[must_use]
pub fn with_max_depth(ssa: &'a SsaFunction, max_depth: usize, ptr_size: PointerSize) -> Self {
Self {
ssa,
cache: HashMap::new(),
visiting: BitSet::new(ssa.variable_count().max(1)),
max_depth,
pointer_size: ptr_size,
}
}
pub fn set_known(&mut self, var: SsaVarId, value: ConstValue) {
self.cache.insert(var, Some(value));
}
pub fn evaluate_var(&mut self, var: SsaVarId) -> Option<ConstValue> {
self.evaluate_var_depth(var, 0)
}
fn evaluate_var_depth(&mut self, var: SsaVarId, depth: usize) -> Option<ConstValue> {
if depth > self.max_depth {
return None;
}
if let Some(cached) = self.cache.get(&var) {
return cached.clone();
}
if var.index() < self.visiting.len() && self.visiting.contains(var.index()) {
return None;
}
if var.index() < self.visiting.len() {
self.visiting.insert(var.index());
}
let result = self
.ssa
.get_definition(var)
.and_then(|op| self.evaluate_op_depth(op, depth));
if var.index() < self.visiting.len() {
self.visiting.remove(var.index());
}
self.cache.insert(var, result.clone());
result
}
pub fn evaluate_op(&mut self, op: &SsaOp) -> Option<ConstValue> {
self.evaluate_op_depth(op, 0)
}
fn evaluate_op_depth(&mut self, op: &SsaOp, depth: usize) -> Option<ConstValue> {
if depth > self.max_depth {
return None;
}
if let SsaOp::Copy { src, .. } = op {
return self.evaluate_var_depth(*src, depth + 1);
}
let ptr_size = self.pointer_size;
evaluate_const_op(op, |var| self.evaluate_var_depth(var, depth + 1), ptr_size)
}
#[must_use]
pub fn into_results(self) -> HashMap<SsaVarId, ConstValue> {
self.cache
.into_iter()
.filter_map(|(var, opt)| opt.map(|val| (var, val)))
.collect()
}
#[must_use]
pub fn ssa(&self) -> &SsaFunction {
self.ssa
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
pub fn evaluate_const_op(
op: &SsaOp,
mut get_const: impl FnMut(SsaVarId) -> Option<ConstValue>,
ptr_size: PointerSize,
) -> Option<ConstValue> {
match op {
SsaOp::Const { value, .. } => Some(value.clone()),
SsaOp::Add { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.add(&r, ptr_size)
}
SsaOp::Sub { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.sub(&r, ptr_size)
}
SsaOp::Mul { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.mul(&r, ptr_size)
}
SsaOp::Div { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.div(&r, ptr_size)
}
SsaOp::Rem { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.rem(&r, ptr_size)
}
SsaOp::Xor { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.bitwise_xor(&r, ptr_size)
}
SsaOp::And { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.bitwise_and(&r, ptr_size)
}
SsaOp::Or { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.bitwise_or(&r, ptr_size)
}
SsaOp::Shl { value, amount, .. } => {
let v = get_const(*value)?;
let a = get_const(*amount)?;
v.shl(&a, ptr_size)
}
SsaOp::Shr {
value,
amount,
unsigned,
..
} => {
let v = get_const(*value)?;
let a = get_const(*amount)?;
v.shr(&a, *unsigned, ptr_size)
}
SsaOp::Neg { operand, .. } => {
let v = get_const(*operand)?;
v.negate(ptr_size)
}
SsaOp::Not { operand, .. } => {
let v = get_const(*operand)?;
v.bitwise_not(ptr_size)
}
SsaOp::Ceq { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.ceq(&r)
}
SsaOp::Clt {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
if *unsigned {
l.clt_un(&r)
} else {
l.clt(&r)
}
}
SsaOp::Cgt {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
if *unsigned {
l.cgt_un(&r)
} else {
l.cgt(&r)
}
}
SsaOp::AddOvf {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.add_checked(&r, *unsigned, ptr_size)
}
SsaOp::SubOvf {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.sub_checked(&r, *unsigned, ptr_size)
}
SsaOp::MulOvf {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.mul_checked(&r, *unsigned, ptr_size)
}
SsaOp::Conv {
operand,
target,
overflow_check,
unsigned,
..
} => {
let v = get_const(*operand)?;
if *overflow_check {
v.convert_to_checked(target, *unsigned, ptr_size)
} else {
v.convert_to(target, *unsigned, ptr_size)
}
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use crate::{
analysis::ssa::{
ConstEvaluator, ConstValue, DefSite, SsaBlock, SsaFunction, SsaInstruction, SsaOp,
SsaType, SsaVarId, VariableOrigin,
},
metadata::typesystem::PointerSize,
};
#[test]
fn test_evaluate_constant() {
let mut ssa = SsaFunction::new(0, 0);
let mut block = SsaBlock::new(0);
let var_id = ssa.create_variable(
VariableOrigin::Local(0),
0,
DefSite::instruction(0, 0),
SsaType::Unknown,
);
block.add_instruction(SsaInstruction::synthetic(SsaOp::Const {
dest: var_id,
value: ConstValue::I32(42),
}));
block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None }));
ssa.add_block(block);
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
let result = evaluator.evaluate_var(var_id);
assert_eq!(result, Some(ConstValue::I32(42)));
}
#[test]
fn test_evaluate_copy_chain() {
let mut ssa = SsaFunction::new(0, 0);
let mut block = SsaBlock::new(0);
let v0_id = ssa.create_variable(
VariableOrigin::Local(0),
0,
DefSite::instruction(0, 0),
SsaType::Unknown,
);
let v1_id = ssa.create_variable(
VariableOrigin::Local(1),
0,
DefSite::instruction(0, 1),
SsaType::Unknown,
);
block.add_instruction(SsaInstruction::synthetic(SsaOp::Const {
dest: v0_id,
value: ConstValue::I32(100),
}));
block.add_instruction(SsaInstruction::synthetic(SsaOp::Copy {
dest: v1_id,
src: v0_id,
}));
block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None }));
ssa.add_block(block);
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
let result = evaluator.evaluate_var(v1_id);
assert_eq!(result, Some(ConstValue::I32(100)));
}
#[test]
fn test_set_known_value() {
let ssa = SsaFunction::new(0, 0);
let var_id = SsaVarId::from_index(0);
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
evaluator.set_known(var_id, ConstValue::I32(999));
let result = evaluator.evaluate_var(var_id);
assert_eq!(result, Some(ConstValue::I32(999)));
}
#[test]
fn test_into_results() {
let ssa = SsaFunction::new(0, 0);
let var1 = SsaVarId::from_index(0);
let var2 = SsaVarId::from_index(1);
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
evaluator.set_known(var1, ConstValue::I32(1));
evaluator.set_known(var2, ConstValue::I32(2));
evaluator.evaluate_var(var1);
evaluator.evaluate_var(var2);
let results = evaluator.into_results();
assert_eq!(results.len(), 2);
assert_eq!(results.get(&var1), Some(&ConstValue::I32(1)));
assert_eq!(results.get(&var2), Some(&ConstValue::I32(2)));
}
}