1use crate::proof::{Proof, ProofNodeId, ProofStep};
7use rustc_hash::{FxHashMap, FxHashSet};
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct ProofHeuristic {
14 pub name: String,
16 pub heuristic_type: HeuristicType,
18 pub confidence: f64,
20 pub support_count: usize,
22 pub avg_improvement: f64,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub enum HeuristicType {
30 RuleOrdering { preferred_sequence: Vec<String> },
32 BranchingStrategy { criteria: String },
34 LemmaSelection { pattern: String },
36 InstantiationPreference { trigger_pattern: String },
38 TheoryCombination { theory_order: Vec<String> },
40}
41
42impl fmt::Display for HeuristicType {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 HeuristicType::RuleOrdering { preferred_sequence } => {
46 write!(f, "RuleOrdering[{}]", preferred_sequence.join(" → "))
47 }
48 HeuristicType::BranchingStrategy { criteria } => {
49 write!(f, "BranchingStrategy[{}]", criteria)
50 }
51 HeuristicType::LemmaSelection { pattern } => {
52 write!(f, "LemmaSelection[{}]", pattern)
53 }
54 HeuristicType::InstantiationPreference { trigger_pattern } => {
55 write!(f, "InstantiationPreference[{}]", trigger_pattern)
56 }
57 HeuristicType::TheoryCombination { theory_order } => {
58 write!(f, "TheoryCombination[{}]", theory_order.join(" + "))
59 }
60 }
61 }
62}
63
64impl fmt::Display for ProofHeuristic {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 writeln!(f, "Heuristic: {}", self.name)?;
67 writeln!(f, "Type: {}", self.heuristic_type)?;
68 writeln!(f, "Confidence: {:.2}", self.confidence)?;
69 writeln!(f, "Support: {} proofs", self.support_count)?;
70 writeln!(f, "Avg improvement: {:.1}%", self.avg_improvement * 100.0)?;
71 Ok(())
72 }
73}
74
75pub struct StrategyLearner {
77 min_support: usize,
79 min_confidence: f64,
81 heuristics: Vec<ProofHeuristic>,
83 rule_sequences: FxHashMap<Vec<String>, usize>,
85}
86
87impl Default for StrategyLearner {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl StrategyLearner {
94 pub fn new() -> Self {
96 Self {
97 min_support: 2,
98 min_confidence: 0.5,
99 heuristics: Vec::new(),
100 rule_sequences: FxHashMap::default(),
101 }
102 }
103
104 pub fn with_min_support(mut self, support: usize) -> Self {
106 self.min_support = support;
107 self
108 }
109
110 pub fn with_min_confidence(mut self, confidence: f64) -> Self {
112 self.min_confidence = confidence.clamp(0.0, 1.0);
113 self
114 }
115
116 pub fn learn_from_proofs(&mut self, proofs: &[&Proof], _proof_stats: &[(f64, f64)]) {
118 self.learn_rule_ordering(proofs);
120
121 self.learn_branching_strategies(proofs);
123
124 self.learn_lemma_selection(proofs);
126
127 self.learn_instantiation_preferences(proofs);
129
130 self.learn_theory_combination(proofs);
132
133 self.heuristics
135 .retain(|h| h.confidence >= self.min_confidence && h.support_count >= self.min_support);
136
137 self.heuristics.sort_by(|a, b| {
139 b.confidence
140 .partial_cmp(&a.confidence)
141 .unwrap_or(std::cmp::Ordering::Equal)
142 });
143 }
144
145 pub fn get_heuristics(&self) -> &[ProofHeuristic] {
147 &self.heuristics
148 }
149
150 pub fn get_heuristics_by_type(&self, type_name: &str) -> Vec<&ProofHeuristic> {
152 self.heuristics
153 .iter()
154 .filter(|h| {
155 matches!(
156 (&h.heuristic_type, type_name),
157 (HeuristicType::RuleOrdering { .. }, "rule_ordering")
158 | (HeuristicType::BranchingStrategy { .. }, "branching")
159 | (HeuristicType::LemmaSelection { .. }, "lemma")
160 | (
161 HeuristicType::InstantiationPreference { .. },
162 "instantiation"
163 )
164 | (HeuristicType::TheoryCombination { .. }, "theory")
165 )
166 })
167 .collect()
168 }
169
170 pub fn get_top_heuristics(&self, n: usize) -> Vec<&ProofHeuristic> {
172 self.heuristics.iter().take(n).collect()
173 }
174
175 pub fn clear(&mut self) {
177 self.heuristics.clear();
178 self.rule_sequences.clear();
179 }
180
181 fn learn_rule_ordering(&mut self, proofs: &[&Proof]) {
183 let mut sequence_freq: FxHashMap<Vec<String>, usize> = FxHashMap::default();
184
185 for proof in proofs {
186 let sequences = self.extract_rule_sequences(proof, 3);
187 for seq in sequences {
188 *sequence_freq.entry(seq).or_insert(0) += 1;
189 }
190 }
191
192 for (seq, count) in sequence_freq.iter() {
194 if *count >= self.min_support {
195 let confidence = (*count as f64) / (proofs.len() as f64);
196 if confidence >= self.min_confidence {
197 self.heuristics.push(ProofHeuristic {
198 name: format!("rule_order_{}", seq.join("_")),
199 heuristic_type: HeuristicType::RuleOrdering {
200 preferred_sequence: seq.clone(),
201 },
202 confidence,
203 support_count: *count,
204 avg_improvement: 0.0,
205 });
206 }
207 }
208 }
209 }
210
211 fn extract_rule_sequences(&self, proof: &Proof, length: usize) -> Vec<Vec<String>> {
213 let mut sequences = Vec::new();
214 let nodes: Vec<ProofNodeId> = proof.nodes().iter().map(|n| n.id).collect();
215
216 if nodes.len() < length {
217 return sequences;
218 }
219
220 for window in nodes.windows(length) {
221 let seq: Vec<String> = window
222 .iter()
223 .filter_map(|&id| {
224 proof.get_node(id).and_then(|node| {
225 if let ProofStep::Inference { rule, .. } = &node.step {
226 Some(rule.clone())
227 } else {
228 None
229 }
230 })
231 })
232 .collect();
233
234 if seq.len() == length {
235 sequences.push(seq);
236 }
237 }
238
239 sequences
240 }
241
242 fn learn_branching_strategies(&mut self, proofs: &[&Proof]) {
244 let mut branching_patterns: FxHashMap<String, usize> = FxHashMap::default();
245
246 for proof in proofs {
247 for node in proof.nodes() {
249 let dependents = proof.get_children(node.id);
250 if dependents.len() > 1 {
251 let pattern = self.abstract_branching_pattern(node.conclusion());
253 *branching_patterns.entry(pattern).or_insert(0) += 1;
254 }
255 }
256 }
257
258 for (pattern, count) in branching_patterns.iter() {
260 if *count >= self.min_support {
261 let confidence = (*count as f64) / (proofs.len() as f64);
262 if confidence >= self.min_confidence {
263 self.heuristics.push(ProofHeuristic {
264 name: format!("branch_{}", pattern),
265 heuristic_type: HeuristicType::BranchingStrategy {
266 criteria: pattern.clone(),
267 },
268 confidence,
269 support_count: *count,
270 avg_improvement: 0.0,
271 });
272 }
273 }
274 }
275 }
276
277 fn abstract_branching_pattern(&self, conclusion: &str) -> String {
279 if conclusion.contains("forall") {
282 "universal".to_string()
283 } else if conclusion.contains("exists") {
284 "existential".to_string()
285 } else if conclusion.contains(" or ") {
286 "disjunction".to_string()
287 } else if conclusion.contains(" and ") {
288 "conjunction".to_string()
289 } else {
290 "other".to_string()
291 }
292 }
293
294 fn learn_lemma_selection(&mut self, proofs: &[&Proof]) {
296 let mut lemma_patterns: FxHashMap<String, usize> = FxHashMap::default();
297
298 for proof in proofs {
299 for node in proof.nodes() {
300 if let ProofStep::Inference { rule, .. } = &node.step
301 && (rule.contains("lemma") || rule.contains("theory"))
302 {
303 let pattern = self.extract_lemma_pattern(node.conclusion());
304 *lemma_patterns.entry(pattern).or_insert(0) += 1;
305 }
306 }
307 }
308
309 for (pattern, count) in lemma_patterns.iter() {
310 if *count >= self.min_support {
311 let confidence = (*count as f64) / (proofs.len() as f64);
312 if confidence >= self.min_confidence {
313 self.heuristics.push(ProofHeuristic {
314 name: format!("lemma_{}", pattern),
315 heuristic_type: HeuristicType::LemmaSelection {
316 pattern: pattern.clone(),
317 },
318 confidence,
319 support_count: *count,
320 avg_improvement: 0.0,
321 });
322 }
323 }
324 }
325 }
326
327 fn extract_lemma_pattern(&self, conclusion: &str) -> String {
329 if conclusion.contains("congruence") {
332 "congruence".to_string()
333 } else if conclusion.contains("<=") || conclusion.contains(">=") {
334 "inequality".to_string()
335 } else if conclusion.contains("=") {
336 "equality".to_string()
337 } else {
338 "other".to_string()
339 }
340 }
341
342 fn learn_instantiation_preferences(&mut self, proofs: &[&Proof]) {
344 let mut instantiation_patterns: FxHashMap<String, usize> = FxHashMap::default();
345
346 for proof in proofs {
347 for node in proof.nodes() {
348 if let ProofStep::Inference { rule, .. } = &node.step
349 && (rule.contains("instantiation") || rule.contains("forall_elim"))
350 {
351 let pattern = self.extract_trigger_pattern(node.conclusion());
352 *instantiation_patterns.entry(pattern).or_insert(0) += 1;
353 }
354 }
355 }
356
357 for (pattern, count) in instantiation_patterns.iter() {
358 if *count >= self.min_support {
359 let confidence = (*count as f64) / (proofs.len() as f64);
360 if confidence >= self.min_confidence {
361 self.heuristics.push(ProofHeuristic {
362 name: format!("inst_{}", pattern),
363 heuristic_type: HeuristicType::InstantiationPreference {
364 trigger_pattern: pattern.clone(),
365 },
366 confidence,
367 support_count: *count,
368 avg_improvement: 0.0,
369 });
370 }
371 }
372 }
373 }
374
375 fn extract_trigger_pattern(&self, conclusion: &str) -> String {
377 if let Some(start) = conclusion.find('(')
379 && let Some(end) = conclusion[start..].find(')')
380 {
381 return conclusion[..start + end + 1].to_string();
382 }
383 "default".to_string()
384 }
385
386 fn learn_theory_combination(&mut self, proofs: &[&Proof]) {
388 let mut theory_sequences: FxHashMap<Vec<String>, usize> = FxHashMap::default();
389
390 for proof in proofs {
391 let theories = self.extract_theory_sequence(proof);
392 if !theories.is_empty() {
393 *theory_sequences.entry(theories).or_insert(0) += 1;
394 }
395 }
396
397 for (seq, count) in theory_sequences.iter() {
398 if *count >= self.min_support {
399 let confidence = (*count as f64) / (proofs.len() as f64);
400 if confidence >= self.min_confidence {
401 self.heuristics.push(ProofHeuristic {
402 name: format!("theory_comb_{}", seq.join("_")),
403 heuristic_type: HeuristicType::TheoryCombination {
404 theory_order: seq.clone(),
405 },
406 confidence,
407 support_count: *count,
408 avg_improvement: 0.0,
409 });
410 }
411 }
412 }
413 }
414
415 fn extract_theory_sequence(&self, proof: &Proof) -> Vec<String> {
417 let mut seen = FxHashSet::default();
418 let mut sequence = Vec::new();
419
420 for node in proof.nodes() {
421 if let ProofStep::Inference { rule, .. } = &node.step {
422 let theory = self.infer_theory_from_rule(rule);
423 if !theory.is_empty() && !seen.contains(&theory) {
424 seen.insert(theory.clone());
425 sequence.push(theory);
426 }
427 }
428 }
429
430 sequence
431 }
432
433 fn infer_theory_from_rule(&self, rule: &str) -> String {
435 if rule.contains("arith") || rule.contains("farkas") {
436 "arithmetic".to_string()
437 } else if rule.contains("euf") || rule.contains("congruence") {
438 "euf".to_string()
439 } else if rule.contains("array") {
440 "arrays".to_string()
441 } else if rule.contains("bv") || rule.contains("bitvector") {
442 "bitvectors".to_string()
443 } else {
444 String::new()
445 }
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn test_strategy_learner_new() {
455 let learner = StrategyLearner::new();
456 assert_eq!(learner.min_support, 2);
457 assert_eq!(learner.min_confidence, 0.5);
458 assert!(learner.heuristics.is_empty());
459 }
460
461 #[test]
462 fn test_strategy_learner_with_settings() {
463 let learner = StrategyLearner::new()
464 .with_min_support(3)
465 .with_min_confidence(0.7);
466 assert_eq!(learner.min_support, 3);
467 assert_eq!(learner.min_confidence, 0.7);
468 }
469
470 #[test]
471 fn test_heuristic_type_display() {
472 let rule_ordering = HeuristicType::RuleOrdering {
473 preferred_sequence: vec!["resolution".to_string(), "unit_prop".to_string()],
474 };
475 assert_eq!(
476 rule_ordering.to_string(),
477 "RuleOrdering[resolution → unit_prop]"
478 );
479
480 let branching = HeuristicType::BranchingStrategy {
481 criteria: "disjunction".to_string(),
482 };
483 assert_eq!(branching.to_string(), "BranchingStrategy[disjunction]");
484 }
485
486 #[test]
487 fn test_proof_heuristic_display() {
488 let heuristic = ProofHeuristic {
489 name: "test_heuristic".to_string(),
490 heuristic_type: HeuristicType::RuleOrdering {
491 preferred_sequence: vec!["resolution".to_string()],
492 },
493 confidence: 0.8,
494 support_count: 10,
495 avg_improvement: 0.15,
496 };
497 let display = format!("{}", heuristic);
498 assert!(display.contains("test_heuristic"));
499 assert!(display.contains("0.80"));
500 assert!(display.contains("10 proofs"));
501 }
502
503 #[test]
504 fn test_clear_heuristics() {
505 let mut learner = StrategyLearner::new();
506 learner.heuristics.push(ProofHeuristic {
507 name: "test".to_string(),
508 heuristic_type: HeuristicType::RuleOrdering {
509 preferred_sequence: vec![],
510 },
511 confidence: 0.5,
512 support_count: 2,
513 avg_improvement: 0.0,
514 });
515 learner.clear();
516 assert!(learner.heuristics.is_empty());
517 }
518
519 #[test]
520 fn test_get_top_heuristics() {
521 let mut learner = StrategyLearner::new();
522 learner.heuristics.push(ProofHeuristic {
523 name: "h1".to_string(),
524 heuristic_type: HeuristicType::RuleOrdering {
525 preferred_sequence: vec![],
526 },
527 confidence: 0.9,
528 support_count: 2,
529 avg_improvement: 0.0,
530 });
531 learner.heuristics.push(ProofHeuristic {
532 name: "h2".to_string(),
533 heuristic_type: HeuristicType::RuleOrdering {
534 preferred_sequence: vec![],
535 },
536 confidence: 0.7,
537 support_count: 2,
538 avg_improvement: 0.0,
539 });
540 let top = learner.get_top_heuristics(1);
541 assert_eq!(top.len(), 1);
542 assert_eq!(top[0].name, "h1");
543 }
544
545 #[test]
546 fn test_abstract_branching_pattern() {
547 let learner = StrategyLearner::new();
548 assert_eq!(learner.abstract_branching_pattern("x or y"), "disjunction");
549 assert_eq!(learner.abstract_branching_pattern("x and y"), "conjunction");
550 assert_eq!(
551 learner.abstract_branching_pattern("forall x. P(x)"),
552 "universal"
553 );
554 }
555
556 #[test]
557 fn test_extract_lemma_pattern() {
558 let learner = StrategyLearner::new();
559 assert_eq!(learner.extract_lemma_pattern("x = y"), "equality");
560 assert_eq!(learner.extract_lemma_pattern("x <= y"), "inequality");
561 assert_eq!(
562 learner.extract_lemma_pattern("congruence f(x) f(y)"),
563 "congruence"
564 );
565 }
566
567 #[test]
568 fn test_infer_theory_from_rule() {
569 let learner = StrategyLearner::new();
570 assert_eq!(learner.infer_theory_from_rule("arith_lemma"), "arithmetic");
571 assert_eq!(learner.infer_theory_from_rule("euf_congruence"), "euf");
572 assert_eq!(
573 learner.infer_theory_from_rule("array_extensionality"),
574 "arrays"
575 );
576 assert_eq!(learner.infer_theory_from_rule("bv_solve"), "bitvectors");
577 }
578}