Skip to main content

prune_lang/interp/solver/
smtlib.rs

1use super::common::*;
2use super::*;
3
4use easy_smt::{Context, ContextBuilder, SExpr};
5
6pub struct SmtLibSolver {
7    ctx: Context,
8}
9
10impl SmtLibSolver {
11    pub fn new(backend: common::SolverBackend) -> Self {
12        let mut ctx_bld = ContextBuilder::new();
13        match backend {
14            common::SolverBackend::Z3 => {
15                ctx_bld.solver("z3").solver_args(["-smt2", "-in", "-v:0"]);
16            }
17            common::SolverBackend::CVC5 => {
18                ctx_bld
19                    .solver("cvc5")
20                    .solver_args(["--quiet", "--lang=smt2", "--incremental"]);
21            }
22        }
23
24        // ctx_bld.replay_file(Some(std::fs::File::create("replay.smt2").unwrap()));
25        let mut ctx = ctx_bld.build().unwrap();
26        ctx.set_logic("QF_NIA").unwrap();
27        match backend {
28            common::SolverBackend::Z3 => {
29                ctx.set_option(":timeout", ctx.numeral(1000)).unwrap();
30            }
31            common::SolverBackend::CVC5 => {
32                ctx.set_option(":tlimit-per", ctx.numeral(1000)).unwrap();
33            }
34        }
35
36        // push an empty context for reset
37        ctx.push().unwrap();
38
39        SmtLibSolver { ctx }
40    }
41
42    pub fn check_sat(
43        &mut self,
44        prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
45    ) -> Option<HashMap<IdentCtx, LitVal>> {
46        // fast path for empty solver query
47        if prims.is_empty() {
48            return Some(HashMap::new());
49        }
50
51        // reset solver state
52        self.ctx.pop().unwrap();
53        self.ctx.push().unwrap();
54
55        let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
56        let sexp_map = self.solve_constraints(prims, &ty_map);
57
58        let check_res = self.ctx.check().unwrap();
59        if check_res == easy_smt::Response::Sat {
60            let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
61            let res = vars
62                .iter()
63                .cloned()
64                .zip(
65                    self.ctx
66                        .get_value(vars.iter().map(|var| sexp_map[var]).collect())
67                        .unwrap()
68                        .iter()
69                        .map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
70                )
71                .collect();
72
73            Some(res)
74        } else {
75            None
76        }
77    }
78
79    fn solve_constraints(
80        &mut self,
81        prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
82        ty_map: &HashMap<IdentCtx, LitType>,
83    ) -> HashMap<IdentCtx, SExpr> {
84        let sexp_map: HashMap<IdentCtx, SExpr> = ty_map
85            .iter()
86            .map(|(var, typ)| {
87                let sort = match typ {
88                    LitType::TyInt => self.ctx.int_sort(),
89                    LitType::TyFloat => self.ctx.real_sort(),
90                    LitType::TyBool => self.ctx.bool_sort(),
91                    LitType::TyChar => todo!(),
92                };
93                let sexp = self.ctx.declare_const(format!("{:?}", var), sort).unwrap();
94                (*var, sexp)
95            })
96            .collect();
97
98        for (prim, args) in prims.iter() {
99            let args: Vec<SExpr> = args
100                .iter()
101                .map(|arg| self.atom_to_sexp(arg, &sexp_map))
102                .collect();
103
104            match (prim, &args[..]) {
105                (
106                    Prim::IAdd | Prim::ISub | Prim::IMul | Prim::IDiv | Prim::IRem,
107                    &[arg1, arg2, arg3],
108                ) => {
109                    let res = match prim {
110                        Prim::IAdd => self.ctx.plus(arg1, arg2),
111                        Prim::ISub => self.ctx.sub(arg1, arg2),
112                        Prim::IMul => self.ctx.times(arg1, arg2),
113                        Prim::IDiv => self.ctx.div(arg1, arg2),
114                        Prim::IRem => self.ctx.rem(arg1, arg2),
115                        _ => unreachable!(),
116                    };
117                    self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
118                }
119                (Prim::INeg, &[arg1, arg2]) => {
120                    let res = self.ctx.negate(arg1);
121                    self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
122                }
123                (Prim::ICmp(cmp), &[arg1, arg2, arg3]) => {
124                    let res = match cmp {
125                        Compare::Lt => self.ctx.lt(arg1, arg2),
126                        Compare::Le => self.ctx.lte(arg1, arg2),
127                        Compare::Eq => self.ctx.eq(arg1, arg2),
128                        Compare::Ge => self.ctx.gte(arg1, arg2),
129                        Compare::Gt => self.ctx.gt(arg1, arg2),
130                        Compare::Ne => self.ctx.not(self.ctx.eq(arg1, arg2)),
131                    };
132                    self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
133                }
134                (Prim::BAnd | Prim::BOr, &[arg1, arg2, arg3]) => {
135                    let res = match prim {
136                        Prim::BAnd => self.ctx.and(arg1, arg2),
137                        Prim::BOr => self.ctx.or(arg1, arg2),
138                        _ => unreachable!(),
139                    };
140                    self.ctx.assert(self.ctx.eq(res, arg3)).unwrap();
141                }
142                (Prim::BNot, &[arg1, arg2]) => {
143                    let res = self.ctx.not(arg1);
144                    self.ctx.assert(self.ctx.eq(res, arg2)).unwrap();
145                }
146                _ => {
147                    panic!("wrong arity of primitives!");
148                }
149            }
150        }
151
152        sexp_map
153    }
154
155    fn atom_to_sexp(&self, atom: &AtomVal<IdentCtx>, map: &HashMap<IdentCtx, SExpr>) -> SExpr {
156        match atom {
157            Term::Var(var) => map[var],
158            Term::Lit(LitVal::Int(x)) => self.ctx.numeral(*x),
159            Term::Lit(LitVal::Float(x)) => self.ctx.decimal(*x),
160            Term::Lit(LitVal::Bool(x)) => {
161                if *x {
162                    self.ctx.true_()
163                } else {
164                    self.ctx.false_()
165                }
166            }
167            Term::Lit(LitVal::Char(_x)) => todo!(),
168            Term::Cons(_cons, _flds) => unreachable!(),
169        }
170    }
171
172    fn sexp_to_lit_val(&self, sexpr: SExpr) -> Option<LitVal> {
173        if let Some(res) = self.ctx.get_i64(sexpr) {
174            return Some(LitVal::Int(res));
175        }
176        if let Some(res) = self.ctx.get_f64(sexpr) {
177            return Some(LitVal::Float(res));
178        }
179        if let Some(res) = self.ctx.get_atom(sexpr) {
180            match res {
181                "true" => {
182                    return Some(LitVal::Bool(true));
183                }
184                "false" => {
185                    return Some(LitVal::Bool(false));
186                }
187                _ => {
188                    return None;
189                }
190            }
191        }
192
193        // todo: basic type `Char``
194
195        None
196    }
197}
198
199impl common::PrimSolver for SmtLibSolver {
200    fn check_sat(
201        &mut self,
202        prims: &[(Prim, Vec<AtomVal<IdentCtx>>)],
203    ) -> Option<HashMap<IdentCtx, LitVal>> {
204        // fast path for empty solver query
205        if prims.is_empty() {
206            return Some(HashMap::new());
207        }
208
209        // reset solver state
210        self.ctx.pop().unwrap();
211        self.ctx.push().unwrap();
212
213        let ty_map: HashMap<IdentCtx, LitType> = infer_type(prims);
214        let sexp_map = self.solve_constraints(prims, &ty_map);
215
216        let check_res = self.ctx.check().unwrap();
217        if check_res == easy_smt::Response::Sat {
218            let vars: Vec<IdentCtx> = ty_map.keys().copied().collect();
219            let res = vars
220                .iter()
221                .cloned()
222                .zip(
223                    self.ctx
224                        .get_value(vars.iter().map(|var| sexp_map[var]).collect())
225                        .unwrap()
226                        .iter()
227                        .map(|(_var, val)| self.sexp_to_lit_val(*val).unwrap()),
228                )
229                .collect();
230
231            Some(res)
232        } else {
233            None
234        }
235    }
236}