1use crate::proof::{Proof, ProofStep};
7use rustc_hash::FxHashMap;
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct LemmaPattern {
14 pub rule: String,
16 pub num_premises: usize,
18 pub variables: Vec<String>,
20 pub structure: PatternStructure,
22 pub frequency: usize,
24 pub avg_depth: f64,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub enum PatternStructure {
32 Atom(String),
34 App {
36 func: String,
38 args: Vec<PatternStructure>,
40 },
41 Binary {
43 op: String,
45 left: Box<PatternStructure>,
47 right: Box<PatternStructure>,
49 },
50 Quantified {
52 quantifier: String,
54 var: String,
56 body: Box<PatternStructure>,
58 },
59}
60
61impl fmt::Display for PatternStructure {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 match self {
64 PatternStructure::Atom(a) => write!(f, "{}", a),
65 PatternStructure::App { func, args } => {
66 write!(f, "{}(", func)?;
67 for (i, arg) in args.iter().enumerate() {
68 if i > 0 {
69 write!(f, ", ")?;
70 }
71 write!(f, "{}", arg)?;
72 }
73 write!(f, ")")
74 }
75 PatternStructure::Binary { op, left, right } => {
76 write!(f, "({} {} {})", left, op, right)
77 }
78 PatternStructure::Quantified {
79 quantifier,
80 var,
81 body,
82 } => {
83 write!(f, "{} {}. {}", quantifier, var, body)
84 }
85 }
86 }
87}
88
89pub struct PatternExtractor {
91 min_frequency: usize,
93 max_depth: usize,
95 patterns: FxHashMap<String, LemmaPattern>,
97}
98
99impl Default for PatternExtractor {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl PatternExtractor {
106 pub fn new() -> Self {
108 Self {
109 min_frequency: 2,
110 max_depth: 5,
111 patterns: FxHashMap::default(),
112 }
113 }
114
115 pub fn with_min_frequency(mut self, freq: usize) -> Self {
117 self.min_frequency = freq;
118 self
119 }
120
121 pub fn with_max_depth(mut self, depth: usize) -> Self {
123 self.max_depth = depth;
124 self
125 }
126
127 pub fn extract_patterns(&mut self, proof: &Proof) {
129 let mut pattern_occurrences: FxHashMap<String, (usize, Vec<f64>)> = FxHashMap::default();
130
131 for node in proof.nodes() {
132 let depth = node.depth;
133
134 if let ProofStep::Inference { rule, premises, .. } = &node.step {
135 let pattern_key = self.create_pattern_key(rule, premises.len(), node.conclusion());
137
138 pattern_occurrences
140 .entry(pattern_key.clone())
141 .or_insert_with(|| (0, Vec::new()))
142 .0 += 1;
143 pattern_occurrences
144 .get_mut(&pattern_key)
145 .expect("key exists after entry().or_insert_with()")
146 .1
147 .push(depth as f64);
148
149 if let Some(pattern) =
151 self.extract_pattern_structure(rule, premises.len(), node.conclusion())
152 {
153 self.patterns.insert(pattern_key, pattern);
154 }
155 }
156 }
157
158 for (key, pattern) in &mut self.patterns {
160 if let Some((freq, depths)) = pattern_occurrences.get(key) {
161 pattern.frequency = *freq;
162 if !depths.is_empty() {
163 pattern.avg_depth = depths.iter().sum::<f64>() / depths.len() as f64;
164 }
165 }
166 }
167 }
168
169 pub fn get_patterns(&self) -> Vec<&LemmaPattern> {
171 self.patterns
172 .values()
173 .filter(|p| p.frequency >= self.min_frequency)
174 .collect()
175 }
176
177 pub fn get_patterns_by_frequency(&self) -> Vec<&LemmaPattern> {
179 let mut patterns = self.get_patterns();
180 patterns.sort_by_key(|p| std::cmp::Reverse(p.frequency));
181 patterns
182 }
183
184 pub fn get_patterns_for_rule(&self, rule: &str) -> Vec<&LemmaPattern> {
186 self.patterns
187 .values()
188 .filter(|p| p.rule == rule && p.frequency >= self.min_frequency)
189 .collect()
190 }
191
192 pub fn clear(&mut self) {
194 self.patterns.clear();
195 }
196
197 fn create_pattern_key(&self, rule: &str, num_premises: usize, conclusion: &str) -> String {
199 format!(
200 "{}:{}:{}",
201 rule,
202 num_premises,
203 self.abstract_conclusion(conclusion)
204 )
205 }
206
207 fn abstract_conclusion(&self, conclusion: &str) -> String {
209 let mut abstracted = conclusion.to_string();
211
212 let re_num = regex::Regex::new(r"\b\d+\b").expect("regex pattern is valid");
214 abstracted = re_num.replace_all(&abstracted, "$$N").to_string();
215
216 let re_str = regex::Regex::new(r#""[^"]*""#).expect("regex pattern is valid");
218 abstracted = re_str.replace_all(&abstracted, "$$S").to_string();
219
220 abstracted
221 }
222
223 fn extract_pattern_structure(
225 &self,
226 rule: &str,
227 num_premises: usize,
228 conclusion: &str,
229 ) -> Option<LemmaPattern> {
230 let structure = Self::parse_conclusion_structure(conclusion);
232 let variables = self.extract_variables(&structure);
233
234 Some(LemmaPattern {
235 rule: rule.to_string(),
236 num_premises,
237 variables,
238 structure,
239 frequency: 0,
240 avg_depth: 0.0,
241 })
242 }
243
244 fn parse_conclusion_structure(conclusion: &str) -> PatternStructure {
246 let trimmed = conclusion.trim();
248
249 if (trimmed.starts_with("forall") || trimmed.starts_with("exists"))
251 && let Some((quantifier, rest)) = trimmed.split_once(' ')
252 && let Some((var, body)) = rest.split_once('.')
253 {
254 return PatternStructure::Quantified {
255 quantifier: quantifier.to_string(),
256 var: var.trim().to_string(),
257 body: Box::new(Self::parse_conclusion_structure(body.trim())),
258 };
259 }
260
261 for op in &["=", "<=", ">=", "<", ">", "!=", "and", "or", "=>"] {
263 if let Some(pos) = trimmed.find(op) {
264 let left = &trimmed[..pos];
265 let right = &trimmed[pos + op.len()..];
266 if !left.is_empty() && !right.is_empty() {
267 return PatternStructure::Binary {
268 op: op.to_string(),
269 left: Box::new(Self::parse_conclusion_structure(left.trim())),
270 right: Box::new(Self::parse_conclusion_structure(right.trim())),
271 };
272 }
273 }
274 }
275
276 if let Some(pos) = trimmed.find('(')
278 && trimmed.ends_with(')')
279 {
280 let func = &trimmed[..pos];
281 let args_str = &trimmed[pos + 1..trimmed.len() - 1];
282 let args = args_str
283 .split(',')
284 .map(|a| Self::parse_conclusion_structure(a.trim()))
285 .collect();
286 return PatternStructure::App {
287 func: func.trim().to_string(),
288 args,
289 };
290 }
291
292 PatternStructure::Atom(trimmed.to_string())
294 }
295
296 fn extract_variables(&self, structure: &PatternStructure) -> Vec<String> {
298 let mut vars = Vec::new();
299 Self::extract_variables_rec(structure, &mut vars);
300 vars.sort();
301 vars.dedup();
302 vars
303 }
304
305 fn extract_variables_rec(structure: &PatternStructure, vars: &mut Vec<String>) {
306 match structure {
307 PatternStructure::Atom(a) => {
308 if a.starts_with('$') || a.chars().next().is_some_and(|c| c.is_lowercase()) {
309 vars.push(a.clone());
310 }
311 }
312 PatternStructure::App { args, .. } => {
313 for arg in args {
314 Self::extract_variables_rec(arg, vars);
315 }
316 }
317 PatternStructure::Binary { left, right, .. } => {
318 Self::extract_variables_rec(left, vars);
319 Self::extract_variables_rec(right, vars);
320 }
321 PatternStructure::Quantified { var, body, .. } => {
322 vars.push(var.clone());
323 Self::extract_variables_rec(body, vars);
324 }
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_pattern_extractor_new() {
335 let extractor = PatternExtractor::new();
336 assert_eq!(extractor.min_frequency, 2);
337 assert_eq!(extractor.max_depth, 5);
338 assert!(extractor.patterns.is_empty());
339 }
340
341 #[test]
342 fn test_pattern_extractor_with_settings() {
343 let extractor = PatternExtractor::new()
344 .with_min_frequency(3)
345 .with_max_depth(10);
346 assert_eq!(extractor.min_frequency, 3);
347 assert_eq!(extractor.max_depth, 10);
348 }
349
350 #[test]
351 fn test_pattern_structure_display() {
352 let atom = PatternStructure::Atom("x".to_string());
353 assert_eq!(atom.to_string(), "x");
354
355 let app = PatternStructure::App {
356 func: "f".to_string(),
357 args: vec![
358 PatternStructure::Atom("x".to_string()),
359 PatternStructure::Atom("y".to_string()),
360 ],
361 };
362 assert_eq!(app.to_string(), "f(x, y)");
363
364 let binary = PatternStructure::Binary {
365 op: "=".to_string(),
366 left: Box::new(PatternStructure::Atom("x".to_string())),
367 right: Box::new(PatternStructure::Atom("y".to_string())),
368 };
369 assert_eq!(binary.to_string(), "(x = y)");
370 }
371
372 #[test]
373 fn test_parse_atom() {
374 let structure = PatternExtractor::parse_conclusion_structure("x");
375 assert!(matches!(structure, PatternStructure::Atom(_)));
376 }
377
378 #[test]
379 fn test_parse_binary() {
380 let structure = PatternExtractor::parse_conclusion_structure("x = y");
381 assert!(matches!(structure, PatternStructure::Binary { .. }));
382 }
383
384 #[test]
385 fn test_parse_app() {
386 let structure = PatternExtractor::parse_conclusion_structure("f(x, y)");
387 if let PatternStructure::App { func, args } = structure {
388 assert_eq!(func, "f");
389 assert_eq!(args.len(), 2);
390 } else {
391 panic!("Expected App pattern");
392 }
393 }
394
395 #[test]
396 fn test_abstract_conclusion() {
397 let extractor = PatternExtractor::new();
398 let abstracted = extractor.abstract_conclusion("x + 42 = y");
399 println!("Abstracted: '{}'", abstracted);
400 assert!(
402 abstracted.contains("$N") || abstracted.contains("42"),
403 "Expected '$N' or '42', got: '{}'",
404 abstracted
405 );
406 }
407
408 #[test]
409 fn test_extract_variables() {
410 let extractor = PatternExtractor::new();
411 let structure = PatternStructure::App {
412 func: "f".to_string(),
413 args: vec![
414 PatternStructure::Atom("x".to_string()),
415 PatternStructure::Atom("y".to_string()),
416 ],
417 };
418 let vars = extractor.extract_variables(&structure);
419 assert_eq!(vars.len(), 2);
420 assert!(vars.contains(&"x".to_string()));
421 assert!(vars.contains(&"y".to_string()));
422 }
423
424 #[test]
425 fn test_extract_patterns_empty_proof() {
426 let mut extractor = PatternExtractor::new();
427 let proof = Proof::new();
428 extractor.extract_patterns(&proof);
429 assert!(extractor.get_patterns().is_empty());
430 }
431
432 #[test]
433 fn test_clear_patterns() {
434 let mut extractor = PatternExtractor::new();
435 let proof = Proof::new();
436 extractor.extract_patterns(&proof);
437 extractor.clear();
438 assert!(extractor.patterns.is_empty());
439 }
440}