prune_lang/constr/
solver.rs1use super::subst::*;
2use super::*;
3use crate::cli::args;
4use backend::SmtSolver;
5
6pub struct Solver {
7 ty_map: EnvMap<IdentCtx, TypeId>,
8 subst: Subst,
9 constr: Box<dyn SmtSolver>,
10 unify_vec: Vec<(TermCtx, TermCtx)>,
11 solve_vec: Vec<(Prim, Vec<AtomCtx>)>,
12 saves: Vec<(usize, usize)>,
13}
14
15impl fmt::Display for Solver {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 let unify_vec = self
18 .unify_vec
19 .iter()
20 .map(|(lhs, rhs)| format!("{} = {}", lhs, rhs))
21 .format(", ");
22 writeln!(f, "unify: [{}]", unify_vec)?;
23
24 let solve_vec = self
25 .solve_vec
26 .iter()
27 .map(|(prim, args)| format!("{:?}({})", prim, args.iter().format(", ")))
28 .format(",");
29 writeln!(f, "solve: [{}]", solve_vec)?;
30 Ok(())
31 }
32}
33
34impl Solver {
35 pub fn new(backend: args::SmtBackend) -> Solver {
36 let subst = Subst::new();
37
38 let constr = match backend {
39 args::SmtBackend::Z3Inc => Box::new(backend::incr_smt::IncrSmtSolver::new(
40 backend::SmtBackend::Z3,
41 )) as Box<dyn SmtSolver>,
42 args::SmtBackend::Z3Sq => Box::new(backend::non_incr_smt::NonIncrSmtSolver::new(
43 backend::SmtBackend::Z3,
44 )) as Box<dyn SmtSolver>,
45 args::SmtBackend::CVC5Inc => Box::new(backend::incr_smt::IncrSmtSolver::new(
46 backend::SmtBackend::CVC5,
47 )) as Box<dyn SmtSolver>,
48 args::SmtBackend::CVC5Sq => Box::new(backend::non_incr_smt::NonIncrSmtSolver::new(
49 backend::SmtBackend::CVC5,
50 )) as Box<dyn SmtSolver>,
51 args::SmtBackend::NoSmt => {
52 Box::new(backend::no_smt::NoSmtSolver::new()) as Box<dyn SmtSolver>
53 }
54 };
55
56 Solver {
57 ty_map: EnvMap::new(),
58 subst,
59 constr,
60 unify_vec: Vec::new(),
61 solve_vec: Vec::new(),
62 saves: Vec::new(),
63 }
64 }
65
66 pub fn is_empty(&self) -> bool {
67 self.saves.is_empty() && self.subst.is_empty() && self.constr.is_empty()
68 }
69
70 pub fn reset(&mut self) {
71 self.ty_map.clear();
72 self.subst.reset();
73 self.constr.reset();
74 self.unify_vec.clear();
75 self.solve_vec.clear();
76 self.saves.clear();
77 }
78
79 pub fn savepoint(&mut self) {
80 self.ty_map.enter_scope();
81 self.subst.savepoint();
82 self.constr.savepoint();
83 self.saves
84 .push((self.unify_vec.len(), self.solve_vec.len()));
85 }
86
87 pub fn backtrack(&mut self) {
88 assert!(!self.saves.is_empty());
89 self.ty_map.leave_scope();
90 self.subst.backtrack();
91 self.constr.backtrack();
92 let (len1, len2) = self.saves.pop().unwrap();
93 for _ in 0..(self.unify_vec.len() - len1) {
94 self.unify_vec.pop().unwrap();
95 }
96 for _ in 0..(self.solve_vec.len() - len2) {
97 self.solve_vec.pop().unwrap();
98 }
99 }
100}
101
102impl Solver {
103 pub fn declare(&mut self, var: &IdentCtx, typ: &TypeId) {
104 assert!(!self.ty_map.contains_key(var));
105 self.ty_map.insert(*var, typ.clone());
106 if let Term::Lit(lit) = typ {
107 self.constr.declare_var(var, lit);
108 }
109 }
110
111 pub fn unify(&mut self, lhs: TermCtx, rhs: TermCtx) -> Option<()> {
112 self.unify_vec.push((lhs.clone(), rhs.clone()));
113 let mut subst = self.subst.unify(lhs, rhs)?;
114 for (x, term) in subst.drain(..) {
115 if self.ty_map[&x].is_lit() {
116 self.constr.push_eq(x, term);
117 }
118 }
119 Some(())
120 }
121
122 pub fn push_cons(&mut self, prim: Prim, args: Vec<AtomCtx>) {
123 self.solve_vec.push((prim, args.clone()));
124 self.constr.push_cons(prim, args);
125 }
126
127 pub fn check_complete(&mut self) -> bool {
128 self.constr.check_complete()
129 }
130
131 pub fn check_sound(&mut self) -> bool {
132 self.constr.check_sound()
133 }
134
135 pub fn get_value(&mut self, vars: &[IdentCtx]) -> Vec<TermCtx> {
136 let terms: Vec<TermCtx> = vars
137 .iter()
138 .map(|var| self.subst.merge(&Term::Var(*var)))
139 .collect();
140
141 let lit_vars: Vec<IdentCtx> = terms
142 .iter()
143 .flat_map(|term| {
144 term.free_vars()
145 .iter()
146 .filter(|var| self.ty_map[var].is_lit())
147 .cloned()
148 .collect::<Vec<_>>()
149 })
150 .collect();
151
152 if lit_vars.is_empty() {
153 return terms;
154 }
155
156 let map = self
157 .constr
158 .get_value(&lit_vars)
159 .into_iter()
160 .map(|(k, v)| (k, Term::Lit(v)))
161 .collect();
162
163 terms
164 .into_iter()
165 .map(|term| term.substitute(&map))
166 .collect()
167 }
168}
169
170#[test]
171fn test_solver() {
172 let x = Ident::dummy(&"x");
173 let y = Ident::dummy(&"y");
174 let z = Ident::dummy(&"z");
175 let cons = Ident::dummy(&"cons");
176
177 let mut sol: Solver = Solver::new(args::SmtBackend::Z3Inc);
178
179 sol.declare(&x.tag_ctx(0), &TypeId::Lit(LitType::TyInt));
180 sol.declare(&y.tag_ctx(0), &TypeId::Lit(LitType::TyInt));
181
182 sol.push_cons(
183 Prim::ICmp(Compare::Lt),
184 vec![
185 Term::Var(x.tag_ctx(0)),
186 Term::Var(y.tag_ctx(0)),
187 Term::Lit(LitVal::Bool(true)),
188 ],
189 );
190
191 sol.savepoint();
192
193 sol.unify(Term::Var(x.tag_ctx(0)), Term::Var(y.tag_ctx(0)))
194 .unwrap();
195
196 assert!(!sol.check_complete());
197
198 sol.backtrack();
199 sol.savepoint();
200
201 sol.unify(
202 Term::Var(z.tag_ctx(0)),
203 Term::Cons(OptCons::Some(cons), vec![Term::Var(x.tag_ctx(0))]),
204 )
205 .unwrap();
206
207 sol.unify(
208 Term::Var(z.tag_ctx(0)),
209 Term::Cons(OptCons::Some(cons), vec![Term::Var(y.tag_ctx(0))]),
210 );
211
212 assert!(!sol.check_complete());
213
214 sol.backtrack();
215}