prune_lang/constr/
solver.rs

1use super::subst::*;
2use super::*;
3use crate::cli::args;
4use backend::SmtSolver;
5
6pub struct Solver {
7    ty_map: EnvMap<IdentCtx, TypeId>,
8    subst: Subst,
9    constr: Box<dyn SmtSolver>,
10    unify_vec: Vec<(TermCtx, TermCtx)>,
11    solve_vec: Vec<(Prim, Vec<AtomCtx>)>,
12    saves: Vec<(usize, usize)>,
13}
14
15impl fmt::Display for Solver {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        let unify_vec = self
18            .unify_vec
19            .iter()
20            .map(|(lhs, rhs)| format!("{} = {}", lhs, rhs))
21            .format(", ");
22        writeln!(f, "unify: [{}]", unify_vec)?;
23
24        let solve_vec = self
25            .solve_vec
26            .iter()
27            .map(|(prim, args)| format!("{:?}({})", prim, args.iter().format(", ")))
28            .format(",");
29        writeln!(f, "solve: [{}]", solve_vec)?;
30        Ok(())
31    }
32}
33
34impl Solver {
35    pub fn new(backend: args::SmtBackend) -> Solver {
36        let subst = Subst::new();
37
38        let constr = match backend {
39            args::SmtBackend::Z3Inc => Box::new(backend::incr_smt::IncrSmtSolver::new(
40                backend::SmtBackend::Z3,
41            )) as Box<dyn SmtSolver>,
42            args::SmtBackend::Z3Sq => Box::new(backend::non_incr_smt::NonIncrSmtSolver::new(
43                backend::SmtBackend::Z3,
44            )) as Box<dyn SmtSolver>,
45            args::SmtBackend::CVC5Inc => Box::new(backend::incr_smt::IncrSmtSolver::new(
46                backend::SmtBackend::CVC5,
47            )) as Box<dyn SmtSolver>,
48            args::SmtBackend::CVC5Sq => Box::new(backend::non_incr_smt::NonIncrSmtSolver::new(
49                backend::SmtBackend::CVC5,
50            )) as Box<dyn SmtSolver>,
51            args::SmtBackend::NoSmt => {
52                Box::new(backend::no_smt::NoSmtSolver::new()) as Box<dyn SmtSolver>
53            }
54        };
55
56        Solver {
57            ty_map: EnvMap::new(),
58            subst,
59            constr,
60            unify_vec: Vec::new(),
61            solve_vec: Vec::new(),
62            saves: Vec::new(),
63        }
64    }
65
66    pub fn is_empty(&self) -> bool {
67        self.saves.is_empty() && self.subst.is_empty() && self.constr.is_empty()
68    }
69
70    pub fn reset(&mut self) {
71        self.ty_map.clear();
72        self.subst.reset();
73        self.constr.reset();
74        self.unify_vec.clear();
75        self.solve_vec.clear();
76        self.saves.clear();
77    }
78
79    pub fn savepoint(&mut self) {
80        self.ty_map.enter_scope();
81        self.subst.savepoint();
82        self.constr.savepoint();
83        self.saves
84            .push((self.unify_vec.len(), self.solve_vec.len()));
85    }
86
87    pub fn backtrack(&mut self) {
88        assert!(!self.saves.is_empty());
89        self.ty_map.leave_scope();
90        self.subst.backtrack();
91        self.constr.backtrack();
92        let (len1, len2) = self.saves.pop().unwrap();
93        for _ in 0..(self.unify_vec.len() - len1) {
94            self.unify_vec.pop().unwrap();
95        }
96        for _ in 0..(self.solve_vec.len() - len2) {
97            self.solve_vec.pop().unwrap();
98        }
99    }
100}
101
102impl Solver {
103    pub fn declare(&mut self, var: &IdentCtx, typ: &TypeId) {
104        assert!(!self.ty_map.contains_key(var));
105        self.ty_map.insert(*var, typ.clone());
106        if let Term::Lit(lit) = typ {
107            self.constr.declare_var(var, lit);
108        }
109    }
110
111    pub fn unify(&mut self, lhs: TermCtx, rhs: TermCtx) -> Option<()> {
112        self.unify_vec.push((lhs.clone(), rhs.clone()));
113        let mut subst = self.subst.unify(lhs, rhs)?;
114        for (x, term) in subst.drain(..) {
115            if self.ty_map[&x].is_lit() {
116                self.constr.push_eq(x, term);
117            }
118        }
119        Some(())
120    }
121
122    pub fn push_cons(&mut self, prim: Prim, args: Vec<AtomCtx>) {
123        self.solve_vec.push((prim, args.clone()));
124        self.constr.push_cons(prim, args);
125    }
126
127    pub fn check_complete(&mut self) -> bool {
128        self.constr.check_complete()
129    }
130
131    pub fn check_sound(&mut self) -> bool {
132        self.constr.check_sound()
133    }
134
135    pub fn get_value(&mut self, vars: &[IdentCtx]) -> Vec<TermCtx> {
136        let terms: Vec<TermCtx> = vars
137            .iter()
138            .map(|var| self.subst.merge(&Term::Var(*var)))
139            .collect();
140
141        let lit_vars: Vec<IdentCtx> = terms
142            .iter()
143            .flat_map(|term| {
144                term.free_vars()
145                    .iter()
146                    .filter(|var| self.ty_map[var].is_lit())
147                    .cloned()
148                    .collect::<Vec<_>>()
149            })
150            .collect();
151
152        if lit_vars.is_empty() {
153            return terms;
154        }
155
156        let map = self
157            .constr
158            .get_value(&lit_vars)
159            .into_iter()
160            .map(|(k, v)| (k, Term::Lit(v)))
161            .collect();
162
163        terms
164            .into_iter()
165            .map(|term| term.substitute(&map))
166            .collect()
167    }
168}
169
170#[test]
171fn test_solver() {
172    let x = Ident::dummy(&"x");
173    let y = Ident::dummy(&"y");
174    let z = Ident::dummy(&"z");
175    let cons = Ident::dummy(&"cons");
176
177    let mut sol: Solver = Solver::new(args::SmtBackend::Z3Inc);
178
179    sol.declare(&x.tag_ctx(0), &TypeId::Lit(LitType::TyInt));
180    sol.declare(&y.tag_ctx(0), &TypeId::Lit(LitType::TyInt));
181
182    sol.push_cons(
183        Prim::ICmp(Compare::Lt),
184        vec![
185            Term::Var(x.tag_ctx(0)),
186            Term::Var(y.tag_ctx(0)),
187            Term::Lit(LitVal::Bool(true)),
188        ],
189    );
190
191    sol.savepoint();
192
193    sol.unify(Term::Var(x.tag_ctx(0)), Term::Var(y.tag_ctx(0)))
194        .unwrap();
195
196    assert!(!sol.check_complete());
197
198    sol.backtrack();
199    sol.savepoint();
200
201    sol.unify(
202        Term::Var(z.tag_ctx(0)),
203        Term::Cons(OptCons::Some(cons), vec![Term::Var(x.tag_ctx(0))]),
204    )
205    .unwrap();
206
207    sol.unify(
208        Term::Var(z.tag_ctx(0)),
209        Term::Cons(OptCons::Some(cons), vec![Term::Var(y.tag_ctx(0))]),
210    );
211
212    assert!(!sol.check_complete());
213
214    sol.backtrack();
215}