from .primitives import GROUP as PRIMITIVES, prim_from_bv, prim_to_bv, bvadd,\
bvult, bvzeroext, bvsplit, bvconcat, bvsignext
from cdsl.ast import Var
from cdsl.types import BVType
from .elaborate import elaborate
from z3 import BitVec, ZeroExt, SignExt, And, Extract, Concat, Not, Solver,\
unsat, BoolRef, BitVecVal, If
from z3.z3core import Z3_mk_eq
try:
from typing import TYPE_CHECKING, Tuple, Dict, List from cdsl.xform import Rtl, XForm from cdsl.ast import VarAtomMap, Atom from cdsl.ti import VarTyping if TYPE_CHECKING:
from z3 import ExprRef, BitVecRef Z3VarMap = Dict[Var, BitVecRef]
except ImportError:
TYPE_CHECKING = False
def mk_eq(e1, e2):
return BoolRef(Z3_mk_eq(e1.ctx_ref(), e1.as_ast(), e2.as_ast()), e1.ctx)
def to_smt(r):
assert r.is_concrete()
primitives = set(PRIMITIVES.instructions)
assert set(d.expr.inst for d in r.rtl).issubset(primitives)
q = [] m = {}
var_to_bv = {} for v in r.vars():
typ = v.get_typevar().singleton_type()
if not isinstance(typ, BVType):
continue
var_to_bv[v] = BitVec(v.name, typ.bits)
for d in r.rtl:
inst = d.expr.inst
exp = None if inst == prim_to_bv:
assert isinstance(d.expr.args[0], Var)
m[d.expr.args[0]] = var_to_bv[d.defs[0]]
continue
if inst == prim_from_bv:
assert isinstance(d.expr.args[0], Var)
m[d.defs[0]] = var_to_bv[d.expr.args[0]]
continue
if inst in [bvadd, bvult]: assert len(d.expr.args) == 2 and len(d.defs) == 1
lhs = d.expr.args[0]
rhs = d.expr.args[1]
df = d.defs[0]
assert isinstance(lhs, Var) and isinstance(rhs, Var)
if inst == bvadd: exp = (var_to_bv[lhs] + var_to_bv[rhs])
else:
assert inst == bvult
exp = (var_to_bv[lhs] < var_to_bv[rhs])
exp = If(exp, BitVecVal(1, 1), BitVecVal(0, 1))
exp = mk_eq(var_to_bv[df], exp)
elif inst == bvzeroext:
arg = d.expr.args[0]
df = d.defs[0]
assert isinstance(arg, Var)
fromW = arg.get_typevar().singleton_type().width()
toW = df.get_typevar().singleton_type().width()
exp = mk_eq(var_to_bv[df], ZeroExt(toW-fromW, var_to_bv[arg]))
elif inst == bvsignext:
arg = d.expr.args[0]
df = d.defs[0]
assert isinstance(arg, Var)
fromW = arg.get_typevar().singleton_type().width()
toW = df.get_typevar().singleton_type().width()
exp = mk_eq(var_to_bv[df], SignExt(toW-fromW, var_to_bv[arg]))
elif inst == bvsplit:
arg = d.expr.args[0]
assert isinstance(arg, Var)
arg_typ = arg.get_typevar().singleton_type()
width = arg_typ.width()
assert (width % 2 == 0)
lo = d.defs[0]
hi = d.defs[1]
exp = And(mk_eq(var_to_bv[lo],
Extract(width//2-1, 0, var_to_bv[arg])),
mk_eq(var_to_bv[hi],
Extract(width-1, width//2, var_to_bv[arg])))
elif inst == bvconcat:
assert isinstance(d.expr.args[0], Var) and \
isinstance(d.expr.args[1], Var)
lo = d.expr.args[0]
hi = d.expr.args[1]
df = d.defs[0]
exp = mk_eq(var_to_bv[df], Concat(var_to_bv[hi], var_to_bv[lo]))
else:
assert False, "Unknown primitive instruction {}".format(inst)
q.append(exp)
return (q, m)
def equivalent(r1, r2, inp_m, out_m):
assert set(r1.free_vars()) == set(inp_m.keys())
assert set(r2.free_vars()) == set(inp_m.values())
src_m = {v: Var(v.name + ".a", v.get_typevar()) for v in r1.vars()} dst_m = {v: Var(v.name + ".b", v.get_typevar()) for v in r2.vars()} r1 = r1.copy(src_m)
r2 = r2.copy(dst_m)
def _translate(m, k_m, v_m):
res = {} for (k, v) in m1.items():
new_k = k_m[k]
new_v = v_m[v]
assert isinstance(new_k, Var)
res[new_k] = new_v
return res
inp_m = _translate(inp_m, src_m, dst_m)
out_m = _translate(out_m, src_m, dst_m)
(q1, m1) = to_smt(r1)
(q2, m2) = to_smt(r2)
args_eq_exp = []
for (v1, v2) in inp_m.items():
assert isinstance(v2, Var)
args_eq_exp.append(mk_eq(m1[v1], m2[v2]))
results_eq_exp = [] for (v1, v2) in out_m.items():
assert isinstance(v2, Var)
results_eq_exp.append(mk_eq(m1[v1], m2[v2]))
return q1 + q2 + args_eq_exp + [Not(And(*results_eq_exp))]
def xform_correct(x, typing):
assert x.ti.permits(typing)
src_m = {v: Var(v.name, typing[v]) for v in x.src.vars()} src = x.src.copy(src_m)
dst = x.apply(src)
dst_m = x.dst.substitution(dst, {})
inp_m = {} out_m = {}
for v in x.src.vars():
src_v = src_m[v]
assert isinstance(src_v, Var)
if v.is_input():
inp_m[src_v] = dst_m[v]
elif v.is_output():
out_m[src_v] = dst_m[v]
prim_src = elaborate(src)
prim_dst = elaborate(dst)
asserts = equivalent(prim_src, prim_dst, inp_m, out_m)
s = Solver()
s.add(*asserts)
return s.check() == unsat