1use crate::proof::{Proof, ProofNodeId, ProofStep};
7use rustc_hash::{FxHashMap, FxHashSet};
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct ProofTemplate {
14 pub name: String,
16 pub steps: Vec<TemplateStep>,
18 pub parameters: Vec<String>,
20 pub occurrences: usize,
22 pub success_rate: f64,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub struct TemplateStep {
30 pub id: usize,
32 pub rule: String,
34 pub premise_ids: Vec<usize>,
36 pub conclusion_pattern: String,
38}
39
40impl fmt::Display for ProofTemplate {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 writeln!(f, "Template: {}", self.name)?;
43 writeln!(f, "Parameters: {}", self.parameters.join(", "))?;
44 writeln!(f, "Occurrences: {}", self.occurrences)?;
45 writeln!(f, "Success rate: {:.1}%", self.success_rate * 100.0)?;
46 writeln!(f, "Steps:")?;
47 for step in &self.steps {
48 writeln!(
49 f,
50 " [{}] {} from {:?} => {}",
51 step.id, step.rule, step.premise_ids, step.conclusion_pattern
52 )?;
53 }
54 Ok(())
55 }
56}
57
58pub struct TemplateIdentifier {
60 min_template_size: usize,
62 max_template_size: usize,
64 min_occurrences: usize,
66 templates: Vec<ProofTemplate>,
68}
69
70impl Default for TemplateIdentifier {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl TemplateIdentifier {
77 pub fn new() -> Self {
79 Self {
80 min_template_size: 3,
81 max_template_size: 10,
82 min_occurrences: 2,
83 templates: Vec::new(),
84 }
85 }
86
87 pub fn with_min_size(mut self, size: usize) -> Self {
89 self.min_template_size = size;
90 self
91 }
92
93 pub fn with_max_size(mut self, size: usize) -> Self {
95 self.max_template_size = size;
96 self
97 }
98
99 pub fn with_min_occurrences(mut self, occurrences: usize) -> Self {
101 self.min_occurrences = occurrences;
102 self
103 }
104
105 pub fn identify_templates(&mut self, proofs: &[&Proof]) {
107 let mut candidates: Vec<Vec<TemplateStep>> = Vec::new();
109
110 for proof in proofs {
111 candidates.extend(self.extract_candidate_templates(proof));
112 }
113
114 let grouped = self.group_similar_templates(&candidates);
116
117 let mut new_templates = Vec::new();
119 for (pattern, instances) in grouped {
120 if instances.len() >= self.min_occurrences {
121 let template = self.create_template(&pattern, instances.len());
122 new_templates.push(template);
123 }
124 }
125
126 self.templates.extend(new_templates);
127
128 self.templates
130 .sort_by_key(|t| std::cmp::Reverse(t.occurrences));
131 }
132
133 pub fn get_templates(&self) -> &[ProofTemplate] {
135 &self.templates
136 }
137
138 pub fn get_templates_by_success_rate(&self) -> Vec<&ProofTemplate> {
140 let mut templates: Vec<&ProofTemplate> = self.templates.iter().collect();
141 templates.sort_by(|a, b| {
142 b.success_rate
143 .partial_cmp(&a.success_rate)
144 .unwrap_or(std::cmp::Ordering::Equal)
145 });
146 templates
147 }
148
149 pub fn find_template(&self, name: &str) -> Option<&ProofTemplate> {
151 self.templates.iter().find(|t| t.name == name)
152 }
153
154 pub fn update_success_rate(&mut self, name: &str, success_rate: f64) {
156 if let Some(template) = self.templates.iter_mut().find(|t| t.name == name) {
157 template.success_rate = success_rate.clamp(0.0, 1.0);
158 }
159 }
160
161 pub fn clear(&mut self) {
163 self.templates.clear();
164 }
165
166 fn extract_candidate_templates(&self, proof: &Proof) -> Vec<Vec<TemplateStep>> {
168 let mut candidates = Vec::new();
169 let nodes: Vec<ProofNodeId> = proof.nodes().iter().map(|n| n.id).collect();
170
171 for window_size in self.min_template_size..=self.max_template_size.min(nodes.len()) {
173 for window in nodes.windows(window_size) {
174 if let Some(template_steps) = self.extract_template_steps(proof, window) {
175 candidates.push(template_steps);
176 }
177 }
178 }
179
180 candidates
181 }
182
183 fn extract_template_steps(
185 &self,
186 proof: &Proof,
187 nodes: &[ProofNodeId],
188 ) -> Option<Vec<TemplateStep>> {
189 let mut steps = Vec::new();
190 let mut node_to_id = FxHashMap::default();
191
192 for (i, &node_id) in nodes.iter().enumerate() {
193 node_to_id.insert(node_id, i);
194
195 if let Some(node) = proof.get_node(node_id)
196 && let ProofStep::Inference { rule, premises, .. } = &node.step
197 {
198 let premise_ids: Vec<usize> = premises
200 .iter()
201 .filter_map(|&p| node_to_id.get(&p).copied())
202 .collect();
203
204 steps.push(TemplateStep {
205 id: i,
206 rule: rule.clone(),
207 premise_ids,
208 conclusion_pattern: self.abstract_conclusion(node.conclusion()),
209 });
210 }
211 }
212
213 if steps.len() >= self.min_template_size {
214 Some(steps)
215 } else {
216 None
217 }
218 }
219
220 fn abstract_conclusion(&self, conclusion: &str) -> String {
222 let mut abstracted = conclusion.to_string();
224
225 let re_num = regex::Regex::new(r"\b\d+\b").expect("regex pattern is valid");
227 abstracted = re_num.replace_all(&abstracted, "$$N").to_string();
228
229 let re_str = regex::Regex::new(r#""[^"]*""#).expect("regex pattern is valid");
231 abstracted = re_str.replace_all(&abstracted, "$$S").to_string();
232
233 let re_id = regex::Regex::new(r"\b[a-z][a-z0-9_]*\b").expect("regex pattern is valid");
235 abstracted = re_id.replace_all(&abstracted, "$$V").to_string();
236
237 abstracted
238 }
239
240 fn group_similar_templates<'a>(
242 &self,
243 candidates: &'a [Vec<TemplateStep>],
244 ) -> FxHashMap<String, Vec<&'a Vec<TemplateStep>>> {
245 let mut groups: FxHashMap<String, Vec<&Vec<TemplateStep>>> = FxHashMap::default();
246
247 for candidate in candidates {
248 let signature = self.compute_template_signature(candidate);
249 groups.entry(signature).or_default().push(candidate);
250 }
251
252 groups
253 }
254
255 fn compute_template_signature(&self, steps: &[TemplateStep]) -> String {
257 steps
258 .iter()
259 .map(|s| format!("{}:{}", s.rule, s.conclusion_pattern))
260 .collect::<Vec<_>>()
261 .join("|")
262 }
263
264 fn create_template(&self, pattern: &str, occurrences: usize) -> ProofTemplate {
266 let parts: Vec<&str> = pattern.split('|').collect();
268 let mut steps = Vec::new();
269 let mut parameters = FxHashSet::default();
270
271 for (i, part) in parts.iter().enumerate() {
272 if let Some((rule, conclusion_pattern)) = part.split_once(':') {
273 for capture in conclusion_pattern.split('$').skip(1) {
275 if let Some(var) = capture.chars().next() {
276 parameters.insert(format!("${}", var));
277 }
278 }
279
280 steps.push(TemplateStep {
281 id: i,
282 rule: rule.to_string(),
283 premise_ids: Vec::new(), conclusion_pattern: conclusion_pattern.to_string(),
285 });
286 }
287 }
288
289 let mut params: Vec<String> = parameters.into_iter().collect();
290 params.sort();
291
292 ProofTemplate {
293 name: format!("template_{}", self.templates.len()),
294 steps,
295 parameters: params,
296 occurrences,
297 success_rate: 0.0, }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_template_identifier_new() {
308 let identifier = TemplateIdentifier::new();
309 assert_eq!(identifier.min_template_size, 3);
310 assert_eq!(identifier.max_template_size, 10);
311 assert_eq!(identifier.min_occurrences, 2);
312 assert!(identifier.templates.is_empty());
313 }
314
315 #[test]
316 fn test_template_identifier_with_settings() {
317 let identifier = TemplateIdentifier::new()
318 .with_min_size(5)
319 .with_max_size(15)
320 .with_min_occurrences(3);
321 assert_eq!(identifier.min_template_size, 5);
322 assert_eq!(identifier.max_template_size, 15);
323 assert_eq!(identifier.min_occurrences, 3);
324 }
325
326 #[test]
327 fn test_template_step() {
328 let step = TemplateStep {
329 id: 0,
330 rule: "resolution".to_string(),
331 premise_ids: vec![1, 2],
332 conclusion_pattern: "x = y".to_string(),
333 };
334 assert_eq!(step.id, 0);
335 assert_eq!(step.rule, "resolution");
336 assert_eq!(step.premise_ids.len(), 2);
337 }
338
339 #[test]
340 fn test_proof_template_display() {
341 let template = ProofTemplate {
342 name: "test_template".to_string(),
343 steps: vec![TemplateStep {
344 id: 0,
345 rule: "resolution".to_string(),
346 premise_ids: vec![],
347 conclusion_pattern: "$V = $V".to_string(),
348 }],
349 parameters: vec!["$V".to_string()],
350 occurrences: 5,
351 success_rate: 0.8,
352 };
353 let display = format!("{}", template);
354 assert!(display.contains("test_template"));
355 assert!(display.contains("80.0%"));
356 }
357
358 #[test]
359 fn test_abstract_conclusion() {
360 let identifier = TemplateIdentifier::new();
361 let abstracted = identifier.abstract_conclusion("x + 42 = y");
362 assert!(abstracted.contains("$N") || abstracted.contains("42"));
364 assert!(abstracted.contains("$V") || abstracted.contains("x"));
365 }
366
367 #[test]
368 fn test_update_success_rate() {
369 let mut identifier = TemplateIdentifier::new();
370 identifier.templates.push(ProofTemplate {
371 name: "test".to_string(),
372 steps: vec![],
373 parameters: vec![],
374 occurrences: 1,
375 success_rate: 0.0,
376 });
377 identifier.update_success_rate("test", 0.75);
378 assert_eq!(identifier.templates[0].success_rate, 0.75);
379 }
380
381 #[test]
382 fn test_update_success_rate_clamp() {
383 let mut identifier = TemplateIdentifier::new();
384 identifier.templates.push(ProofTemplate {
385 name: "test".to_string(),
386 steps: vec![],
387 parameters: vec![],
388 occurrences: 1,
389 success_rate: 0.0,
390 });
391 identifier.update_success_rate("test", 1.5);
392 assert_eq!(identifier.templates[0].success_rate, 1.0);
393 }
394
395 #[test]
396 fn test_find_template() {
397 let mut identifier = TemplateIdentifier::new();
398 identifier.templates.push(ProofTemplate {
399 name: "test".to_string(),
400 steps: vec![],
401 parameters: vec![],
402 occurrences: 1,
403 success_rate: 0.0,
404 });
405 assert!(identifier.find_template("test").is_some());
406 assert!(identifier.find_template("nonexistent").is_none());
407 }
408
409 #[test]
410 fn test_clear_templates() {
411 let mut identifier = TemplateIdentifier::new();
412 identifier.templates.push(ProofTemplate {
413 name: "test".to_string(),
414 steps: vec![],
415 parameters: vec![],
416 occurrences: 1,
417 success_rate: 0.0,
418 });
419 identifier.clear();
420 assert!(identifier.templates.is_empty());
421 }
422
423 #[test]
424 fn test_get_templates_by_success_rate() {
425 let mut identifier = TemplateIdentifier::new();
426 identifier.templates.push(ProofTemplate {
427 name: "low".to_string(),
428 steps: vec![],
429 parameters: vec![],
430 occurrences: 1,
431 success_rate: 0.3,
432 });
433 identifier.templates.push(ProofTemplate {
434 name: "high".to_string(),
435 steps: vec![],
436 parameters: vec![],
437 occurrences: 1,
438 success_rate: 0.9,
439 });
440 let sorted = identifier.get_templates_by_success_rate();
441 assert_eq!(sorted[0].name, "high");
442 assert_eq!(sorted[1].name, "low");
443 }
444}