use std::rc::Rc;
use rustc_hash::FxHashMap;
use oxiz_core::ast::{TermId, TermKind, TermManager};
use oxiz_core::sort::{SortId, SortKind};
use crate::z3_compat::{BV, Bool, Int, Real, Z3Context};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Z3SortKind {
Bool,
Int,
Real,
BitVec,
Array,
Datatype,
Uninterpreted,
Other,
}
#[derive(Clone)]
pub struct Z3Sort {
pub id: SortId,
ctx: Rc<core::cell::RefCell<TermManager>>,
}
impl Z3Sort {
#[must_use]
pub fn new(ctx: &Z3Context, id: SortId) -> Self {
Self {
id,
ctx: ctx.tm_handle(),
}
}
fn from_handle(ctx: Rc<core::cell::RefCell<TermManager>>, id: SortId) -> Self {
Self { id, ctx }
}
#[must_use]
pub fn kind(&self) -> Z3SortKind {
let tm = self.ctx.borrow();
match tm.sorts.get(self.id).map(|s| &s.kind) {
Some(SortKind::Bool) => Z3SortKind::Bool,
Some(SortKind::Int) => Z3SortKind::Int,
Some(SortKind::Real) => Z3SortKind::Real,
Some(SortKind::BitVec(_)) => Z3SortKind::BitVec,
Some(SortKind::Array { .. }) => Z3SortKind::Array,
Some(SortKind::Datatype(_)) => Z3SortKind::Datatype,
Some(SortKind::Uninterpreted(_)) => Z3SortKind::Uninterpreted,
Some(
SortKind::String
| SortKind::FloatingPoint { .. }
| SortKind::Parameter(_)
| SortKind::Parametric { .. },
)
| None => Z3SortKind::Other,
}
}
#[must_use]
pub fn bv_size(&self) -> Option<u32> {
let tm = self.ctx.borrow();
match tm.sorts.get(self.id).map(|s| &s.kind) {
Some(&SortKind::BitVec(width)) => Some(width),
_ => None,
}
}
#[must_use]
pub fn array_domain(&self) -> Option<Z3Sort> {
let domain = {
let tm = self.ctx.borrow();
match tm.sorts.get(self.id).map(|s| &s.kind) {
Some(&SortKind::Array { domain, .. }) => domain,
_ => return None,
}
};
Some(Z3Sort::from_handle(self.ctx.clone(), domain))
}
#[must_use]
pub fn array_range(&self) -> Option<Z3Sort> {
let range = {
let tm = self.ctx.borrow();
match tm.sorts.get(self.id).map(|s| &s.kind) {
Some(&SortKind::Array { range, .. }) => range,
_ => return None,
}
};
Some(Z3Sort::from_handle(self.ctx.clone(), range))
}
#[must_use]
pub fn name(&self) -> String {
let tm = self.ctx.borrow();
tm.sorts
.sort_name(self.id)
.unwrap_or_else(|| "Unknown".to_string())
}
}
impl core::fmt::Debug for Z3Sort {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Z3Sort")
.field("id", &self.id)
.field("kind", &self.kind())
.finish()
}
}
impl core::fmt::Display for Z3Sort {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&self.name())
}
}
impl Z3Context {
fn tm_handle(&self) -> Rc<core::cell::RefCell<TermManager>> {
self.tm.clone()
}
#[must_use]
pub fn sort_of_term(&self, term: TermId) -> Z3Sort {
let sort_id = {
let tm = self.tm.borrow();
tm.get(term).map_or(tm.sorts.bool_sort, |t| t.sort)
};
Z3Sort::from_handle(self.tm.clone(), sort_id)
}
#[must_use]
pub fn sort_of_bool(&self, b: &Bool) -> Z3Sort {
self.sort_of_term(b.id)
}
#[must_use]
pub fn sort_of_int(&self, x: &Int) -> Z3Sort {
self.sort_of_term(x.id)
}
#[must_use]
pub fn sort_of_real(&self, x: &Real) -> Z3Sort {
self.sort_of_term(x.id)
}
#[must_use]
pub fn sort_of_bv(&self, b: &BV) -> Z3Sort {
self.sort_of_term(b.id)
}
#[must_use]
pub fn wrap_sort(&self, id: SortId) -> Z3Sort {
Z3Sort::from_handle(self.tm.clone(), id)
}
}
impl Z3Context {
#[must_use]
pub fn substitute(&self, expr: TermId, subst: &[(TermId, TermId)]) -> TermId {
if subst.is_empty() {
return expr;
}
let map: FxHashMap<TermId, TermId> = subst.iter().copied().collect();
let mut cache: FxHashMap<TermId, TermId> = FxHashMap::default();
let mut tm = self.tm.borrow_mut();
subst_rebuild(&mut tm, expr, &map, &mut cache)
}
}
fn subst_rebuild(
tm: &mut TermManager,
id: TermId,
map: &FxHashMap<TermId, TermId>,
cache: &mut FxHashMap<TermId, TermId>,
) -> TermId {
if let Some(&to) = map.get(&id) {
return to;
}
if let Some(&done) = cache.get(&id) {
return done;
}
let kind = match tm.get(id).map(|t| t.kind.clone()) {
Some(k) => k,
None => return id,
};
macro_rules! rec {
($child:expr) => {
subst_rebuild(tm, $child, map, cache)
};
}
let result = match kind {
TermKind::True
| TermKind::False
| TermKind::IntConst(_)
| TermKind::RealConst(_)
| TermKind::BitVecConst { .. }
| TermKind::StringLit(_)
| TermKind::Var(_) => id,
TermKind::Not(a) => {
let na = rec!(a);
if na == a { id } else { tm.mk_not(na) }
}
TermKind::And(args) => rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_and(a)),
TermKind::Or(args) => rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_or(a)),
TermKind::Xor(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_xor(na, nb)
}
}
TermKind::Implies(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_implies(na, nb)
}
}
TermKind::Ite(c, t, e) => {
let (nc, nt, ne) = (rec!(c), rec!(t), rec!(e));
if nc == c && nt == t && ne == e {
id
} else {
tm.mk_ite(nc, nt, ne)
}
}
TermKind::Eq(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_eq(na, nb)
}
}
TermKind::Distinct(args) => {
rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_distinct(a))
}
TermKind::Neg(a) => {
let na = rec!(a);
if na == a { id } else { tm.mk_neg(na) }
}
TermKind::Add(args) => rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_add(a)),
TermKind::Mul(args) => rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_mul(a)),
TermKind::Sub(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_sub(na, nb)
}
}
TermKind::Div(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_div(na, nb)
}
}
TermKind::Mod(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_mod(na, nb)
}
}
TermKind::Lt(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_lt(na, nb)
}
}
TermKind::Le(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_le(na, nb)
}
}
TermKind::Gt(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_gt(na, nb)
}
}
TermKind::Ge(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_ge(na, nb)
}
}
TermKind::BvConcat(a, b) => {
let (na, nb) = (rec!(a), rec!(b));
if na == a && nb == b {
id
} else {
tm.mk_bv_concat(na, nb)
}
}
TermKind::BvExtract { high, low, arg } => {
let na = rec!(arg);
if na == arg {
id
} else {
tm.mk_bv_extract(high, low, na)
}
}
TermKind::BvNot(a) => {
let na = rec!(a);
if na == a { id } else { tm.mk_bv_not(na) }
}
TermKind::BvAnd(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_and),
TermKind::BvOr(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_or),
TermKind::BvXor(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_xor),
TermKind::BvAdd(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_add),
TermKind::BvSub(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_sub),
TermKind::BvMul(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_mul),
TermKind::BvUdiv(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_udiv),
TermKind::BvSdiv(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_sdiv),
TermKind::BvUrem(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_urem),
TermKind::BvSrem(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_srem),
TermKind::BvShl(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_shl),
TermKind::BvLshr(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_lshr),
TermKind::BvAshr(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_ashr),
TermKind::BvUlt(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_ult),
TermKind::BvUle(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_ule),
TermKind::BvSlt(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_slt),
TermKind::BvSle(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_sle),
TermKind::Select(arr, idx) => {
let (na, ni) = (rec!(arr), rec!(idx));
if na == arr && ni == idx {
id
} else {
tm.mk_select(na, ni)
}
}
TermKind::Store(arr, idx, val) => {
let (na, ni, nv) = (rec!(arr), rec!(idx), rec!(val));
if na == arr && ni == idx && nv == val {
id
} else {
tm.mk_store(na, ni, nv)
}
}
TermKind::Apply { func, args } => {
let new_args: smallvec::SmallVec<[TermId; 4]> = args.iter().map(|&a| rec!(a)).collect();
if new_args.iter().zip(args.iter()).all(|(a, b)| a == b) {
id
} else {
let func_name = tm.resolve_str(func).to_string();
let sort = tm.get(id).map_or(tm.sorts.bool_sort, |t| t.sort);
tm.mk_apply(&func_name, new_args, sort)
}
}
_ => id,
};
cache.insert(id, result);
result
}
fn rebuild_nary<F>(
tm: &mut TermManager,
id: TermId,
args: &[TermId],
map: &FxHashMap<TermId, TermId>,
cache: &mut FxHashMap<TermId, TermId>,
build: F,
) -> TermId
where
F: FnOnce(&mut TermManager, smallvec::SmallVec<[TermId; 4]>) -> TermId,
{
let new_args: smallvec::SmallVec<[TermId; 4]> = args
.iter()
.map(|&a| subst_rebuild(tm, a, map, cache))
.collect();
if new_args.iter().zip(args.iter()).all(|(a, b)| a == b) {
id
} else {
build(tm, new_args)
}
}
fn rebuild_bin<F>(
tm: &mut TermManager,
id: TermId,
a: TermId,
b: TermId,
map: &FxHashMap<TermId, TermId>,
cache: &mut FxHashMap<TermId, TermId>,
build: F,
) -> TermId
where
F: FnOnce(&mut TermManager, TermId, TermId) -> TermId,
{
let na = subst_rebuild(tm, a, map, cache);
let nb = subst_rebuild(tm, b, map, cache);
if na == a && nb == b {
id
} else {
build(tm, na, nb)
}
}
#[derive(Debug, Clone)]
pub struct Z3Pattern {
pub terms: Vec<TermId>,
}
impl Z3Pattern {
#[must_use]
pub fn len(&self) -> usize {
self.terms.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.terms.is_empty()
}
}
impl Z3Context {
#[must_use]
pub fn mk_pattern(&self, terms: &[TermId]) -> Z3Pattern {
Z3Pattern {
terms: terms.to_vec(),
}
}
#[must_use]
pub fn forall_with_patterns(
&self,
bound: &[(&str, SortId)],
patterns: &[Z3Pattern],
body: &Bool,
) -> Bool {
let vars: Vec<(&str, SortId)> = bound.to_vec();
let pats: Vec<Vec<TermId>> = patterns.iter().map(|p| p.terms.clone()).collect();
let id = self
.tm
.borrow_mut()
.mk_forall_with_patterns(vars, body.id, pats);
Bool::from_id(id)
}
#[must_use]
pub fn exists_with_patterns(
&self,
bound: &[(&str, SortId)],
patterns: &[Z3Pattern],
body: &Bool,
) -> Bool {
let vars: Vec<(&str, SortId)> = bound.to_vec();
let pats: Vec<Vec<TermId>> = patterns.iter().map(|p| p.terms.clone()).collect();
let id = self
.tm
.borrow_mut()
.mk_exists_with_patterns(vars, body.id, pats);
Bool::from_id(id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::z3_compat::Z3Config;
fn ctx() -> Z3Context {
Z3Context::new(&Z3Config::new())
}
#[test]
fn unit_sort_kinds() {
let c = ctx();
assert_eq!(c.wrap_sort(c.bool_sort()).kind(), Z3SortKind::Bool);
assert_eq!(c.wrap_sort(c.int_sort()).kind(), Z3SortKind::Int);
assert_eq!(c.wrap_sort(c.real_sort()).kind(), Z3SortKind::Real);
assert_eq!(c.wrap_sort(c.bv_sort(8)).kind(), Z3SortKind::BitVec);
}
#[test]
fn unit_bv_size_and_array() {
let c = ctx();
assert_eq!(c.wrap_sort(c.bv_sort(16)).bv_size(), Some(16));
assert_eq!(c.wrap_sort(c.bool_sort()).bv_size(), None);
let arr = c.array_sort(c.int_sort(), c.bool_sort());
let s = c.wrap_sort(arr);
assert_eq!(s.kind(), Z3SortKind::Array);
assert_eq!(s.array_domain().map(|d| d.kind()), Some(Z3SortKind::Int));
assert_eq!(s.array_range().map(|r| r.kind()), Some(Z3SortKind::Bool));
}
#[test]
fn unit_substitute_identity() {
let c = ctx();
let x = Int::new_const(&c, "x");
let y = Int::new_const(&c, "y");
let sum = Int::add(&c, &[x.clone(), y.clone()]);
assert_eq!(c.substitute(sum.id, &[]), sum.id);
}
#[test]
fn unit_pattern_basic() {
let c = ctx();
let p = c.mk_pattern(&[]);
assert!(p.is_empty());
assert_eq!(p.len(), 0);
}
}