prune_lang/interp/solver/
smtlib.rs1use super::common::*;
2use super::*;
3
4use easy_smt::{Context, ContextBuilder, SExpr};
5
6pub struct SmtLibSolver {
7 ctx: Context,
8}
9
10impl SmtLibSolver {
11 pub fn new(backend: common::SolverBackend) -> Self {
12 let mut ctx_bld = ContextBuilder::new();
13 match backend {
14 common::SolverBackend::Z3 => {
15 ctx_bld.solver("z3").solver_args(["-smt2", "-in", "-v:0"]);
16 }
17 common::SolverBackend::CVC5 => {
18 ctx_bld
19 .solver("cvc5")
20 .solver_args(["--quiet", "--lang=smt2", "--incremental"]);
21 }
22 }
23
24 let mut ctx = ctx_bld.build().unwrap();
26 ctx.set_logic("QF_NIA").unwrap();
27 match backend {
28 common::SolverBackend::Z3 => {
29 ctx.set_option(":timeout", ctx.numeral(1000)).unwrap();
30 }
31 common::SolverBackend::CVC5 => {
32 ctx.set_option(":tlimit-per", ctx.numeral(1000)).unwrap();
33 }
34 }
35
36 ctx.push().unwrap();
38
39 SmtLibSolver { ctx }
40 }
41
42 pub fn check_sat(
43 &mut self,
44 prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
45 ) -> Option<HashMap<IdentCtx, LitVal>> {
46 if prims.is_empty() {
48 return Some(HashMap::new());
49 }
50
51 self.ctx.pop().unwrap();
53 self.ctx.push().unwrap();
54
55 let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
56 let sexp_map = self.solve_constraints(prims, &ty_map);
57
58 let check_res = self.ctx.check().unwrap();
59 if check_res == easy_smt::Response::Sat {
60 let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
61 let res = vars
62 .iter()
63 .cloned()
64 .zip(
65 self.ctx
66 .get_value(vars.iter().map(|var| sexp_map[var]).collect())
67 .unwrap()
68 .iter()
69 .map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
70 )
71 .collect();
72
73 Some(res)
74 } else {
75 None
76 }
77 }
78
79 fn solve_constraints(
80 &mut self,
81 prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
82 ty_map: &HashMap<IdentCtx, LitType>,
83 ) -> HashMap<IdentCtx, SExpr> {
84 let sexp_map: HashMap<IdentCtx, SExpr> = ty_map
85 .iter()
86 .map(|(var, typ)| {
87 let sort = match typ {
88 LitType::TyInt => self.ctx.int_sort(),
89 LitType::TyFloat => self.ctx.real_sort(),
90 LitType::TyBool => self.ctx.bool_sort(),
91 LitType::TyChar => todo!(),
92 };
93 let sexp = self.ctx.declare_const(format!("{:?}", var), sort).unwrap();
94 (*var, sexp)
95 })
96 .collect();
97
98 for (prim, args) in prims.iter() {
99 let args: Vec<SExpr> = args
100 .iter()
101 .map(|arg| self.atom_to_sexp(arg, &sexp_map))
102 .collect();
103
104 match (prim, &args[..]) {
105 (
106 Prim::IAdd | Prim::ISub | Prim::IMul | Prim::IDiv | Prim::IRem,
107 &[arg1, arg2, arg3],
108 ) => {
109 let res = match prim {
110 Prim::IAdd => self.ctx.plus(arg1, arg2),
111 Prim::ISub => self.ctx.sub(arg1, arg2),
112 Prim::IMul => self.ctx.times(arg1, arg2),
113 Prim::IDiv => self.ctx.div(arg1, arg2),
114 Prim::IRem => self.ctx.rem(arg1, arg2),
115 _ => unreachable!(),
116 };
117 self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
118 }
119 (Prim::INeg, &[arg1, arg2]) => {
120 let res = self.ctx.negate(arg1);
121 self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
122 }
123 (Prim::ICmp(cmp), &[arg1, arg2, arg3]) => {
124 let res = match cmp {
125 Compare::Lt => self.ctx.lt(arg1, arg2),
126 Compare::Le => self.ctx.lte(arg1, arg2),
127 Compare::Eq => self.ctx.eq(arg1, arg2),
128 Compare::Ge => self.ctx.gte(arg1, arg2),
129 Compare::Gt => self.ctx.gt(arg1, arg2),
130 Compare::Ne => self.ctx.not(self.ctx.eq(arg1, arg2)),
131 };
132 self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
133 }
134 (Prim::BAnd | Prim::BOr, &[arg1, arg2, arg3]) => {
135 let res = match prim {
136 Prim::BAnd => self.ctx.and(arg1, arg2),
137 Prim::BOr => self.ctx.or(arg1, arg2),
138 _ => unreachable!(),
139 };
140 self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
141 }
142 (Prim::BNot, &[arg1, arg2]) => {
143 let res = self.ctx.not(arg1);
144 self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
145 }
146 _ => {
147 panic!("wrong arity of primitives!");
148 }
149 }
150 }
151
152 sexp_map
153 }
154
155 fn atom_to_sexp(&self, atom: &AtomVal<IdentCtx>, map: &HashMap<IdentCtx, SExpr>) -> SExpr {
156 match atom {
157 Term::Var(var) => map[var],
158 Term::Lit(LitVal::Int(x)) => self.ctx.numeral(*x),
159 Term::Lit(LitVal::Float(x)) => self.ctx.decimal(*x),
160 Term::Lit(LitVal::Bool(x)) => {
161 if *x {
162 self.ctx.true_()
163 } else {
164 self.ctx.false_()
165 }
166 }
167 Term::Lit(LitVal::Char(_x)) => todo!(),
168 Term::Cons(_cons, _flds) => unreachable!(),
169 }
170 }
171
172 fn sexp_to_lit_val(&self, sexpr: SExpr) -> Option<LitVal> {
173 if let Some(res) = self.ctx.get_i64(sexpr) {
174 return Some(LitVal::Int(res));
175 }
176 if let Some(res) = self.ctx.get_f64(sexpr) {
177 return Some(LitVal::Float(res));
178 }
179 if let Some(res) = self.ctx.get_atom(sexpr) {
180 match res {
181 "true" => {
182 return Some(LitVal::Bool(true));
183 }
184 "false" => {
185 return Some(LitVal::Bool(false));
186 }
187 _ => {
188 return None;
189 }
190 }
191 }
192
193 None
196 }
197}
198
199impl common::PrimSolver for SmtLibSolver {
200 fn check_sat(
201 &mut self,
202 prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
203 ) -> Option<HashMap<IdentCtx, LitVal>> {
204 if prims.is_empty() {
206 return Some(HashMap::new());
207 }
208
209 self.ctx.pop().unwrap();
211 self.ctx.push().unwrap();
212
213 let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
214 let sexp_map = self.solve_constraints(prims, &ty_map);
215
216 let check_res = self.ctx.check().unwrap();
217 if check_res == easy_smt::Response::Sat {
218 let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
219 let res = vars
220 .iter()
221 .cloned()
222 .zip(
223 self.ctx
224 .get_value(vars.iter().map(|var| sexp_map[var]).collect())
225 .unwrap()
226 .iter()
227 .map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
228 )
229 .collect();
230
231 Some(res)
232 } else {
233 None
234 }
235 }
236}