use std::collections::HashMap;
use crate::{
analysis::symbolic::{expr::SymbolicExpr, ops::SymbolicOp},
ir::{function::SsaFunction, ops::SsaOp, value::ConstValue, variable::SsaVarId},
target::Target,
PointerSize,
};
#[derive(Debug)]
pub struct SymbolicEvaluator<'a, T: Target> {
ssa: &'a SsaFunction<T>,
expressions: HashMap<SsaVarId, SymbolicExpr<T>>,
pointer_size: PointerSize,
}
impl<'a, T: Target> SymbolicEvaluator<'a, T> {
#[must_use]
pub fn new(ssa: &'a SsaFunction<T>, ptr_size: PointerSize) -> Self {
Self {
ssa,
expressions: HashMap::new(),
pointer_size: ptr_size,
}
}
pub fn set_symbolic(&mut self, var: SsaVarId, name: impl Into<String>) {
self.expressions.insert(var, SymbolicExpr::named(name));
}
pub fn set_constant(&mut self, var: SsaVarId, value: ConstValue<T>) {
self.expressions.insert(var, SymbolicExpr::constant(value));
}
#[must_use]
pub fn get_expression(&self, var: SsaVarId) -> Option<&SymbolicExpr<T>> {
self.expressions.get(&var)
}
#[must_use]
pub fn get_simplified(&self, var: SsaVarId) -> Option<SymbolicExpr<T>> {
self.expressions
.get(&var)
.map(|e| e.simplify(self.pointer_size))
}
#[must_use]
pub fn expressions(&self) -> &HashMap<SsaVarId, SymbolicExpr<T>> {
&self.expressions
}
pub fn evaluate_block(&mut self, block_idx: usize) {
let Some(block) = self.ssa.block(block_idx) else {
return;
};
for instr in block.instructions() {
self.evaluate_op(instr.op());
}
}
pub fn evaluate_blocks(&mut self, block_indices: &[usize]) {
for &block_idx in block_indices {
self.evaluate_block(block_idx);
}
}
pub fn evaluate_op(&mut self, op: &SsaOp<T>) {
match op {
SsaOp::Const { dest, value } => {
self.expressions
.insert(*dest, SymbolicExpr::constant(value.clone()));
}
SsaOp::Copy { dest, src } => {
if let Some(expr) = self.expressions.get(src) {
self.expressions.insert(*dest, expr.clone());
} else {
self.expressions.insert(*dest, SymbolicExpr::variable(*src));
}
}
SsaOp::Add {
dest, left, right, ..
} => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Add);
}
SsaOp::Sub {
dest, left, right, ..
} => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Sub);
}
SsaOp::Mul {
dest, left, right, ..
} => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Mul);
}
SsaOp::Div {
dest,
left,
right,
unsigned,
..
} => {
let op = if *unsigned {
SymbolicOp::DivU
} else {
SymbolicOp::DivS
};
self.eval_binary(*dest, *left, *right, op);
}
SsaOp::Rem {
dest,
left,
right,
unsigned,
..
} => {
let op = if *unsigned {
SymbolicOp::RemU
} else {
SymbolicOp::RemS
};
self.eval_binary(*dest, *left, *right, op);
}
SsaOp::Xor {
dest, left, right, ..
} => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Xor);
}
SsaOp::And {
dest, left, right, ..
} => {
self.eval_binary(*dest, *left, *right, SymbolicOp::And);
}
SsaOp::Or {
dest, left, right, ..
} => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Or);
}
SsaOp::Shl {
dest,
value,
amount,
..
} => {
self.eval_binary(*dest, *value, *amount, SymbolicOp::Shl);
}
SsaOp::Shr {
dest,
value,
amount,
unsigned,
..
} => {
let op = if *unsigned {
SymbolicOp::ShrU
} else {
SymbolicOp::ShrS
};
self.eval_binary(*dest, *value, *amount, op);
}
SsaOp::Neg { dest, operand, .. } => {
self.eval_unary(*dest, *operand, SymbolicOp::Neg);
}
SsaOp::Not { dest, operand, .. } => {
self.eval_unary(*dest, *operand, SymbolicOp::Not);
}
SsaOp::Ceq { dest, left, right } => {
self.eval_binary(*dest, *left, *right, SymbolicOp::Eq);
}
SsaOp::Cgt {
dest,
left,
right,
unsigned,
} => {
let op = if *unsigned {
SymbolicOp::GtU
} else {
SymbolicOp::GtS
};
self.eval_binary(*dest, *left, *right, op);
}
SsaOp::Clt {
dest,
left,
right,
unsigned,
} => {
let op = if *unsigned {
SymbolicOp::LtU
} else {
SymbolicOp::LtS
};
self.eval_binary(*dest, *left, *right, op);
}
SsaOp::Conv { dest, operand, .. } => {
if let Some(expr) = self.expressions.get(operand) {
self.expressions.insert(*dest, expr.clone());
} else {
self.expressions
.insert(*dest, SymbolicExpr::variable(*operand));
}
}
SsaOp::LoadArg { dest, arg_index } => {
if let Some(arg_var) = self
.ssa
.variables_from_argument(*arg_index)
.find(|v| v.version() == 0)
{
let src = arg_var.id();
if let Some(expr) = self.expressions.get(&src).cloned() {
self.expressions.insert(*dest, expr);
} else {
self.expressions.insert(*dest, SymbolicExpr::variable(src));
}
}
}
SsaOp::LoadLocal { dest, local_index } => {
if let Some(local_var) = self
.ssa
.variables_from_local(*local_index)
.find(|v| v.version() == 0)
{
let src = local_var.id();
if let Some(expr) = self.expressions.get(&src).cloned() {
self.expressions.insert(*dest, expr);
} else {
self.expressions.insert(*dest, SymbolicExpr::variable(src));
}
}
}
SsaOp::Rol {
dest,
value,
amount,
} => {
self.eval_binary(*dest, *value, *amount, SymbolicOp::Rol);
}
SsaOp::Ror {
dest,
value,
amount,
} => {
self.eval_binary(*dest, *value, *amount, SymbolicOp::Ror);
}
SsaOp::Rcl {
dest,
value,
amount,
} => {
self.eval_binary(*dest, *value, *amount, SymbolicOp::Rcl);
}
SsaOp::Rcr {
dest,
value,
amount,
} => {
self.eval_binary(*dest, *value, *amount, SymbolicOp::Rcr);
}
SsaOp::BSwap { dest, src } => {
self.eval_unary(*dest, *src, SymbolicOp::BSwap);
}
SsaOp::BRev { dest, src } => {
self.eval_unary(*dest, *src, SymbolicOp::BRev);
}
SsaOp::BitScanForward { dest, src } => {
self.eval_unary(*dest, *src, SymbolicOp::BitScanForward);
}
SsaOp::BitScanReverse { dest, src } => {
self.eval_unary(*dest, *src, SymbolicOp::BitScanReverse);
}
SsaOp::Popcount { dest, src } => {
self.eval_unary(*dest, *src, SymbolicOp::Popcount);
}
SsaOp::Parity { dest, src } => {
self.eval_unary(*dest, *src, SymbolicOp::Parity);
}
_ => {}
}
}
fn eval_binary(&mut self, dest: SsaVarId, left: SsaVarId, right: SsaVarId, op: SymbolicOp) {
let left_expr = self
.expressions
.get(&left)
.cloned()
.unwrap_or_else(|| SymbolicExpr::variable(left));
let right_expr = self
.expressions
.get(&right)
.cloned()
.unwrap_or_else(|| SymbolicExpr::variable(right));
let result = SymbolicExpr::binary(op, left_expr, right_expr).simplify(self.pointer_size);
self.expressions.insert(dest, result);
}
fn eval_unary(&mut self, dest: SsaVarId, operand: SsaVarId, op: SymbolicOp) {
let operand_expr = self
.expressions
.get(&operand)
.cloned()
.unwrap_or_else(|| SymbolicExpr::variable(operand));
let result = SymbolicExpr::unary(op, operand_expr).simplify(self.pointer_size);
self.expressions.insert(dest, result);
}
}