Skip to main content

prune_lang/interp/solver/
smtlib.rs

1use super::common::*;
2use super::*;
3
4use easy_smt::{Context, ContextBuilder, SExpr};
5
6#[derive(Copy, Clone, Debug, PartialEq, Eq)]
7pub enum SolverBackend {
8    Z3,
9    CVC5,
10}
11
12pub struct SmtLibSolver {
13    ctx: Context,
14}
15
16impl SmtLibSolver {
17    pub fn new(backend: SolverBackend) -> Self {
18        let mut ctx_bld = ContextBuilder::new();
19        match backend {
20            SolverBackend::Z3 => {
21                ctx_bld.solver("z3").solver_args(["-smt2", "-in", "-v:0"]);
22            }
23            SolverBackend::CVC5 => {
24                ctx_bld
25                    .solver("cvc5")
26                    .solver_args(["--quiet", "--lang=smt2", "--incremental"]);
27            }
28        }
29
30        // ctx_bld.replay_file(Some(std::fs::File::create("replay.smt2").unwrap()));
31        let mut ctx = ctx_bld.build().unwrap();
32        ctx.set_logic("QF_NIA").unwrap();
33        match backend {
34            SolverBackend::Z3 => {
35                ctx.set_option(":timeout", ctx.numeral(1000)).unwrap();
36            }
37            SolverBackend::CVC5 => {
38                ctx.set_option(":tlimit-per", ctx.numeral(1000)).unwrap();
39            }
40        }
41
42        // push an empty context for reset
43        ctx.push().unwrap();
44
45        SmtLibSolver { ctx }
46    }
47
48    pub fn check_sat(
49        &mut self,
50        prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
51    ) -> Option<HashMap<IdentCtx, LitVal>> {
52        // fast path for empty solver query
53        if prims.is_empty() {
54            return Some(HashMap::new());
55        }
56
57        // reset solver state
58        self.ctx.pop().unwrap();
59        self.ctx.push().unwrap();
60
61        let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
62        let sexp_map = self.solve_constraints(prims, &ty_map);
63
64        let check_res = self.ctx.check().unwrap();
65        if check_res == easy_smt::Response::Sat {
66            let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
67            let res = vars
68                .iter()
69                .cloned()
70                .zip(
71                    self.ctx
72                        .get_value(vars.iter().map(|var| sexp_map[var]).collect())
73                        .unwrap()
74                        .iter()
75                        .map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
76                )
77                .collect();
78
79            Some(res)
80        } else {
81            None
82        }
83    }
84
85    fn solve_constraints(
86        &mut self,
87        prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
88        ty_map: &HashMap<IdentCtx, LitType>,
89    ) -> HashMap<IdentCtx, SExpr> {
90        let sexp_map: HashMap<IdentCtx, SExpr> = ty_map
91            .iter()
92            .map(|(var, typ)| {
93                let sort = match typ {
94                    LitType::TyInt => self.ctx.int_sort(),
95                    LitType::TyFloat => self.ctx.real_sort(),
96                    LitType::TyBool => self.ctx.bool_sort(),
97                    LitType::TyChar => todo!(),
98                };
99                let sexp = self.ctx.declare_const(format!("{:?}", var), sort).unwrap();
100                (*var, sexp)
101            })
102            .collect();
103
104        for (prim, args) in prims.iter() {
105            let args: Vec<SExpr> = args
106                .iter()
107                .map(|arg| self.atom_to_sexp(arg, &sexp_map))
108                .collect();
109
110            match (prim, &args[..]) {
111                (
112                    Prim::IAdd | Prim::ISub | Prim::IMul | Prim::IDiv | Prim::IRem,
113                    &[arg1, arg2, arg3],
114                ) => {
115                    let res = match prim {
116                        Prim::IAdd => self.ctx.plus(arg1, arg2),
117                        Prim::ISub => self.ctx.sub(arg1, arg2),
118                        Prim::IMul => self.ctx.times(arg1, arg2),
119                        Prim::IDiv => self.ctx.div(arg1, arg2),
120                        Prim::IRem => self.ctx.rem(arg1, arg2),
121                        _ => unreachable!(),
122                    };
123                    self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
124                }
125                (Prim::INeg, &[arg1, arg2]) => {
126                    let res = self.ctx.negate(arg1);
127                    self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
128                }
129                (Prim::ICmp(cmp), &[arg1, arg2, arg3]) => {
130                    let res = match cmp {
131                        Compare::Lt => self.ctx.lt(arg1, arg2),
132                        Compare::Le => self.ctx.lte(arg1, arg2),
133                        Compare::Eq => self.ctx.eq(arg1, arg2),
134                        Compare::Ge => self.ctx.gte(arg1, arg2),
135                        Compare::Gt => self.ctx.gt(arg1, arg2),
136                        Compare::Ne => self.ctx.not(self.ctx.eq(arg1, arg2)),
137                    };
138                    self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
139                }
140                (Prim::BAnd | Prim::BOr, &[arg1, arg2, arg3]) => {
141                    let res = match prim {
142                        Prim::BAnd => self.ctx.and(arg1, arg2),
143                        Prim::BOr => self.ctx.or(arg1, arg2),
144                        _ => unreachable!(),
145                    };
146                    self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
147                }
148                (Prim::BNot, &[arg1, arg2]) => {
149                    let res = self.ctx.not(arg1);
150                    self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
151                }
152                _ => {
153                    panic!("wrong arity of primitives!");
154                }
155            }
156        }
157
158        sexp_map
159    }
160
161    fn atom_to_sexp(&self, atom: &AtomVal<IdentCtx>, map: &HashMap<IdentCtx, SExpr>) -> SExpr {
162        match atom {
163            Term::Var(var) => map[var],
164            Term::Lit(LitVal::Int(x)) => self.ctx.numeral(*x),
165            Term::Lit(LitVal::Float(x)) => self.ctx.decimal(*x),
166            Term::Lit(LitVal::Bool(x)) => {
167                if *x {
168                    self.ctx.true_()
169                } else {
170                    self.ctx.false_()
171                }
172            }
173            Term::Lit(LitVal::Char(_x)) => todo!(),
174            Term::Cons(_cons, _flds) => unreachable!(),
175        }
176    }
177
178    fn sexp_to_lit_val(&self, sexpr: SExpr) -> Option<LitVal> {
179        if let Some(res) = self.ctx.get_i64(sexpr) {
180            return Some(LitVal::Int(res));
181        }
182        if let Some(res) = self.ctx.get_f64(sexpr) {
183            return Some(LitVal::Float(res));
184        }
185        if let Some(res) = self.ctx.get_atom(sexpr) {
186            match res {
187                "true" => {
188                    return Some(LitVal::Bool(true));
189                }
190                "false" => {
191                    return Some(LitVal::Bool(false));
192                }
193                _ => {
194                    return None;
195                }
196            }
197        }
198
199        // todo: basic type `Char``
200
201        None
202    }
203}
204
205impl common::PrimSolver for SmtLibSolver {
206    fn check_sat(
207        &mut self,
208        prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
209    ) -> Option<HashMap<IdentCtx, LitVal>> {
210        // fast path for empty solver query
211        if prims.is_empty() {
212            return Some(HashMap::new());
213        }
214
215        // reset solver state
216        self.ctx.pop().unwrap();
217        self.ctx.push().unwrap();
218
219        let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
220        let sexp_map = self.solve_constraints(prims, &ty_map);
221
222        let check_res = self.ctx.check().unwrap();
223        if check_res == easy_smt::Response::Sat {
224            let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
225            let res = vars
226                .iter()
227                .cloned()
228                .zip(
229                    self.ctx
230                        .get_value(vars.iter().map(|var| sexp_map[var]).collect())
231                        .unwrap()
232                        .iter()
233                        .map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
234                )
235                .collect();
236
237            Some(res)
238        } else {
239            None
240        }
241    }
242}