use crate::proof::{Proof, ProofStep};
use rustc_hash::FxHashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct LemmaPattern {
pub rule: String,
pub num_premises: usize,
pub variables: Vec<String>,
pub structure: PatternStructure,
pub frequency: usize,
pub avg_depth: f64,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum PatternStructure {
Atom(String),
App {
func: String,
args: Vec<PatternStructure>,
},
Binary {
op: String,
left: Box<PatternStructure>,
right: Box<PatternStructure>,
},
Quantified {
quantifier: String,
var: String,
body: Box<PatternStructure>,
},
}
impl fmt::Display for PatternStructure {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PatternStructure::Atom(a) => write!(f, "{}", a),
PatternStructure::App { func, args } => {
write!(f, "{}(", func)?;
for (i, arg) in args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", arg)?;
}
write!(f, ")")
}
PatternStructure::Binary { op, left, right } => {
write!(f, "({} {} {})", left, op, right)
}
PatternStructure::Quantified {
quantifier,
var,
body,
} => {
write!(f, "{} {}. {}", quantifier, var, body)
}
}
}
}
pub struct PatternExtractor {
min_frequency: usize,
max_depth: usize,
patterns: FxHashMap<String, LemmaPattern>,
}
impl Default for PatternExtractor {
fn default() -> Self {
Self::new()
}
}
impl PatternExtractor {
pub fn new() -> Self {
Self {
min_frequency: 2,
max_depth: 5,
patterns: FxHashMap::default(),
}
}
pub fn with_min_frequency(mut self, freq: usize) -> Self {
self.min_frequency = freq;
self
}
pub fn with_max_depth(mut self, depth: usize) -> Self {
self.max_depth = depth;
self
}
pub fn extract_patterns(&mut self, proof: &Proof) {
let mut pattern_occurrences: FxHashMap<String, (usize, Vec<f64>)> = FxHashMap::default();
for node in proof.nodes() {
let depth = node.depth;
if let ProofStep::Inference { rule, premises, .. } = &node.step {
let pattern_key = self.create_pattern_key(rule, premises.len(), node.conclusion());
pattern_occurrences
.entry(pattern_key.clone())
.or_insert_with(|| (0, Vec::new()))
.0 += 1;
pattern_occurrences
.get_mut(&pattern_key)
.expect("key exists after entry().or_insert_with()")
.1
.push(depth as f64);
if let Some(pattern) =
self.extract_pattern_structure(rule, premises.len(), node.conclusion())
{
self.patterns.insert(pattern_key, pattern);
}
}
}
for (key, pattern) in &mut self.patterns {
if let Some((freq, depths)) = pattern_occurrences.get(key) {
pattern.frequency = *freq;
if !depths.is_empty() {
pattern.avg_depth = depths.iter().sum::<f64>() / depths.len() as f64;
}
}
}
}
pub fn get_patterns(&self) -> Vec<&LemmaPattern> {
self.patterns
.values()
.filter(|p| p.frequency >= self.min_frequency)
.collect()
}
pub fn get_patterns_by_frequency(&self) -> Vec<&LemmaPattern> {
let mut patterns = self.get_patterns();
patterns.sort_by_key(|p| std::cmp::Reverse(p.frequency));
patterns
}
pub fn get_patterns_for_rule(&self, rule: &str) -> Vec<&LemmaPattern> {
self.patterns
.values()
.filter(|p| p.rule == rule && p.frequency >= self.min_frequency)
.collect()
}
pub fn clear(&mut self) {
self.patterns.clear();
}
fn create_pattern_key(&self, rule: &str, num_premises: usize, conclusion: &str) -> String {
format!(
"{}:{}:{}",
rule,
num_premises,
self.abstract_conclusion(conclusion)
)
}
fn abstract_conclusion(&self, conclusion: &str) -> String {
let mut abstracted = conclusion.to_string();
let re_num = regex::Regex::new(r"\b\d+\b").expect("regex pattern is valid");
abstracted = re_num.replace_all(&abstracted, "$$N").to_string();
let re_str = regex::Regex::new(r#""[^"]*""#).expect("regex pattern is valid");
abstracted = re_str.replace_all(&abstracted, "$$S").to_string();
abstracted
}
fn extract_pattern_structure(
&self,
rule: &str,
num_premises: usize,
conclusion: &str,
) -> Option<LemmaPattern> {
let structure = Self::parse_conclusion_structure(conclusion);
let variables = self.extract_variables(&structure);
Some(LemmaPattern {
rule: rule.to_string(),
num_premises,
variables,
structure,
frequency: 0,
avg_depth: 0.0,
})
}
fn parse_conclusion_structure(conclusion: &str) -> PatternStructure {
let trimmed = conclusion.trim();
if (trimmed.starts_with("forall") || trimmed.starts_with("exists"))
&& let Some((quantifier, rest)) = trimmed.split_once(' ')
&& let Some((var, body)) = rest.split_once('.')
{
return PatternStructure::Quantified {
quantifier: quantifier.to_string(),
var: var.trim().to_string(),
body: Box::new(Self::parse_conclusion_structure(body.trim())),
};
}
for op in &["=", "<=", ">=", "<", ">", "!=", "and", "or", "=>"] {
if let Some(pos) = trimmed.find(op) {
let left = &trimmed[..pos];
let right = &trimmed[pos + op.len()..];
if !left.is_empty() && !right.is_empty() {
return PatternStructure::Binary {
op: op.to_string(),
left: Box::new(Self::parse_conclusion_structure(left.trim())),
right: Box::new(Self::parse_conclusion_structure(right.trim())),
};
}
}
}
if let Some(pos) = trimmed.find('(')
&& trimmed.ends_with(')')
{
let func = &trimmed[..pos];
let args_str = &trimmed[pos + 1..trimmed.len() - 1];
let args = args_str
.split(',')
.map(|a| Self::parse_conclusion_structure(a.trim()))
.collect();
return PatternStructure::App {
func: func.trim().to_string(),
args,
};
}
PatternStructure::Atom(trimmed.to_string())
}
fn extract_variables(&self, structure: &PatternStructure) -> Vec<String> {
let mut vars = Vec::new();
Self::extract_variables_rec(structure, &mut vars);
vars.sort();
vars.dedup();
vars
}
fn extract_variables_rec(structure: &PatternStructure, vars: &mut Vec<String>) {
match structure {
PatternStructure::Atom(a) => {
if a.starts_with('$') || a.chars().next().is_some_and(|c| c.is_lowercase()) {
vars.push(a.clone());
}
}
PatternStructure::App { args, .. } => {
for arg in args {
Self::extract_variables_rec(arg, vars);
}
}
PatternStructure::Binary { left, right, .. } => {
Self::extract_variables_rec(left, vars);
Self::extract_variables_rec(right, vars);
}
PatternStructure::Quantified { var, body, .. } => {
vars.push(var.clone());
Self::extract_variables_rec(body, vars);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_extractor_new() {
let extractor = PatternExtractor::new();
assert_eq!(extractor.min_frequency, 2);
assert_eq!(extractor.max_depth, 5);
assert!(extractor.patterns.is_empty());
}
#[test]
fn test_pattern_extractor_with_settings() {
let extractor = PatternExtractor::new()
.with_min_frequency(3)
.with_max_depth(10);
assert_eq!(extractor.min_frequency, 3);
assert_eq!(extractor.max_depth, 10);
}
#[test]
fn test_pattern_structure_display() {
let atom = PatternStructure::Atom("x".to_string());
assert_eq!(atom.to_string(), "x");
let app = PatternStructure::App {
func: "f".to_string(),
args: vec![
PatternStructure::Atom("x".to_string()),
PatternStructure::Atom("y".to_string()),
],
};
assert_eq!(app.to_string(), "f(x, y)");
let binary = PatternStructure::Binary {
op: "=".to_string(),
left: Box::new(PatternStructure::Atom("x".to_string())),
right: Box::new(PatternStructure::Atom("y".to_string())),
};
assert_eq!(binary.to_string(), "(x = y)");
}
#[test]
fn test_parse_atom() {
let structure = PatternExtractor::parse_conclusion_structure("x");
assert!(matches!(structure, PatternStructure::Atom(_)));
}
#[test]
fn test_parse_binary() {
let structure = PatternExtractor::parse_conclusion_structure("x = y");
assert!(matches!(structure, PatternStructure::Binary { .. }));
}
#[test]
fn test_parse_app() {
let structure = PatternExtractor::parse_conclusion_structure("f(x, y)");
if let PatternStructure::App { func, args } = structure {
assert_eq!(func, "f");
assert_eq!(args.len(), 2);
} else {
panic!("Expected App pattern");
}
}
#[test]
fn test_abstract_conclusion() {
let extractor = PatternExtractor::new();
let abstracted = extractor.abstract_conclusion("x + 42 = y");
println!("Abstracted: '{}'", abstracted);
assert!(
abstracted.contains("$N") || abstracted.contains("42"),
"Expected '$N' or '42', got: '{}'",
abstracted
);
}
#[test]
fn test_extract_variables() {
let extractor = PatternExtractor::new();
let structure = PatternStructure::App {
func: "f".to_string(),
args: vec![
PatternStructure::Atom("x".to_string()),
PatternStructure::Atom("y".to_string()),
],
};
let vars = extractor.extract_variables(&structure);
assert_eq!(vars.len(), 2);
assert!(vars.contains(&"x".to_string()));
assert!(vars.contains(&"y".to_string()));
}
#[test]
fn test_extract_patterns_empty_proof() {
let mut extractor = PatternExtractor::new();
let proof = Proof::new();
extractor.extract_patterns(&proof);
assert!(extractor.get_patterns().is_empty());
}
#[test]
fn test_clear_patterns() {
let mut extractor = PatternExtractor::new();
let proof = Proof::new();
extractor.extract_patterns(&proof);
extractor.clear();
assert!(extractor.patterns.is_empty());
}
}