1use lasso::Spur;
10use oxiz_core::ast::{TermId, TermKind, TermManager};
11use rustc_hash::{FxHashMap, FxHashSet};
12use std::cmp::Ordering;
13use std::collections::BinaryHeap;
14
15use super::QuantifiedFormula;
16use super::model_completion::CompletedModel;
17
18#[derive(Debug, Clone)]
20pub struct MBQIHeuristics {
21 pub quantifier_selection: SelectionStrategy,
23 pub trigger_selection: TriggerSelection,
25 pub instantiation_ordering: InstantiationOrdering,
27 pub resource_allocation: ResourceAllocation,
29 pub enable_conflict_analysis: bool,
31 pub enable_model_bounds: bool,
33}
34
35impl MBQIHeuristics {
36 pub fn new() -> Self {
38 Self {
39 quantifier_selection: SelectionStrategy::PriorityBased,
40 trigger_selection: TriggerSelection::MatchingLoopAvoidance,
41 instantiation_ordering: InstantiationOrdering::CostBased,
42 resource_allocation: ResourceAllocation::Balanced,
43 enable_conflict_analysis: true,
44 enable_model_bounds: true,
45 }
46 }
47
48 pub fn conservative() -> Self {
50 Self {
51 quantifier_selection: SelectionStrategy::MostConstrained,
52 trigger_selection: TriggerSelection::MinPatterns,
53 instantiation_ordering: InstantiationOrdering::DepthFirst,
54 resource_allocation: ResourceAllocation::Conservative,
55 enable_conflict_analysis: true,
56 enable_model_bounds: true,
57 }
58 }
59
60 pub fn aggressive() -> Self {
62 Self {
63 quantifier_selection: SelectionStrategy::BreadthFirst,
64 trigger_selection: TriggerSelection::MaxCoverage,
65 instantiation_ordering: InstantiationOrdering::BreadthFirst,
66 resource_allocation: ResourceAllocation::Aggressive,
67 enable_conflict_analysis: false,
68 enable_model_bounds: false,
69 }
70 }
71}
72
73impl Default for MBQIHeuristics {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum SelectionStrategy {
82 Sequential,
84 PriorityBased,
86 BreadthFirst,
88 DepthFirst,
90 MostConstrained,
92 LeastConstrained,
94 Random,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum TriggerSelection {
101 All,
103 MinVars,
105 MinPatterns,
107 MaxCoverage,
109 MatchingLoopAvoidance,
111 UserOnly,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
117pub enum InstantiationOrdering {
118 CostBased,
120 DepthFirst,
122 BreadthFirst,
124 SimplestFirst,
126 GroundnessFirst,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum ResourceAllocation {
133 Conservative,
135 Balanced,
137 Aggressive,
139 Adaptive,
141}
142
143#[derive(Debug)]
145pub struct InstantiationHeuristic {
146 config: MBQIHeuristics,
148 quantifier_scores: FxHashMap<TermId, f64>,
150 pattern_scores: FxHashMap<TermId, f64>,
152 success_history: FxHashMap<TermId, SuccessRate>,
154}
155
156impl InstantiationHeuristic {
157 pub fn new(config: MBQIHeuristics) -> Self {
159 Self {
160 config,
161 quantifier_scores: FxHashMap::default(),
162 pattern_scores: FxHashMap::default(),
163 success_history: FxHashMap::default(),
164 }
165 }
166
167 pub fn calculate_priority(
169 &mut self,
170 quantifier: &QuantifiedFormula,
171 model: &CompletedModel,
172 manager: &TermManager,
173 ) -> f64 {
174 if let Some(&cached) = self.quantifier_scores.get(&quantifier.term) {
176 return cached;
177 }
178
179 let score = match self.config.quantifier_selection {
180 SelectionStrategy::Sequential => 1.0,
181 SelectionStrategy::PriorityBased => self.priority_based_score(quantifier, manager),
182 SelectionStrategy::BreadthFirst => 1.0 / (1.0 + quantifier.instantiation_count as f64),
183 SelectionStrategy::DepthFirst => quantifier.instantiation_count as f64,
184 SelectionStrategy::MostConstrained => self.constraint_score(quantifier, model, manager),
185 SelectionStrategy::LeastConstrained => {
186 -self.constraint_score(quantifier, model, manager)
187 }
188 SelectionStrategy::Random => {
189 let hash = quantifier.term.raw() as u64;
191 ((hash.wrapping_mul(2654435761) >> 32) as f64) / (u32::MAX as f64)
192 }
193 };
194
195 self.quantifier_scores.insert(quantifier.term, score);
196 score
197 }
198
199 fn priority_based_score(&self, quantifier: &QuantifiedFormula, manager: &TermManager) -> f64 {
201 let weight_factor = quantifier.weight;
203 let inst_factor = 1.0 / (1.0 + quantifier.instantiation_count as f64);
204 let depth_factor = 1.0 / (1.0 + quantifier.nesting_depth as f64);
205 let complexity_factor = 1.0 / (1.0 + self.body_complexity(quantifier.body, manager) as f64);
206
207 weight_factor * inst_factor * depth_factor * complexity_factor
208 }
209
210 fn constraint_score(
212 &self,
213 quantifier: &QuantifiedFormula,
214 model: &CompletedModel,
215 manager: &TermManager,
216 ) -> f64 {
217 let mut score = 0.0;
218
219 for &(_name, sort) in &quantifier.bound_vars {
221 let num_candidates = model.universe(sort).map_or(0, |u| u.len());
222 if num_candidates > 0 {
223 score += 1.0 / num_candidates as f64;
224 } else {
225 score += 1.0;
226 }
227 }
228
229 let complexity = self.body_complexity(quantifier.body, manager);
231 score += complexity as f64 * 0.1;
232
233 score
234 }
235
236 fn body_complexity(&self, term: TermId, manager: &TermManager) -> usize {
238 let mut visited = FxHashSet::default();
239 self.body_complexity_rec(term, manager, &mut visited)
240 }
241
242 fn body_complexity_rec(
243 &self,
244 term: TermId,
245 manager: &TermManager,
246 visited: &mut FxHashSet<TermId>,
247 ) -> usize {
248 if visited.contains(&term) {
249 return 0;
250 }
251 visited.insert(term);
252
253 let Some(t) = manager.get(term) else {
254 return 1;
255 };
256
257 let children_complexity = match &t.kind {
258 TermKind::And(args) | TermKind::Or(args) => args
259 .iter()
260 .map(|&arg| self.body_complexity_rec(arg, manager, visited))
261 .sum(),
262 TermKind::Not(arg) | TermKind::Neg(arg) => {
263 self.body_complexity_rec(*arg, manager, visited)
264 }
265 TermKind::Eq(lhs, rhs)
266 | TermKind::Lt(lhs, rhs)
267 | TermKind::Le(lhs, rhs)
268 | TermKind::Gt(lhs, rhs)
269 | TermKind::Ge(lhs, rhs) => {
270 self.body_complexity_rec(*lhs, manager, visited)
271 + self.body_complexity_rec(*rhs, manager, visited)
272 }
273 TermKind::Apply { args, .. } => args
274 .iter()
275 .map(|&arg| self.body_complexity_rec(arg, manager, visited))
276 .sum(),
277 _ => 0,
278 };
279
280 1 + children_complexity
281 }
282
283 pub fn select_patterns(
285 &self,
286 quantifier: &QuantifiedFormula,
287 manager: &TermManager,
288 ) -> Vec<Vec<TermId>> {
289 match self.config.trigger_selection {
290 TriggerSelection::All => quantifier.patterns.clone(),
291 TriggerSelection::MinVars => self.select_min_vars_patterns(quantifier, manager),
292 TriggerSelection::MinPatterns => self.select_min_patterns(quantifier),
293 TriggerSelection::MaxCoverage => self.select_max_coverage_patterns(quantifier, manager),
294 TriggerSelection::MatchingLoopAvoidance => {
295 self.select_loop_avoiding_patterns(quantifier, manager)
296 }
297 TriggerSelection::UserOnly => quantifier.patterns.clone(),
298 }
299 }
300
301 fn select_min_vars_patterns(
302 &self,
303 quantifier: &QuantifiedFormula,
304 manager: &TermManager,
305 ) -> Vec<Vec<TermId>> {
306 if quantifier.patterns.is_empty() {
307 return vec![];
308 }
309
310 let mut patterns_with_vars: Vec<_> = quantifier
311 .patterns
312 .iter()
313 .map(|pattern| {
314 let num_vars = self.count_vars_in_pattern(pattern, manager);
315 (pattern.clone(), num_vars)
316 })
317 .collect();
318
319 patterns_with_vars.sort_by_key(|(_, num_vars)| *num_vars);
320
321 vec![patterns_with_vars[0].0.clone()]
322 }
323
324 fn select_min_patterns(&self, quantifier: &QuantifiedFormula) -> Vec<Vec<TermId>> {
325 if quantifier.patterns.is_empty() {
326 return vec![];
327 }
328
329 let min_pattern = quantifier
331 .patterns
332 .iter()
333 .min_by_key(|pattern| pattern.len())
334 .cloned();
335
336 min_pattern.map_or_else(Vec::new, |p| vec![p])
337 }
338
339 fn select_max_coverage_patterns(
340 &self,
341 quantifier: &QuantifiedFormula,
342 manager: &TermManager,
343 ) -> Vec<Vec<TermId>> {
344 let mut selected = Vec::new();
346 let mut covered_vars: FxHashSet<Spur> = FxHashSet::default();
347
348 for pattern in &quantifier.patterns {
349 let pattern_vars = self.collect_vars_in_pattern(pattern, manager);
350 let new_vars: FxHashSet<_> = pattern_vars.difference(&covered_vars).copied().collect();
351
352 if !new_vars.is_empty() {
353 selected.push(pattern.clone());
354 covered_vars.extend(new_vars);
355 }
356
357 if covered_vars.len() >= quantifier.num_vars() {
358 break;
359 }
360 }
361
362 selected
363 }
364
365 fn select_loop_avoiding_patterns(
366 &self,
367 quantifier: &QuantifiedFormula,
368 manager: &TermManager,
369 ) -> Vec<Vec<TermId>> {
370 quantifier
372 .patterns
373 .iter()
374 .filter(|pattern| !self.contains_quantified_symbol(pattern, quantifier, manager))
375 .cloned()
376 .collect()
377 }
378
379 fn count_vars_in_pattern(&self, pattern: &[TermId], manager: &TermManager) -> usize {
380 self.collect_vars_in_pattern(pattern, manager).len()
381 }
382
383 fn collect_vars_in_pattern(
384 &self,
385 pattern: &[TermId],
386 manager: &TermManager,
387 ) -> FxHashSet<Spur> {
388 let mut vars = FxHashSet::default();
389 let mut visited = FxHashSet::default();
390
391 for &term in pattern {
392 self.collect_vars_rec(term, &mut vars, &mut visited, manager);
393 }
394
395 vars
396 }
397
398 fn collect_vars_rec(
399 &self,
400 term: TermId,
401 vars: &mut FxHashSet<Spur>,
402 visited: &mut FxHashSet<TermId>,
403 manager: &TermManager,
404 ) {
405 if visited.contains(&term) {
406 return;
407 }
408 visited.insert(term);
409
410 let Some(t) = manager.get(term) else {
411 return;
412 };
413
414 if let TermKind::Var(name) = t.kind {
415 vars.insert(name);
416 return;
417 }
418
419 match &t.kind {
420 TermKind::Apply { args, .. } => {
421 for &arg in args.iter() {
422 self.collect_vars_rec(arg, vars, visited, manager);
423 }
424 }
425 TermKind::Not(arg) | TermKind::Neg(arg) => {
426 self.collect_vars_rec(*arg, vars, visited, manager);
427 }
428 _ => {}
429 }
430 }
431
432 fn contains_quantified_symbol(
433 &self,
434 pattern: &[TermId],
435 _quantifier: &QuantifiedFormula,
436 manager: &TermManager,
437 ) -> bool {
438 for &term in pattern {
439 if self.is_function_application(term, manager) {
440 return true;
441 }
442 }
443 false
444 }
445
446 fn is_function_application(&self, term: TermId, manager: &TermManager) -> bool {
447 let Some(t) = manager.get(term) else {
448 return false;
449 };
450 matches!(t.kind, TermKind::Apply { .. })
451 }
452
453 pub fn record_result(&mut self, quantifier: TermId, success: bool) {
455 let entry = self
456 .success_history
457 .entry(quantifier)
458 .or_insert_with(SuccessRate::new);
459 entry.record(success);
460 }
461
462 pub fn success_rate(&self, quantifier: TermId) -> f64 {
464 self.success_history
465 .get(&quantifier)
466 .map_or(0.5, |sr| sr.rate())
467 }
468}
469
470#[derive(Debug, Clone)]
472struct SuccessRate {
473 successes: usize,
474 failures: usize,
475}
476
477impl SuccessRate {
478 fn new() -> Self {
479 Self {
480 successes: 0,
481 failures: 0,
482 }
483 }
484
485 fn record(&mut self, success: bool) {
486 if success {
487 self.successes += 1;
488 } else {
489 self.failures += 1;
490 }
491 }
492
493 fn rate(&self) -> f64 {
494 let total = self.successes + self.failures;
495 if total == 0 {
496 0.5
497 } else {
498 self.successes as f64 / total as f64
499 }
500 }
501}
502
503#[derive(Debug)]
505pub struct QuantifierQueue {
506 heap: BinaryHeap<ScoredQuantifier>,
508 heuristic: InstantiationHeuristic,
510}
511
512impl QuantifierQueue {
513 pub fn new(heuristic: InstantiationHeuristic) -> Self {
515 Self {
516 heap: BinaryHeap::new(),
517 heuristic,
518 }
519 }
520
521 pub fn push(
523 &mut self,
524 quantifier: QuantifiedFormula,
525 model: &CompletedModel,
526 manager: &TermManager,
527 ) {
528 let score = self
529 .heuristic
530 .calculate_priority(&quantifier, model, manager);
531 self.heap.push(ScoredQuantifier { quantifier, score });
532 }
533
534 pub fn pop(&mut self) -> Option<QuantifiedFormula> {
536 self.heap.pop().map(|sq| sq.quantifier)
537 }
538
539 pub fn is_empty(&self) -> bool {
541 self.heap.is_empty()
542 }
543
544 pub fn len(&self) -> usize {
546 self.heap.len()
547 }
548}
549
550#[derive(Debug, Clone)]
552struct ScoredQuantifier {
553 quantifier: QuantifiedFormula,
554 score: f64,
555}
556
557impl PartialEq for ScoredQuantifier {
558 fn eq(&self, other: &Self) -> bool {
559 self.score == other.score
560 }
561}
562
563impl Eq for ScoredQuantifier {}
564
565impl PartialOrd for ScoredQuantifier {
566 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
567 Some(self.cmp(other))
568 }
569}
570
571impl Ord for ScoredQuantifier {
572 fn cmp(&self, other: &Self) -> Ordering {
573 self.score
575 .partial_cmp(&other.score)
576 .unwrap_or(Ordering::Equal)
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583
584 #[test]
585 fn test_mbqi_heuristics_creation() {
586 let heuristics = MBQIHeuristics::new();
587 assert!(heuristics.enable_conflict_analysis);
588 }
589
590 #[test]
591 fn test_conservative_heuristics() {
592 let heuristics = MBQIHeuristics::conservative();
593 assert_eq!(
594 heuristics.quantifier_selection,
595 SelectionStrategy::MostConstrained
596 );
597 }
598
599 #[test]
600 fn test_aggressive_heuristics() {
601 let heuristics = MBQIHeuristics::aggressive();
602 assert_eq!(
603 heuristics.quantifier_selection,
604 SelectionStrategy::BreadthFirst
605 );
606 }
607
608 #[test]
609 fn test_instantiation_heuristic_creation() {
610 let config = MBQIHeuristics::new();
611 let heuristic = InstantiationHeuristic::new(config);
612 assert_eq!(heuristic.quantifier_scores.len(), 0);
613 }
614
615 #[test]
616 fn test_success_rate_tracker() {
617 let mut sr = SuccessRate::new();
618 assert_eq!(sr.rate(), 0.5);
619
620 sr.record(true);
621 assert_eq!(sr.rate(), 1.0);
622
623 sr.record(false);
624 assert_eq!(sr.rate(), 0.5);
625 }
626
627 #[test]
628 fn test_quantifier_queue_creation() {
629 let config = MBQIHeuristics::new();
630 let heuristic = InstantiationHeuristic::new(config);
631 let queue = QuantifierQueue::new(heuristic);
632 assert!(queue.is_empty());
633 }
634
635 #[test]
636 fn test_selection_strategy_equality() {
637 assert_eq!(SelectionStrategy::Sequential, SelectionStrategy::Sequential);
638 assert_ne!(SelectionStrategy::Sequential, SelectionStrategy::Random);
639 }
640
641 #[test]
642 fn test_trigger_selection_equality() {
643 assert_eq!(TriggerSelection::All, TriggerSelection::All);
644 assert_ne!(TriggerSelection::All, TriggerSelection::MinVars);
645 }
646
647 #[test]
648 fn test_resource_allocation_equality() {
649 assert_eq!(ResourceAllocation::Balanced, ResourceAllocation::Balanced);
650 assert_ne!(ResourceAllocation::Balanced, ResourceAllocation::Aggressive);
651 }
652}