use std::collections::{HashMap, HashSet};
use crate::{
analysis::ssa::{ConstValue, SsaFunction, SsaOp, SsaVarId},
metadata::typesystem::PointerSize,
};
pub struct ConstEvaluator<'a> {
ssa: &'a SsaFunction,
cache: HashMap<SsaVarId, Option<ConstValue>>,
visiting: HashSet<SsaVarId>,
max_depth: usize,
pointer_size: PointerSize,
}
impl<'a> ConstEvaluator<'a> {
const DEFAULT_MAX_DEPTH: usize = 20;
#[must_use]
pub fn new(ssa: &'a SsaFunction, ptr_size: PointerSize) -> Self {
Self::with_max_depth(ssa, Self::DEFAULT_MAX_DEPTH, ptr_size)
}
#[must_use]
pub fn with_max_depth(ssa: &'a SsaFunction, max_depth: usize, ptr_size: PointerSize) -> Self {
Self {
ssa,
cache: HashMap::new(),
visiting: HashSet::new(),
max_depth,
pointer_size: ptr_size,
}
}
pub fn set_known(&mut self, var: SsaVarId, value: ConstValue) {
self.cache.insert(var, Some(value));
}
pub fn evaluate_var(&mut self, var: SsaVarId) -> Option<ConstValue> {
self.evaluate_var_depth(var, 0)
}
fn evaluate_var_depth(&mut self, var: SsaVarId, depth: usize) -> Option<ConstValue> {
if depth > self.max_depth {
return None;
}
if let Some(cached) = self.cache.get(&var) {
return cached.clone();
}
if self.visiting.contains(&var) {
return None;
}
self.visiting.insert(var);
let result = self
.ssa
.get_definition(var)
.and_then(|op| self.evaluate_op_depth(op, depth));
self.visiting.remove(&var);
self.cache.insert(var, result.clone());
result
}
pub fn evaluate_op(&mut self, op: &SsaOp) -> Option<ConstValue> {
self.evaluate_op_depth(op, 0)
}
fn evaluate_op_depth(&mut self, op: &SsaOp, depth: usize) -> Option<ConstValue> {
if depth > self.max_depth {
return None;
}
match op {
SsaOp::Const { value, .. } => Some(value.clone()),
SsaOp::Copy { src, .. } => self.evaluate_var_depth(*src, depth + 1),
SsaOp::Xor { left, right, .. } => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.bitwise_xor(&r, self.pointer_size)
}
SsaOp::And { left, right, .. } => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.bitwise_and(&r, self.pointer_size)
}
SsaOp::Or { left, right, .. } => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.bitwise_or(&r, self.pointer_size)
}
SsaOp::Add { left, right, .. } => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.add(&r, self.pointer_size)
}
SsaOp::Sub { left, right, .. } => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.sub(&r, self.pointer_size)
}
SsaOp::Mul { left, right, .. } => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.mul(&r, self.pointer_size)
}
SsaOp::Div { left, right, .. } => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.div(&r, self.pointer_size)
}
SsaOp::Rem { left, right, .. } => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.rem(&r, self.pointer_size)
}
SsaOp::Shl { value, amount, .. } => {
let v = self.evaluate_var_depth(*value, depth + 1)?;
let a = self.evaluate_var_depth(*amount, depth + 1)?;
v.shl(&a, self.pointer_size)
}
SsaOp::Shr {
value,
amount,
unsigned,
..
} => {
let v = self.evaluate_var_depth(*value, depth + 1)?;
let a = self.evaluate_var_depth(*amount, depth + 1)?;
v.shr(&a, *unsigned, self.pointer_size)
}
SsaOp::Neg { operand, .. } => {
let v = self.evaluate_var_depth(*operand, depth + 1)?;
v.negate(self.pointer_size)
}
SsaOp::Not { operand, .. } => {
let v = self.evaluate_var_depth(*operand, depth + 1)?;
v.bitwise_not(self.pointer_size)
}
SsaOp::Ceq { left, right, .. } => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.ceq(&r)
}
SsaOp::Clt {
left,
right,
unsigned,
..
} => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
if *unsigned {
l.clt_un(&r)
} else {
l.clt(&r)
}
}
SsaOp::Cgt {
left,
right,
unsigned,
..
} => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
if *unsigned {
l.cgt_un(&r)
} else {
l.cgt(&r)
}
}
SsaOp::AddOvf {
left,
right,
unsigned,
..
} => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.add_checked(&r, *unsigned, self.pointer_size)
}
SsaOp::SubOvf {
left,
right,
unsigned,
..
} => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.sub_checked(&r, *unsigned, self.pointer_size)
}
SsaOp::MulOvf {
left,
right,
unsigned,
..
} => {
let l = self.evaluate_var_depth(*left, depth + 1)?;
let r = self.evaluate_var_depth(*right, depth + 1)?;
l.mul_checked(&r, *unsigned, self.pointer_size)
}
SsaOp::Conv {
operand,
target,
overflow_check,
unsigned,
..
} => {
let v = self.evaluate_var_depth(*operand, depth + 1)?;
if *overflow_check {
v.convert_to_checked(target, *unsigned, self.pointer_size)
} else {
v.convert_to(target, *unsigned, self.pointer_size)
}
}
_ => None,
}
}
#[must_use]
pub fn into_results(self) -> HashMap<SsaVarId, ConstValue> {
self.cache
.into_iter()
.filter_map(|(var, opt)| opt.map(|val| (var, val)))
.collect()
}
#[must_use]
pub fn ssa(&self) -> &SsaFunction {
self.ssa
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
#[cfg(test)]
mod tests {
use crate::{
analysis::ssa::{
ConstEvaluator, ConstValue, DefSite, SsaBlock, SsaFunction, SsaInstruction, SsaOp,
SsaVarId, SsaVariable, VariableOrigin,
},
metadata::typesystem::PointerSize,
};
#[test]
fn test_evaluate_constant() {
let mut ssa = SsaFunction::new(0, 0);
let mut block = SsaBlock::new(0);
let var = SsaVariable::new(VariableOrigin::Stack(0), 0, DefSite::instruction(0, 0));
let var_id = var.id();
ssa.add_variable(var);
block.add_instruction(SsaInstruction::synthetic(SsaOp::Const {
dest: var_id,
value: ConstValue::I32(42),
}));
block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None }));
ssa.add_block(block);
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
let result = evaluator.evaluate_var(var_id);
assert_eq!(result, Some(ConstValue::I32(42)));
}
#[test]
fn test_evaluate_copy_chain() {
let mut ssa = SsaFunction::new(0, 0);
let mut block = SsaBlock::new(0);
let v0 = SsaVariable::new(VariableOrigin::Stack(0), 0, DefSite::instruction(0, 0));
let v0_id = v0.id();
ssa.add_variable(v0);
let v1 = SsaVariable::new(VariableOrigin::Stack(1), 0, DefSite::instruction(0, 1));
let v1_id = v1.id();
ssa.add_variable(v1);
block.add_instruction(SsaInstruction::synthetic(SsaOp::Const {
dest: v0_id,
value: ConstValue::I32(100),
}));
block.add_instruction(SsaInstruction::synthetic(SsaOp::Copy {
dest: v1_id,
src: v0_id,
}));
block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None }));
ssa.add_block(block);
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
let result = evaluator.evaluate_var(v1_id);
assert_eq!(result, Some(ConstValue::I32(100)));
}
#[test]
fn test_set_known_value() {
let ssa = SsaFunction::new(0, 0);
let var_id = SsaVarId::new();
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
evaluator.set_known(var_id, ConstValue::I32(999));
let result = evaluator.evaluate_var(var_id);
assert_eq!(result, Some(ConstValue::I32(999)));
}
#[test]
fn test_into_results() {
let ssa = SsaFunction::new(0, 0);
let var1 = SsaVarId::new();
let var2 = SsaVarId::new();
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
evaluator.set_known(var1, ConstValue::I32(1));
evaluator.set_known(var2, ConstValue::I32(2));
evaluator.evaluate_var(var1);
evaluator.evaluate_var(var2);
let results = evaluator.into_results();
assert_eq!(results.len(), 2);
assert_eq!(results.get(&var1), Some(&ConstValue::I32(1)));
assert_eq!(results.get(&var2), Some(&ConstValue::I32(2)));
}
}