Skip to main content

oxiz_proof/
heuristic.rs

1//! Strategy heuristics learned from successful proofs.
2//!
3//! This module extracts heuristics and strategies from successful proofs
4//! to guide the solver in future problem solving.
5
6use crate::proof::{Proof, ProofNodeId, ProofStep};
7use rustc_hash::{FxHashMap, FxHashSet};
8use std::fmt;
9
10/// A heuristic learned from successful proofs.
11#[derive(Debug, Clone, PartialEq)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct ProofHeuristic {
14    /// Heuristic name
15    pub name: String,
16    /// Heuristic type
17    pub heuristic_type: HeuristicType,
18    /// Confidence score (0.0 - 1.0)
19    pub confidence: f64,
20    /// Number of proofs supporting this heuristic
21    pub support_count: usize,
22    /// Average improvement when applied
23    pub avg_improvement: f64,
24}
25
26/// Types of heuristics that can be learned.
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub enum HeuristicType {
30    /// Rule ordering preference
31    RuleOrdering { preferred_sequence: Vec<String> },
32    /// Branching strategy
33    BranchingStrategy { criteria: String },
34    /// Lemma selection
35    LemmaSelection { pattern: String },
36    /// Instantiation preference
37    InstantiationPreference { trigger_pattern: String },
38    /// Theory combination strategy
39    TheoryCombination { theory_order: Vec<String> },
40}
41
42impl fmt::Display for HeuristicType {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            HeuristicType::RuleOrdering { preferred_sequence } => {
46                write!(f, "RuleOrdering[{}]", preferred_sequence.join(" → "))
47            }
48            HeuristicType::BranchingStrategy { criteria } => {
49                write!(f, "BranchingStrategy[{}]", criteria)
50            }
51            HeuristicType::LemmaSelection { pattern } => {
52                write!(f, "LemmaSelection[{}]", pattern)
53            }
54            HeuristicType::InstantiationPreference { trigger_pattern } => {
55                write!(f, "InstantiationPreference[{}]", trigger_pattern)
56            }
57            HeuristicType::TheoryCombination { theory_order } => {
58                write!(f, "TheoryCombination[{}]", theory_order.join(" + "))
59            }
60        }
61    }
62}
63
64impl fmt::Display for ProofHeuristic {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        writeln!(f, "Heuristic: {}", self.name)?;
67        writeln!(f, "Type: {}", self.heuristic_type)?;
68        writeln!(f, "Confidence: {:.2}", self.confidence)?;
69        writeln!(f, "Support: {} proofs", self.support_count)?;
70        writeln!(f, "Avg improvement: {:.1}%", self.avg_improvement * 100.0)?;
71        Ok(())
72    }
73}
74
75/// Strategy learner for extracting heuristics from proofs.
76pub struct StrategyLearner {
77    /// Minimum support count for a heuristic
78    min_support: usize,
79    /// Minimum confidence threshold
80    min_confidence: f64,
81    /// Learned heuristics
82    heuristics: Vec<ProofHeuristic>,
83    /// Rule sequence frequency tracker
84    rule_sequences: FxHashMap<Vec<String>, usize>,
85}
86
87impl Default for StrategyLearner {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl StrategyLearner {
94    /// Create a new strategy learner with default settings.
95    pub fn new() -> Self {
96        Self {
97            min_support: 2,
98            min_confidence: 0.5,
99            heuristics: Vec::new(),
100            rule_sequences: FxHashMap::default(),
101        }
102    }
103
104    /// Set the minimum support count.
105    pub fn with_min_support(mut self, support: usize) -> Self {
106        self.min_support = support;
107        self
108    }
109
110    /// Set the minimum confidence threshold.
111    pub fn with_min_confidence(mut self, confidence: f64) -> Self {
112        self.min_confidence = confidence.clamp(0.0, 1.0);
113        self
114    }
115
116    /// Learn heuristics from a collection of successful proofs.
117    pub fn learn_from_proofs(&mut self, proofs: &[&Proof], _proof_stats: &[(f64, f64)]) {
118        // Extract rule ordering heuristics
119        self.learn_rule_ordering(proofs);
120
121        // Extract branching heuristics
122        self.learn_branching_strategies(proofs);
123
124        // Extract lemma selection heuristics
125        self.learn_lemma_selection(proofs);
126
127        // Extract instantiation heuristics
128        self.learn_instantiation_preferences(proofs);
129
130        // Extract theory combination heuristics
131        self.learn_theory_combination(proofs);
132
133        // Filter by confidence and support
134        self.heuristics
135            .retain(|h| h.confidence >= self.min_confidence && h.support_count >= self.min_support);
136
137        // Sort by confidence
138        self.heuristics.sort_by(|a, b| {
139            b.confidence
140                .partial_cmp(&a.confidence)
141                .unwrap_or(std::cmp::Ordering::Equal)
142        });
143    }
144
145    /// Get all learned heuristics.
146    pub fn get_heuristics(&self) -> &[ProofHeuristic] {
147        &self.heuristics
148    }
149
150    /// Get heuristics of a specific type.
151    pub fn get_heuristics_by_type(&self, type_name: &str) -> Vec<&ProofHeuristic> {
152        self.heuristics
153            .iter()
154            .filter(|h| {
155                matches!(
156                    (&h.heuristic_type, type_name),
157                    (HeuristicType::RuleOrdering { .. }, "rule_ordering")
158                        | (HeuristicType::BranchingStrategy { .. }, "branching")
159                        | (HeuristicType::LemmaSelection { .. }, "lemma")
160                        | (
161                            HeuristicType::InstantiationPreference { .. },
162                            "instantiation"
163                        )
164                        | (HeuristicType::TheoryCombination { .. }, "theory")
165                )
166            })
167            .collect()
168    }
169
170    /// Get top N heuristics by confidence.
171    pub fn get_top_heuristics(&self, n: usize) -> Vec<&ProofHeuristic> {
172        self.heuristics.iter().take(n).collect()
173    }
174
175    /// Clear all learned heuristics.
176    pub fn clear(&mut self) {
177        self.heuristics.clear();
178        self.rule_sequences.clear();
179    }
180
181    // Helper: Learn rule ordering preferences
182    fn learn_rule_ordering(&mut self, proofs: &[&Proof]) {
183        let mut sequence_freq: FxHashMap<Vec<String>, usize> = FxHashMap::default();
184
185        for proof in proofs {
186            let sequences = self.extract_rule_sequences(proof, 3);
187            for seq in sequences {
188                *sequence_freq.entry(seq).or_insert(0) += 1;
189            }
190        }
191
192        // Create heuristics for frequent sequences
193        for (seq, count) in sequence_freq.iter() {
194            if *count >= self.min_support {
195                let confidence = (*count as f64) / (proofs.len() as f64);
196                if confidence >= self.min_confidence {
197                    self.heuristics.push(ProofHeuristic {
198                        name: format!("rule_order_{}", seq.join("_")),
199                        heuristic_type: HeuristicType::RuleOrdering {
200                            preferred_sequence: seq.clone(),
201                        },
202                        confidence,
203                        support_count: *count,
204                        avg_improvement: 0.0,
205                    });
206                }
207            }
208        }
209    }
210
211    // Helper: Extract rule sequences from a proof
212    fn extract_rule_sequences(&self, proof: &Proof, length: usize) -> Vec<Vec<String>> {
213        let mut sequences = Vec::new();
214        let nodes: Vec<ProofNodeId> = proof.nodes().iter().map(|n| n.id).collect();
215
216        if nodes.len() < length {
217            return sequences;
218        }
219
220        for window in nodes.windows(length) {
221            let seq: Vec<String> = window
222                .iter()
223                .filter_map(|&id| {
224                    proof.get_node(id).and_then(|node| {
225                        if let ProofStep::Inference { rule, .. } = &node.step {
226                            Some(rule.clone())
227                        } else {
228                            None
229                        }
230                    })
231                })
232                .collect();
233
234            if seq.len() == length {
235                sequences.push(seq);
236            }
237        }
238
239        sequences
240    }
241
242    // Helper: Learn branching strategies
243    fn learn_branching_strategies(&mut self, proofs: &[&Proof]) {
244        let mut branching_patterns: FxHashMap<String, usize> = FxHashMap::default();
245
246        for proof in proofs {
247            // Look for nodes with multiple dependents (branching points)
248            for node in proof.nodes() {
249                let dependents = proof.get_children(node.id);
250                if dependents.len() > 1 {
251                    // This is a branching point
252                    let pattern = self.abstract_branching_pattern(node.conclusion());
253                    *branching_patterns.entry(pattern).or_insert(0) += 1;
254                }
255            }
256        }
257
258        // Create heuristics for common branching patterns
259        for (pattern, count) in branching_patterns.iter() {
260            if *count >= self.min_support {
261                let confidence = (*count as f64) / (proofs.len() as f64);
262                if confidence >= self.min_confidence {
263                    self.heuristics.push(ProofHeuristic {
264                        name: format!("branch_{}", pattern),
265                        heuristic_type: HeuristicType::BranchingStrategy {
266                            criteria: pattern.clone(),
267                        },
268                        confidence,
269                        support_count: *count,
270                        avg_improvement: 0.0,
271                    });
272                }
273            }
274        }
275    }
276
277    // Helper: Abstract branching pattern
278    fn abstract_branching_pattern(&self, conclusion: &str) -> String {
279        // Simple abstraction - in practice would be more sophisticated
280        // Check more specific patterns first
281        if conclusion.contains("forall") {
282            "universal".to_string()
283        } else if conclusion.contains("exists") {
284            "existential".to_string()
285        } else if conclusion.contains(" or ") {
286            "disjunction".to_string()
287        } else if conclusion.contains(" and ") {
288            "conjunction".to_string()
289        } else {
290            "other".to_string()
291        }
292    }
293
294    // Helper: Learn lemma selection patterns
295    fn learn_lemma_selection(&mut self, proofs: &[&Proof]) {
296        let mut lemma_patterns: FxHashMap<String, usize> = FxHashMap::default();
297
298        for proof in proofs {
299            for node in proof.nodes() {
300                if let ProofStep::Inference { rule, .. } = &node.step
301                    && (rule.contains("lemma") || rule.contains("theory"))
302                {
303                    let pattern = self.extract_lemma_pattern(node.conclusion());
304                    *lemma_patterns.entry(pattern).or_insert(0) += 1;
305                }
306            }
307        }
308
309        for (pattern, count) in lemma_patterns.iter() {
310            if *count >= self.min_support {
311                let confidence = (*count as f64) / (proofs.len() as f64);
312                if confidence >= self.min_confidence {
313                    self.heuristics.push(ProofHeuristic {
314                        name: format!("lemma_{}", pattern),
315                        heuristic_type: HeuristicType::LemmaSelection {
316                            pattern: pattern.clone(),
317                        },
318                        confidence,
319                        support_count: *count,
320                        avg_improvement: 0.0,
321                    });
322                }
323            }
324        }
325    }
326
327    // Helper: Extract lemma pattern
328    fn extract_lemma_pattern(&self, conclusion: &str) -> String {
329        // Extract the type of lemma (simplified)
330        // Check more specific patterns first
331        if conclusion.contains("congruence") {
332            "congruence".to_string()
333        } else if conclusion.contains("<=") || conclusion.contains(">=") {
334            "inequality".to_string()
335        } else if conclusion.contains("=") {
336            "equality".to_string()
337        } else {
338            "other".to_string()
339        }
340    }
341
342    // Helper: Learn instantiation preferences
343    fn learn_instantiation_preferences(&mut self, proofs: &[&Proof]) {
344        let mut instantiation_patterns: FxHashMap<String, usize> = FxHashMap::default();
345
346        for proof in proofs {
347            for node in proof.nodes() {
348                if let ProofStep::Inference { rule, .. } = &node.step
349                    && (rule.contains("instantiation") || rule.contains("forall_elim"))
350                {
351                    let pattern = self.extract_trigger_pattern(node.conclusion());
352                    *instantiation_patterns.entry(pattern).or_insert(0) += 1;
353                }
354            }
355        }
356
357        for (pattern, count) in instantiation_patterns.iter() {
358            if *count >= self.min_support {
359                let confidence = (*count as f64) / (proofs.len() as f64);
360                if confidence >= self.min_confidence {
361                    self.heuristics.push(ProofHeuristic {
362                        name: format!("inst_{}", pattern),
363                        heuristic_type: HeuristicType::InstantiationPreference {
364                            trigger_pattern: pattern.clone(),
365                        },
366                        confidence,
367                        support_count: *count,
368                        avg_improvement: 0.0,
369                    });
370                }
371            }
372        }
373    }
374
375    // Helper: Extract trigger pattern
376    fn extract_trigger_pattern(&self, conclusion: &str) -> String {
377        // Extract function applications as triggers (simplified)
378        if let Some(start) = conclusion.find('(')
379            && let Some(end) = conclusion[start..].find(')')
380        {
381            return conclusion[..start + end + 1].to_string();
382        }
383        "default".to_string()
384    }
385
386    // Helper: Learn theory combination strategies
387    fn learn_theory_combination(&mut self, proofs: &[&Proof]) {
388        let mut theory_sequences: FxHashMap<Vec<String>, usize> = FxHashMap::default();
389
390        for proof in proofs {
391            let theories = self.extract_theory_sequence(proof);
392            if !theories.is_empty() {
393                *theory_sequences.entry(theories).or_insert(0) += 1;
394            }
395        }
396
397        for (seq, count) in theory_sequences.iter() {
398            if *count >= self.min_support {
399                let confidence = (*count as f64) / (proofs.len() as f64);
400                if confidence >= self.min_confidence {
401                    self.heuristics.push(ProofHeuristic {
402                        name: format!("theory_comb_{}", seq.join("_")),
403                        heuristic_type: HeuristicType::TheoryCombination {
404                            theory_order: seq.clone(),
405                        },
406                        confidence,
407                        support_count: *count,
408                        avg_improvement: 0.0,
409                    });
410                }
411            }
412        }
413    }
414
415    // Helper: Extract theory sequence from proof
416    fn extract_theory_sequence(&self, proof: &Proof) -> Vec<String> {
417        let mut seen = FxHashSet::default();
418        let mut sequence = Vec::new();
419
420        for node in proof.nodes() {
421            if let ProofStep::Inference { rule, .. } = &node.step {
422                let theory = self.infer_theory_from_rule(rule);
423                if !theory.is_empty() && !seen.contains(&theory) {
424                    seen.insert(theory.clone());
425                    sequence.push(theory);
426                }
427            }
428        }
429
430        sequence
431    }
432
433    // Helper: Infer theory from rule name
434    fn infer_theory_from_rule(&self, rule: &str) -> String {
435        if rule.contains("arith") || rule.contains("farkas") {
436            "arithmetic".to_string()
437        } else if rule.contains("euf") || rule.contains("congruence") {
438            "euf".to_string()
439        } else if rule.contains("array") {
440            "arrays".to_string()
441        } else if rule.contains("bv") || rule.contains("bitvector") {
442            "bitvectors".to_string()
443        } else {
444            String::new()
445        }
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn test_strategy_learner_new() {
455        let learner = StrategyLearner::new();
456        assert_eq!(learner.min_support, 2);
457        assert_eq!(learner.min_confidence, 0.5);
458        assert!(learner.heuristics.is_empty());
459    }
460
461    #[test]
462    fn test_strategy_learner_with_settings() {
463        let learner = StrategyLearner::new()
464            .with_min_support(3)
465            .with_min_confidence(0.7);
466        assert_eq!(learner.min_support, 3);
467        assert_eq!(learner.min_confidence, 0.7);
468    }
469
470    #[test]
471    fn test_heuristic_type_display() {
472        let rule_ordering = HeuristicType::RuleOrdering {
473            preferred_sequence: vec!["resolution".to_string(), "unit_prop".to_string()],
474        };
475        assert_eq!(
476            rule_ordering.to_string(),
477            "RuleOrdering[resolution → unit_prop]"
478        );
479
480        let branching = HeuristicType::BranchingStrategy {
481            criteria: "disjunction".to_string(),
482        };
483        assert_eq!(branching.to_string(), "BranchingStrategy[disjunction]");
484    }
485
486    #[test]
487    fn test_proof_heuristic_display() {
488        let heuristic = ProofHeuristic {
489            name: "test_heuristic".to_string(),
490            heuristic_type: HeuristicType::RuleOrdering {
491                preferred_sequence: vec!["resolution".to_string()],
492            },
493            confidence: 0.8,
494            support_count: 10,
495            avg_improvement: 0.15,
496        };
497        let display = format!("{}", heuristic);
498        assert!(display.contains("test_heuristic"));
499        assert!(display.contains("0.80"));
500        assert!(display.contains("10 proofs"));
501    }
502
503    #[test]
504    fn test_clear_heuristics() {
505        let mut learner = StrategyLearner::new();
506        learner.heuristics.push(ProofHeuristic {
507            name: "test".to_string(),
508            heuristic_type: HeuristicType::RuleOrdering {
509                preferred_sequence: vec![],
510            },
511            confidence: 0.5,
512            support_count: 2,
513            avg_improvement: 0.0,
514        });
515        learner.clear();
516        assert!(learner.heuristics.is_empty());
517    }
518
519    #[test]
520    fn test_get_top_heuristics() {
521        let mut learner = StrategyLearner::new();
522        learner.heuristics.push(ProofHeuristic {
523            name: "h1".to_string(),
524            heuristic_type: HeuristicType::RuleOrdering {
525                preferred_sequence: vec![],
526            },
527            confidence: 0.9,
528            support_count: 2,
529            avg_improvement: 0.0,
530        });
531        learner.heuristics.push(ProofHeuristic {
532            name: "h2".to_string(),
533            heuristic_type: HeuristicType::RuleOrdering {
534                preferred_sequence: vec![],
535            },
536            confidence: 0.7,
537            support_count: 2,
538            avg_improvement: 0.0,
539        });
540        let top = learner.get_top_heuristics(1);
541        assert_eq!(top.len(), 1);
542        assert_eq!(top[0].name, "h1");
543    }
544
545    #[test]
546    fn test_abstract_branching_pattern() {
547        let learner = StrategyLearner::new();
548        assert_eq!(learner.abstract_branching_pattern("x or y"), "disjunction");
549        assert_eq!(learner.abstract_branching_pattern("x and y"), "conjunction");
550        assert_eq!(
551            learner.abstract_branching_pattern("forall x. P(x)"),
552            "universal"
553        );
554    }
555
556    #[test]
557    fn test_extract_lemma_pattern() {
558        let learner = StrategyLearner::new();
559        assert_eq!(learner.extract_lemma_pattern("x = y"), "equality");
560        assert_eq!(learner.extract_lemma_pattern("x <= y"), "inequality");
561        assert_eq!(
562            learner.extract_lemma_pattern("congruence f(x) f(y)"),
563            "congruence"
564        );
565    }
566
567    #[test]
568    fn test_infer_theory_from_rule() {
569        let learner = StrategyLearner::new();
570        assert_eq!(learner.infer_theory_from_rule("arith_lemma"), "arithmetic");
571        assert_eq!(learner.infer_theory_from_rule("euf_congruence"), "euf");
572        assert_eq!(
573            learner.infer_theory_from_rule("array_extensionality"),
574            "arrays"
575        );
576        assert_eq!(learner.infer_theory_from_rule("bv_solve"), "bitvectors");
577    }
578}