1use lasso::Spur;
23use oxiz_core::ast::{TermId, TermKind, TermManager};
24use oxiz_core::error::Result;
25use oxiz_core::sort::SortId;
26use oxiz_core::tactic::{GroundTermCollector, PatternMatcher};
27use rustc_hash::{FxHashMap, FxHashSet};
28use smallvec::SmallVec;
29
30use crate::theory::{EqualityNotification, Theory, TheoryCombination, TheoryId, TheoryResult};
31
32#[derive(Debug, Clone)]
34pub struct QuantifierConfig {
35 pub enable_ematch: bool,
37 pub enable_mbqi: bool,
39 pub max_inst_per_quantifier: usize,
41 pub max_total_instantiations: usize,
43 pub eagerness: u8,
45}
46
47impl Default for QuantifierConfig {
48 fn default() -> Self {
49 Self {
50 enable_ematch: true,
51 enable_mbqi: true,
52 max_inst_per_quantifier: 100,
53 max_total_instantiations: 10000,
54 eagerness: 5,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct TrackedQuantifier {
62 pub term: TermId,
64 pub bound_vars: SmallVec<[(Spur, SortId); 2]>,
66 pub body: TermId,
68 pub patterns: SmallVec<[SmallVec<[TermId; 2]>; 2]>,
70 pub universal: bool,
72 pub instantiation_count: usize,
74}
75
76impl TrackedQuantifier {
77 pub fn from_forall(
79 term: TermId,
80 vars: SmallVec<[(Spur, SortId); 2]>,
81 body: TermId,
82 patterns: SmallVec<[SmallVec<[TermId; 2]>; 2]>,
83 ) -> Self {
84 Self {
85 term,
86 bound_vars: vars,
87 body,
88 patterns,
89 universal: true,
90 instantiation_count: 0,
91 }
92 }
93
94 pub fn from_exists(
96 term: TermId,
97 vars: SmallVec<[(Spur, SortId); 2]>,
98 body: TermId,
99 patterns: SmallVec<[SmallVec<[TermId; 2]>; 2]>,
100 ) -> Self {
101 Self {
102 term,
103 bound_vars: vars,
104 body,
105 patterns,
106 universal: false,
107 instantiation_count: 0,
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct InstantiationLemma {
115 pub quantifier: TermId,
117 pub substitution: FxHashMap<Spur, TermId>,
119 pub instance: TermId,
121}
122
123#[derive(Debug, Clone, Default)]
125pub struct QuantifierStats {
126 pub num_quantifiers: usize,
128 pub ematch_instantiations: usize,
130 pub mbqi_instantiations: usize,
132 pub total_instantiations: usize,
134 pub conflicts: usize,
136}
137
138#[derive(Debug)]
140pub struct QuantifierSolver {
141 config: QuantifierConfig,
143 quantifiers: Vec<TrackedQuantifier>,
145 pattern_matcher: PatternMatcher,
147 ground_collector: GroundTermCollector,
149 generated_instances: FxHashSet<(TermId, Vec<(Spur, TermId)>)>,
151 context_stack: Vec<usize>,
153 pending_lemmas: Vec<InstantiationLemma>,
155 stats: QuantifierStats,
157}
158
159impl Default for QuantifierSolver {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165impl QuantifierSolver {
166 pub fn new() -> Self {
168 Self::with_config(QuantifierConfig::default())
169 }
170
171 pub fn with_config(config: QuantifierConfig) -> Self {
173 Self {
174 config,
175 quantifiers: Vec::new(),
176 pattern_matcher: PatternMatcher::new(),
177 ground_collector: GroundTermCollector::new(),
178 generated_instances: FxHashSet::default(),
179 context_stack: Vec::new(),
180 pending_lemmas: Vec::new(),
181 stats: QuantifierStats::default(),
182 }
183 }
184
185 pub fn config(&self) -> &QuantifierConfig {
187 &self.config
188 }
189
190 pub fn set_config(&mut self, config: QuantifierConfig) {
192 self.config = config;
193 }
194
195 pub fn stats(&self) -> &QuantifierStats {
197 &self.stats
198 }
199
200 pub fn reset_stats(&mut self) {
202 self.stats = QuantifierStats::default();
203 }
204
205 pub fn add_quantifier(&mut self, term: TermId, manager: &TermManager) {
207 let Some(t) = manager.get(term) else {
208 return;
209 };
210
211 match &t.kind {
212 TermKind::Forall {
213 vars,
214 body,
215 patterns,
216 } => {
217 let tracked =
218 TrackedQuantifier::from_forall(term, vars.clone(), *body, patterns.clone());
219
220 self.pattern_matcher.add_pattern(term, manager);
222
223 self.quantifiers.push(tracked);
224 self.stats.num_quantifiers += 1;
225 }
226 TermKind::Exists {
227 vars,
228 body,
229 patterns,
230 } => {
231 let tracked =
232 TrackedQuantifier::from_exists(term, vars.clone(), *body, patterns.clone());
233 self.quantifiers.push(tracked);
234 self.stats.num_quantifiers += 1;
235 }
236 _ => {}
237 }
238 }
239
240 pub fn collect_ground_terms(&mut self, term: TermId, manager: &TermManager) {
242 self.ground_collector.collect(term, manager);
243 }
244
245 pub fn get_ground_terms(&self, sort: SortId) -> &[TermId] {
247 self.ground_collector.get_terms(sort)
248 }
249
250 pub fn do_ematch(&mut self, manager: &mut TermManager) -> Vec<InstantiationLemma> {
252 if !self.config.enable_ematch {
253 return Vec::new();
254 }
255
256 let mut lemmas = Vec::new();
257
258 let bindings = self
260 .pattern_matcher
261 .match_against(&self.ground_collector, manager);
262
263 for binding in bindings {
264 if self.stats.total_instantiations >= self.config.max_total_instantiations {
266 break;
267 }
268
269 let Some(quantifier_id) = self.pattern_matcher.get_quantifier(binding.pattern_idx)
271 else {
272 continue;
273 };
274
275 let quantifier_idx = self
276 .quantifiers
277 .iter()
278 .position(|q| q.term == quantifier_id);
279 let Some(idx) = quantifier_idx else {
280 continue;
281 };
282
283 if self.quantifiers[idx].instantiation_count >= self.config.max_inst_per_quantifier {
284 continue;
285 }
286
287 let mut key_vec: Vec<_> = binding.substitution.iter().map(|(&k, &v)| (k, v)).collect();
289 key_vec.sort_by_key(|(k, _)| k.into_inner());
290 let key = (quantifier_id, key_vec.clone());
291
292 if self.generated_instances.contains(&key) {
293 continue;
294 }
295
296 let Some(instance) = self.pattern_matcher.instantiate(&binding, manager) else {
298 continue;
299 };
300
301 let lemma = InstantiationLemma {
302 quantifier: quantifier_id,
303 substitution: binding.substitution.clone(),
304 instance,
305 };
306
307 self.generated_instances.insert(key);
308 self.quantifiers[idx].instantiation_count += 1;
309 self.stats.ematch_instantiations += 1;
310 self.stats.total_instantiations += 1;
311 lemmas.push(lemma);
312 }
313
314 lemmas
315 }
316
317 pub fn get_pending_lemmas(&mut self) -> Vec<InstantiationLemma> {
319 std::mem::take(&mut self.pending_lemmas)
320 }
321
322 pub fn has_quantifiers(&self) -> bool {
324 !self.quantifiers.is_empty()
325 }
326
327 pub fn num_quantifiers(&self) -> usize {
329 self.quantifiers.len()
330 }
331}
332
333impl Theory for QuantifierSolver {
334 fn id(&self) -> TheoryId {
335 TheoryId::Bool
337 }
338
339 fn name(&self) -> &str {
340 "Quantifier"
341 }
342
343 fn can_handle(&self, term: TermId) -> bool {
344 let _ = term;
347 false
348 }
349
350 fn assert_true(&mut self, _term: TermId) -> Result<TheoryResult> {
351 self.pending_lemmas.clear();
354 Ok(TheoryResult::Sat)
355 }
356
357 fn assert_false(&mut self, _term: TermId) -> Result<TheoryResult> {
358 self.pending_lemmas.clear();
360 Ok(TheoryResult::Sat)
361 }
362
363 fn check(&mut self) -> Result<TheoryResult> {
364 if !self.pending_lemmas.is_empty() {
366 let propagations: Vec<_> = self
367 .pending_lemmas
368 .iter()
369 .map(|l| (l.instance, vec![l.quantifier]))
370 .collect();
371 return Ok(TheoryResult::Propagate(propagations));
372 }
373
374 Ok(TheoryResult::Sat)
375 }
376
377 fn push(&mut self) {
378 self.context_stack.push(self.quantifiers.len());
379 }
380
381 fn pop(&mut self) {
382 if let Some(num) = self.context_stack.pop() {
383 self.quantifiers.truncate(num);
384 }
385 }
386
387 fn reset(&mut self) {
388 self.quantifiers.clear();
389 self.pattern_matcher = PatternMatcher::new();
390 self.ground_collector.clear();
391 self.generated_instances.clear();
392 self.context_stack.clear();
393 self.pending_lemmas.clear();
394 self.stats = QuantifierStats::default();
395 }
396
397 fn get_model(&self) -> Vec<(TermId, TermId)> {
398 Vec::new()
399 }
400}
401
402impl TheoryCombination for QuantifierSolver {
403 fn notify_equality(&mut self, _eq: EqualityNotification) -> bool {
404 true
407 }
408
409 fn get_shared_equalities(&self) -> Vec<EqualityNotification> {
410 Vec::new()
411 }
412
413 fn is_relevant(&self, term: TermId) -> bool {
414 self.quantifiers.iter().any(|q| q.term == term)
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_quantifier_solver_new() {
425 let solver = QuantifierSolver::new();
426 assert_eq!(solver.num_quantifiers(), 0);
427 assert!(!solver.has_quantifiers());
428 }
429
430 #[test]
431 fn test_quantifier_config_default() {
432 let config = QuantifierConfig::default();
433 assert!(config.enable_ematch);
434 assert!(config.enable_mbqi);
435 assert_eq!(config.eagerness, 5);
436 }
437
438 #[test]
439 fn test_quantifier_solver_push_pop() {
440 let mut solver = QuantifierSolver::new();
441
442 solver.push();
443 solver.pop();
445
446 assert_eq!(solver.num_quantifiers(), 0);
447 }
448
449 #[test]
450 fn test_quantifier_solver_reset() {
451 let mut solver = QuantifierSolver::new();
452 let mut manager = TermManager::new();
453
454 let one = manager.mk_int(1);
456 solver.collect_ground_terms(one, &manager);
457 solver.reset();
458
459 assert!(solver.ground_collector.is_empty());
460 }
461
462 #[test]
463 fn test_quantifier_solver_stats() {
464 let solver = QuantifierSolver::new();
465 let stats = solver.stats();
466
467 assert_eq!(stats.num_quantifiers, 0);
468 assert_eq!(stats.total_instantiations, 0);
469 }
470
471 #[test]
472 fn test_theory_trait() {
473 let mut solver = QuantifierSolver::new();
474
475 assert_eq!(solver.name(), "Quantifier");
476
477 solver.push();
478 solver.pop();
479
480 let result = solver.check().unwrap();
481 assert!(matches!(result, TheoryResult::Sat));
482 }
483}