use std::collections::HashMap;
use crate::{
analysis::ssa::{
symbolic::{expr::SymbolicExpr, ops::SymbolicOp},
ConstValue, SsaFunction, SsaOp, SsaVarId,
},
metadata::typesystem::PointerSize,
};
#[derive(Debug)]
pub struct SymbolicEvaluator<'a> {
ssa: &'a SsaFunction,
expressions: HashMap<SsaVarId, SymbolicExpr>,
pointer_size: PointerSize,
}
impl<'a> SymbolicEvaluator<'a> {
#[must_use]
pub fn new(ssa: &'a SsaFunction, ptr_size: PointerSize) -> Self {
Self {
ssa,
expressions: HashMap::new(),
pointer_size: ptr_size,
}
}
pub fn set_symbolic(&mut self, var: SsaVarId, name: impl Into<String>) {
self.expressions.insert(var, SymbolicExpr::named(name));
}
pub fn set_constant(&mut self, var: SsaVarId, value: ConstValue) {
self.expressions.insert(var, SymbolicExpr::constant(value));
}
#[must_use]
pub fn get_expression(&self, var: SsaVarId) -> Option<&SymbolicExpr> {
self.expressions.get(&var)
}
#[must_use]
pub fn get_simplified(&self, var: SsaVarId) -> Option<SymbolicExpr> {
self.expressions
.get(&var)
.map(|e| e.simplify(self.pointer_size))
}
#[must_use]
pub fn expressions(&self) -> &HashMap<SsaVarId, SymbolicExpr> {
&self.expressions
}
pub fn evaluate_block(&mut self, block_idx: usize) {
let Some(block) = self.ssa.block(block_idx) else {
return;
};
for instr in block.instructions() {
self.evaluate_op(instr.op());
}
}
pub fn evaluate_blocks(&mut self, block_indices: &[usize]) {
for &block_idx in block_indices {
self.evaluate_block(block_idx);
}
}
pub fn evaluate_op(&mut self, op: &SsaOp) {
match op {
SsaOp::Const { dest, value } => {
self.expressions
.insert(*dest, SymbolicExpr::constant(value.clone()));
}
SsaOp::Copy { dest, src } => {
if let Some(expr) = self.expressions.get(src) {
self.expressions.insert(*dest, expr.clone());
} else {
self.expressions.insert(*dest, SymbolicExpr::variable(*src));
}
}
SsaOp::Add { dest, left, right } => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Add);
}
SsaOp::Sub { dest, left, right } => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Sub);
}
SsaOp::Mul { dest, left, right } => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Mul);
}
SsaOp::Div {
dest,
left,
right,
unsigned,
} => {
let op = if *unsigned {
SymbolicOp::DivU
} else {
SymbolicOp::DivS
};
self.eval_binary(*dest, *left, *right, op);
}
SsaOp::Rem {
dest,
left,
right,
unsigned,
} => {
let op = if *unsigned {
SymbolicOp::RemU
} else {
SymbolicOp::RemS
};
self.eval_binary(*dest, *left, *right, op);
}
SsaOp::Xor { dest, left, right } => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Xor);
}
SsaOp::And { dest, left, right } => {
self.eval_binary(*dest, *left, *right, SymbolicOp::And);
}
SsaOp::Or { dest, left, right } => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Or);
}
SsaOp::Shl {
dest,
value,
amount,
} => {
self.eval_binary(*dest, *value, *amount, SymbolicOp::Shl);
}
SsaOp::Shr {
dest,
value,
amount,
unsigned,
} => {
let op = if *unsigned {
SymbolicOp::ShrU
} else {
SymbolicOp::ShrS
};
self.eval_binary(*dest, *value, *amount, op);
}
SsaOp::Neg { dest, operand } => {
self.eval_unary(*dest, *operand, SymbolicOp::Neg);
}
SsaOp::Not { dest, operand } => {
self.eval_unary(*dest, *operand, SymbolicOp::Not);
}
SsaOp::Ceq { dest, left, right } => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Eq);
}
SsaOp::Cgt {
dest,
left,
right,
unsigned,
} => {
let op = if *unsigned {
SymbolicOp::GtU
} else {
SymbolicOp::GtS
};
self.eval_binary(*dest, *left, *right, op);
}
SsaOp::Clt {
dest,
left,
right,
unsigned,
} => {
let op = if *unsigned {
SymbolicOp::LtU
} else {
SymbolicOp::LtS
};
self.eval_binary(*dest, *left, *right, op);
}
SsaOp::Conv { dest, operand, .. } => {
if let Some(expr) = self.expressions.get(operand) {
self.expressions.insert(*dest, expr.clone());
} else {
self.expressions
.insert(*dest, SymbolicExpr::variable(*operand));
}
}
_ => {}
}
}
fn eval_binary(&mut self, dest: SsaVarId, left: SsaVarId, right: SsaVarId, op: SymbolicOp) {
let left_expr = self
.expressions
.get(&left)
.cloned()
.unwrap_or_else(|| SymbolicExpr::variable(left));
let right_expr = self
.expressions
.get(&right)
.cloned()
.unwrap_or_else(|| SymbolicExpr::variable(right));
let result = SymbolicExpr::binary(op, left_expr, right_expr).simplify(self.pointer_size);
self.expressions.insert(dest, result);
}
fn eval_unary(&mut self, dest: SsaVarId, operand: SsaVarId, op: SymbolicOp) {
let operand_expr = self
.expressions
.get(&operand)
.cloned()
.unwrap_or_else(|| SymbolicExpr::variable(operand));
let result = SymbolicExpr::unary(op, operand_expr).simplify(self.pointer_size);
self.expressions.insert(dest, result);
}
}