Skip to main content

oxiz_solver/mbqi/
heuristics.rs

1//! MBQI Heuristics and Selection Strategies
2//!
3//! This module implements various heuristics for guiding MBQI, including:
4//! - Quantifier selection strategies
5//! - Trigger/pattern selection
6//! - Instantiation ordering
7//! - Resource allocation
8
9#[allow(unused_imports)]
10use crate::prelude::*;
11use core::cmp::Ordering;
12use oxiz_core::ast::{TermId, TermKind, TermManager};
13use oxiz_core::interner::Spur;
14
15use super::model_completion::CompletedModel;
16use super::{ConflictScores, QuantifiedFormula, QuantifierId};
17
18/// Overall MBQI heuristics configuration
19#[derive(Debug, Clone)]
20pub struct MBQIHeuristics {
21    /// Quantifier selection strategy
22    pub quantifier_selection: SelectionStrategy,
23    /// Trigger selection strategy
24    pub trigger_selection: TriggerSelection,
25    /// Instantiation ordering
26    pub instantiation_ordering: InstantiationOrdering,
27    /// Resource allocation strategy
28    pub resource_allocation: ResourceAllocation,
29    /// Enable conflict analysis
30    pub enable_conflict_analysis: bool,
31    /// Enable model-based bounds
32    pub enable_model_bounds: bool,
33}
34
35impl MBQIHeuristics {
36    /// Create default heuristics
37    pub fn new() -> Self {
38        Self {
39            quantifier_selection: SelectionStrategy::PriorityBased,
40            trigger_selection: TriggerSelection::MatchingLoopAvoidance,
41            instantiation_ordering: InstantiationOrdering::CostBased,
42            resource_allocation: ResourceAllocation::Balanced,
43            enable_conflict_analysis: true,
44            enable_model_bounds: true,
45        }
46    }
47
48    /// Create conservative heuristics (fewer instantiations)
49    pub fn conservative() -> Self {
50        Self {
51            quantifier_selection: SelectionStrategy::MostConstrained,
52            trigger_selection: TriggerSelection::MinPatterns,
53            instantiation_ordering: InstantiationOrdering::DepthFirst,
54            resource_allocation: ResourceAllocation::Conservative,
55            enable_conflict_analysis: true,
56            enable_model_bounds: true,
57        }
58    }
59
60    /// Create aggressive heuristics (more instantiations)
61    pub fn aggressive() -> Self {
62        Self {
63            quantifier_selection: SelectionStrategy::BreadthFirst,
64            trigger_selection: TriggerSelection::MaxCoverage,
65            instantiation_ordering: InstantiationOrdering::BreadthFirst,
66            resource_allocation: ResourceAllocation::Aggressive,
67            enable_conflict_analysis: false,
68            enable_model_bounds: false,
69        }
70    }
71}
72
73impl Default for MBQIHeuristics {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79/// Strategy for selecting which quantifiers to instantiate
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum SelectionStrategy {
82    /// Select in order of definition
83    Sequential,
84    /// Select based on priority scores
85    PriorityBased,
86    /// Breadth-first (rotate through quantifiers)
87    BreadthFirst,
88    /// Depth-first (exhaust one before moving to next)
89    DepthFirst,
90    /// Most constrained first
91    MostConstrained,
92    /// Least constrained first
93    LeastConstrained,
94    /// Random selection
95    Random,
96}
97
98/// Strategy for trigger/pattern selection
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum TriggerSelection {
101    /// Use all available patterns
102    All,
103    /// Use patterns with minimum variables
104    MinVars,
105    /// Use patterns with minimum terms
106    MinPatterns,
107    /// Maximize coverage of quantifier body
108    MaxCoverage,
109    /// Avoid patterns that cause matching loops
110    MatchingLoopAvoidance,
111    /// User-specified patterns only
112    UserOnly,
113}
114
115/// Ordering for generated instantiations
116#[derive(Debug, Clone, Copy, PartialEq, Eq)]
117pub enum InstantiationOrdering {
118    /// Order by estimated cost
119    CostBased,
120    /// Depth-first
121    DepthFirst,
122    /// Breadth-first
123    BreadthFirst,
124    /// Prefer simpler instantiations
125    SimplestFirst,
126    /// Prefer instantiations with more ground terms
127    GroundnessFirst,
128}
129
130/// Resource allocation strategy
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum ResourceAllocation {
133    /// Conservative (few instantiations)
134    Conservative,
135    /// Balanced
136    Balanced,
137    /// Aggressive (many instantiations)
138    Aggressive,
139    /// Adaptive (adjust based on progress)
140    Adaptive,
141}
142
143/// Budget for MBQI instantiations.
144#[derive(Debug, Clone)]
145pub struct MBQIBudget {
146    /// Total budget available for a round.
147    pub global_budget: u32,
148    /// Per-quantifier slices of the total budget.
149    pub per_quantifier: FxHashMap<QuantifierId, u32>,
150    /// Remaining global budget after consumption.
151    pub remaining_global: u32,
152}
153
154impl MBQIBudget {
155    /// Create a fresh budget.
156    pub fn new(global_budget: u32) -> Self {
157        Self {
158            global_budget,
159            per_quantifier: FxHashMap::default(),
160            remaining_global: global_budget,
161        }
162    }
163
164    /// Distribute the remaining budget across quantifiers, weighted by conflict scores.
165    pub fn carve_per_quantifier(
166        &mut self,
167        quantifiers: &[QuantifierId],
168        conflict_scores: Option<&ConflictScores>,
169    ) {
170        self.per_quantifier.clear();
171        self.remaining_global = self.global_budget;
172
173        if quantifiers.is_empty() || self.global_budget == 0 {
174            return;
175        }
176
177        let total_weight: u64 = quantifiers
178            .iter()
179            .map(|qid| {
180                conflict_scores
181                    .and_then(|scores| scores.score(*qid))
182                    .map_or(1_u64, |score| score as u64 + 1)
183            })
184            .sum();
185
186        let mut assigned = 0_u32;
187        for (index, &qid) in quantifiers.iter().enumerate() {
188            let weight = conflict_scores
189                .and_then(|scores| scores.score(qid))
190                .map_or(1_u64, |score| score as u64 + 1);
191            let mut share = ((self.global_budget as u64 * weight) / total_weight) as u32;
192            if share == 0 {
193                share = 1;
194            }
195            if index + 1 == quantifiers.len() {
196                share = self.global_budget.saturating_sub(assigned);
197            }
198            assigned = assigned.saturating_add(share);
199            self.per_quantifier.insert(qid, share);
200        }
201    }
202
203    /// Consume part of the budget for one quantifier.
204    pub fn consume(&mut self, qid: QuantifierId, amount: u32) -> bool {
205        if amount == 0 {
206            return true;
207        }
208        let Some(remaining_for_q) = self.per_quantifier.get_mut(&qid) else {
209            return false;
210        };
211        if *remaining_for_q < amount || self.remaining_global < amount {
212            return false;
213        }
214        *remaining_for_q -= amount;
215        self.remaining_global -= amount;
216        true
217    }
218}
219
220/// Instantiation heuristic scorer
221#[derive(Debug)]
222pub struct InstantiationHeuristic {
223    /// Heuristics configuration
224    config: MBQIHeuristics,
225    /// Quantifier scores
226    quantifier_scores: FxHashMap<TermId, f64>,
227    /// Pattern quality scores
228    pattern_scores: FxHashMap<TermId, f64>,
229    /// Historical success rates
230    success_history: FxHashMap<TermId, SuccessRate>,
231}
232
233impl InstantiationHeuristic {
234    /// Create a new heuristic
235    pub fn new(config: MBQIHeuristics) -> Self {
236        Self {
237            config,
238            quantifier_scores: FxHashMap::default(),
239            pattern_scores: FxHashMap::default(),
240            success_history: FxHashMap::default(),
241        }
242    }
243
244    /// Calculate priority score for a quantifier
245    pub fn calculate_priority(
246        &mut self,
247        quantifier: &QuantifiedFormula,
248        model: &CompletedModel,
249        manager: &TermManager,
250    ) -> f64 {
251        // Check cache
252        if let Some(&cached) = self.quantifier_scores.get(&quantifier.term) {
253            return cached;
254        }
255
256        let score = match self.config.quantifier_selection {
257            SelectionStrategy::Sequential => 1.0,
258            SelectionStrategy::PriorityBased => self.priority_based_score(quantifier, manager),
259            SelectionStrategy::BreadthFirst => 1.0 / (1.0 + quantifier.instantiation_count as f64),
260            SelectionStrategy::DepthFirst => quantifier.instantiation_count as f64,
261            SelectionStrategy::MostConstrained => self.constraint_score(quantifier, model, manager),
262            SelectionStrategy::LeastConstrained => {
263                -self.constraint_score(quantifier, model, manager)
264            }
265            SelectionStrategy::Random => {
266                // Use deterministic pseudo-random based on term ID
267                let hash = quantifier.term.raw() as u64;
268                ((hash.wrapping_mul(2654435761) >> 32) as f64) / (u32::MAX as f64)
269            }
270        };
271
272        self.quantifier_scores.insert(quantifier.term, score);
273        score
274    }
275
276    /// Calculate priority-based score
277    fn priority_based_score(&self, quantifier: &QuantifiedFormula, manager: &TermManager) -> f64 {
278        // Combine multiple factors
279        let weight_factor = quantifier.weight;
280        let inst_factor = 1.0 / (1.0 + quantifier.instantiation_count as f64);
281        let depth_factor = 1.0 / (1.0 + quantifier.nesting_depth as f64);
282        let complexity_factor = 1.0 / (1.0 + self.body_complexity(quantifier.body, manager) as f64);
283
284        weight_factor * inst_factor * depth_factor * complexity_factor
285    }
286
287    /// Calculate constraint score (higher = more constrained)
288    fn constraint_score(
289        &self,
290        quantifier: &QuantifiedFormula,
291        model: &CompletedModel,
292        manager: &TermManager,
293    ) -> f64 {
294        let mut score = 0.0;
295
296        // Count available candidates for each variable
297        for &(_name, sort) in &quantifier.bound_vars {
298            let num_candidates = model.universe(sort).map_or(0, |u| u.len());
299            if num_candidates > 0 {
300                score += 1.0 / num_candidates as f64;
301            } else {
302                score += 1.0;
303            }
304        }
305
306        // Add complexity penalty
307        let complexity = self.body_complexity(quantifier.body, manager);
308        score += complexity as f64 * 0.1;
309
310        score
311    }
312
313    /// Calculate body complexity
314    fn body_complexity(&self, term: TermId, manager: &TermManager) -> usize {
315        let mut visited = FxHashSet::default();
316        self.body_complexity_rec(term, manager, &mut visited)
317    }
318
319    fn body_complexity_rec(
320        &self,
321        term: TermId,
322        manager: &TermManager,
323        visited: &mut FxHashSet<TermId>,
324    ) -> usize {
325        if visited.contains(&term) {
326            return 0;
327        }
328        visited.insert(term);
329
330        let Some(t) = manager.get(term) else {
331            return 1;
332        };
333
334        let children_complexity = match &t.kind {
335            TermKind::And(args) | TermKind::Or(args) => args
336                .iter()
337                .map(|&arg| self.body_complexity_rec(arg, manager, visited))
338                .sum(),
339            TermKind::Not(arg) | TermKind::Neg(arg) => {
340                self.body_complexity_rec(*arg, manager, visited)
341            }
342            TermKind::Eq(lhs, rhs)
343            | TermKind::Lt(lhs, rhs)
344            | TermKind::Le(lhs, rhs)
345            | TermKind::Gt(lhs, rhs)
346            | TermKind::Ge(lhs, rhs) => {
347                self.body_complexity_rec(*lhs, manager, visited)
348                    + self.body_complexity_rec(*rhs, manager, visited)
349            }
350            TermKind::Apply { args, .. } => args
351                .iter()
352                .map(|&arg| self.body_complexity_rec(arg, manager, visited))
353                .sum(),
354            _ => 0,
355        };
356
357        1 + children_complexity
358    }
359
360    /// Select patterns for a quantifier
361    pub fn select_patterns(
362        &self,
363        quantifier: &QuantifiedFormula,
364        manager: &TermManager,
365    ) -> Vec<Vec<TermId>> {
366        match self.config.trigger_selection {
367            TriggerSelection::All => quantifier.patterns.clone(),
368            TriggerSelection::MinVars => self.select_min_vars_patterns(quantifier, manager),
369            TriggerSelection::MinPatterns => self.select_min_patterns(quantifier),
370            TriggerSelection::MaxCoverage => self.select_max_coverage_patterns(quantifier, manager),
371            TriggerSelection::MatchingLoopAvoidance => {
372                self.select_loop_avoiding_patterns(quantifier, manager)
373            }
374            TriggerSelection::UserOnly => quantifier.patterns.clone(),
375        }
376    }
377
378    fn select_min_vars_patterns(
379        &self,
380        quantifier: &QuantifiedFormula,
381        manager: &TermManager,
382    ) -> Vec<Vec<TermId>> {
383        if quantifier.patterns.is_empty() {
384            return vec![];
385        }
386
387        let mut patterns_with_vars: Vec<_> = quantifier
388            .patterns
389            .iter()
390            .map(|pattern| {
391                let num_vars = self.count_vars_in_pattern(pattern, manager);
392                (pattern.clone(), num_vars)
393            })
394            .collect();
395
396        patterns_with_vars.sort_by_key(|(_, num_vars)| *num_vars);
397
398        vec![patterns_with_vars[0].0.clone()]
399    }
400
401    fn select_min_patterns(&self, quantifier: &QuantifiedFormula) -> Vec<Vec<TermId>> {
402        if quantifier.patterns.is_empty() {
403            return vec![];
404        }
405
406        // Select the pattern with fewest terms
407        let min_pattern = quantifier
408            .patterns
409            .iter()
410            .min_by_key(|pattern| pattern.len())
411            .cloned();
412
413        min_pattern.map_or_else(Vec::new, |p| vec![p])
414    }
415
416    fn select_max_coverage_patterns(
417        &self,
418        quantifier: &QuantifiedFormula,
419        manager: &TermManager,
420    ) -> Vec<Vec<TermId>> {
421        // Select patterns that together cover all variables
422        let mut selected = Vec::new();
423        let mut covered_vars: FxHashSet<Spur> = FxHashSet::default();
424
425        for pattern in &quantifier.patterns {
426            let pattern_vars = self.collect_vars_in_pattern(pattern, manager);
427            let new_vars: FxHashSet<_> = pattern_vars.difference(&covered_vars).copied().collect();
428
429            if !new_vars.is_empty() {
430                selected.push(pattern.clone());
431                covered_vars.extend(new_vars);
432            }
433
434            if covered_vars.len() >= quantifier.num_vars() {
435                break;
436            }
437        }
438
439        selected
440    }
441
442    fn select_loop_avoiding_patterns(
443        &self,
444        quantifier: &QuantifiedFormula,
445        manager: &TermManager,
446    ) -> Vec<Vec<TermId>> {
447        // Avoid patterns that contain the quantified function symbol
448        quantifier
449            .patterns
450            .iter()
451            .filter(|pattern| !self.contains_quantified_symbol(pattern, quantifier, manager))
452            .cloned()
453            .collect()
454    }
455
456    fn count_vars_in_pattern(&self, pattern: &[TermId], manager: &TermManager) -> usize {
457        self.collect_vars_in_pattern(pattern, manager).len()
458    }
459
460    fn collect_vars_in_pattern(
461        &self,
462        pattern: &[TermId],
463        manager: &TermManager,
464    ) -> FxHashSet<Spur> {
465        let mut vars = FxHashSet::default();
466        let mut visited = FxHashSet::default();
467
468        for &term in pattern {
469            self.collect_vars_rec(term, &mut vars, &mut visited, manager);
470        }
471
472        vars
473    }
474
475    fn collect_vars_rec(
476        &self,
477        term: TermId,
478        vars: &mut FxHashSet<Spur>,
479        visited: &mut FxHashSet<TermId>,
480        manager: &TermManager,
481    ) {
482        if visited.contains(&term) {
483            return;
484        }
485        visited.insert(term);
486
487        let Some(t) = manager.get(term) else {
488            return;
489        };
490
491        if let TermKind::Var(name) = t.kind {
492            vars.insert(name);
493            return;
494        }
495
496        match &t.kind {
497            TermKind::Apply { args, .. } => {
498                for &arg in args.iter() {
499                    self.collect_vars_rec(arg, vars, visited, manager);
500                }
501            }
502            TermKind::Not(arg) | TermKind::Neg(arg) => {
503                self.collect_vars_rec(*arg, vars, visited, manager);
504            }
505            _ => {}
506        }
507    }
508
509    fn contains_quantified_symbol(
510        &self,
511        pattern: &[TermId],
512        _quantifier: &QuantifiedFormula,
513        manager: &TermManager,
514    ) -> bool {
515        for &term in pattern {
516            if self.is_function_application(term, manager) {
517                return true;
518            }
519        }
520        false
521    }
522
523    fn is_function_application(&self, term: TermId, manager: &TermManager) -> bool {
524        let Some(t) = manager.get(term) else {
525            return false;
526        };
527        matches!(t.kind, TermKind::Apply { .. })
528    }
529
530    /// Record success or failure of an instantiation
531    pub fn record_result(&mut self, quantifier: TermId, success: bool) {
532        let entry = self
533            .success_history
534            .entry(quantifier)
535            .or_insert_with(SuccessRate::new);
536        entry.record(success);
537    }
538
539    /// Get success rate for a quantifier
540    pub fn success_rate(&self, quantifier: TermId) -> f64 {
541        self.success_history
542            .get(&quantifier)
543            .map_or(0.5, |sr| sr.rate())
544    }
545}
546
547/// Success rate tracker
548#[derive(Debug, Clone)]
549struct SuccessRate {
550    successes: usize,
551    failures: usize,
552}
553
554impl SuccessRate {
555    fn new() -> Self {
556        Self {
557            successes: 0,
558            failures: 0,
559        }
560    }
561
562    fn record(&mut self, success: bool) {
563        if success {
564            self.successes += 1;
565        } else {
566            self.failures += 1;
567        }
568    }
569
570    fn rate(&self) -> f64 {
571        let total = self.successes + self.failures;
572        if total == 0 {
573            0.5
574        } else {
575            self.successes as f64 / total as f64
576        }
577    }
578}
579
580/// Priority queue for quantifiers
581#[derive(Debug)]
582pub struct QuantifierQueue {
583    /// Heap of scored quantifiers
584    heap: BinaryHeap<ScoredQuantifier>,
585    /// Heuristic for scoring
586    heuristic: InstantiationHeuristic,
587}
588
589impl QuantifierQueue {
590    /// Create a new queue
591    pub fn new(heuristic: InstantiationHeuristic) -> Self {
592        Self {
593            heap: BinaryHeap::new(),
594            heuristic,
595        }
596    }
597
598    /// Add a quantifier to the queue
599    pub fn push(
600        &mut self,
601        quantifier: QuantifiedFormula,
602        model: &CompletedModel,
603        manager: &TermManager,
604    ) {
605        let score = self
606            .heuristic
607            .calculate_priority(&quantifier, model, manager);
608        self.heap.push(ScoredQuantifier { quantifier, score });
609    }
610
611    /// Pop the highest priority quantifier
612    pub fn pop(&mut self) -> Option<QuantifiedFormula> {
613        self.heap.pop().map(|sq| sq.quantifier)
614    }
615
616    /// Check if empty
617    pub fn is_empty(&self) -> bool {
618        self.heap.is_empty()
619    }
620
621    /// Get length
622    pub fn len(&self) -> usize {
623        self.heap.len()
624    }
625}
626
627/// Scored quantifier for priority queue
628#[derive(Debug, Clone)]
629struct ScoredQuantifier {
630    quantifier: QuantifiedFormula,
631    score: f64,
632}
633
634impl PartialEq for ScoredQuantifier {
635    fn eq(&self, other: &Self) -> bool {
636        self.score == other.score
637    }
638}
639
640impl Eq for ScoredQuantifier {}
641
642impl PartialOrd for ScoredQuantifier {
643    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
644        Some(self.cmp(other))
645    }
646}
647
648impl Ord for ScoredQuantifier {
649    fn cmp(&self, other: &Self) -> Ordering {
650        // Higher score = higher priority (max-heap)
651        self.score
652            .partial_cmp(&other.score)
653            .unwrap_or(Ordering::Equal)
654    }
655}
656
657/// Policy for multi-trigger scoring
658#[derive(Debug, Clone, Copy, PartialEq, Eq)]
659pub enum ScorerPolicy {
660    /// Conservative: current behavior, prefer simpler trigger sets
661    Conservative,
662    /// Ranked: score by depth, shared variables, and ground-term reachability
663    Ranked,
664}
665
666/// A trigger set (a multi-pattern: a list of terms that together cover all variables)
667#[derive(Debug, Clone)]
668pub struct TriggerSet {
669    /// The pattern terms forming this trigger set
670    pub terms: Vec<TermId>,
671    /// Syntactic depth of the deepest term in this set
672    pub max_depth: usize,
673    /// Number of distinct variables shared across terms
674    pub shared_var_count: usize,
675}
676
677impl TriggerSet {
678    /// Create a new trigger set from pattern terms
679    pub fn new(terms: Vec<TermId>) -> Self {
680        Self {
681            terms,
682            max_depth: 0,
683            shared_var_count: 0,
684        }
685    }
686
687    /// Create with precomputed metrics
688    pub fn with_metrics(terms: Vec<TermId>, max_depth: usize, shared_var_count: usize) -> Self {
689        Self {
690            terms,
691            max_depth,
692            shared_var_count,
693        }
694    }
695}
696
697/// A trigger set annotated with a ranking score
698#[derive(Debug, Clone)]
699pub struct ScoredTriggerSet {
700    /// The trigger set
701    pub triggers: TriggerSet,
702    /// Score (higher is better / higher priority)
703    pub score: f64,
704}
705
706/// Multi-trigger scorer: ranks candidate trigger sets
707#[derive(Debug, Clone)]
708pub struct MultiTriggerScorer {
709    /// Scoring policy
710    pub policy: ScorerPolicy,
711    /// Number of top candidates to return
712    pub top_k: usize,
713}
714
715impl MultiTriggerScorer {
716    /// Create a new scorer
717    pub fn new(policy: ScorerPolicy, top_k: usize) -> Self {
718        Self { policy, top_k }
719    }
720
721    /// Score a collection of trigger-set candidates and return the top-k.
722    ///
723    /// Scoring criteria (Ranked policy):
724    ///   (a) syntactic depth: deeper terms get a lower score
725    ///   (b) shared variable count: more shared variables get a higher score
726    ///   (c) ground-term reachability: trigger sets whose terms appear in
727    ///       the equality graph get a bonus
728    pub fn score_trigger_sets(
729        &self,
730        candidates: &[TriggerSet],
731        manager: &TermManager,
732    ) -> Vec<ScoredTriggerSet> {
733        if candidates.is_empty() {
734            return Vec::new();
735        }
736
737        let mut scored: Vec<ScoredTriggerSet> = candidates
738            .iter()
739            .map(|ts| {
740                let score = self.compute_score(ts, manager);
741                ScoredTriggerSet {
742                    triggers: ts.clone(),
743                    score,
744                }
745            })
746            .collect();
747
748        // Sort descending by score
749        scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
750        scored.truncate(self.top_k);
751        scored
752    }
753
754    fn compute_score(&self, ts: &TriggerSet, manager: &TermManager) -> f64 {
755        match self.policy {
756            ScorerPolicy::Conservative => {
757                // Conservative: prefer smaller trigger sets
758                1.0 / (1.0 + ts.terms.len() as f64)
759            }
760            ScorerPolicy::Ranked => self.ranked_score(ts, manager),
761        }
762    }
763
764    fn ranked_score(&self, ts: &TriggerSet, manager: &TermManager) -> f64 {
765        // (a) Depth component: deeper terms are less preferred.
766        //     depth_penalty lowers score as max_depth increases.
767        let depth = if ts.max_depth > 0 {
768            ts.max_depth
769        } else {
770            // Compute depth on the fly if not pre-set
771            ts.terms
772                .iter()
773                .map(|&t| self.term_depth(t, manager, 0, 20))
774                .max()
775                .unwrap_or(0)
776        };
777        let depth_component = 1.0 / (1.0 + depth as f64);
778
779        // (b) Shared variable component: more shared variables → higher score.
780        let shared = if ts.shared_var_count > 0 {
781            ts.shared_var_count
782        } else {
783            // Compute shared vars on the fly
784            self.count_shared_vars(&ts.terms, manager)
785        };
786        let shared_component = 1.0 + shared as f64;
787
788        // (c) Ground-term reachability: terms that are ground (no free vars) are
789        //     reachable in the E-graph and provide stronger triggering.
790        let ground_count = ts
791            .terms
792            .iter()
793            .filter(|&&t| self.is_ground(t, manager))
794            .count();
795        let ground_component = 1.0 + ground_count as f64 * 0.5;
796
797        depth_component * shared_component * ground_component
798    }
799
800    /// Compute term depth (capped at `max_depth_cap` to avoid large recursion)
801    fn term_depth(&self, term: TermId, manager: &TermManager, current: usize, cap: usize) -> usize {
802        if current >= cap {
803            return current;
804        }
805        let Some(t) = manager.get(term) else {
806            return current;
807        };
808        match &t.kind {
809            TermKind::Apply { args, .. } => args
810                .iter()
811                .map(|&a| self.term_depth(a, manager, current + 1, cap))
812                .max()
813                .unwrap_or(current),
814            TermKind::Not(a) | TermKind::Neg(a) => self.term_depth(*a, manager, current + 1, cap),
815            TermKind::And(args) | TermKind::Or(args) => args
816                .iter()
817                .map(|&a| self.term_depth(a, manager, current + 1, cap))
818                .max()
819                .unwrap_or(current),
820            TermKind::Eq(l, r)
821            | TermKind::Lt(l, r)
822            | TermKind::Le(l, r)
823            | TermKind::Gt(l, r)
824            | TermKind::Ge(l, r) => {
825                let ld = self.term_depth(*l, manager, current + 1, cap);
826                let rd = self.term_depth(*r, manager, current + 1, cap);
827                ld.max(rd)
828            }
829            _ => current,
830        }
831    }
832
833    /// Count variables that appear in more than one term (shared across the trigger set)
834    fn count_shared_vars(&self, terms: &[TermId], manager: &TermManager) -> usize {
835        if terms.len() <= 1 {
836            return 0;
837        }
838
839        let var_sets: Vec<FxHashSet<Spur>> = terms
840            .iter()
841            .map(|&t| self.collect_vars(t, manager))
842            .collect();
843
844        let mut frequencies: FxHashMap<Spur, usize> = FxHashMap::default();
845        for vars in &var_sets {
846            for &var in vars {
847                *frequencies.entry(var).or_insert(0) += 1;
848            }
849        }
850
851        frequencies.values().filter(|&&count| count >= 2).count()
852    }
853
854    fn collect_vars(&self, term: TermId, manager: &TermManager) -> FxHashSet<Spur> {
855        let mut vars = FxHashSet::default();
856        let mut visited = FxHashSet::default();
857        self.collect_vars_rec(term, manager, &mut vars, &mut visited);
858        vars
859    }
860
861    fn collect_vars_rec(
862        &self,
863        term: TermId,
864        manager: &TermManager,
865        vars: &mut FxHashSet<Spur>,
866        visited: &mut FxHashSet<TermId>,
867    ) {
868        if !visited.insert(term) {
869            return;
870        }
871        let Some(t) = manager.get(term) else {
872            return;
873        };
874        if let TermKind::Var(name) = t.kind {
875            vars.insert(name);
876            return;
877        }
878        match &t.kind {
879            TermKind::Apply { args, .. } => {
880                for &a in args.iter() {
881                    self.collect_vars_rec(a, manager, vars, visited);
882                }
883            }
884            TermKind::Not(a) | TermKind::Neg(a) => {
885                self.collect_vars_rec(*a, manager, vars, visited);
886            }
887            TermKind::And(args) | TermKind::Or(args) => {
888                for &a in args {
889                    self.collect_vars_rec(a, manager, vars, visited);
890                }
891            }
892            TermKind::Eq(l, r)
893            | TermKind::Lt(l, r)
894            | TermKind::Le(l, r)
895            | TermKind::Gt(l, r)
896            | TermKind::Ge(l, r) => {
897                self.collect_vars_rec(*l, manager, vars, visited);
898                self.collect_vars_rec(*r, manager, vars, visited);
899            }
900            _ => {}
901        }
902    }
903
904    /// Return true if the term contains no free variables
905    fn is_ground(&self, term: TermId, manager: &TermManager) -> bool {
906        self.collect_vars(term, manager).is_empty()
907    }
908}
909
910#[cfg(test)]
911mod tests {
912    use super::*;
913
914    #[test]
915    fn test_mbqi_heuristics_creation() {
916        let heuristics = MBQIHeuristics::new();
917        assert!(heuristics.enable_conflict_analysis);
918    }
919
920    #[test]
921    fn test_conservative_heuristics() {
922        let heuristics = MBQIHeuristics::conservative();
923        assert_eq!(
924            heuristics.quantifier_selection,
925            SelectionStrategy::MostConstrained
926        );
927    }
928
929    #[test]
930    fn test_aggressive_heuristics() {
931        let heuristics = MBQIHeuristics::aggressive();
932        assert_eq!(
933            heuristics.quantifier_selection,
934            SelectionStrategy::BreadthFirst
935        );
936    }
937
938    #[test]
939    fn test_instantiation_heuristic_creation() {
940        let config = MBQIHeuristics::new();
941        let heuristic = InstantiationHeuristic::new(config);
942        assert_eq!(heuristic.quantifier_scores.len(), 0);
943    }
944
945    #[test]
946    fn test_success_rate_tracker() {
947        let mut sr = SuccessRate::new();
948        assert_eq!(sr.rate(), 0.5);
949
950        sr.record(true);
951        assert_eq!(sr.rate(), 1.0);
952
953        sr.record(false);
954        assert_eq!(sr.rate(), 0.5);
955    }
956
957    #[test]
958    fn test_quantifier_queue_creation() {
959        let config = MBQIHeuristics::new();
960        let heuristic = InstantiationHeuristic::new(config);
961        let queue = QuantifierQueue::new(heuristic);
962        assert!(queue.is_empty());
963    }
964
965    #[test]
966    fn test_selection_strategy_equality() {
967        assert_eq!(SelectionStrategy::Sequential, SelectionStrategy::Sequential);
968        assert_ne!(SelectionStrategy::Sequential, SelectionStrategy::Random);
969    }
970
971    #[test]
972    fn test_trigger_selection_equality() {
973        assert_eq!(TriggerSelection::All, TriggerSelection::All);
974        assert_ne!(TriggerSelection::All, TriggerSelection::MinVars);
975    }
976
977    #[test]
978    fn test_resource_allocation_equality() {
979        assert_eq!(ResourceAllocation::Balanced, ResourceAllocation::Balanced);
980        assert_ne!(ResourceAllocation::Balanced, ResourceAllocation::Aggressive);
981    }
982}