1use std::cell::RefCell;
37use std::rc::Rc;
38
39use num_bigint::BigInt;
40use num_rational::Rational64;
41
42use oxiz_core::ast::{TermId, TermManager};
43use oxiz_core::sort::SortId;
44
45use crate::Context;
46use crate::SolverResult;
47use crate::solver::SolverConfig;
48
49#[derive(Debug, Clone, Default)]
55pub struct Z3Config {
56 inner: SolverConfig,
57}
58
59impl Z3Config {
60 #[must_use]
62 pub fn new() -> Self {
63 Self {
64 inner: SolverConfig::default(),
65 }
66 }
67
68 pub fn set_proof(&mut self, enabled: bool) -> &mut Self {
70 self.inner.proof = enabled;
71 self
72 }
73
74 #[must_use]
76 pub fn as_solver_config(&self) -> &SolverConfig {
77 &self.inner
78 }
79}
80
81pub struct Z3Context {
91 pub(crate) tm: Rc<RefCell<TermManager>>,
93 pub(crate) config: SolverConfig,
95}
96
97impl Z3Context {
98 #[must_use]
100 pub fn new(cfg: &Z3Config) -> Self {
101 Self {
102 tm: Rc::new(RefCell::new(TermManager::new())),
103 config: cfg.inner.clone(),
104 }
105 }
106
107 #[must_use]
109 pub fn bool_sort(&self) -> SortId {
110 self.tm.borrow().sorts.bool_sort
111 }
112
113 #[must_use]
115 pub fn int_sort(&self) -> SortId {
116 self.tm.borrow().sorts.int_sort
117 }
118
119 #[must_use]
121 pub fn real_sort(&self) -> SortId {
122 self.tm.borrow().sorts.real_sort
123 }
124
125 #[must_use]
127 pub fn bv_sort(&self, width: u32) -> SortId {
128 self.tm.borrow_mut().sorts.bitvec(width)
129 }
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum SatResult {
137 Sat,
139 Unsat,
141 Unknown,
143}
144
145impl From<SolverResult> for SatResult {
146 fn from(r: SolverResult) -> Self {
147 match r {
148 SolverResult::Sat => SatResult::Sat,
149 SolverResult::Unsat => SatResult::Unsat,
150 SolverResult::Unknown => SatResult::Unknown,
151 }
152 }
153}
154
155impl std::fmt::Display for SatResult {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 match self {
158 SatResult::Sat => write!(f, "sat"),
159 SatResult::Unsat => write!(f, "unsat"),
160 SatResult::Unknown => write!(f, "unknown"),
161 }
162 }
163}
164
165pub struct Z3Model {
171 entries: Vec<(String, String, String)>,
173}
174
175impl Z3Model {
176 fn from_context_model(entries: Vec<(String, String, String)>) -> Self {
177 Self { entries }
178 }
179
180 #[must_use]
184 pub fn eval_const(&self, name: &str) -> Option<&str> {
185 self.entries
186 .iter()
187 .find(|(n, _, _)| n == name)
188 .map(|(_, _, v)| v.as_str())
189 }
190
191 #[must_use]
193 pub fn entries(&self) -> &[(String, String, String)] {
194 &self.entries
195 }
196}
197
198impl std::fmt::Display for Z3Model {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 writeln!(f, "(model")?;
201 for (name, sort, value) in &self.entries {
202 writeln!(f, " (define-fun {} () {} {})", name, sort, value)?;
203 }
204 write!(f, ")")
205 }
206}
207
208pub struct Z3Solver {
214 ctx: Context,
215}
216
217impl Z3Solver {
218 #[must_use]
223 pub fn new(z3ctx: &Z3Context) -> Self {
224 let mut ctx = Context::new();
225 if z3ctx.config.proof {
227 ctx.set_option("produce-proofs", "true");
228 }
229 Self { ctx }
233 }
234
235 pub fn assert(&mut self, t: &Bool) {
237 self.ctx.assert(t.id);
238 }
239
240 #[must_use]
242 pub fn check(&mut self) -> SatResult {
243 self.ctx.check_sat().into()
244 }
245
246 pub fn push(&mut self) {
248 self.ctx.push();
249 }
250
251 pub fn pop(&mut self) {
253 self.ctx.pop();
254 }
255
256 #[must_use]
258 pub fn get_model(&self) -> Option<Z3Model> {
259 self.ctx.get_model().map(Z3Model::from_context_model)
260 }
261
262 pub fn set_logic(&mut self, logic: &str) {
264 self.ctx.set_logic(logic);
265 }
266
267 #[must_use]
269 pub fn context(&self) -> &Context {
270 &self.ctx
271 }
272
273 pub fn context_mut(&mut self) -> &mut Context {
275 &mut self.ctx
276 }
277}
278
279macro_rules! build {
285 ($ctx:expr, $method:ident $(, $arg:expr)* ) => {
286 $ctx.tm.borrow_mut().$method($($arg),*)
287 };
288}
289
290#[derive(Debug, Clone)]
294pub struct Bool {
295 pub id: TermId,
297}
298
299impl Bool {
300 #[must_use]
302 pub fn from_id(id: TermId) -> Self {
303 Self { id }
304 }
305
306 #[must_use]
308 pub fn new_const(ctx: &Z3Context, name: &str) -> Self {
309 let sort = ctx.bool_sort();
310 let id = build!(ctx, mk_var, name, sort);
311 Self { id }
312 }
313
314 #[must_use]
316 pub fn from_bool(ctx: &Z3Context, value: bool) -> Self {
317 let id = build!(ctx, mk_bool, value);
318 Self { id }
319 }
320
321 #[must_use]
323 pub fn and(ctx: &Z3Context, args: &[Bool]) -> Self {
324 let ids: Vec<TermId> = args.iter().map(|b| b.id).collect();
325 let id = build!(ctx, mk_and, ids);
326 Self { id }
327 }
328
329 #[must_use]
331 pub fn or(ctx: &Z3Context, args: &[Bool]) -> Self {
332 let ids: Vec<TermId> = args.iter().map(|b| b.id).collect();
333 let id = build!(ctx, mk_or, ids);
334 Self { id }
335 }
336
337 #[must_use]
339 pub fn not(ctx: &Z3Context, arg: &Bool) -> Self {
340 let id = build!(ctx, mk_not, arg.id);
341 Self { id }
342 }
343
344 #[must_use]
346 pub fn implies(ctx: &Z3Context, lhs: &Bool, rhs: &Bool) -> Self {
347 let id = build!(ctx, mk_implies, lhs.id, rhs.id);
348 Self { id }
349 }
350
351 #[must_use]
355 pub fn iff(ctx: &Z3Context, lhs: &Bool, rhs: &Bool) -> Self {
356 let id = build!(ctx, mk_eq, lhs.id, rhs.id);
358 Self { id }
359 }
360
361 #[must_use]
363 pub fn xor(ctx: &Z3Context, lhs: &Bool, rhs: &Bool) -> Self {
364 let id = build!(ctx, mk_xor, lhs.id, rhs.id);
365 Self { id }
366 }
367}
368
369impl From<Bool> for TermId {
370 fn from(b: Bool) -> Self {
371 b.id
372 }
373}
374
375#[derive(Debug, Clone)]
379pub struct Int {
380 pub id: TermId,
382}
383
384impl Int {
385 #[must_use]
387 pub fn from_id(id: TermId) -> Self {
388 Self { id }
389 }
390
391 #[must_use]
393 pub fn new_const(ctx: &Z3Context, name: &str) -> Self {
394 let sort = ctx.int_sort();
395 let id = build!(ctx, mk_var, name, sort);
396 Self { id }
397 }
398
399 #[must_use]
401 pub fn from_i64(ctx: &Z3Context, value: i64) -> Self {
402 let id = build!(ctx, mk_int, BigInt::from(value));
403 Self { id }
404 }
405
406 #[must_use]
408 pub fn add(ctx: &Z3Context, args: &[Int]) -> Self {
409 let ids: Vec<TermId> = args.iter().map(|x| x.id).collect();
410 let id = build!(ctx, mk_add, ids);
411 Self { id }
412 }
413
414 #[must_use]
416 pub fn sub(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Self {
417 let id = build!(ctx, mk_sub, lhs.id, rhs.id);
418 Self { id }
419 }
420
421 #[must_use]
423 pub fn mul(ctx: &Z3Context, args: &[Int]) -> Self {
424 let ids: Vec<TermId> = args.iter().map(|x| x.id).collect();
425 let id = build!(ctx, mk_mul, ids);
426 Self { id }
427 }
428
429 #[must_use]
431 pub fn neg(ctx: &Z3Context, arg: &Int) -> Self {
432 let id = build!(ctx, mk_neg, arg.id);
433 Self { id }
434 }
435
436 #[must_use]
438 pub fn div(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Self {
439 let id = build!(ctx, mk_div, lhs.id, rhs.id);
440 Self { id }
441 }
442
443 #[must_use]
445 pub fn modulo(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Self {
446 let id = build!(ctx, mk_mod, lhs.id, rhs.id);
447 Self { id }
448 }
449
450 #[must_use]
452 pub fn lt(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
453 let id = build!(ctx, mk_lt, lhs.id, rhs.id);
454 Bool { id }
455 }
456
457 #[must_use]
459 pub fn le(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
460 let id = build!(ctx, mk_le, lhs.id, rhs.id);
461 Bool { id }
462 }
463
464 #[must_use]
466 pub fn gt(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
467 let id = build!(ctx, mk_gt, lhs.id, rhs.id);
468 Bool { id }
469 }
470
471 #[must_use]
473 pub fn ge(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
474 let id = build!(ctx, mk_ge, lhs.id, rhs.id);
475 Bool { id }
476 }
477
478 #[must_use]
480 pub fn eq(ctx: &Z3Context, lhs: &Int, rhs: &Int) -> Bool {
481 let id = build!(ctx, mk_eq, lhs.id, rhs.id);
482 Bool { id }
483 }
484}
485
486impl From<Int> for TermId {
487 fn from(x: Int) -> Self {
488 x.id
489 }
490}
491
492#[derive(Debug, Clone)]
496pub struct Real {
497 pub id: TermId,
499}
500
501impl Real {
502 #[must_use]
504 pub fn from_id(id: TermId) -> Self {
505 Self { id }
506 }
507
508 #[must_use]
510 pub fn new_const(ctx: &Z3Context, name: &str) -> Self {
511 let sort = ctx.real_sort();
512 let id = build!(ctx, mk_var, name, sort);
513 Self { id }
514 }
515
516 #[must_use]
518 pub fn from_frac(ctx: &Z3Context, num: i64, den: i64) -> Self {
519 let id = build!(ctx, mk_real, Rational64::new(num, den));
520 Self { id }
521 }
522
523 #[must_use]
525 pub fn add(ctx: &Z3Context, args: &[Real]) -> Self {
526 let ids: Vec<TermId> = args.iter().map(|x| x.id).collect();
527 let id = build!(ctx, mk_add, ids);
528 Self { id }
529 }
530
531 #[must_use]
533 pub fn sub(ctx: &Z3Context, lhs: &Real, rhs: &Real) -> Self {
534 let id = build!(ctx, mk_sub, lhs.id, rhs.id);
535 Self { id }
536 }
537
538 #[must_use]
540 pub fn mul(ctx: &Z3Context, args: &[Real]) -> Self {
541 let ids: Vec<TermId> = args.iter().map(|x| x.id).collect();
542 let id = build!(ctx, mk_mul, ids);
543 Self { id }
544 }
545
546 #[must_use]
548 pub fn lt(ctx: &Z3Context, lhs: &Real, rhs: &Real) -> Bool {
549 let id = build!(ctx, mk_lt, lhs.id, rhs.id);
550 Bool { id }
551 }
552
553 #[must_use]
555 pub fn le(ctx: &Z3Context, lhs: &Real, rhs: &Real) -> Bool {
556 let id = build!(ctx, mk_le, lhs.id, rhs.id);
557 Bool { id }
558 }
559
560 #[must_use]
562 pub fn eq(ctx: &Z3Context, lhs: &Real, rhs: &Real) -> Bool {
563 let id = build!(ctx, mk_eq, lhs.id, rhs.id);
564 Bool { id }
565 }
566}
567
568impl From<Real> for TermId {
569 fn from(x: Real) -> Self {
570 x.id
571 }
572}
573
574#[derive(Debug, Clone)]
578pub struct BV {
579 pub id: TermId,
581 pub width: u32,
583}
584
585impl BV {
586 #[must_use]
588 pub fn from_id(id: TermId, width: u32) -> Self {
589 Self { id, width }
590 }
591
592 #[must_use]
594 pub fn new_const(ctx: &Z3Context, name: &str, width: u32) -> Self {
595 let sort = ctx.bv_sort(width);
596 let id = build!(ctx, mk_var, name, sort);
597 Self { id, width }
598 }
599
600 #[must_use]
602 pub fn from_u64(ctx: &Z3Context, value: u64, width: u32) -> Self {
603 let id = build!(ctx, mk_bitvec, BigInt::from(value), width);
604 Self { id, width }
605 }
606
607 #[must_use]
609 pub fn bvadd(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
610 let width = lhs.width;
611 let id = build!(ctx, mk_bv_add, lhs.id, rhs.id);
612 Self { id, width }
613 }
614
615 #[must_use]
617 pub fn bvsub(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
618 let width = lhs.width;
619 let id = build!(ctx, mk_bv_sub, lhs.id, rhs.id);
620 Self { id, width }
621 }
622
623 #[must_use]
625 pub fn bvmul(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
626 let width = lhs.width;
627 let id = build!(ctx, mk_bv_mul, lhs.id, rhs.id);
628 Self { id, width }
629 }
630
631 #[must_use]
633 pub fn bvand(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
634 let width = lhs.width;
635 let id = build!(ctx, mk_bv_and, lhs.id, rhs.id);
636 Self { id, width }
637 }
638
639 #[must_use]
641 pub fn bvor(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
642 let width = lhs.width;
643 let id = build!(ctx, mk_bv_or, lhs.id, rhs.id);
644 Self { id, width }
645 }
646
647 #[must_use]
649 pub fn bvxor(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
650 let width = lhs.width;
651 let id = build!(ctx, mk_bv_xor, lhs.id, rhs.id);
652 Self { id, width }
653 }
654
655 #[must_use]
657 pub fn bvnot(ctx: &Z3Context, arg: &BV) -> Self {
658 let width = arg.width;
659 let id = build!(ctx, mk_bv_not, arg.id);
660 Self { id, width }
661 }
662
663 #[must_use]
665 pub fn bvneg(ctx: &Z3Context, arg: &BV) -> Self {
666 let width = arg.width;
667 let id = build!(ctx, mk_bv_neg, arg.id);
668 Self { id, width }
669 }
670
671 #[must_use]
673 pub fn bvult(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
674 let id = build!(ctx, mk_bv_ult, lhs.id, rhs.id);
675 Bool { id }
676 }
677
678 #[must_use]
680 pub fn bvslt(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
681 let id = build!(ctx, mk_bv_slt, lhs.id, rhs.id);
682 Bool { id }
683 }
684
685 #[must_use]
687 pub fn bvule(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
688 let id = build!(ctx, mk_bv_ule, lhs.id, rhs.id);
689 Bool { id }
690 }
691
692 #[must_use]
694 pub fn bvsle(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
695 let id = build!(ctx, mk_bv_sle, lhs.id, rhs.id);
696 Bool { id }
697 }
698
699 #[must_use]
701 pub fn eq(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Bool {
702 let id = build!(ctx, mk_eq, lhs.id, rhs.id);
703 Bool { id }
704 }
705
706 #[must_use]
708 pub fn bvshl(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
709 let width = lhs.width;
710 let id = build!(ctx, mk_bv_shl, lhs.id, rhs.id);
711 Self { id, width }
712 }
713
714 #[must_use]
716 pub fn bvlshr(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
717 let width = lhs.width;
718 let id = build!(ctx, mk_bv_lshr, lhs.id, rhs.id);
719 Self { id, width }
720 }
721
722 #[must_use]
724 pub fn bvashr(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
725 let width = lhs.width;
726 let id = build!(ctx, mk_bv_ashr, lhs.id, rhs.id);
727 Self { id, width }
728 }
729
730 #[must_use]
732 pub fn bvudiv(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
733 let width = lhs.width;
734 let id = build!(ctx, mk_bv_udiv, lhs.id, rhs.id);
735 Self { id, width }
736 }
737
738 #[must_use]
740 pub fn bvurem(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
741 let width = lhs.width;
742 let id = build!(ctx, mk_bv_urem, lhs.id, rhs.id);
743 Self { id, width }
744 }
745
746 #[must_use]
750 pub fn concat(ctx: &Z3Context, lhs: &BV, rhs: &BV) -> Self {
751 let width = lhs.width + rhs.width;
752 let id = build!(ctx, mk_bv_concat, lhs.id, rhs.id);
753 Self { id, width }
754 }
755
756 #[must_use]
765 pub fn extract(ctx: &Z3Context, high: u32, low: u32, arg: &BV) -> Self {
766 debug_assert!(
767 high >= low,
768 "extract: high ({}) must be >= low ({})",
769 high,
770 low
771 );
772 let width = high - low + 1;
773 let id = build!(ctx, mk_bv_extract, high, low, arg.id);
774 Self { id, width }
775 }
776}
777
778impl From<BV> for TermId {
779 fn from(b: BV) -> Self {
780 b.id
781 }
782}
783
784#[cfg(test)]
787mod tests {
788 use super::*;
789
790 #[test]
791 fn test_bool_and_sat() {
792 let cfg = Z3Config::new();
793 let ctx = Z3Context::new(&cfg);
794 let mut solver = Z3Solver::new(&ctx);
795
796 let p = Bool::new_const(&ctx, "p");
798 let q = Bool::new_const(&ctx, "q");
799
800 let true_p = Bool::from_bool(&ctx, true);
805 solver.ctx.assert(true_p.id);
806
807 assert_eq!(solver.check(), SatResult::Sat);
808 }
809
810 #[test]
811 fn test_bool_and_unsat() {
812 let cfg = Z3Config::new();
813 let ctx = Z3Context::new(&cfg);
814 let mut solver = Z3Solver::new(&ctx);
815
816 let t = solver.ctx.terms.mk_true();
818 let f = solver.ctx.terms.mk_false();
819 solver.ctx.assert(t);
820 solver.ctx.assert(f);
821
822 assert_eq!(solver.check(), SatResult::Unsat);
823 }
824
825 #[test]
826 fn test_sat_result_from_solver_result() {
827 assert_eq!(SatResult::from(SolverResult::Sat), SatResult::Sat);
828 assert_eq!(SatResult::from(SolverResult::Unsat), SatResult::Unsat);
829 assert_eq!(SatResult::from(SolverResult::Unknown), SatResult::Unknown);
830 }
831
832 #[test]
833 fn test_bool_api_term_building() {
834 let cfg = Z3Config::new();
835 let ctx = Z3Context::new(&cfg);
836
837 let p = Bool::new_const(&ctx, "p");
839 let q = Bool::new_const(&ctx, "q");
840 let _conj = Bool::and(&ctx, &[p.clone(), q.clone()]);
841 let _disj = Bool::or(&ctx, &[p.clone(), q.clone()]);
842 let _neg = Bool::not(&ctx, &p);
843 let _impl = Bool::implies(&ctx, &p, &q);
844 let _iff = Bool::iff(&ctx, &p, &q);
845 }
846
847 #[test]
848 fn test_int_api_term_building() {
849 let cfg = Z3Config::new();
850 let ctx = Z3Context::new(&cfg);
851
852 let x = Int::new_const(&ctx, "x");
853 let y = Int::new_const(&ctx, "y");
854 let five = Int::from_i64(&ctx, 5);
855
856 let _sum = Int::add(&ctx, &[x.clone(), y.clone()]);
857 let _diff = Int::sub(&ctx, &x, &y);
858 let _prod = Int::mul(&ctx, &[x.clone(), five.clone()]);
859 let _lt = Int::lt(&ctx, &x, &five);
860 let _le = Int::le(&ctx, &x, &y);
861 let _eq = Int::eq(&ctx, &x, &y);
862 }
863
864 #[test]
865 fn test_bv_api_term_building() {
866 let cfg = Z3Config::new();
867 let ctx = Z3Context::new(&cfg);
868
869 let a = BV::new_const(&ctx, "a", 32);
870 let b = BV::new_const(&ctx, "b", 32);
871 let lit = BV::from_u64(&ctx, 42, 32);
872
873 let _add = BV::bvadd(&ctx, &a, &b);
874 let _and = BV::bvand(&ctx, &a, &b);
875 let _ult = BV::bvult(&ctx, &a, &lit);
876 let concat = BV::concat(&ctx, &a, &b);
877 assert_eq!(concat.width, 64);
878 let extr = BV::extract(&ctx, 7, 0, &a);
879 assert_eq!(extr.width, 8);
880 }
881
882 #[test]
883 fn test_push_pop() {
884 let cfg = Z3Config::new();
885 let ctx = Z3Context::new(&cfg);
886 let mut solver = Z3Solver::new(&ctx);
887
888 let t = solver.ctx.terms.mk_true();
889 solver.ctx.assert(t);
890
891 solver.push();
892 let f = solver.ctx.terms.mk_false();
893 solver.ctx.assert(f);
894 assert_eq!(solver.check(), SatResult::Unsat);
895
896 solver.pop();
897 assert_eq!(solver.check(), SatResult::Sat);
898 }
899
900 #[test]
901 fn test_int_solver_sat() {
902 let cfg = Z3Config::new();
905 let ctx = Z3Context::new(&cfg);
906 let mut solver = Z3Solver::new(&ctx);
907 solver.set_logic("QF_LIA");
908
909 let x = solver
912 .ctx
913 .terms
914 .mk_var("x", solver.ctx.terms.sorts.int_sort);
915 let five = solver.ctx.terms.mk_int(BigInt::from(5));
916 let ten = solver.ctx.terms.mk_int(BigInt::from(10));
917 let c1 = solver.ctx.terms.mk_ge(x, five);
918 let c2 = solver.ctx.terms.mk_le(x, ten);
919 solver.ctx.assert(c1);
920 solver.ctx.assert(c2);
921
922 assert_eq!(solver.check(), SatResult::Sat);
923 }
924
925 #[test]
926 fn test_get_model() {
927 let cfg = Z3Config::new();
928 let ctx = Z3Context::new(&cfg);
929 let mut solver = Z3Solver::new(&ctx);
930
931 let bool_sort = solver.ctx.terms.sorts.bool_sort;
932 let _p = solver.ctx.declare_const("p", bool_sort);
933 let t = solver.ctx.terms.mk_true();
934 solver.ctx.assert(t);
935
936 assert_eq!(solver.check(), SatResult::Sat);
937 let model = solver.get_model();
938 assert!(model.is_some(), "Expected a model after SAT");
939 }
940}