use std::cell::RefCell;
use std::rc::Rc;
use num_bigint::BigInt;
use num_rational::Rational64;
use oxiz_core::ast::{TermId, TermManager};
use oxiz_core::sort::SortId;
use crate::Context;
use crate::SolverResult;
use crate::solver::SolverConfig;
#[derive(Debug, Clone, Default)]
pub struct Z3Config {
inner: SolverConfig,
}
impl Z3Config {
#[must_use]
pub fn new() -> Self {
Self {
inner: SolverConfig::default(),
}
}
pub fn set_proof(&mut self, enabled: bool) -> &mut Self {
self.inner.proof = enabled;
self
}
#[must_use]
pub fn as_solver_config(&self) -> &SolverConfig {
&self.inner
}
}
pub struct Z3Context {
pub(crate) tm: Rc<RefCell<TermManager>>,
pub(crate) config: SolverConfig,
}
impl Z3Context {
#[must_use]
pub fn new(cfg: &Z3Config) -> Self {
Self {
tm: Rc::new(RefCell::new(TermManager::new())),
config: cfg.inner.clone(),
}
}
#[must_use]
pub fn bool_sort(&self) -> SortId {
self.tm.borrow().sorts.bool_sort
}
#[must_use]
pub fn int_sort(&self) -> SortId {
self.tm.borrow().sorts.int_sort
}
#[must_use]
pub fn real_sort(&self) -> SortId {
self.tm.borrow().sorts.real_sort
}
#[must_use]
pub fn bv_sort(&self, width: u32) -> SortId {
self.tm.borrow_mut().sorts.bitvec(width)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SatResult {
Sat,
Unsat,
Unknown,
}
impl From<SolverResult> for SatResult {
fn from(r: SolverResult) -> Self {
match r {
SolverResult::Sat => SatResult::Sat,
SolverResult::Unsat => SatResult::Unsat,
SolverResult::Unknown => SatResult::Unknown,
}
}
}
impl std::fmt::Display for SatResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SatResult::Sat => write!(f, "sat"),
SatResult::Unsat => write!(f, "unsat"),
SatResult::Unknown => write!(f, "unknown"),
}
}
}
pub struct Z3Model {
entries: Vec<(String, String, String)>,
}
impl Z3Model {
fn from_context_model(entries: Vec<(String, String, String)>) -> Self {
Self { entries }
}
#[must_use]
pub fn eval_const(&self, name: &str) -> Option<&str> {
self.entries
.iter()
.find(|(n, _, _)| n == name)
.map(|(_, _, v)| v.as_str())
}
#[must_use]
pub fn entries(&self) -> &[(String, String, String)] {
&self.entries
}
}
impl std::fmt::Display for Z3Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "(model")?;
for (name, sort, value) in &self.entries {
writeln!(f, " (define-fun {} () {} {})", name, sort, value)?;
}
write!(f, ")")
}
}
pub struct Z3Solver {
ctx: Context,
}
impl Z3Solver {
#[must_use]
pub fn new(z3ctx: &Z3Context) -> Self {
let mut ctx = Context::new();
if z3ctx.config.proof {
ctx.set_option("produce-proofs", "true");
}
Self { ctx }
}
pub fn assert(&mut self, t: &Bool) {
self.ctx.assert(t.id);
}
#[must_use]
pub fn check(&mut self) -> SatResult {
self.ctx.check_sat().into()
}
pub fn push(&mut self) {
self.ctx.push();
}
pub fn pop(&mut self) {
self.ctx.pop();
}
#[must_use]
pub fn get_model(&self) -> Option<Z3Model> {
self.ctx.get_model().map(Z3Model::from_context_model)
}
pub fn set_logic(&mut self, logic: &str) {
self.ctx.set_logic(logic);
}
#[must_use]
pub fn context(&self) -> &Context {
&self.ctx
}
pub fn context_mut(&mut self) -> &mut Context {
&mut self.ctx
}
}
macro_rules! build {
($ctx:expr, $method:ident $(, $arg:expr)* ) => {
$ctx.tm.borrow_mut().$method($($arg),*)
};
}
#[derive(Debug, Clone)]
pub struct Bool {
pub id: TermId,
}
impl Bool {
#[must_use]
pub fn from_id(id: TermId) -> Self {
Self { id }
}
#[must_use]
pub fn new_const(ctx: &Z3Context, name: &str) -> Self {
let sort = ctx.bool_sort();
let id = build!(ctx, mk_var, name, sort);
Self { id }
}
#[must_use]
pub fn from_bool(ctx: &Z3Context, value: bool) -> Self {
let id = build!(ctx, mk_bool, value);
Self { id }
}
#[must_use]
pub fn and(ctx: &Z3Context, args: &[Bool]) -> Self {
let ids: Vec<TermId> = args.iter().map(|b| b.id).collect();
let id = build!(ctx, mk_and, ids);
Self { id }
}
#[must_use]
pub fn or(ctx: &Z3Context, args: &[Bool]) -> Self {
let ids: Vec<TermId> = args.iter().map(|b| b.id).collect();
let id = build!(ctx, mk_or, ids);
Self { id }
}
#[must_use]
pub fn not(ctx: &Z3Context, arg: &Bool) -> Self {
let id = build!(ctx, mk_not, arg.id);
Self { id }
}
#[must_use]
pub fn implies(ctx: &Z3Context, lhs: &Bool, rhs: &Bool) -> Self {
let id = build!(ctx, mk_implies, lhs.id, rhs.id);
Self { id }
}
#[must_use]
pub fn iff(ctx: &Z3Context, lhs: &Bool, rhs: &Bool) -> Self {
let id = build!(ctx, mk_eq, lhs.id, rhs.id);
Self { id }
}
#[must_use]
pub fn xor(ctx: &Z3Context, lhs: &Bool, rhs: &Bool) -> Self {
let id = build!(ctx, mk_xor, lhs.id, rhs.id);
Self { id }
}
}
impl From<Bool> for TermId {
fn from(b: Bool) -> Self {
b.id
}
}
#[derive(Debug, Clone)]
pub struct Int {
pub id: TermId,
}
impl Int {
#[must_use]
pub fn from_id(id: TermId) -> Self {
Self { id }
}
#[must_use]
pub fn new_const(ctx: &Z3Context, name: &str) -> Self {
let sort = ctx.int_sort();
let id = build!(ctx, mk_var, name, sort);
Self { id }
}
#[must_use]
pub fn from_i64(ctx: &Z3Context, value: i64) -> Self {
let id = build!(ctx, mk_int, BigInt::from(value));
Self { id }
}
#[must_use]
pub fn add(ctx: &Z3Context, args: &[Int]) -> Self {
let ids: Vec<TermId> = args.iter().map(|x| x.id).collect();
let id = build!(ctx, mk_add, ids);
Self { id }
}
#[must_use]
pub fn sub(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Self {
let id = build!(ctx, mk_sub, lhs.id, rhs.id);
Self { id }
}
#[must_use]
pub fn mul(ctx: &Z3Context, args: &[Int]) -> Self {
let ids: Vec<TermId> = args.iter().map(|x| x.id).collect();
let id = build!(ctx, mk_mul, ids);
Self { id }
}
#[must_use]
pub fn neg(ctx: &Z3Context, arg: &Int) -> Self {
let id = build!(ctx, mk_neg, arg.id);
Self { id }
}
#[must_use]
pub fn div(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Self {
let id = build!(ctx, mk_div, lhs.id, rhs.id);
Self { id }
}
#[must_use]
pub fn modulo(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Self {
let id = build!(ctx, mk_mod, lhs.id, rhs.id);
Self { id }
}
#[must_use]
pub fn lt(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
let id = build!(ctx, mk_lt, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn le(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
let id = build!(ctx, mk_le, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn gt(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
let id = build!(ctx, mk_gt, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn ge(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
let id = build!(ctx, mk_ge, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn eq(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
let id = build!(ctx, mk_eq, lhs.id, rhs.id);
Bool { id }
}
}
impl From<Int> for TermId {
fn from(x: Int) -> Self {
x.id
}
}
#[derive(Debug, Clone)]
pub struct Real {
pub id: TermId,
}
impl Real {
#[must_use]
pub fn from_id(id: TermId) -> Self {
Self { id }
}
#[must_use]
pub fn new_const(ctx: &Z3Context, name: &str) -> Self {
let sort = ctx.real_sort();
let id = build!(ctx, mk_var, name, sort);
Self { id }
}
#[must_use]
pub fn from_frac(ctx: &Z3Context, num: i64, den: i64) -> Self {
let id = build!(ctx, mk_real, Rational64::new(num, den));
Self { id }
}
#[must_use]
pub fn add(ctx: &Z3Context, args: &[Real]) -> Self {
let ids: Vec<TermId> = args.iter().map(|x| x.id).collect();
let id = build!(ctx, mk_add, ids);
Self { id }
}
#[must_use]
pub fn sub(ctx: &Z3Context, lhs: &Real, rhs: &Real) -> Self {
let id = build!(ctx, mk_sub, lhs.id, rhs.id);
Self { id }
}
#[must_use]
pub fn mul(ctx: &Z3Context, args: &[Real]) -> Self {
let ids: Vec<TermId> = args.iter().map(|x| x.id).collect();
let id = build!(ctx, mk_mul, ids);
Self { id }
}
#[must_use]
pub fn lt(ctx: &Z3Context, lhs: &Real, rhs: &Real) -> Bool {
let id = build!(ctx, mk_lt, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn le(ctx: &Z3Context, lhs: &Real, rhs: &Real) -> Bool {
let id = build!(ctx, mk_le, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn eq(ctx: &Z3Context, lhs: &Real, rhs: &Real) -> Bool {
let id = build!(ctx, mk_eq, lhs.id, rhs.id);
Bool { id }
}
}
impl From<Real> for TermId {
fn from(x: Real) -> Self {
x.id
}
}
#[derive(Debug, Clone)]
pub struct BV {
pub id: TermId,
pub width: u32,
}
impl BV {
#[must_use]
pub fn from_id(id: TermId, width: u32) -> Self {
Self { id, width }
}
#[must_use]
pub fn new_const(ctx: &Z3Context, name: &str, width: u32) -> Self {
let sort = ctx.bv_sort(width);
let id = build!(ctx, mk_var, name, sort);
Self { id, width }
}
#[must_use]
pub fn from_u64(ctx: &Z3Context, value: u64, width: u32) -> Self {
let id = build!(ctx, mk_bitvec, BigInt::from(value), width);
Self { id, width }
}
#[must_use]
pub fn bvadd(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_add, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvsub(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_sub, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvmul(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_mul, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvand(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_and, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvor(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_or, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvxor(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_xor, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvnot(ctx: &Z3Context, arg: &BV) -> Self {
let width = arg.width;
let id = build!(ctx, mk_bv_not, arg.id);
Self { id, width }
}
#[must_use]
pub fn bvneg(ctx: &Z3Context, arg: &BV) -> Self {
let width = arg.width;
let id = build!(ctx, mk_bv_neg, arg.id);
Self { id, width }
}
#[must_use]
pub fn bvult(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
let id = build!(ctx, mk_bv_ult, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn bvslt(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
let id = build!(ctx, mk_bv_slt, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn bvule(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
let id = build!(ctx, mk_bv_ule, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn bvsle(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
let id = build!(ctx, mk_bv_sle, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn eq(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
let id = build!(ctx, mk_eq, lhs.id, rhs.id);
Bool { id }
}
#[must_use]
pub fn bvshl(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_shl, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvlshr(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_lshr, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvashr(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_ashr, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvudiv(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_udiv, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn bvurem(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width;
let id = build!(ctx, mk_bv_urem, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn concat(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
let width = lhs.width + rhs.width;
let id = build!(ctx, mk_bv_concat, lhs.id, rhs.id);
Self { id, width }
}
#[must_use]
pub fn extract(ctx: &Z3Context, high: u32, low: u32, arg: &BV) -> Self {
debug_assert!(
high >= low,
"extract: high ({}) must be >= low ({})",
high,
low
);
let width = high - low + 1;
let id = build!(ctx, mk_bv_extract, high, low, arg.id);
Self { id, width }
}
}
impl From<BV> for TermId {
fn from(b: BV) -> Self {
b.id
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bool_and_sat() {
let cfg = Z3Config::new();
let ctx = Z3Context::new(&cfg);
let mut solver = Z3Solver::new(&ctx);
let p = Bool::new_const(&ctx, "p");
let q = Bool::new_const(&ctx, "q");
let true_p = Bool::from_bool(&ctx, true);
solver.ctx.assert(true_p.id);
assert_eq!(solver.check(), SatResult::Sat);
}
#[test]
fn test_bool_and_unsat() {
let cfg = Z3Config::new();
let ctx = Z3Context::new(&cfg);
let mut solver = Z3Solver::new(&ctx);
let t = solver.ctx.terms.mk_true();
let f = solver.ctx.terms.mk_false();
solver.ctx.assert(t);
solver.ctx.assert(f);
assert_eq!(solver.check(), SatResult::Unsat);
}
#[test]
fn test_sat_result_from_solver_result() {
assert_eq!(SatResult::from(SolverResult::Sat), SatResult::Sat);
assert_eq!(SatResult::from(SolverResult::Unsat), SatResult::Unsat);
assert_eq!(SatResult::from(SolverResult::Unknown), SatResult::Unknown);
}
#[test]
fn test_bool_api_term_building() {
let cfg = Z3Config::new();
let ctx = Z3Context::new(&cfg);
let p = Bool::new_const(&ctx, "p");
let q = Bool::new_const(&ctx, "q");
let _conj = Bool::and(&ctx, &[p.clone(), q.clone()]);
let _disj = Bool::or(&ctx, &[p.clone(), q.clone()]);
let _neg = Bool::not(&ctx, &p);
let _impl = Bool::implies(&ctx, &p, &q);
let _iff = Bool::iff(&ctx, &p, &q);
}
#[test]
fn test_int_api_term_building() {
let cfg = Z3Config::new();
let ctx = Z3Context::new(&cfg);
let x = Int::new_const(&ctx, "x");
let y = Int::new_const(&ctx, "y");
let five = Int::from_i64(&ctx, 5);
let _sum = Int::add(&ctx, &[x.clone(), y.clone()]);
let _diff = Int::sub(&ctx, &x, &y);
let _prod = Int::mul(&ctx, &[x.clone(), five.clone()]);
let _lt = Int::lt(&ctx, &x, &five);
let _le = Int::le(&ctx, &x, &y);
let _eq = Int::eq(&ctx, &x, &y);
}
#[test]
fn test_bv_api_term_building() {
let cfg = Z3Config::new();
let ctx = Z3Context::new(&cfg);
let a = BV::new_const(&ctx, "a", 32);
let b = BV::new_const(&ctx, "b", 32);
let lit = BV::from_u64(&ctx, 42, 32);
let _add = BV::bvadd(&ctx, &a, &b);
let _and = BV::bvand(&ctx, &a, &b);
let _ult = BV::bvult(&ctx, &a, &lit);
let concat = BV::concat(&ctx, &a, &b);
assert_eq!(concat.width, 64);
let extr = BV::extract(&ctx, 7, 0, &a);
assert_eq!(extr.width, 8);
}
#[test]
fn test_push_pop() {
let cfg = Z3Config::new();
let ctx = Z3Context::new(&cfg);
let mut solver = Z3Solver::new(&ctx);
let t = solver.ctx.terms.mk_true();
solver.ctx.assert(t);
solver.push();
let f = solver.ctx.terms.mk_false();
solver.ctx.assert(f);
assert_eq!(solver.check(), SatResult::Unsat);
solver.pop();
assert_eq!(solver.check(), SatResult::Sat);
}
#[test]
fn test_int_solver_sat() {
let cfg = Z3Config::new();
let ctx = Z3Context::new(&cfg);
let mut solver = Z3Solver::new(&ctx);
solver.set_logic("QF_LIA");
let x = solver
.ctx
.terms
.mk_var("x", solver.ctx.terms.sorts.int_sort);
let five = solver.ctx.terms.mk_int(BigInt::from(5));
let ten = solver.ctx.terms.mk_int(BigInt::from(10));
let c1 = solver.ctx.terms.mk_ge(x, five);
let c2 = solver.ctx.terms.mk_le(x, ten);
solver.ctx.assert(c1);
solver.ctx.assert(c2);
assert_eq!(solver.check(), SatResult::Sat);
}
#[test]
fn test_get_model() {
let cfg = Z3Config::new();
let ctx = Z3Context::new(&cfg);
let mut solver = Z3Solver::new(&ctx);
let bool_sort = solver.ctx.terms.sorts.bool_sort;
let _p = solver.ctx.declare_const("p", bool_sort);
let t = solver.ctx.terms.mk_true();
solver.ctx.assert(t);
assert_eq!(solver.check(), SatResult::Sat);
let model = solver.get_model();
assert!(model.is_some(), "Expected a model after SAT");
}
}