prune_lang/interp/solver/
smtlib.rs1use super::common::*;
2use super::*;
3
4use easy_smt::{Context, ContextBuilder, SExpr};
5
6#[derive(Copy, Clone, Debug, PartialEq, Eq)]
7pub enum SolverBackend {
8 Z3,
9 CVC5,
10}
11
12pub struct SmtLibSolver {
13 ctx: Context,
14}
15
16impl SmtLibSolver {
17 pub fn new(backend: SolverBackend) -> Self {
18 let mut ctx_bld = ContextBuilder::new();
19 match backend {
20 SolverBackend::Z3 => {
21 ctx_bld.solver("z3").solver_args(["-smt2", "-in", "-v:0"]);
22 }
23 SolverBackend::CVC5 => {
24 ctx_bld
25 .solver("cvc5")
26 .solver_args(["--quiet", "--lang=smt2", "--incremental"]);
27 }
28 }
29
30 let mut ctx = ctx_bld.build().unwrap();
32 ctx.set_logic("QF_NIA").unwrap();
33 match backend {
34 SolverBackend::Z3 => {
35 ctx.set_option(":timeout", ctx.numeral(1000)).unwrap();
36 }
37 SolverBackend::CVC5 => {
38 ctx.set_option(":tlimit-per", ctx.numeral(1000)).unwrap();
39 }
40 }
41
42 ctx.push().unwrap();
44
45 SmtLibSolver { ctx }
46 }
47
48 pub fn check_sat(
49 &mut self,
50 prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
51 ) -> Option<HashMap<IdentCtx, LitVal>> {
52 if prims.is_empty() {
54 return Some(HashMap::new());
55 }
56
57 self.ctx.pop().unwrap();
59 self.ctx.push().unwrap();
60
61 let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
62 let sexp_map = self.solve_constraints(prims, &ty_map);
63
64 let check_res = self.ctx.check().unwrap();
65 if check_res == easy_smt::Response::Sat {
66 let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
67 let res = vars
68 .iter()
69 .cloned()
70 .zip(
71 self.ctx
72 .get_value(vars.iter().map(|var| sexp_map[var]).collect())
73 .unwrap()
74 .iter()
75 .map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
76 )
77 .collect();
78
79 Some(res)
80 } else {
81 None
82 }
83 }
84
85 fn solve_constraints(
86 &mut self,
87 prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
88 ty_map: &HashMap<IdentCtx, LitType>,
89 ) -> HashMap<IdentCtx, SExpr> {
90 let sexp_map: HashMap<IdentCtx, SExpr> = ty_map
91 .iter()
92 .map(|(var, typ)| {
93 let sort = match typ {
94 LitType::TyInt => self.ctx.int_sort(),
95 LitType::TyFloat => self.ctx.real_sort(),
96 LitType::TyBool => self.ctx.bool_sort(),
97 LitType::TyChar => todo!(),
98 };
99 let sexp = self.ctx.declare_const(format!("{:?}", var), sort).unwrap();
100 (*var, sexp)
101 })
102 .collect();
103
104 for (prim, args) in prims.iter() {
105 let args: Vec<SExpr> = args
106 .iter()
107 .map(|arg| self.atom_to_sexp(arg, &sexp_map))
108 .collect();
109
110 match (prim, &args[..]) {
111 (
112 Prim::IAdd | Prim::ISub | Prim::IMul | Prim::IDiv | Prim::IRem,
113 &[arg1, arg2, arg3],
114 ) => {
115 let res = match prim {
116 Prim::IAdd => self.ctx.plus(arg1, arg2),
117 Prim::ISub => self.ctx.sub(arg1, arg2),
118 Prim::IMul => self.ctx.times(arg1, arg2),
119 Prim::IDiv => self.ctx.div(arg1, arg2),
120 Prim::IRem => self.ctx.rem(arg1, arg2),
121 _ => unreachable!(),
122 };
123 self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
124 }
125 (Prim::INeg, &[arg1, arg2]) => {
126 let res = self.ctx.negate(arg1);
127 self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
128 }
129 (Prim::ICmp(cmp), &[arg1, arg2, arg3]) => {
130 let res = match cmp {
131 Compare::Lt => self.ctx.lt(arg1, arg2),
132 Compare::Le => self.ctx.lte(arg1, arg2),
133 Compare::Eq => self.ctx.eq(arg1, arg2),
134 Compare::Ge => self.ctx.gte(arg1, arg2),
135 Compare::Gt => self.ctx.gt(arg1, arg2),
136 Compare::Ne => self.ctx.not(self.ctx.eq(arg1, arg2)),
137 };
138 self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
139 }
140 (Prim::BAnd | Prim::BOr, &[arg1, arg2, arg3]) => {
141 let res = match prim {
142 Prim::BAnd => self.ctx.and(arg1, arg2),
143 Prim::BOr => self.ctx.or(arg1, arg2),
144 _ => unreachable!(),
145 };
146 self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
147 }
148 (Prim::BNot, &[arg1, arg2]) => {
149 let res = self.ctx.not(arg1);
150 self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
151 }
152 _ => {
153 panic!("wrong arity of primitives!");
154 }
155 }
156 }
157
158 sexp_map
159 }
160
161 fn atom_to_sexp(&self, atom: &AtomVal<IdentCtx>, map: &HashMap<IdentCtx, SExpr>) -> SExpr {
162 match atom {
163 Term::Var(var) => map[var],
164 Term::Lit(LitVal::Int(x)) => self.ctx.numeral(*x),
165 Term::Lit(LitVal::Float(x)) => self.ctx.decimal(*x),
166 Term::Lit(LitVal::Bool(x)) => {
167 if *x {
168 self.ctx.true_()
169 } else {
170 self.ctx.false_()
171 }
172 }
173 Term::Lit(LitVal::Char(_x)) => todo!(),
174 Term::Cons(_cons, _flds) => unreachable!(),
175 }
176 }
177
178 fn sexp_to_lit_val(&self, sexpr: SExpr) -> Option<LitVal> {
179 if let Some(res) = self.ctx.get_i64(sexpr) {
180 return Some(LitVal::Int(res));
181 }
182 if let Some(res) = self.ctx.get_f64(sexpr) {
183 return Some(LitVal::Float(res));
184 }
185 if let Some(res) = self.ctx.get_atom(sexpr) {
186 match res {
187 "true" => {
188 return Some(LitVal::Bool(true));
189 }
190 "false" => {
191 return Some(LitVal::Bool(false));
192 }
193 _ => {
194 return None;
195 }
196 }
197 }
198
199 None
202 }
203}
204
205impl common::PrimSolver for SmtLibSolver {
206 fn check_sat(
207 &mut self,
208 prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
209 ) -> Option<HashMap<IdentCtx, LitVal>> {
210 if prims.is_empty() {
212 return Some(HashMap::new());
213 }
214
215 self.ctx.pop().unwrap();
217 self.ctx.push().unwrap();
218
219 let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
220 let sexp_map = self.solve_constraints(prims, &ty_map);
221
222 let check_res = self.ctx.check().unwrap();
223 if check_res == easy_smt::Response::Sat {
224 let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
225 let res = vars
226 .iter()
227 .cloned()
228 .zip(
229 self.ctx
230 .get_value(vars.iter().map(|var| sexp_map[var]).collect())
231 .unwrap()
232 .iter()
233 .map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
234 )
235 .collect();
236
237 Some(res)
238 } else {
239 None
240 }
241 }
242}