use std::{
collections::{HashMap, HashSet},
fmt,
};
use crate::{
analysis::symbolic::ops::SymbolicOp,
ir::{value::ConstValue, variable::SsaVarId},
target::Target,
PointerSize,
};
#[derive(Debug, Clone, PartialEq)]
pub enum SymbolicExpr<T: Target> {
Constant(ConstValue<T>),
Variable(SsaVarId),
NamedVar(String),
Unary {
op: SymbolicOp,
operand: Box<SymbolicExpr<T>>,
},
Binary {
op: SymbolicOp,
left: Box<SymbolicExpr<T>>,
right: Box<SymbolicExpr<T>>,
},
}
impl<T: Target> SymbolicExpr<T> {
#[must_use]
pub fn constant(value: ConstValue<T>) -> Self {
Self::Constant(value)
}
#[must_use]
pub fn constant_i64(value: i64) -> Self {
Self::Constant(ConstValue::I64(value))
}
#[must_use]
pub fn constant_i32(value: i32) -> Self {
Self::Constant(ConstValue::I32(value))
}
#[must_use]
pub const fn variable(var: SsaVarId) -> Self {
Self::Variable(var)
}
#[must_use]
pub fn named(name: impl Into<String>) -> Self {
Self::NamedVar(name.into())
}
#[must_use]
pub fn unary(op: SymbolicOp, operand: Self) -> Self {
Self::Unary {
op,
operand: Box::new(operand),
}
}
#[must_use]
pub fn binary(op: SymbolicOp, left: Self, right: Self) -> Self {
Self::Binary {
op,
left: Box::new(left),
right: Box::new(right),
}
}
#[must_use]
pub const fn is_constant(&self) -> bool {
matches!(self, Self::Constant(_))
}
#[must_use]
pub const fn is_variable(&self) -> bool {
matches!(self, Self::Variable(_) | Self::NamedVar(_))
}
#[must_use]
pub const fn as_constant(&self) -> Option<&ConstValue<T>> {
match self {
Self::Constant(v) => Some(v),
_ => None,
}
}
#[must_use]
pub fn as_i64(&self) -> Option<i64> {
match self {
Self::Constant(v) => v.as_i64(),
_ => None,
}
}
#[must_use]
pub const fn as_variable(&self) -> Option<SsaVarId> {
match self {
Self::Variable(v) => Some(*v),
_ => None,
}
}
#[must_use]
pub fn variables(&self) -> HashSet<SsaVarId> {
match self {
Self::Constant(_) | Self::NamedVar(_) => HashSet::new(),
Self::Variable(v) => {
let mut vars = HashSet::new();
vars.insert(*v);
vars
}
Self::Unary { operand, .. } => operand.variables(),
Self::Binary { left, right, .. } => {
let mut vars = left.variables();
vars.extend(right.variables());
vars
}
}
}
#[must_use]
pub fn named_variables(&self) -> HashSet<String> {
let mut vars = HashSet::new();
self.collect_named_variables(&mut vars);
vars
}
fn collect_named_variables(&self, vars: &mut HashSet<String>) {
match self {
Self::Constant(_) | Self::Variable(_) => {}
Self::NamedVar(name) => {
vars.insert(name.clone());
}
Self::Unary { operand, .. } => operand.collect_named_variables(vars),
Self::Binary { left, right, .. } => {
left.collect_named_variables(vars);
right.collect_named_variables(vars);
}
}
}
#[must_use]
pub fn evaluate(
&self,
bindings: &HashMap<SsaVarId, ConstValue<T>>,
ptr_size: PointerSize,
) -> Option<ConstValue<T>> {
match self {
Self::Constant(v) => Some(v.clone()),
Self::Variable(var) => bindings.get(var).cloned(),
Self::NamedVar(_) => None,
Self::Unary { op, operand } => {
let v = operand.evaluate(bindings, ptr_size)?;
evaluate_unary_typed(*op, &v, ptr_size)
}
Self::Binary { op, left, right } => {
let l = left.evaluate(bindings, ptr_size)?;
let r = right.evaluate(bindings, ptr_size)?;
evaluate_binary_typed(*op, &l, &r, ptr_size)
}
}
}
#[must_use]
pub fn evaluate_named(
&self,
bindings: &HashMap<&str, ConstValue<T>>,
ptr_size: PointerSize,
) -> Option<ConstValue<T>> {
match self {
Self::Constant(v) => Some(v.clone()),
Self::Variable(_) => None,
Self::NamedVar(name) => bindings.get(name.as_str()).cloned(),
Self::Unary { op, operand } => {
let v = operand.evaluate_named(bindings, ptr_size)?;
evaluate_unary_typed(*op, &v, ptr_size)
}
Self::Binary { op, left, right } => {
let l = left.evaluate_named(bindings, ptr_size)?;
let r = right.evaluate_named(bindings, ptr_size)?;
evaluate_binary_typed(*op, &l, &r, ptr_size)
}
}
}
#[must_use]
pub fn substitute(&self, var: SsaVarId, replacement: &Self) -> Self {
match self {
Self::Constant(v) => Self::Constant(v.clone()),
Self::Variable(v) if *v == var => replacement.clone(),
Self::Variable(v) => Self::Variable(*v),
Self::NamedVar(name) => Self::NamedVar(name.clone()),
Self::Unary { op, operand } => Self::Unary {
op: *op,
operand: Box::new(operand.substitute(var, replacement)),
},
Self::Binary { op, left, right } => Self::Binary {
op: *op,
left: Box::new(left.substitute(var, replacement)),
right: Box::new(right.substitute(var, replacement)),
},
}
}
#[must_use]
pub fn substitute_named(&self, name: &str, value: i64, ptr_size: PointerSize) -> Self {
self.substitute_named_expr(name, &Self::Constant(ConstValue::I64(value)))
.simplify(ptr_size)
}
#[must_use]
pub fn substitute_named_expr(&self, name: &str, replacement: &Self) -> Self {
match self {
Self::Constant(v) => Self::Constant(v.clone()),
Self::Variable(v) => Self::Variable(*v),
Self::NamedVar(n) if n == name => replacement.clone(),
Self::NamedVar(n) => Self::NamedVar(n.clone()),
Self::Unary { op, operand } => Self::Unary {
op: *op,
operand: Box::new(operand.substitute_named_expr(name, replacement)),
},
Self::Binary { op, left, right } => Self::Binary {
op: *op,
left: Box::new(left.substitute_named_expr(name, replacement)),
right: Box::new(right.substitute_named_expr(name, replacement)),
},
}
}
#[must_use]
#[allow(clippy::match_same_arms)] pub fn simplify(&self, ptr_size: PointerSize) -> Self {
match self {
Self::Constant(_) | Self::Variable(_) | Self::NamedVar(_) => self.clone(),
Self::Unary { op, operand } => {
let simplified = operand.simplify(ptr_size);
if let Self::Constant(v) = &simplified {
if let Some(result) = evaluate_unary_typed(*op, v, ptr_size) {
return Self::Constant(result);
}
}
if let Self::Unary {
op: inner_op,
operand: inner_operand,
} = &simplified
{
if op == inner_op {
match op {
SymbolicOp::Neg => return (**inner_operand).clone(),
SymbolicOp::Not => return (**inner_operand).clone(),
_ => {}
}
}
}
Self::Unary {
op: *op,
operand: Box::new(simplified),
}
}
Self::Binary { op, left, right } => {
let left_simp = left.simplify(ptr_size);
let right_simp = right.simplify(ptr_size);
if let (Self::Constant(l), Self::Constant(r)) = (&left_simp, &right_simp) {
if let Some(result) = evaluate_binary_typed(*op, l, r, ptr_size) {
return Self::Constant(result);
}
}
if left_simp == right_simp {
match op {
SymbolicOp::Xor => return Self::Constant(ConstValue::I32(0)),
SymbolicOp::Sub => return Self::Constant(ConstValue::I32(0)),
SymbolicOp::Or => return left_simp,
SymbolicOp::And => return left_simp,
_ => {}
}
}
if *op == SymbolicOp::Xor {
if let Self::Constant(c1) = &right_simp {
if let Self::Binary {
op: SymbolicOp::Xor,
left: inner_left,
right: inner_right,
} = &left_simp
{
if let Self::Constant(c2) = inner_right.as_ref() {
if c1 == c2 {
return (**inner_left).clone();
}
}
if let Self::Constant(c2) = inner_left.as_ref() {
if c1 == c2 {
return (**inner_right).clone();
}
}
}
}
if let Self::Constant(c1) = &left_simp {
if let Self::Binary {
op: SymbolicOp::Xor,
left: inner_left,
right: inner_right,
} = &right_simp
{
if let Self::Constant(c2) = inner_right.as_ref() {
if c1 == c2 {
return (**inner_left).clone();
}
}
if let Self::Constant(c2) = inner_left.as_ref() {
if c1 == c2 {
return (**inner_right).clone();
}
}
}
}
}
if let Self::Constant(r) = &right_simp {
if r.is_zero() {
match op {
SymbolicOp::Add | SymbolicOp::Sub => return left_simp,
SymbolicOp::Mul => return Self::Constant(ConstValue::I32(0)),
SymbolicOp::Xor | SymbolicOp::Or => return left_simp,
SymbolicOp::And => return Self::Constant(ConstValue::I32(0)),
_ => {}
}
} else if r.is_one() {
match op {
SymbolicOp::Mul | SymbolicOp::DivS | SymbolicOp::DivU => {
return left_simp
}
_ => {}
}
} else if r.is_all_ones() {
match op {
SymbolicOp::And => return left_simp,
SymbolicOp::Or => return Self::Constant(r.clone()),
SymbolicOp::Xor => {
return Self::Unary {
op: SymbolicOp::Not,
operand: Box::new(left_simp),
}
}
_ => {}
}
}
}
if let Self::Constant(l) = &left_simp {
if l.is_zero() {
match op {
SymbolicOp::Add => return right_simp,
SymbolicOp::Sub => {
return Self::Unary {
op: SymbolicOp::Neg,
operand: Box::new(right_simp),
}
}
SymbolicOp::Mul => return Self::Constant(ConstValue::I32(0)),
SymbolicOp::Xor | SymbolicOp::Or => return right_simp,
SymbolicOp::And => return Self::Constant(ConstValue::I32(0)),
_ => {}
}
} else if l.is_one() {
if *op == SymbolicOp::Mul {
return right_simp;
}
} else if l.is_all_ones() {
match op {
SymbolicOp::And => return right_simp,
SymbolicOp::Or => return Self::Constant(l.clone()),
SymbolicOp::Xor => {
return Self::Unary {
op: SymbolicOp::Not,
operand: Box::new(right_simp),
}
}
_ => {}
}
}
}
Self::Binary {
op: *op,
left: Box::new(left_simp),
right: Box::new(right_simp),
}
}
}
}
#[must_use]
pub fn depth(&self) -> usize {
match self {
Self::Constant(_) | Self::Variable(_) | Self::NamedVar(_) => 0,
Self::Unary { operand, .. } => 1usize.saturating_add(operand.depth()),
Self::Binary { left, right, .. } => {
1usize.saturating_add(left.depth().max(right.depth()))
}
}
}
}
impl<T: Target> fmt::Display for SymbolicExpr<T>
where
T::TypeRef: fmt::Display,
T::MethodRef: fmt::Display,
T::FieldRef: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Constant(v) => write!(f, "{v}"),
Self::Variable(var) => write!(f, "v{}", var.index()),
Self::NamedVar(name) => write!(f, "{name}"),
Self::Unary { op, operand } => write!(f, "({op}{operand})"),
Self::Binary { op, left, right } => write!(f, "({left} {op} {right})"),
}
}
}
impl<T: Target> From<ConstValue<T>> for SymbolicExpr<T> {
fn from(value: ConstValue<T>) -> Self {
Self::Constant(value)
}
}
impl<T: Target> From<i32> for SymbolicExpr<T> {
fn from(value: i32) -> Self {
Self::Constant(ConstValue::I32(value))
}
}
impl<T: Target> From<i64> for SymbolicExpr<T> {
fn from(value: i64) -> Self {
Self::Constant(ConstValue::I64(value))
}
}
pub fn evaluate_unary_typed<T: Target>(
op: SymbolicOp,
value: &ConstValue<T>,
ptr_size: PointerSize,
) -> Option<ConstValue<T>> {
match op {
SymbolicOp::Neg => value.negate(ptr_size),
SymbolicOp::Not => value.bitwise_not(ptr_size),
_ => None,
}
}
pub fn evaluate_binary_typed<T: Target>(
op: SymbolicOp,
left: &ConstValue<T>,
right: &ConstValue<T>,
ptr_size: PointerSize,
) -> Option<ConstValue<T>> {
match op {
SymbolicOp::Add => left.add(right, ptr_size),
SymbolicOp::Sub => left.sub(right, ptr_size),
SymbolicOp::Mul => left.mul(right, ptr_size),
SymbolicOp::DivS | SymbolicOp::DivU => left.div(right, ptr_size),
SymbolicOp::RemS | SymbolicOp::RemU => left.rem(right, ptr_size),
SymbolicOp::And => left.bitwise_and(right, ptr_size),
SymbolicOp::Or => left.bitwise_or(right, ptr_size),
SymbolicOp::Xor => left.bitwise_xor(right, ptr_size),
SymbolicOp::Shl => left.shl(right, ptr_size),
SymbolicOp::ShrS => left.shr(right, false, ptr_size),
SymbolicOp::ShrU => left.shr(right, true, ptr_size),
SymbolicOp::Eq => left.ceq(right),
SymbolicOp::Ne => left.ceq(right).map(|v| {
if v.is_zero() {
ConstValue::I32(1)
} else {
ConstValue::I32(0)
}
}),
SymbolicOp::LtS => left.clt(right),
SymbolicOp::LtU => left.clt_un(right),
SymbolicOp::GtS => left.cgt(right),
SymbolicOp::GtU => left.cgt_un(right),
SymbolicOp::LeS => {
left.cgt(right).map(|v| {
if v.is_zero() {
ConstValue::I32(1)
} else {
ConstValue::I32(0)
}
})
}
SymbolicOp::LeU => {
left.cgt_un(right).map(|v| {
if v.is_zero() {
ConstValue::I32(1)
} else {
ConstValue::I32(0)
}
})
}
SymbolicOp::GeS => {
left.clt(right).map(|v| {
if v.is_zero() {
ConstValue::I32(1)
} else {
ConstValue::I32(0)
}
})
}
SymbolicOp::GeU => {
left.clt_un(right).map(|v| {
if v.is_zero() {
ConstValue::I32(1)
} else {
ConstValue::I32(0)
}
})
}
SymbolicOp::Neg | SymbolicOp::Not => None,
SymbolicOp::Rol
| SymbolicOp::Ror
| SymbolicOp::Rcl
| SymbolicOp::Rcr
| SymbolicOp::BSwap
| SymbolicOp::BRev
| SymbolicOp::BitScanForward
| SymbolicOp::BitScanReverse
| SymbolicOp::Popcount
| SymbolicOp::Parity => None,
}
}