use crate::proof::{Proof, ProofNodeId, ProofStep};
use rustc_hash::{FxHashMap, FxHashSet};
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ProofTemplate {
pub name: String,
pub steps: Vec<TemplateStep>,
pub parameters: Vec<String>,
pub occurrences: usize,
pub success_rate: f64,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TemplateStep {
pub id: usize,
pub rule: String,
pub premise_ids: Vec<usize>,
pub conclusion_pattern: String,
}
impl fmt::Display for ProofTemplate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Template: {}", self.name)?;
writeln!(f, "Parameters: {}", self.parameters.join(", "))?;
writeln!(f, "Occurrences: {}", self.occurrences)?;
writeln!(f, "Success rate: {:.1}%", self.success_rate * 100.0)?;
writeln!(f, "Steps:")?;
for step in &self.steps {
writeln!(
f,
" [{}] {} from {:?} => {}",
step.id, step.rule, step.premise_ids, step.conclusion_pattern
)?;
}
Ok(())
}
}
pub struct TemplateIdentifier {
min_template_size: usize,
max_template_size: usize,
min_occurrences: usize,
templates: Vec<ProofTemplate>,
}
impl Default for TemplateIdentifier {
fn default() -> Self {
Self::new()
}
}
impl TemplateIdentifier {
pub fn new() -> Self {
Self {
min_template_size: 3,
max_template_size: 10,
min_occurrences: 2,
templates: Vec::new(),
}
}
pub fn with_min_size(mut self, size: usize) -> Self {
self.min_template_size = size;
self
}
pub fn with_max_size(mut self, size: usize) -> Self {
self.max_template_size = size;
self
}
pub fn with_min_occurrences(mut self, occurrences: usize) -> Self {
self.min_occurrences = occurrences;
self
}
pub fn identify_templates(&mut self, proofs: &[&Proof]) {
let mut candidates: Vec<Vec<TemplateStep>> = Vec::new();
for proof in proofs {
candidates.extend(self.extract_candidate_templates(proof));
}
let grouped = self.group_similar_templates(&candidates);
let mut new_templates = Vec::new();
for (pattern, instances) in grouped {
if instances.len() >= self.min_occurrences {
let template = self.create_template(&pattern, instances.len());
new_templates.push(template);
}
}
self.templates.extend(new_templates);
self.templates
.sort_by_key(|t| std::cmp::Reverse(t.occurrences));
}
pub fn get_templates(&self) -> &[ProofTemplate] {
&self.templates
}
pub fn get_templates_by_success_rate(&self) -> Vec<&ProofTemplate> {
let mut templates: Vec<&ProofTemplate> = self.templates.iter().collect();
templates.sort_by(|a, b| {
b.success_rate
.partial_cmp(&a.success_rate)
.unwrap_or(std::cmp::Ordering::Equal)
});
templates
}
pub fn find_template(&self, name: &str) -> Option<&ProofTemplate> {
self.templates.iter().find(|t| t.name == name)
}
pub fn update_success_rate(&mut self, name: &str, success_rate: f64) {
if let Some(template) = self.templates.iter_mut().find(|t| t.name == name) {
template.success_rate = success_rate.clamp(0.0, 1.0);
}
}
pub fn clear(&mut self) {
self.templates.clear();
}
fn extract_candidate_templates(&self, proof: &Proof) -> Vec<Vec<TemplateStep>> {
let mut candidates = Vec::new();
let nodes: Vec<ProofNodeId> = proof.nodes().iter().map(|n| n.id).collect();
for window_size in self.min_template_size..=self.max_template_size.min(nodes.len()) {
for window in nodes.windows(window_size) {
if let Some(template_steps) = self.extract_template_steps(proof, window) {
candidates.push(template_steps);
}
}
}
candidates
}
fn extract_template_steps(
&self,
proof: &Proof,
nodes: &[ProofNodeId],
) -> Option<Vec<TemplateStep>> {
let mut steps = Vec::new();
let mut node_to_id = FxHashMap::default();
for (i, &node_id) in nodes.iter().enumerate() {
node_to_id.insert(node_id, i);
if let Some(node) = proof.get_node(node_id)
&& let ProofStep::Inference { rule, premises, .. } = &node.step
{
let premise_ids: Vec<usize> = premises
.iter()
.filter_map(|&p| node_to_id.get(&p).copied())
.collect();
steps.push(TemplateStep {
id: i,
rule: rule.clone(),
premise_ids,
conclusion_pattern: self.abstract_conclusion(node.conclusion()),
});
}
}
if steps.len() >= self.min_template_size {
Some(steps)
} else {
None
}
}
fn abstract_conclusion(&self, conclusion: &str) -> String {
let mut abstracted = conclusion.to_string();
let re_num = regex::Regex::new(r"\b\d+\b").expect("regex pattern is valid");
abstracted = re_num.replace_all(&abstracted, "$$N").to_string();
let re_str = regex::Regex::new(r#""[^"]*""#).expect("regex pattern is valid");
abstracted = re_str.replace_all(&abstracted, "$$S").to_string();
let re_id = regex::Regex::new(r"\b[a-z][a-z0-9_]*\b").expect("regex pattern is valid");
abstracted = re_id.replace_all(&abstracted, "$$V").to_string();
abstracted
}
fn group_similar_templates<'a>(
&self,
candidates: &'a [Vec<TemplateStep>],
) -> FxHashMap<String, Vec<&'a Vec<TemplateStep>>> {
let mut groups: FxHashMap<String, Vec<&Vec<TemplateStep>>> = FxHashMap::default();
for candidate in candidates {
let signature = self.compute_template_signature(candidate);
groups.entry(signature).or_default().push(candidate);
}
groups
}
fn compute_template_signature(&self, steps: &[TemplateStep]) -> String {
steps
.iter()
.map(|s| format!("{}:{}", s.rule, s.conclusion_pattern))
.collect::<Vec<_>>()
.join("|")
}
fn create_template(&self, pattern: &str, occurrences: usize) -> ProofTemplate {
let parts: Vec<&str> = pattern.split('|').collect();
let mut steps = Vec::new();
let mut parameters = FxHashSet::default();
for (i, part) in parts.iter().enumerate() {
if let Some((rule, conclusion_pattern)) = part.split_once(':') {
for capture in conclusion_pattern.split('$').skip(1) {
if let Some(var) = capture.chars().next() {
parameters.insert(format!("${}", var));
}
}
steps.push(TemplateStep {
id: i,
rule: rule.to_string(),
premise_ids: Vec::new(), conclusion_pattern: conclusion_pattern.to_string(),
});
}
}
let mut params: Vec<String> = parameters.into_iter().collect();
params.sort();
ProofTemplate {
name: format!("template_{}", self.templates.len()),
steps,
parameters: params,
occurrences,
success_rate: 0.0, }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_template_identifier_new() {
let identifier = TemplateIdentifier::new();
assert_eq!(identifier.min_template_size, 3);
assert_eq!(identifier.max_template_size, 10);
assert_eq!(identifier.min_occurrences, 2);
assert!(identifier.templates.is_empty());
}
#[test]
fn test_template_identifier_with_settings() {
let identifier = TemplateIdentifier::new()
.with_min_size(5)
.with_max_size(15)
.with_min_occurrences(3);
assert_eq!(identifier.min_template_size, 5);
assert_eq!(identifier.max_template_size, 15);
assert_eq!(identifier.min_occurrences, 3);
}
#[test]
fn test_template_step() {
let step = TemplateStep {
id: 0,
rule: "resolution".to_string(),
premise_ids: vec![1, 2],
conclusion_pattern: "x = y".to_string(),
};
assert_eq!(step.id, 0);
assert_eq!(step.rule, "resolution");
assert_eq!(step.premise_ids.len(), 2);
}
#[test]
fn test_proof_template_display() {
let template = ProofTemplate {
name: "test_template".to_string(),
steps: vec![TemplateStep {
id: 0,
rule: "resolution".to_string(),
premise_ids: vec![],
conclusion_pattern: "$V = $V".to_string(),
}],
parameters: vec!["$V".to_string()],
occurrences: 5,
success_rate: 0.8,
};
let display = format!("{}", template);
assert!(display.contains("test_template"));
assert!(display.contains("80.0%"));
}
#[test]
fn test_abstract_conclusion() {
let identifier = TemplateIdentifier::new();
let abstracted = identifier.abstract_conclusion("x + 42 = y");
assert!(abstracted.contains("$N") || abstracted.contains("42"));
assert!(abstracted.contains("$V") || abstracted.contains("x"));
}
#[test]
fn test_update_success_rate() {
let mut identifier = TemplateIdentifier::new();
identifier.templates.push(ProofTemplate {
name: "test".to_string(),
steps: vec![],
parameters: vec![],
occurrences: 1,
success_rate: 0.0,
});
identifier.update_success_rate("test", 0.75);
assert_eq!(identifier.templates[0].success_rate, 0.75);
}
#[test]
fn test_update_success_rate_clamp() {
let mut identifier = TemplateIdentifier::new();
identifier.templates.push(ProofTemplate {
name: "test".to_string(),
steps: vec![],
parameters: vec![],
occurrences: 1,
success_rate: 0.0,
});
identifier.update_success_rate("test", 1.5);
assert_eq!(identifier.templates[0].success_rate, 1.0);
}
#[test]
fn test_find_template() {
let mut identifier = TemplateIdentifier::new();
identifier.templates.push(ProofTemplate {
name: "test".to_string(),
steps: vec![],
parameters: vec![],
occurrences: 1,
success_rate: 0.0,
});
assert!(identifier.find_template("test").is_some());
assert!(identifier.find_template("nonexistent").is_none());
}
#[test]
fn test_clear_templates() {
let mut identifier = TemplateIdentifier::new();
identifier.templates.push(ProofTemplate {
name: "test".to_string(),
steps: vec![],
parameters: vec![],
occurrences: 1,
success_rate: 0.0,
});
identifier.clear();
assert!(identifier.templates.is_empty());
}
#[test]
fn test_get_templates_by_success_rate() {
let mut identifier = TemplateIdentifier::new();
identifier.templates.push(ProofTemplate {
name: "low".to_string(),
steps: vec![],
parameters: vec![],
occurrences: 1,
success_rate: 0.3,
});
identifier.templates.push(ProofTemplate {
name: "high".to_string(),
steps: vec![],
parameters: vec![],
occurrences: 1,
success_rate: 0.9,
});
let sorted = identifier.get_templates_by_success_rate();
assert_eq!(sorted[0].name, "high");
assert_eq!(sorted[1].name, "low");
}
}