use super::{
BVLitValue, Context, Expr, ExprMap, ExprRef, SerializableIrNode, SparseExprMap, TypeCheck,
WidthInt, do_transform_expr, find_symbols,
};
use crate::expr::meta::get_fixed_point;
use crate::expr::transform::ExprTransformMode;
use crate::smt::{CheckSatResponse, SolverContext};
use baa::BitVecOps;
use smallvec::{SmallVec, smallvec};
pub fn simplify_single_expression(ctx: &mut Context, expr: ExprRef) -> ExprRef {
let mut simplifier = Simplifier::new(SparseExprMap::default());
simplifier.simplify(ctx, expr)
}
pub struct Simplifier<T: ExprMap<Option<ExprRef>>> {
cache: T,
}
impl<T: ExprMap<Option<ExprRef>>> Simplifier<T> {
pub fn new(cache: T) -> Self {
Self { cache }
}
pub fn simplify(&mut self, ctx: &mut Context, e: ExprRef) -> ExprRef {
do_transform_expr(
ctx,
ExprTransformMode::FixedPoint,
&mut self.cache,
vec![e],
simplify,
);
get_fixed_point(&mut self.cache, e).unwrap()
}
pub fn verify_simplification(
&self,
ctx: &mut Context,
solver: &mut impl SolverContext,
) -> crate::smt::Result<usize> {
let mut incorrect = 0;
let mut correct = 0;
for (key, &value) in self.cache.iter() {
if let Some(simplified) = value {
let key_symbols = find_symbols(ctx, key);
let simpl_symbols = find_symbols(ctx, simplified);
let symbols: Vec<_> = key_symbols.union(&simpl_symbols).cloned().collect();
solver.push()?;
for &sym in symbols.iter() {
solver.declare_const(ctx, sym)?;
}
let not_eq = ctx.distinct(key, simplified);
solver.assert(ctx, not_eq)?;
match solver.check_sat()? {
CheckSatResponse::Sat => {
let key_value = solver.get_value(ctx, key)?;
let simplified_value = solver.get_value(ctx, simplified)?;
println!(
"{} ({}) =/= ({}) {}",
key.serialize_to_str(ctx),
key_value.serialize_to_str(ctx),
simplified_value.serialize_to_str(ctx),
simplified.serialize_to_str(ctx)
);
let mut syms = vec![];
for &sym in symbols.iter() {
let value = solver.get_value(ctx, sym)?;
syms.push(format!(
"{}={}",
sym.serialize_to_str(ctx),
value.serialize_to_str(ctx)
));
}
println!(" w/ {}", syms.join(", "));
incorrect += 1;
}
CheckSatResponse::Unsat => {
correct += 1;
} CheckSatResponse::Unknown => {} }
solver.pop()?;
}
}
if incorrect > 0 {
println!(
"{incorrect} / {} simplifications were incorrect. See log.",
incorrect + correct
);
}
Ok(incorrect)
}
}
pub(crate) fn simplify(ctx: &mut Context, expr: ExprRef, children: &[ExprRef]) -> Option<ExprRef> {
match (ctx[expr].clone(), children) {
(Expr::BVNot(_, _), [e]) => simplify_bv_not(ctx, *e),
(Expr::BVZeroExt { by, .. }, [e]) => simplify_bv_zero_ext(ctx, *e, by),
(Expr::BVSlice { lo, hi, .. }, [e]) => simplify_bv_slice(ctx, *e, hi, lo),
(Expr::BVIte { .. }, [cond, tru, fals]) => simplify_ite(ctx, *cond, *tru, *fals),
(Expr::BVConcat(..), [a, b]) => simplify_bv_concat(ctx, *a, *b),
(Expr::BVEqual(..), [a, b]) => simplify_bv_equal(ctx, *a, *b),
(Expr::BVAnd(..), [a, b]) => simplify_bv_and(ctx, *a, *b),
(Expr::BVOr(..), [a, b]) => simplify_bv_or(ctx, *a, *b),
(Expr::BVXor(..), [a, b]) => simplify_bv_xor(ctx, *a, *b),
(Expr::BVImplies(..), [a, b]) => simplify_bv_implies(ctx, *a, *b),
(Expr::BVGreaterEqual(..), [a, b]) => simplify_bv_greater_equal(ctx, *a, *b),
(Expr::BVAdd(..), [a, b]) => simplify_bv_add(ctx, *a, *b),
(Expr::BVMul(..), [a, b]) => simplify_bv_mul(ctx, *a, *b),
(Expr::BVShiftLeft(_, _, w), [a, b]) => simplify_bv_shift_left(ctx, *a, *b, w),
(Expr::BVShiftRight(_, _, w), [a, b]) => simplify_bv_shift_right(ctx, *a, *b, w),
(Expr::BVSignExt { by, .. }, [e]) => simplify_bv_sign_ext(ctx, *e, by),
(Expr::BVArithmeticShiftRight(_, _, w), [a, b]) => {
simplify_bv_arithmetic_shift_right(ctx, *a, *b, w)
}
_ => None,
}
}
fn simplify_ite(ctx: &mut Context, cond: ExprRef, tru: ExprRef, fals: ExprRef) -> Option<ExprRef> {
if tru == fals {
return Some(tru);
}
if let Expr::BVLiteral(value) = ctx[cond] {
if value.get(ctx).is_false() {
return Some(fals);
} else {
return Some(tru);
}
}
let value_width = ctx[tru].get_bv_type(ctx).unwrap();
debug_assert_eq!(ctx[fals].get_bv_type(ctx), ctx[tru].get_bv_type(ctx));
if value_width == 1 {
match (&ctx[tru], &ctx[fals]) {
(Expr::BVLiteral(vt), Expr::BVLiteral(vf)) => {
let res = match (
vt.get(ctx).to_bool().unwrap(),
vf.get(ctx).to_bool().unwrap(),
) {
(true, false) => cond,
(false, true) => ctx.not(cond),
_ => unreachable!(
"both arguments are the same, this should have been handled earlier"
),
};
Some(res)
}
(Expr::BVLiteral(vt), _) => {
match vt.get(ctx).to_bool().unwrap() {
true => Some(ctx.or(cond, fals)),
false => Some(ctx.build(|c| c.and(c.not(cond), fals))),
}
}
(_, Expr::BVLiteral(vf)) => {
match vf.get(ctx).to_bool().unwrap() {
true => Some(ctx.build(|c| c.or(c.not(cond), tru))),
false => Some(ctx.and(cond, tru)),
}
}
_ => None,
}
} else {
None
}
}
enum Lits {
Two(BVLitValue, BVLitValue),
One((BVLitValue, ExprRef), ExprRef),
None,
}
#[inline]
fn find_lits_commutative(ctx: &Context, a: ExprRef, b: ExprRef) -> Lits {
match (&ctx[a], &ctx[b]) {
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => Lits::Two(*va, *vb),
(Expr::BVLiteral(va), _) => Lits::One((*va, a), b),
(_, Expr::BVLiteral(vb)) => Lits::One((*vb, b), a),
(_, _) => Lits::None,
}
}
#[inline]
fn find_one_concat(ctx: &Context, a: ExprRef, b: ExprRef) -> Option<(ExprRef, ExprRef, ExprRef)> {
match (&ctx[a], &ctx[b]) {
(Expr::BVConcat(c_a, c_b, _), _) => Some((*c_a, *c_b, b)),
(_, Expr::BVConcat(c_a, c_b, _)) => Some((*c_a, *c_b, a)),
_ => None,
}
}
fn simplify_bv_equal(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
if a == b {
return Some(ctx.get_true());
}
match find_lits_commutative(ctx, a, b) {
Lits::Two(va, vb) => {
debug_assert!(!va.get(ctx).is_equal(&vb.get(ctx)));
return Some(ctx.get_false());
}
Lits::One((lit, _), expr) => {
if lit.is_true() {
return Some(expr);
} else if lit.is_false() {
return Some(ctx.not(expr));
}
}
Lits::None => {}
}
if let Some((concat_a, concat_b, other)) = find_one_concat(ctx, a, b) {
let a_width = ctx[concat_a].get_bv_type(ctx).unwrap();
let b_width = ctx[concat_b].get_bv_type(ctx).unwrap();
let width = a_width + b_width;
debug_assert_eq!(width, other.get_bv_type(ctx).unwrap());
let eq_a = ctx.build(|c| c.equal(concat_a, c.slice(other, width - 1, width - a_width)));
let eq_b = ctx.build(|c| c.equal(concat_b, c.slice(other, b_width - 1, 0)));
return Some(ctx.and(eq_a, eq_b));
}
None
}
fn simplify_bv_and(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
if a == b {
return Some(a);
}
match find_lits_commutative(ctx, a, b) {
Lits::Two(va, vb) => {
Some(ctx.bv_lit(&va.get(ctx).and(&vb.get(ctx))))
}
Lits::One((lit, lit_expr), expr) => {
if lit.get(ctx).is_zero() {
Some(lit_expr)
} else if lit.get(ctx).is_all_ones() {
Some(expr)
} else {
if let Expr::BVConcat(a, b, width) = ctx[expr].clone() {
let b_width = b.get_bv_type(ctx).unwrap();
debug_assert_eq!(width, b_width + a.get_bv_type(ctx).unwrap());
let a_mask = ctx.bv_lit(&lit.get(ctx).slice(width - 1, b_width));
let b_mask = ctx.bv_lit(&lit.get(ctx).slice(b_width - 1, 0));
Some(ctx.build(|c| c.concat(c.and(a, a_mask), c.and(b, b_mask))))
} else {
let width = expr.get_bv_type(ctx).unwrap();
let ones = lit.get(ctx).bit_set_intervals();
debug_assert!(!ones.is_empty());
let mut bit = 0;
let mut values: SmallVec<[ExprRef; 6]> = smallvec![];
for interval in ones.into_iter() {
if interval.start > bit {
values.push(ctx.zero(interval.start - bit));
}
values.push(ctx.slice(expr, interval.end - 1, interval.start));
bit = interval.end;
}
if bit < width {
values.push(ctx.zero(width - bit));
}
debug_assert!(values.len() > 1);
let out = values
.into_iter()
.rev()
.reduce(|a, b| ctx.concat(a, b))
.unwrap();
Some(out)
}
}
}
Lits::None => {
match (&ctx[a], &ctx[b]) {
(Expr::BVNot(inner, w), _) if *inner == b => Some(ctx.zero(*w)),
(_, Expr::BVNot(inner, w)) if *inner == a => Some(ctx.zero(*w)),
(Expr::BVNot(a, _), Expr::BVNot(b, _)) => {
let or = ctx.or(*a, *b);
Some(ctx.not(or))
}
_ => None,
}
}
}
}
fn simplify_bv_or(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
if a == b {
return Some(a);
}
match find_lits_commutative(ctx, a, b) {
Lits::Two(va, vb) => {
Some(ctx.bv_lit(&va.get(ctx).or(&vb.get(ctx))))
}
Lits::One((lit, lit_expr), expr) => {
if lit.get(ctx).is_zero() {
Some(expr)
} else if lit.get(ctx).is_all_ones() {
Some(lit_expr)
} else {
None
}
}
Lits::None => {
match (&ctx[a], &ctx[b]) {
(Expr::BVNot(inner, w), _) if *inner == b => Some(ctx.ones(*w)),
(_, Expr::BVNot(inner, w)) if *inner == a => Some(ctx.ones(*w)),
(Expr::BVNot(a, _), Expr::BVNot(b, _)) => {
let and = ctx.and(*a, *b);
Some(ctx.not(and))
}
_ => None,
}
}
}
}
fn simplify_bv_xor(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
if a == b {
let width = ctx[a].get_bv_type(ctx).unwrap();
return Some(ctx.zero(width));
}
match find_lits_commutative(ctx, a, b) {
Lits::Two(va, vb) => {
Some(ctx.bv_lit(&va.get(ctx).xor(&vb.get(ctx))))
}
Lits::One((lit, _), expr) => {
if lit.get(ctx).is_zero() {
Some(expr)
} else if lit.get(ctx).is_all_ones() {
Some(ctx.not(expr))
} else {
None
}
}
Lits::None => {
match (&ctx[a], &ctx[b]) {
(Expr::BVNot(inner, w), _) if *inner == b => Some(ctx.ones(*w)),
(_, Expr::BVNot(inner, w)) if *inner == a => Some(ctx.ones(*w)),
_ => None,
}
}
}
}
fn simplify_bv_implies(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
if let Expr::BVLiteral(va) = ctx[a] {
if va.get(ctx).is_false() {
return Some(ctx.get_true());
} else {
return Some(b);
}
}
None
}
fn simplify_bv_greater_equal(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
match (&ctx[a], &ctx[b]) {
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => {
let result = va.get(ctx).is_greater_or_equal(&vb.get(ctx));
Some(ctx.bv_lit(&result.into()))
}
_ => None,
}
}
fn simplify_bv_not(ctx: &mut Context, e: ExprRef) -> Option<ExprRef> {
match &ctx[e] {
Expr::BVNot(inner, _) => Some(*inner), Expr::BVLiteral(value) => Some(ctx.bv_lit(&value.get(ctx).not())),
_ => None,
}
}
fn simplify_bv_zero_ext(ctx: &mut Context, e: ExprRef, by: WidthInt) -> Option<ExprRef> {
if by == 0 {
Some(e)
} else {
match &ctx[e] {
Expr::BVLiteral(value) => Some(ctx.bv_lit(&value.get(ctx).zero_extend(by))),
_ => Some(ctx.build(|c| c.concat(c.zero(by), e))),
}
}
}
fn simplify_bv_sign_ext(ctx: &mut Context, e: ExprRef, by: WidthInt) -> Option<ExprRef> {
if by == 0 {
Some(e)
} else {
match &ctx[e] {
Expr::BVLiteral(value) => Some(ctx.bv_lit(&value.get(ctx).sign_extend(by))),
Expr::BVSignExt {
e: inner_e,
by: inner_by,
..
} => Some(ctx.sign_extend(*inner_e, by + inner_by)),
_ => None,
}
}
}
fn simplify_bv_concat(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
match (ctx[a].clone(), ctx[b].clone()) {
(Expr::BVConcat(a_a, a_b, _), _) => Some(ctx.build(|c| c.concat(a_a, c.concat(a_b, b)))),
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => {
Some(ctx.bv_lit(&va.get(ctx).concat(&vb.get(ctx))))
}
(Expr::BVLiteral(va), Expr::BVConcat(b_a, b_b, _)) => {
if let Expr::BVLiteral(v_b_a) = ctx[b_a] {
let lit = ctx.bv_lit(&va.get(ctx).concat(&v_b_a.get(ctx)));
Some(ctx.concat(lit, b_b))
} else {
None
}
}
(
Expr::BVSlice {
e: a,
hi: hi_a,
lo: lo_a,
},
Expr::BVSlice {
e: b,
hi: hi_b,
lo: lo_b,
},
) => {
if a == b && lo_a == hi_b + 1 {
Some(ctx.slice(a, hi_a, lo_b))
} else {
None
}
}
_ => None,
}
}
fn simplify_bv_slice(ctx: &mut Context, e: ExprRef, hi: WidthInt, lo: WidthInt) -> Option<ExprRef> {
debug_assert!(hi >= lo);
match ctx[e].clone() {
Expr::BVSlice {
lo: inner_lo,
e: inner_e,
..
} => Some(ctx.slice(inner_e, hi + inner_lo, lo + inner_lo)),
Expr::BVLiteral(value) => Some(ctx.bv_lit(&value.get(ctx).slice(hi, lo))),
Expr::BVConcat(a, b, _) => {
let b_width = b.get_bv_type(ctx).unwrap();
if hi < b_width {
Some(ctx.slice(b, hi, lo))
} else if lo >= b_width {
Some(ctx.slice(a, hi - b_width, lo - b_width))
} else {
let a_slice = ctx.slice(a, hi - b_width, 0);
let b_slice = ctx.slice(b, b_width - 1, lo);
Some(ctx.concat(a_slice, b_slice))
}
}
Expr::BVSignExt { e, .. } => {
let e_width = e.get_bv_type(ctx).unwrap();
if hi < e_width {
Some(ctx.slice(e, hi, lo))
} else {
let inner = ctx.slice(e, e_width - 1, lo);
Some(ctx.sign_extend(inner, hi - e_width + 1))
}
}
Expr::BVIte { cond, tru, fals } => {
Some(ctx.build(|c| c.ite(cond, c.slice(tru, hi, lo), c.slice(fals, hi, lo))))
}
Expr::BVNot(e, _) => Some(ctx.build(|c| c.not(c.slice(e, hi, lo)))),
Expr::BVNegate(e, _) if lo == 0 => Some(ctx.build(|c| c.negate(c.slice(e, hi, lo)))),
Expr::BVAnd(a, b, _) => Some(ctx.build(|c| c.and(c.slice(a, hi, lo), c.slice(b, hi, lo)))),
Expr::BVOr(a, b, _) => Some(ctx.build(|c| c.or(c.slice(a, hi, lo), c.slice(b, hi, lo)))),
Expr::BVXor(a, b, _) => Some(ctx.build(|c| c.xor(c.slice(a, hi, lo), c.slice(b, hi, lo)))),
Expr::BVAdd(a, b, _) if lo == 0 => {
Some(ctx.build(|c| c.add(c.slice(a, hi, lo), c.slice(b, hi, lo))))
}
Expr::BVSub(a, b, _) if lo == 0 => {
Some(ctx.build(|c| c.sub(c.slice(a, hi, lo), c.slice(b, hi, lo))))
}
Expr::BVMul(a, b, _) if lo == 0 => {
Some(ctx.build(|c| c.mul(c.slice(a, hi, lo), c.slice(b, hi, lo))))
}
_ => None,
}
}
fn simplify_bv_shift_left(
ctx: &mut Context,
a: ExprRef,
b: ExprRef,
width: WidthInt,
) -> Option<ExprRef> {
match (&ctx[a], &ctx[b]) {
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => {
Some(ctx.bv_lit(&va.get(ctx).shift_left(&vb.get(ctx))))
}
(_, Expr::BVLiteral(by)) => {
let by = by.get(ctx);
if let Some(by) = by.to_u64() {
let by = by as WidthInt;
if by >= width {
Some(ctx.zero(width))
} else if by == 0 {
Some(a)
} else {
let msb = width - 1 - by;
Some(ctx.build(|c| c.concat(c.slice(a, msb, 0), c.zero(by))))
}
} else {
Some(ctx.zero(width))
}
}
(_, _) => None,
}
}
fn simplify_bv_shift_right(
ctx: &mut Context,
a: ExprRef,
b: ExprRef,
width: WidthInt,
) -> Option<ExprRef> {
match (&ctx[a], &ctx[b]) {
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => {
Some(ctx.bv_lit(&va.get(ctx).shift_right(&vb.get(ctx))))
}
(_, Expr::BVLiteral(by)) => {
let by = by.get(ctx);
if let Some(by) = by.to_u64() {
let by = by as WidthInt;
if by >= width {
Some(ctx.zero(width))
} else if by == 0 {
Some(a)
} else {
let msb = width - 1;
let lsb = by;
Some(ctx.build(|c| c.zero_extend(c.slice(a, msb, lsb), by)))
}
} else {
Some(ctx.zero(width))
}
}
(_, _) => None,
}
}
fn simplify_bv_arithmetic_shift_right(
ctx: &mut Context,
a: ExprRef,
b: ExprRef,
width: WidthInt,
) -> Option<ExprRef> {
match (&ctx[a], &ctx[b]) {
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => {
Some(ctx.bv_lit(&va.get(ctx).arithmetic_shift_right(&vb.get(ctx))))
}
(_, Expr::BVLiteral(by)) => {
let by = by.get(ctx);
if let Some(by) = by.to_u64() {
let by = by as WidthInt;
if by >= width {
Some(ctx.build(|c| c.sign_extend(c.slice(a, width - 1, width - 1), width - 1)))
} else if by == 0 {
Some(a)
} else {
let msb = width - 1;
let lsb = by;
Some(ctx.build(|c| c.sign_extend(c.slice(a, msb, lsb), by)))
}
} else {
Some(ctx.build(|c| c.sign_extend(c.slice(a, width - 1, width - 1), width - 1)))
}
}
(_, _) => None,
}
}
fn simplify_bv_add(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
match find_lits_commutative(ctx, a, b) {
Lits::Two(va, vb) => Some(ctx.bv_lit(&va.get(ctx).add(&vb.get(ctx)))),
Lits::One((va, _), b) => {
if va.get(ctx).is_zero() {
Some(b)
} else {
None
}
}
Lits::None => None,
}
}
fn simplify_bv_mul(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
match find_lits_commutative(ctx, a, b) {
Lits::Two(va, vb) => Some(ctx.bv_lit(&va.get(ctx).mul(&vb.get(ctx)))),
Lits::One((va, a), b) => {
let va = va.get(ctx);
if va.is_zero() {
Some(a)
} else if va.is_one() {
Some(b)
} else if let Some(log_2) = va.is_pow_2() {
let log_2 = ctx.bit_vec_val(log_2, va.width());
Some(ctx.shift_left(b, log_2))
} else {
None
}
}
Lits::None => None,
}
}