Skip to main content

oxiz_solver/mbqi/
heuristics.rs

1//! MBQI Heuristics and Selection Strategies
2//!
3//! This module implements various heuristics for guiding MBQI, including:
4//! - Quantifier selection strategies
5//! - Trigger/pattern selection
6//! - Instantiation ordering
7//! - Resource allocation
8
9use lasso::Spur;
10use oxiz_core::ast::{TermId, TermKind, TermManager};
11use rustc_hash::{FxHashMap, FxHashSet};
12use std::cmp::Ordering;
13use std::collections::BinaryHeap;
14
15use super::QuantifiedFormula;
16use super::model_completion::CompletedModel;
17
18/// Overall MBQI heuristics configuration
19#[derive(Debug, Clone)]
20pub struct MBQIHeuristics {
21    /// Quantifier selection strategy
22    pub quantifier_selection: SelectionStrategy,
23    /// Trigger selection strategy
24    pub trigger_selection: TriggerSelection,
25    /// Instantiation ordering
26    pub instantiation_ordering: InstantiationOrdering,
27    /// Resource allocation strategy
28    pub resource_allocation: ResourceAllocation,
29    /// Enable conflict analysis
30    pub enable_conflict_analysis: bool,
31    /// Enable model-based bounds
32    pub enable_model_bounds: bool,
33}
34
35impl MBQIHeuristics {
36    /// Create default heuristics
37    pub fn new() -> Self {
38        Self {
39            quantifier_selection: SelectionStrategy::PriorityBased,
40            trigger_selection: TriggerSelection::MatchingLoopAvoidance,
41            instantiation_ordering: InstantiationOrdering::CostBased,
42            resource_allocation: ResourceAllocation::Balanced,
43            enable_conflict_analysis: true,
44            enable_model_bounds: true,
45        }
46    }
47
48    /// Create conservative heuristics (fewer instantiations)
49    pub fn conservative() -> Self {
50        Self {
51            quantifier_selection: SelectionStrategy::MostConstrained,
52            trigger_selection: TriggerSelection::MinPatterns,
53            instantiation_ordering: InstantiationOrdering::DepthFirst,
54            resource_allocation: ResourceAllocation::Conservative,
55            enable_conflict_analysis: true,
56            enable_model_bounds: true,
57        }
58    }
59
60    /// Create aggressive heuristics (more instantiations)
61    pub fn aggressive() -> Self {
62        Self {
63            quantifier_selection: SelectionStrategy::BreadthFirst,
64            trigger_selection: TriggerSelection::MaxCoverage,
65            instantiation_ordering: InstantiationOrdering::BreadthFirst,
66            resource_allocation: ResourceAllocation::Aggressive,
67            enable_conflict_analysis: false,
68            enable_model_bounds: false,
69        }
70    }
71}
72
73impl Default for MBQIHeuristics {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79/// Strategy for selecting which quantifiers to instantiate
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum SelectionStrategy {
82    /// Select in order of definition
83    Sequential,
84    /// Select based on priority scores
85    PriorityBased,
86    /// Breadth-first (rotate through quantifiers)
87    BreadthFirst,
88    /// Depth-first (exhaust one before moving to next)
89    DepthFirst,
90    /// Most constrained first
91    MostConstrained,
92    /// Least constrained first
93    LeastConstrained,
94    /// Random selection
95    Random,
96}
97
98/// Strategy for trigger/pattern selection
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum TriggerSelection {
101    /// Use all available patterns
102    All,
103    /// Use patterns with minimum variables
104    MinVars,
105    /// Use patterns with minimum terms
106    MinPatterns,
107    /// Maximize coverage of quantifier body
108    MaxCoverage,
109    /// Avoid patterns that cause matching loops
110    MatchingLoopAvoidance,
111    /// User-specified patterns only
112    UserOnly,
113}
114
115/// Ordering for generated instantiations
116#[derive(Debug, Clone, Copy, PartialEq, Eq)]
117pub enum InstantiationOrdering {
118    /// Order by estimated cost
119    CostBased,
120    /// Depth-first
121    DepthFirst,
122    /// Breadth-first
123    BreadthFirst,
124    /// Prefer simpler instantiations
125    SimplestFirst,
126    /// Prefer instantiations with more ground terms
127    GroundnessFirst,
128}
129
130/// Resource allocation strategy
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum ResourceAllocation {
133    /// Conservative (few instantiations)
134    Conservative,
135    /// Balanced
136    Balanced,
137    /// Aggressive (many instantiations)
138    Aggressive,
139    /// Adaptive (adjust based on progress)
140    Adaptive,
141}
142
143/// Instantiation heuristic scorer
144#[derive(Debug)]
145pub struct InstantiationHeuristic {
146    /// Heuristics configuration
147    config: MBQIHeuristics,
148    /// Quantifier scores
149    quantifier_scores: FxHashMap<TermId, f64>,
150    /// Pattern quality scores
151    pattern_scores: FxHashMap<TermId, f64>,
152    /// Historical success rates
153    success_history: FxHashMap<TermId, SuccessRate>,
154}
155
156impl InstantiationHeuristic {
157    /// Create a new heuristic
158    pub fn new(config: MBQIHeuristics) -> Self {
159        Self {
160            config,
161            quantifier_scores: FxHashMap::default(),
162            pattern_scores: FxHashMap::default(),
163            success_history: FxHashMap::default(),
164        }
165    }
166
167    /// Calculate priority score for a quantifier
168    pub fn calculate_priority(
169        &mut self,
170        quantifier: &QuantifiedFormula,
171        model: &CompletedModel,
172        manager: &TermManager,
173    ) -> f64 {
174        // Check cache
175        if let Some(&cached) = self.quantifier_scores.get(&quantifier.term) {
176            return cached;
177        }
178
179        let score = match self.config.quantifier_selection {
180            SelectionStrategy::Sequential => 1.0,
181            SelectionStrategy::PriorityBased => self.priority_based_score(quantifier, manager),
182            SelectionStrategy::BreadthFirst => 1.0 / (1.0 + quantifier.instantiation_count as f64),
183            SelectionStrategy::DepthFirst => quantifier.instantiation_count as f64,
184            SelectionStrategy::MostConstrained => self.constraint_score(quantifier, model, manager),
185            SelectionStrategy::LeastConstrained => {
186                -self.constraint_score(quantifier, model, manager)
187            }
188            SelectionStrategy::Random => {
189                // Use deterministic pseudo-random based on term ID
190                let hash = quantifier.term.raw() as u64;
191                ((hash.wrapping_mul(2654435761) >> 32) as f64) / (u32::MAX as f64)
192            }
193        };
194
195        self.quantifier_scores.insert(quantifier.term, score);
196        score
197    }
198
199    /// Calculate priority-based score
200    fn priority_based_score(&self, quantifier: &QuantifiedFormula, manager: &TermManager) -> f64 {
201        // Combine multiple factors
202        let weight_factor = quantifier.weight;
203        let inst_factor = 1.0 / (1.0 + quantifier.instantiation_count as f64);
204        let depth_factor = 1.0 / (1.0 + quantifier.nesting_depth as f64);
205        let complexity_factor = 1.0 / (1.0 + self.body_complexity(quantifier.body, manager) as f64);
206
207        weight_factor * inst_factor * depth_factor * complexity_factor
208    }
209
210    /// Calculate constraint score (higher = more constrained)
211    fn constraint_score(
212        &self,
213        quantifier: &QuantifiedFormula,
214        model: &CompletedModel,
215        manager: &TermManager,
216    ) -> f64 {
217        let mut score = 0.0;
218
219        // Count available candidates for each variable
220        for &(_name, sort) in &quantifier.bound_vars {
221            let num_candidates = model.universe(sort).map_or(0, |u| u.len());
222            if num_candidates > 0 {
223                score += 1.0 / num_candidates as f64;
224            } else {
225                score += 1.0;
226            }
227        }
228
229        // Add complexity penalty
230        let complexity = self.body_complexity(quantifier.body, manager);
231        score += complexity as f64 * 0.1;
232
233        score
234    }
235
236    /// Calculate body complexity
237    fn body_complexity(&self, term: TermId, manager: &TermManager) -> usize {
238        let mut visited = FxHashSet::default();
239        self.body_complexity_rec(term, manager, &mut visited)
240    }
241
242    fn body_complexity_rec(
243        &self,
244        term: TermId,
245        manager: &TermManager,
246        visited: &mut FxHashSet<TermId>,
247    ) -> usize {
248        if visited.contains(&term) {
249            return 0;
250        }
251        visited.insert(term);
252
253        let Some(t) = manager.get(term) else {
254            return 1;
255        };
256
257        let children_complexity = match &t.kind {
258            TermKind::And(args) | TermKind::Or(args) => args
259                .iter()
260                .map(|&arg| self.body_complexity_rec(arg, manager, visited))
261                .sum(),
262            TermKind::Not(arg) | TermKind::Neg(arg) => {
263                self.body_complexity_rec(*arg, manager, visited)
264            }
265            TermKind::Eq(lhs, rhs)
266            | TermKind::Lt(lhs, rhs)
267            | TermKind::Le(lhs, rhs)
268            | TermKind::Gt(lhs, rhs)
269            | TermKind::Ge(lhs, rhs) => {
270                self.body_complexity_rec(*lhs, manager, visited)
271                    + self.body_complexity_rec(*rhs, manager, visited)
272            }
273            TermKind::Apply { args, .. } => args
274                .iter()
275                .map(|&arg| self.body_complexity_rec(arg, manager, visited))
276                .sum(),
277            _ => 0,
278        };
279
280        1 + children_complexity
281    }
282
283    /// Select patterns for a quantifier
284    pub fn select_patterns(
285        &self,
286        quantifier: &QuantifiedFormula,
287        manager: &TermManager,
288    ) -> Vec<Vec<TermId>> {
289        match self.config.trigger_selection {
290            TriggerSelection::All => quantifier.patterns.clone(),
291            TriggerSelection::MinVars => self.select_min_vars_patterns(quantifier, manager),
292            TriggerSelection::MinPatterns => self.select_min_patterns(quantifier),
293            TriggerSelection::MaxCoverage => self.select_max_coverage_patterns(quantifier, manager),
294            TriggerSelection::MatchingLoopAvoidance => {
295                self.select_loop_avoiding_patterns(quantifier, manager)
296            }
297            TriggerSelection::UserOnly => quantifier.patterns.clone(),
298        }
299    }
300
301    fn select_min_vars_patterns(
302        &self,
303        quantifier: &QuantifiedFormula,
304        manager: &TermManager,
305    ) -> Vec<Vec<TermId>> {
306        if quantifier.patterns.is_empty() {
307            return vec![];
308        }
309
310        let mut patterns_with_vars: Vec<_> = quantifier
311            .patterns
312            .iter()
313            .map(|pattern| {
314                let num_vars = self.count_vars_in_pattern(pattern, manager);
315                (pattern.clone(), num_vars)
316            })
317            .collect();
318
319        patterns_with_vars.sort_by_key(|(_, num_vars)| *num_vars);
320
321        vec![patterns_with_vars[0].0.clone()]
322    }
323
324    fn select_min_patterns(&self, quantifier: &QuantifiedFormula) -> Vec<Vec<TermId>> {
325        if quantifier.patterns.is_empty() {
326            return vec![];
327        }
328
329        // Select the pattern with fewest terms
330        let min_pattern = quantifier
331            .patterns
332            .iter()
333            .min_by_key(|pattern| pattern.len())
334            .cloned();
335
336        min_pattern.map_or_else(Vec::new, |p| vec![p])
337    }
338
339    fn select_max_coverage_patterns(
340        &self,
341        quantifier: &QuantifiedFormula,
342        manager: &TermManager,
343    ) -> Vec<Vec<TermId>> {
344        // Select patterns that together cover all variables
345        let mut selected = Vec::new();
346        let mut covered_vars: FxHashSet<Spur> = FxHashSet::default();
347
348        for pattern in &quantifier.patterns {
349            let pattern_vars = self.collect_vars_in_pattern(pattern, manager);
350            let new_vars: FxHashSet<_> = pattern_vars.difference(&covered_vars).copied().collect();
351
352            if !new_vars.is_empty() {
353                selected.push(pattern.clone());
354                covered_vars.extend(new_vars);
355            }
356
357            if covered_vars.len() >= quantifier.num_vars() {
358                break;
359            }
360        }
361
362        selected
363    }
364
365    fn select_loop_avoiding_patterns(
366        &self,
367        quantifier: &QuantifiedFormula,
368        manager: &TermManager,
369    ) -> Vec<Vec<TermId>> {
370        // Avoid patterns that contain the quantified function symbol
371        quantifier
372            .patterns
373            .iter()
374            .filter(|pattern| !self.contains_quantified_symbol(pattern, quantifier, manager))
375            .cloned()
376            .collect()
377    }
378
379    fn count_vars_in_pattern(&self, pattern: &[TermId], manager: &TermManager) -> usize {
380        self.collect_vars_in_pattern(pattern, manager).len()
381    }
382
383    fn collect_vars_in_pattern(
384        &self,
385        pattern: &[TermId],
386        manager: &TermManager,
387    ) -> FxHashSet<Spur> {
388        let mut vars = FxHashSet::default();
389        let mut visited = FxHashSet::default();
390
391        for &term in pattern {
392            self.collect_vars_rec(term, &mut vars, &mut visited, manager);
393        }
394
395        vars
396    }
397
398    fn collect_vars_rec(
399        &self,
400        term: TermId,
401        vars: &mut FxHashSet<Spur>,
402        visited: &mut FxHashSet<TermId>,
403        manager: &TermManager,
404    ) {
405        if visited.contains(&term) {
406            return;
407        }
408        visited.insert(term);
409
410        let Some(t) = manager.get(term) else {
411            return;
412        };
413
414        if let TermKind::Var(name) = t.kind {
415            vars.insert(name);
416            return;
417        }
418
419        match &t.kind {
420            TermKind::Apply { args, .. } => {
421                for &arg in args.iter() {
422                    self.collect_vars_rec(arg, vars, visited, manager);
423                }
424            }
425            TermKind::Not(arg) | TermKind::Neg(arg) => {
426                self.collect_vars_rec(*arg, vars, visited, manager);
427            }
428            _ => {}
429        }
430    }
431
432    fn contains_quantified_symbol(
433        &self,
434        pattern: &[TermId],
435        _quantifier: &QuantifiedFormula,
436        manager: &TermManager,
437    ) -> bool {
438        for &term in pattern {
439            if self.is_function_application(term, manager) {
440                return true;
441            }
442        }
443        false
444    }
445
446    fn is_function_application(&self, term: TermId, manager: &TermManager) -> bool {
447        let Some(t) = manager.get(term) else {
448            return false;
449        };
450        matches!(t.kind, TermKind::Apply { .. })
451    }
452
453    /// Record success or failure of an instantiation
454    pub fn record_result(&mut self, quantifier: TermId, success: bool) {
455        let entry = self
456            .success_history
457            .entry(quantifier)
458            .or_insert_with(SuccessRate::new);
459        entry.record(success);
460    }
461
462    /// Get success rate for a quantifier
463    pub fn success_rate(&self, quantifier: TermId) -> f64 {
464        self.success_history
465            .get(&quantifier)
466            .map_or(0.5, |sr| sr.rate())
467    }
468}
469
470/// Success rate tracker
471#[derive(Debug, Clone)]
472struct SuccessRate {
473    successes: usize,
474    failures: usize,
475}
476
477impl SuccessRate {
478    fn new() -> Self {
479        Self {
480            successes: 0,
481            failures: 0,
482        }
483    }
484
485    fn record(&mut self, success: bool) {
486        if success {
487            self.successes += 1;
488        } else {
489            self.failures += 1;
490        }
491    }
492
493    fn rate(&self) -> f64 {
494        let total = self.successes + self.failures;
495        if total == 0 {
496            0.5
497        } else {
498            self.successes as f64 / total as f64
499        }
500    }
501}
502
503/// Priority queue for quantifiers
504#[derive(Debug)]
505pub struct QuantifierQueue {
506    /// Heap of scored quantifiers
507    heap: BinaryHeap<ScoredQuantifier>,
508    /// Heuristic for scoring
509    heuristic: InstantiationHeuristic,
510}
511
512impl QuantifierQueue {
513    /// Create a new queue
514    pub fn new(heuristic: InstantiationHeuristic) -> Self {
515        Self {
516            heap: BinaryHeap::new(),
517            heuristic,
518        }
519    }
520
521    /// Add a quantifier to the queue
522    pub fn push(
523        &mut self,
524        quantifier: QuantifiedFormula,
525        model: &CompletedModel,
526        manager: &TermManager,
527    ) {
528        let score = self
529            .heuristic
530            .calculate_priority(&quantifier, model, manager);
531        self.heap.push(ScoredQuantifier { quantifier, score });
532    }
533
534    /// Pop the highest priority quantifier
535    pub fn pop(&mut self) -> Option<QuantifiedFormula> {
536        self.heap.pop().map(|sq| sq.quantifier)
537    }
538
539    /// Check if empty
540    pub fn is_empty(&self) -> bool {
541        self.heap.is_empty()
542    }
543
544    /// Get length
545    pub fn len(&self) -> usize {
546        self.heap.len()
547    }
548}
549
550/// Scored quantifier for priority queue
551#[derive(Debug, Clone)]
552struct ScoredQuantifier {
553    quantifier: QuantifiedFormula,
554    score: f64,
555}
556
557impl PartialEq for ScoredQuantifier {
558    fn eq(&self, other: &Self) -> bool {
559        self.score == other.score
560    }
561}
562
563impl Eq for ScoredQuantifier {}
564
565impl PartialOrd for ScoredQuantifier {
566    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
567        Some(self.cmp(other))
568    }
569}
570
571impl Ord for ScoredQuantifier {
572    fn cmp(&self, other: &Self) -> Ordering {
573        // Higher score = higher priority (max-heap)
574        self.score
575            .partial_cmp(&other.score)
576            .unwrap_or(Ordering::Equal)
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583
584    #[test]
585    fn test_mbqi_heuristics_creation() {
586        let heuristics = MBQIHeuristics::new();
587        assert!(heuristics.enable_conflict_analysis);
588    }
589
590    #[test]
591    fn test_conservative_heuristics() {
592        let heuristics = MBQIHeuristics::conservative();
593        assert_eq!(
594            heuristics.quantifier_selection,
595            SelectionStrategy::MostConstrained
596        );
597    }
598
599    #[test]
600    fn test_aggressive_heuristics() {
601        let heuristics = MBQIHeuristics::aggressive();
602        assert_eq!(
603            heuristics.quantifier_selection,
604            SelectionStrategy::BreadthFirst
605        );
606    }
607
608    #[test]
609    fn test_instantiation_heuristic_creation() {
610        let config = MBQIHeuristics::new();
611        let heuristic = InstantiationHeuristic::new(config);
612        assert_eq!(heuristic.quantifier_scores.len(), 0);
613    }
614
615    #[test]
616    fn test_success_rate_tracker() {
617        let mut sr = SuccessRate::new();
618        assert_eq!(sr.rate(), 0.5);
619
620        sr.record(true);
621        assert_eq!(sr.rate(), 1.0);
622
623        sr.record(false);
624        assert_eq!(sr.rate(), 0.5);
625    }
626
627    #[test]
628    fn test_quantifier_queue_creation() {
629        let config = MBQIHeuristics::new();
630        let heuristic = InstantiationHeuristic::new(config);
631        let queue = QuantifierQueue::new(heuristic);
632        assert!(queue.is_empty());
633    }
634
635    #[test]
636    fn test_selection_strategy_equality() {
637        assert_eq!(SelectionStrategy::Sequential, SelectionStrategy::Sequential);
638        assert_ne!(SelectionStrategy::Sequential, SelectionStrategy::Random);
639    }
640
641    #[test]
642    fn test_trigger_selection_equality() {
643        assert_eq!(TriggerSelection::All, TriggerSelection::All);
644        assert_ne!(TriggerSelection::All, TriggerSelection::MinVars);
645    }
646
647    #[test]
648    fn test_resource_allocation_equality() {
649        assert_eq!(ResourceAllocation::Balanced, ResourceAllocation::Balanced);
650        assert_ne!(ResourceAllocation::Balanced, ResourceAllocation::Aggressive);
651    }
652}