1use std::rc::Rc;
29
30use rustc_hash::FxHashMap;
31
32use oxiz_core::ast::{TermId, TermKind, TermManager};
33use oxiz_core::sort::{SortId, SortKind};
34
35use crate::z3_compat::{BV, Bool, Int, Real, Z3Context};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum Z3SortKind {
47 Bool,
49 Int,
51 Real,
53 BitVec,
55 Array,
57 Datatype,
59 Uninterpreted,
61 Other,
64}
65
66#[derive(Clone)]
74pub struct Z3Sort {
75 pub id: SortId,
77 ctx: Rc<core::cell::RefCell<TermManager>>,
79}
80
81impl Z3Sort {
82 #[must_use]
84 pub fn new(ctx: &Z3Context, id: SortId) -> Self {
85 Self {
86 id,
87 ctx: ctx.tm_handle(),
88 }
89 }
90
91 fn from_handle(ctx: Rc<core::cell::RefCell<TermManager>>, id: SortId) -> Self {
93 Self { id, ctx }
94 }
95
96 #[must_use]
98 pub fn kind(&self) -> Z3SortKind {
99 let tm = self.ctx.borrow();
100 match tm.sorts.get(self.id).map(|s| &s.kind) {
101 Some(SortKind::Bool) => Z3SortKind::Bool,
102 Some(SortKind::Int) => Z3SortKind::Int,
103 Some(SortKind::Real) => Z3SortKind::Real,
104 Some(SortKind::BitVec(_)) => Z3SortKind::BitVec,
105 Some(SortKind::Array { .. }) => Z3SortKind::Array,
106 Some(SortKind::Datatype(_)) => Z3SortKind::Datatype,
107 Some(SortKind::Uninterpreted(_)) => Z3SortKind::Uninterpreted,
108 Some(
109 SortKind::String
110 | SortKind::FloatingPoint { .. }
111 | SortKind::Parameter(_)
112 | SortKind::Parametric { .. },
113 )
114 | None => Z3SortKind::Other,
115 }
116 }
117
118 #[must_use]
122 pub fn bv_size(&self) -> Option<u32> {
123 let tm = self.ctx.borrow();
124 match tm.sorts.get(self.id).map(|s| &s.kind) {
125 Some(&SortKind::BitVec(width)) => Some(width),
126 _ => None,
127 }
128 }
129
130 #[must_use]
134 pub fn array_domain(&self) -> Option<Z3Sort> {
135 let domain = {
136 let tm = self.ctx.borrow();
137 match tm.sorts.get(self.id).map(|s| &s.kind) {
138 Some(&SortKind::Array { domain, .. }) => domain,
139 _ => return None,
140 }
141 };
142 Some(Z3Sort::from_handle(self.ctx.clone(), domain))
143 }
144
145 #[must_use]
149 pub fn array_range(&self) -> Option<Z3Sort> {
150 let range = {
151 let tm = self.ctx.borrow();
152 match tm.sorts.get(self.id).map(|s| &s.kind) {
153 Some(&SortKind::Array { range, .. }) => range,
154 _ => return None,
155 }
156 };
157 Some(Z3Sort::from_handle(self.ctx.clone(), range))
158 }
159
160 #[must_use]
166 pub fn name(&self) -> String {
167 let tm = self.ctx.borrow();
168 tm.sorts
169 .sort_name(self.id)
170 .unwrap_or_else(|| "Unknown".to_string())
171 }
172}
173
174impl core::fmt::Debug for Z3Sort {
175 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
176 f.debug_struct("Z3Sort")
177 .field("id", &self.id)
178 .field("kind", &self.kind())
179 .finish()
180 }
181}
182
183impl core::fmt::Display for Z3Sort {
184 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
185 f.write_str(&self.name())
186 }
187}
188
189impl Z3Context {
192 fn tm_handle(&self) -> Rc<core::cell::RefCell<TermManager>> {
196 self.tm.clone()
197 }
198
199 #[must_use]
205 pub fn sort_of_term(&self, term: TermId) -> Z3Sort {
206 let sort_id = {
207 let tm = self.tm.borrow();
208 tm.get(term).map_or(tm.sorts.bool_sort, |t| t.sort)
209 };
210 Z3Sort::from_handle(self.tm.clone(), sort_id)
211 }
212
213 #[must_use]
215 pub fn sort_of_bool(&self, b: &Bool) -> Z3Sort {
216 self.sort_of_term(b.id)
217 }
218
219 #[must_use]
221 pub fn sort_of_int(&self, x: &Int) -> Z3Sort {
222 self.sort_of_term(x.id)
223 }
224
225 #[must_use]
227 pub fn sort_of_real(&self, x: &Real) -> Z3Sort {
228 self.sort_of_term(x.id)
229 }
230
231 #[must_use]
233 pub fn sort_of_bv(&self, b: &BV) -> Z3Sort {
234 self.sort_of_term(b.id)
235 }
236
237 #[must_use]
239 pub fn wrap_sort(&self, id: SortId) -> Z3Sort {
240 Z3Sort::from_handle(self.tm.clone(), id)
241 }
242}
243
244impl Z3Context {
247 #[must_use]
269 pub fn substitute(&self, expr: TermId, subst: &[(TermId, TermId)]) -> TermId {
270 if subst.is_empty() {
271 return expr;
272 }
273 let map: FxHashMap<TermId, TermId> = subst.iter().copied().collect();
274 let mut cache: FxHashMap<TermId, TermId> = FxHashMap::default();
275 let mut tm = self.tm.borrow_mut();
276 subst_rebuild(&mut tm, expr, &map, &mut cache)
277 }
278}
279
280fn subst_rebuild(
285 tm: &mut TermManager,
286 id: TermId,
287 map: &FxHashMap<TermId, TermId>,
288 cache: &mut FxHashMap<TermId, TermId>,
289) -> TermId {
290 if let Some(&to) = map.get(&id) {
292 return to;
293 }
294 if let Some(&done) = cache.get(&id) {
295 return done;
296 }
297
298 let kind = match tm.get(id).map(|t| t.kind.clone()) {
299 Some(k) => k,
300 None => return id,
301 };
302
303 macro_rules! rec {
306 ($child:expr) => {
307 subst_rebuild(tm, $child, map, cache)
308 };
309 }
310
311 let result = match kind {
312 TermKind::True
316 | TermKind::False
317 | TermKind::IntConst(_)
318 | TermKind::RealConst(_)
319 | TermKind::BitVecConst { .. }
320 | TermKind::StringLit(_)
321 | TermKind::Var(_) => id,
322
323 TermKind::Not(a) => {
325 let na = rec!(a);
326 if na == a { id } else { tm.mk_not(na) }
327 }
328 TermKind::And(args) => rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_and(a)),
329 TermKind::Or(args) => rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_or(a)),
330 TermKind::Xor(a, b) => {
331 let (na, nb) = (rec!(a), rec!(b));
332 if na == a && nb == b {
333 id
334 } else {
335 tm.mk_xor(na, nb)
336 }
337 }
338 TermKind::Implies(a, b) => {
339 let (na, nb) = (rec!(a), rec!(b));
340 if na == a && nb == b {
341 id
342 } else {
343 tm.mk_implies(na, nb)
344 }
345 }
346 TermKind::Ite(c, t, e) => {
347 let (nc, nt, ne) = (rec!(c), rec!(t), rec!(e));
348 if nc == c && nt == t && ne == e {
349 id
350 } else {
351 tm.mk_ite(nc, nt, ne)
352 }
353 }
354
355 TermKind::Eq(a, b) => {
357 let (na, nb) = (rec!(a), rec!(b));
358 if na == a && nb == b {
359 id
360 } else {
361 tm.mk_eq(na, nb)
362 }
363 }
364 TermKind::Distinct(args) => {
365 rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_distinct(a))
366 }
367
368 TermKind::Neg(a) => {
370 let na = rec!(a);
371 if na == a { id } else { tm.mk_neg(na) }
372 }
373 TermKind::Add(args) => rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_add(a)),
374 TermKind::Mul(args) => rebuild_nary(tm, id, &args, map, cache, |tm, a| tm.mk_mul(a)),
375 TermKind::Sub(a, b) => {
376 let (na, nb) = (rec!(a), rec!(b));
377 if na == a && nb == b {
378 id
379 } else {
380 tm.mk_sub(na, nb)
381 }
382 }
383 TermKind::Div(a, b) => {
384 let (na, nb) = (rec!(a), rec!(b));
385 if na == a && nb == b {
386 id
387 } else {
388 tm.mk_div(na, nb)
389 }
390 }
391 TermKind::Mod(a, b) => {
392 let (na, nb) = (rec!(a), rec!(b));
393 if na == a && nb == b {
394 id
395 } else {
396 tm.mk_mod(na, nb)
397 }
398 }
399 TermKind::Lt(a, b) => {
400 let (na, nb) = (rec!(a), rec!(b));
401 if na == a && nb == b {
402 id
403 } else {
404 tm.mk_lt(na, nb)
405 }
406 }
407 TermKind::Le(a, b) => {
408 let (na, nb) = (rec!(a), rec!(b));
409 if na == a && nb == b {
410 id
411 } else {
412 tm.mk_le(na, nb)
413 }
414 }
415 TermKind::Gt(a, b) => {
416 let (na, nb) = (rec!(a), rec!(b));
417 if na == a && nb == b {
418 id
419 } else {
420 tm.mk_gt(na, nb)
421 }
422 }
423 TermKind::Ge(a, b) => {
424 let (na, nb) = (rec!(a), rec!(b));
425 if na == a && nb == b {
426 id
427 } else {
428 tm.mk_ge(na, nb)
429 }
430 }
431
432 TermKind::BvConcat(a, b) => {
434 let (na, nb) = (rec!(a), rec!(b));
435 if na == a && nb == b {
436 id
437 } else {
438 tm.mk_bv_concat(na, nb)
439 }
440 }
441 TermKind::BvExtract { high, low, arg } => {
442 let na = rec!(arg);
443 if na == arg {
444 id
445 } else {
446 tm.mk_bv_extract(high, low, na)
447 }
448 }
449 TermKind::BvNot(a) => {
450 let na = rec!(a);
451 if na == a { id } else { tm.mk_bv_not(na) }
452 }
453 TermKind::BvAnd(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_and),
456 TermKind::BvOr(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_or),
457 TermKind::BvXor(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_xor),
458 TermKind::BvAdd(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_add),
459 TermKind::BvSub(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_sub),
460 TermKind::BvMul(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_mul),
461 TermKind::BvUdiv(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_udiv),
462 TermKind::BvSdiv(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_sdiv),
463 TermKind::BvUrem(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_urem),
464 TermKind::BvSrem(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_srem),
465 TermKind::BvShl(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_shl),
466 TermKind::BvLshr(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_lshr),
467 TermKind::BvAshr(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_ashr),
468 TermKind::BvUlt(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_ult),
469 TermKind::BvUle(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_ule),
470 TermKind::BvSlt(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_slt),
471 TermKind::BvSle(a, b) => rebuild_bin(tm, id, a, b, map, cache, TermManager::mk_bv_sle),
472
473 TermKind::Select(arr, idx) => {
475 let (na, ni) = (rec!(arr), rec!(idx));
476 if na == arr && ni == idx {
477 id
478 } else {
479 tm.mk_select(na, ni)
480 }
481 }
482 TermKind::Store(arr, idx, val) => {
483 let (na, ni, nv) = (rec!(arr), rec!(idx), rec!(val));
484 if na == arr && ni == idx && nv == val {
485 id
486 } else {
487 tm.mk_store(na, ni, nv)
488 }
489 }
490
491 TermKind::Apply { func, args } => {
493 let new_args: smallvec::SmallVec<[TermId; 4]> = args.iter().map(|&a| rec!(a)).collect();
494 if new_args.iter().zip(args.iter()).all(|(a, b)| a == b) {
495 id
496 } else {
497 let func_name = tm.resolve_str(func).to_string();
498 let sort = tm.get(id).map_or(tm.sorts.bool_sort, |t| t.sort);
499 tm.mk_apply(&func_name, new_args, sort)
500 }
501 }
502
503 _ => id,
508 };
509
510 cache.insert(id, result);
511 result
512}
513
514fn rebuild_nary<F>(
521 tm: &mut TermManager,
522 id: TermId,
523 args: &[TermId],
524 map: &FxHashMap<TermId, TermId>,
525 cache: &mut FxHashMap<TermId, TermId>,
526 build: F,
527) -> TermId
528where
529 F: FnOnce(&mut TermManager, smallvec::SmallVec<[TermId; 4]>) -> TermId,
530{
531 let new_args: smallvec::SmallVec<[TermId; 4]> = args
532 .iter()
533 .map(|&a| subst_rebuild(tm, a, map, cache))
534 .collect();
535 if new_args.iter().zip(args.iter()).all(|(a, b)| a == b) {
536 id
537 } else {
538 build(tm, new_args)
539 }
540}
541
542fn rebuild_bin<F>(
545 tm: &mut TermManager,
546 id: TermId,
547 a: TermId,
548 b: TermId,
549 map: &FxHashMap<TermId, TermId>,
550 cache: &mut FxHashMap<TermId, TermId>,
551 build: F,
552) -> TermId
553where
554 F: FnOnce(&mut TermManager, TermId, TermId) -> TermId,
555{
556 let na = subst_rebuild(tm, a, map, cache);
557 let nb = subst_rebuild(tm, b, map, cache);
558 if na == a && nb == b {
559 id
560 } else {
561 build(tm, na, nb)
562 }
563}
564
565#[derive(Debug, Clone)]
575pub struct Z3Pattern {
576 pub terms: Vec<TermId>,
578}
579
580impl Z3Pattern {
581 #[must_use]
583 pub fn len(&self) -> usize {
584 self.terms.len()
585 }
586
587 #[must_use]
589 pub fn is_empty(&self) -> bool {
590 self.terms.is_empty()
591 }
592}
593
594impl Z3Context {
595 #[must_use]
602 pub fn mk_pattern(&self, terms: &[TermId]) -> Z3Pattern {
603 Z3Pattern {
604 terms: terms.to_vec(),
605 }
606 }
607
608 #[must_use]
618 pub fn forall_with_patterns(
619 &self,
620 bound: &[(&str, SortId)],
621 patterns: &[Z3Pattern],
622 body: &Bool,
623 ) -> Bool {
624 let vars: Vec<(&str, SortId)> = bound.to_vec();
625 let pats: Vec<Vec<TermId>> = patterns.iter().map(|p| p.terms.clone()).collect();
626 let id = self
627 .tm
628 .borrow_mut()
629 .mk_forall_with_patterns(vars, body.id, pats);
630 Bool::from_id(id)
631 }
632
633 #[must_use]
638 pub fn exists_with_patterns(
639 &self,
640 bound: &[(&str, SortId)],
641 patterns: &[Z3Pattern],
642 body: &Bool,
643 ) -> Bool {
644 let vars: Vec<(&str, SortId)> = bound.to_vec();
645 let pats: Vec<Vec<TermId>> = patterns.iter().map(|p| p.terms.clone()).collect();
646 let id = self
647 .tm
648 .borrow_mut()
649 .mk_exists_with_patterns(vars, body.id, pats);
650 Bool::from_id(id)
651 }
652}
653
654#[cfg(test)]
657mod tests {
658 use super::*;
659 use crate::z3_compat::Z3Config;
660
661 fn ctx() -> Z3Context {
662 Z3Context::new(&Z3Config::new())
663 }
664
665 #[test]
666 fn unit_sort_kinds() {
667 let c = ctx();
668 assert_eq!(c.wrap_sort(c.bool_sort()).kind(), Z3SortKind::Bool);
669 assert_eq!(c.wrap_sort(c.int_sort()).kind(), Z3SortKind::Int);
670 assert_eq!(c.wrap_sort(c.real_sort()).kind(), Z3SortKind::Real);
671 assert_eq!(c.wrap_sort(c.bv_sort(8)).kind(), Z3SortKind::BitVec);
672 }
673
674 #[test]
675 fn unit_bv_size_and_array() {
676 let c = ctx();
677 assert_eq!(c.wrap_sort(c.bv_sort(16)).bv_size(), Some(16));
678 assert_eq!(c.wrap_sort(c.bool_sort()).bv_size(), None);
679
680 let arr = c.array_sort(c.int_sort(), c.bool_sort());
681 let s = c.wrap_sort(arr);
682 assert_eq!(s.kind(), Z3SortKind::Array);
683 assert_eq!(s.array_domain().map(|d| d.kind()), Some(Z3SortKind::Int));
684 assert_eq!(s.array_range().map(|r| r.kind()), Some(Z3SortKind::Bool));
685 }
686
687 #[test]
688 fn unit_substitute_identity() {
689 let c = ctx();
690 let x = Int::new_const(&c, "x");
691 let y = Int::new_const(&c, "y");
692 let sum = Int::add(&c, &[x.clone(), y.clone()]);
693 assert_eq!(c.substitute(sum.id, &[]), sum.id);
695 }
696
697 #[test]
698 fn unit_pattern_basic() {
699 let c = ctx();
700 let p = c.mk_pattern(&[]);
701 assert!(p.is_empty());
702 assert_eq!(p.len(), 0);
703 }
704}