use super::common::*;
use super::*;
use easy_smt::{Context, ContextBuilder, SExpr};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum SolverBackend {
Z3,
CVC5,
}
pub struct SmtLibSolver {
ctx: Context,
}
impl SmtLibSolver {
pub fn new(backend: SolverBackend) -> Self {
let mut ctx_bld = ContextBuilder::new();
match backend {
SolverBackend::Z3 => {
ctx_bld.solver("z3").solver_args(["-smt2", "-in", "-v:0"]);
}
SolverBackend::CVC5 => {
ctx_bld
.solver("cvc5")
.solver_args(["--quiet", "--lang=smt2", "--incremental"]);
}
}
let mut ctx = ctx_bld.build().unwrap();
ctx.set_logic("QF_NIA").unwrap();
match backend {
SolverBackend::Z3 => {
ctx.set_option(":timeout", ctx.numeral(1000)).unwrap();
}
SolverBackend::CVC5 => {
ctx.set_option(":tlimit-per", ctx.numeral(1000)).unwrap();
}
}
ctx.push().unwrap();
SmtLibSolver { ctx }
}
pub fn check_sat(
&mut self,
prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
) -> Option<HashMap<IdentCtx, LitVal>> {
if prims.is_empty() {
return Some(HashMap::new());
}
self.ctx.pop().unwrap();
self.ctx.push().unwrap();
let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
let sexp_map = self.solve_constraints(prims, &ty_map);
let check_res = self.ctx.check().unwrap();
if check_res == easy_smt::Response::Sat {
let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
let res = vars
.iter()
.cloned()
.zip(
self.ctx
.get_value(vars.iter().map(|var| sexp_map[var]).collect())
.unwrap()
.iter()
.map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
)
.collect();
Some(res)
} else {
None
}
}
fn solve_constraints(
&mut self,
prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
ty_map: &HashMap<IdentCtx, LitType>,
) -> HashMap<IdentCtx, SExpr> {
let sexp_map: HashMap<IdentCtx, SExpr> = ty_map
.iter()
.map(|(var, typ)| {
let sort = match typ {
LitType::TyInt => self.ctx.int_sort(),
LitType::TyFloat => self.ctx.real_sort(),
LitType::TyBool => self.ctx.bool_sort(),
LitType::TyChar => todo!(),
};
let sexp = self.ctx.declare_const(format!("{var:?}"), sort).unwrap();
(*var, sexp)
})
.collect();
for (prim, args) in prims {
let args: Vec<SExpr> = args
.iter()
.map(|arg| self.atom_to_sexp(arg, &sexp_map))
.collect();
match (prim, &args[..]) {
(
Prim::IAdd | Prim::ISub | Prim::IMul | Prim::IDiv | Prim::IRem,
&[arg1, arg2, arg3],
) => {
let res = match prim {
Prim::IAdd => self.ctx.plus(arg1, arg2),
Prim::ISub => self.ctx.sub(arg1, arg2),
Prim::IMul => self.ctx.times(arg1, arg2),
Prim::IDiv => self.ctx.div(arg1, arg2),
Prim::IRem => self.ctx.rem(arg1, arg2),
_ => unreachable!(),
};
self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
}
(Prim::INeg, &[arg1, arg2]) => {
let res = self.ctx.negate(arg1);
self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
}
(Prim::ICmp(cmp), &[arg1, arg2, arg3]) => {
let res = match cmp {
Compare::Lt => self.ctx.lt(arg1, arg2),
Compare::Le => self.ctx.lte(arg1, arg2),
Compare::Eq => self.ctx.eq(arg1, arg2),
Compare::Ge => self.ctx.gte(arg1, arg2),
Compare::Gt => self.ctx.gt(arg1, arg2),
Compare::Ne => self.ctx.not(self.ctx.eq(arg1, arg2)),
};
self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
}
(Prim::BAnd | Prim::BOr, &[arg1, arg2, arg3]) => {
let res = match prim {
Prim::BAnd => self.ctx.and(arg1, arg2),
Prim::BOr => self.ctx.or(arg1, arg2),
_ => unreachable!(),
};
self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
}
(Prim::BNot, &[arg1, arg2]) => {
let res = self.ctx.not(arg1);
self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
}
_ => {
panic!("wrong arity of primitives!");
}
}
}
sexp_map
}
fn atom_to_sexp(&self, atom: &AtomVal<IdentCtx>, map: &HashMap<IdentCtx, SExpr>) -> SExpr {
match atom {
Term::Var(var) => map[var],
Term::Lit(LitVal::Int(x)) => self.ctx.numeral(*x),
Term::Lit(LitVal::Float(x)) => self.ctx.decimal(*x),
Term::Lit(LitVal::Bool(x)) => {
if *x {
self.ctx.true_()
} else {
self.ctx.false_()
}
}
Term::Lit(LitVal::Char(_x)) => todo!(),
Term::Cons(_cons, _flds) => unreachable!(),
}
}
fn sexp_to_lit_val(&self, sexpr: SExpr) -> Option<LitVal> {
if let Some(res) = self.ctx.get_i64(sexpr) {
return Some(LitVal::Int(res));
}
if let Some(res) = self.ctx.get_f64(sexpr) {
return Some(LitVal::Float(res));
}
if let Some(res) = self.ctx.get_atom(sexpr) {
match res {
"true" => {
return Some(LitVal::Bool(true));
}
"false" => {
return Some(LitVal::Bool(false));
}
_ => {
return None;
}
}
}
None
}
}
impl common::PrimSolver for SmtLibSolver {
fn check_sat(
&mut self,
prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
) -> Option<HashMap<IdentCtx, LitVal>> {
if prims.is_empty() {
return Some(HashMap::new());
}
self.ctx.pop().unwrap();
self.ctx.push().unwrap();
let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
let sexp_map = self.solve_constraints(prims, &ty_map);
let check_res = self.ctx.check().unwrap();
if check_res == easy_smt::Response::Sat {
let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
let res = vars
.iter()
.cloned()
.zip(
self.ctx
.get_value(vars.iter().map(|var| sexp_map[var]).collect())
.unwrap()
.iter()
.map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
)
.collect();
Some(res)
} else {
None
}
}
}