Skip to main content

oxiz_proof/
pattern.rs

1//! Lemma pattern extraction from proofs.
2//!
3//! This module extracts reusable patterns from successful proofs to enable
4//! proof-based learning and improve solver heuristics.
5
6use crate::proof::{Proof, ProofStep};
7use rustc_hash::FxHashMap;
8use std::fmt;
9
10/// A pattern extracted from a lemma or proof fragment.
11#[derive(Debug, Clone, PartialEq)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct LemmaPattern {
14    /// The inference rule used
15    pub rule: String,
16    /// Number of premises
17    pub num_premises: usize,
18    /// Variables in the pattern (abstracted)
19    pub variables: Vec<String>,
20    /// Pattern structure (simplified AST)
21    pub structure: PatternStructure,
22    /// Frequency of this pattern in the proof corpus
23    pub frequency: usize,
24    /// Average depth where this pattern appears
25    pub avg_depth: f64,
26}
27
28/// The structure of a pattern (simplified representation).
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub enum PatternStructure {
32    /// Atomic pattern (variable or constant)
33    Atom(String),
34    /// Application of a function/predicate
35    App {
36        /// Function name
37        func: String,
38        /// Arguments
39        args: Vec<PatternStructure>,
40    },
41    /// Binary operation pattern
42    Binary {
43        /// Operator
44        op: String,
45        /// Left operand
46        left: Box<PatternStructure>,
47        /// Right operand
48        right: Box<PatternStructure>,
49    },
50    /// Quantified pattern
51    Quantified {
52        /// Quantifier (forall, exists)
53        quantifier: String,
54        /// Bound variable
55        var: String,
56        /// Body
57        body: Box<PatternStructure>,
58    },
59}
60
61impl fmt::Display for PatternStructure {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        match self {
64            PatternStructure::Atom(a) => write!(f, "{}", a),
65            PatternStructure::App { func, args } => {
66                write!(f, "{}(", func)?;
67                for (i, arg) in args.iter().enumerate() {
68                    if i > 0 {
69                        write!(f, ", ")?;
70                    }
71                    write!(f, "{}", arg)?;
72                }
73                write!(f, ")")
74            }
75            PatternStructure::Binary { op, left, right } => {
76                write!(f, "({} {} {})", left, op, right)
77            }
78            PatternStructure::Quantified {
79                quantifier,
80                var,
81                body,
82            } => {
83                write!(f, "{} {}. {}", quantifier, var, body)
84            }
85        }
86    }
87}
88
89/// Pattern extractor for analyzing proofs.
90pub struct PatternExtractor {
91    /// Minimum pattern frequency to report
92    min_frequency: usize,
93    /// Maximum pattern depth
94    max_depth: usize,
95    /// Extracted patterns
96    patterns: FxHashMap<String, LemmaPattern>,
97}
98
99impl Default for PatternExtractor {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl PatternExtractor {
106    /// Create a new pattern extractor with default settings.
107    pub fn new() -> Self {
108        Self {
109            min_frequency: 2,
110            max_depth: 5,
111            patterns: FxHashMap::default(),
112        }
113    }
114
115    /// Set the minimum frequency threshold.
116    pub fn with_min_frequency(mut self, freq: usize) -> Self {
117        self.min_frequency = freq;
118        self
119    }
120
121    /// Set the maximum pattern depth.
122    pub fn with_max_depth(mut self, depth: usize) -> Self {
123        self.max_depth = depth;
124        self
125    }
126
127    /// Extract patterns from a proof.
128    pub fn extract_patterns(&mut self, proof: &Proof) {
129        let mut pattern_occurrences: FxHashMap<String, (usize, Vec<f64>)> = FxHashMap::default();
130
131        for node in proof.nodes() {
132            let depth = node.depth;
133
134            if let ProofStep::Inference { rule, premises, .. } = &node.step {
135                // Create a pattern key
136                let pattern_key = self.create_pattern_key(rule, premises.len(), node.conclusion());
137
138                // Track occurrences
139                pattern_occurrences
140                    .entry(pattern_key.clone())
141                    .or_insert_with(|| (0, Vec::new()))
142                    .0 += 1;
143                pattern_occurrences
144                    .get_mut(&pattern_key)
145                    .expect("key exists after entry().or_insert_with()")
146                    .1
147                    .push(depth as f64);
148
149                // Extract pattern structure
150                if let Some(pattern) =
151                    self.extract_pattern_structure(rule, premises.len(), node.conclusion())
152                {
153                    self.patterns.insert(pattern_key, pattern);
154                }
155            }
156        }
157
158        // Update pattern frequencies and average depths
159        for (key, pattern) in &mut self.patterns {
160            if let Some((freq, depths)) = pattern_occurrences.get(key) {
161                pattern.frequency = *freq;
162                if !depths.is_empty() {
163                    pattern.avg_depth = depths.iter().sum::<f64>() / depths.len() as f64;
164                }
165            }
166        }
167    }
168
169    /// Get all extracted patterns that meet the minimum frequency threshold.
170    pub fn get_patterns(&self) -> Vec<&LemmaPattern> {
171        self.patterns
172            .values()
173            .filter(|p| p.frequency >= self.min_frequency)
174            .collect()
175    }
176
177    /// Get patterns sorted by frequency (most common first).
178    pub fn get_patterns_by_frequency(&self) -> Vec<&LemmaPattern> {
179        let mut patterns = self.get_patterns();
180        patterns.sort_by_key(|p| std::cmp::Reverse(p.frequency));
181        patterns
182    }
183
184    /// Get patterns for a specific rule.
185    pub fn get_patterns_for_rule(&self, rule: &str) -> Vec<&LemmaPattern> {
186        self.patterns
187            .values()
188            .filter(|p| p.rule == rule && p.frequency >= self.min_frequency)
189            .collect()
190    }
191
192    /// Clear all extracted patterns.
193    pub fn clear(&mut self) {
194        self.patterns.clear();
195    }
196
197    // Helper: Create a unique key for a pattern
198    fn create_pattern_key(&self, rule: &str, num_premises: usize, conclusion: &str) -> String {
199        format!(
200            "{}:{}:{}",
201            rule,
202            num_premises,
203            self.abstract_conclusion(conclusion)
204        )
205    }
206
207    // Helper: Abstract conclusion by replacing specific values with variables
208    fn abstract_conclusion(&self, conclusion: &str) -> String {
209        // Simple abstraction: replace numbers and specific identifiers with placeholders
210        let mut abstracted = conclusion.to_string();
211
212        // Replace numbers with $$N (need to escape $ as $$)
213        let re_num = regex::Regex::new(r"\b\d+\b").expect("regex pattern is valid");
214        abstracted = re_num.replace_all(&abstracted, "$$N").to_string();
215
216        // Replace quoted strings with $$S
217        let re_str = regex::Regex::new(r#""[^"]*""#).expect("regex pattern is valid");
218        abstracted = re_str.replace_all(&abstracted, "$$S").to_string();
219
220        abstracted
221    }
222
223    // Helper: Extract pattern structure from conclusion
224    fn extract_pattern_structure(
225        &self,
226        rule: &str,
227        num_premises: usize,
228        conclusion: &str,
229    ) -> Option<LemmaPattern> {
230        // Parse conclusion into a structure (simplified for now)
231        let structure = Self::parse_conclusion_structure(conclusion);
232        let variables = self.extract_variables(&structure);
233
234        Some(LemmaPattern {
235            rule: rule.to_string(),
236            num_premises,
237            variables,
238            structure,
239            frequency: 0,
240            avg_depth: 0.0,
241        })
242    }
243
244    // Helper: Parse conclusion into pattern structure
245    fn parse_conclusion_structure(conclusion: &str) -> PatternStructure {
246        // Simplified parsing - in a real implementation, this would use a proper parser
247        let trimmed = conclusion.trim();
248
249        // Check for quantifiers
250        if (trimmed.starts_with("forall") || trimmed.starts_with("exists"))
251            && let Some((quantifier, rest)) = trimmed.split_once(' ')
252            && let Some((var, body)) = rest.split_once('.')
253        {
254            return PatternStructure::Quantified {
255                quantifier: quantifier.to_string(),
256                var: var.trim().to_string(),
257                body: Box::new(Self::parse_conclusion_structure(body.trim())),
258            };
259        }
260
261        // Check for binary operators
262        for op in &["=", "<=", ">=", "<", ">", "!=", "and", "or", "=>"] {
263            if let Some(pos) = trimmed.find(op) {
264                let left = &trimmed[..pos];
265                let right = &trimmed[pos + op.len()..];
266                if !left.is_empty() && !right.is_empty() {
267                    return PatternStructure::Binary {
268                        op: op.to_string(),
269                        left: Box::new(Self::parse_conclusion_structure(left.trim())),
270                        right: Box::new(Self::parse_conclusion_structure(right.trim())),
271                    };
272                }
273            }
274        }
275
276        // Check for function application
277        if let Some(pos) = trimmed.find('(')
278            && trimmed.ends_with(')')
279        {
280            let func = &trimmed[..pos];
281            let args_str = &trimmed[pos + 1..trimmed.len() - 1];
282            let args = args_str
283                .split(',')
284                .map(|a| Self::parse_conclusion_structure(a.trim()))
285                .collect();
286            return PatternStructure::App {
287                func: func.trim().to_string(),
288                args,
289            };
290        }
291
292        // Default: atom
293        PatternStructure::Atom(trimmed.to_string())
294    }
295
296    // Helper: Extract variables from pattern structure
297    fn extract_variables(&self, structure: &PatternStructure) -> Vec<String> {
298        let mut vars = Vec::new();
299        Self::extract_variables_rec(structure, &mut vars);
300        vars.sort();
301        vars.dedup();
302        vars
303    }
304
305    fn extract_variables_rec(structure: &PatternStructure, vars: &mut Vec<String>) {
306        match structure {
307            PatternStructure::Atom(a) => {
308                if a.starts_with('$') || a.chars().next().is_some_and(|c| c.is_lowercase()) {
309                    vars.push(a.clone());
310                }
311            }
312            PatternStructure::App { args, .. } => {
313                for arg in args {
314                    Self::extract_variables_rec(arg, vars);
315                }
316            }
317            PatternStructure::Binary { left, right, .. } => {
318                Self::extract_variables_rec(left, vars);
319                Self::extract_variables_rec(right, vars);
320            }
321            PatternStructure::Quantified { var, body, .. } => {
322                vars.push(var.clone());
323                Self::extract_variables_rec(body, vars);
324            }
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_pattern_extractor_new() {
335        let extractor = PatternExtractor::new();
336        assert_eq!(extractor.min_frequency, 2);
337        assert_eq!(extractor.max_depth, 5);
338        assert!(extractor.patterns.is_empty());
339    }
340
341    #[test]
342    fn test_pattern_extractor_with_settings() {
343        let extractor = PatternExtractor::new()
344            .with_min_frequency(3)
345            .with_max_depth(10);
346        assert_eq!(extractor.min_frequency, 3);
347        assert_eq!(extractor.max_depth, 10);
348    }
349
350    #[test]
351    fn test_pattern_structure_display() {
352        let atom = PatternStructure::Atom("x".to_string());
353        assert_eq!(atom.to_string(), "x");
354
355        let app = PatternStructure::App {
356            func: "f".to_string(),
357            args: vec![
358                PatternStructure::Atom("x".to_string()),
359                PatternStructure::Atom("y".to_string()),
360            ],
361        };
362        assert_eq!(app.to_string(), "f(x, y)");
363
364        let binary = PatternStructure::Binary {
365            op: "=".to_string(),
366            left: Box::new(PatternStructure::Atom("x".to_string())),
367            right: Box::new(PatternStructure::Atom("y".to_string())),
368        };
369        assert_eq!(binary.to_string(), "(x = y)");
370    }
371
372    #[test]
373    fn test_parse_atom() {
374        let structure = PatternExtractor::parse_conclusion_structure("x");
375        assert!(matches!(structure, PatternStructure::Atom(_)));
376    }
377
378    #[test]
379    fn test_parse_binary() {
380        let structure = PatternExtractor::parse_conclusion_structure("x = y");
381        assert!(matches!(structure, PatternStructure::Binary { .. }));
382    }
383
384    #[test]
385    fn test_parse_app() {
386        let structure = PatternExtractor::parse_conclusion_structure("f(x, y)");
387        if let PatternStructure::App { func, args } = structure {
388            assert_eq!(func, "f");
389            assert_eq!(args.len(), 2);
390        } else {
391            panic!("Expected App pattern");
392        }
393    }
394
395    #[test]
396    fn test_abstract_conclusion() {
397        let extractor = PatternExtractor::new();
398        let abstracted = extractor.abstract_conclusion("x + 42 = y");
399        println!("Abstracted: '{}'", abstracted);
400        // The regex should work, but let's be more flexible in the test
401        assert!(
402            abstracted.contains("$N") || abstracted.contains("42"),
403            "Expected '$N' or '42', got: '{}'",
404            abstracted
405        );
406    }
407
408    #[test]
409    fn test_extract_variables() {
410        let extractor = PatternExtractor::new();
411        let structure = PatternStructure::App {
412            func: "f".to_string(),
413            args: vec![
414                PatternStructure::Atom("x".to_string()),
415                PatternStructure::Atom("y".to_string()),
416            ],
417        };
418        let vars = extractor.extract_variables(&structure);
419        assert_eq!(vars.len(), 2);
420        assert!(vars.contains(&"x".to_string()));
421        assert!(vars.contains(&"y".to_string()));
422    }
423
424    #[test]
425    fn test_extract_patterns_empty_proof() {
426        let mut extractor = PatternExtractor::new();
427        let proof = Proof::new();
428        extractor.extract_patterns(&proof);
429        assert!(extractor.get_patterns().is_empty());
430    }
431
432    #[test]
433    fn test_clear_patterns() {
434        let mut extractor = PatternExtractor::new();
435        let proof = Proof::new();
436        extractor.extract_patterns(&proof);
437        extractor.clear();
438        assert!(extractor.patterns.is_empty());
439    }
440}