Skip to main content

oxiz_proof/
template.rs

1//! Proof template identification and reuse.
2//!
3//! This module identifies reusable proof templates that can be instantiated
4//! for similar problems, enabling proof-based learning.
5
6use crate::proof::{Proof, ProofNodeId, ProofStep};
7use rustc_hash::{FxHashMap, FxHashSet};
8use std::fmt;
9
10/// A proof template representing a reusable proof pattern.
11#[derive(Debug, Clone, PartialEq)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct ProofTemplate {
14    /// Template name/identifier
15    pub name: String,
16    /// Template steps (abstracted)
17    pub steps: Vec<TemplateStep>,
18    /// Template parameters (variables to instantiate)
19    pub parameters: Vec<String>,
20    /// Number of times this template was found
21    pub occurrences: usize,
22    /// Success rate when applied
23    pub success_rate: f64,
24}
25
26/// A single step in a proof template.
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
29pub struct TemplateStep {
30    /// Step identifier
31    pub id: usize,
32    /// Inference rule
33    pub rule: String,
34    /// Premise step IDs
35    pub premise_ids: Vec<usize>,
36    /// Abstracted conclusion pattern
37    pub conclusion_pattern: String,
38}
39
40impl fmt::Display for ProofTemplate {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        writeln!(f, "Template: {}", self.name)?;
43        writeln!(f, "Parameters: {}", self.parameters.join(", "))?;
44        writeln!(f, "Occurrences: {}", self.occurrences)?;
45        writeln!(f, "Success rate: {:.1}%", self.success_rate * 100.0)?;
46        writeln!(f, "Steps:")?;
47        for step in &self.steps {
48            writeln!(
49                f,
50                "  [{}] {} from {:?} => {}",
51                step.id, step.rule, step.premise_ids, step.conclusion_pattern
52            )?;
53        }
54        Ok(())
55    }
56}
57
58/// Template identifier for analyzing proofs.
59pub struct TemplateIdentifier {
60    /// Minimum template size (number of steps)
61    min_template_size: usize,
62    /// Maximum template size
63    max_template_size: usize,
64    /// Minimum occurrences to consider a template
65    min_occurrences: usize,
66    /// Identified templates
67    templates: Vec<ProofTemplate>,
68}
69
70impl Default for TemplateIdentifier {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl TemplateIdentifier {
77    /// Create a new template identifier with default settings.
78    pub fn new() -> Self {
79        Self {
80            min_template_size: 3,
81            max_template_size: 10,
82            min_occurrences: 2,
83            templates: Vec::new(),
84        }
85    }
86
87    /// Set the minimum template size.
88    pub fn with_min_size(mut self, size: usize) -> Self {
89        self.min_template_size = size;
90        self
91    }
92
93    /// Set the maximum template size.
94    pub fn with_max_size(mut self, size: usize) -> Self {
95        self.max_template_size = size;
96        self
97    }
98
99    /// Set the minimum occurrences threshold.
100    pub fn with_min_occurrences(mut self, occurrences: usize) -> Self {
101        self.min_occurrences = occurrences;
102        self
103    }
104
105    /// Identify templates from a collection of proofs.
106    pub fn identify_templates(&mut self, proofs: &[&Proof]) {
107        // Extract candidate subproofs from each proof
108        let mut candidates: Vec<Vec<TemplateStep>> = Vec::new();
109
110        for proof in proofs {
111            candidates.extend(self.extract_candidate_templates(proof));
112        }
113
114        // Group similar candidates
115        let grouped = self.group_similar_templates(&candidates);
116
117        // Convert groups to templates
118        let mut new_templates = Vec::new();
119        for (pattern, instances) in grouped {
120            if instances.len() >= self.min_occurrences {
121                let template = self.create_template(&pattern, instances.len());
122                new_templates.push(template);
123            }
124        }
125
126        self.templates.extend(new_templates);
127
128        // Sort templates by occurrences
129        self.templates
130            .sort_by_key(|t| std::cmp::Reverse(t.occurrences));
131    }
132
133    /// Get all identified templates.
134    pub fn get_templates(&self) -> &[ProofTemplate] {
135        &self.templates
136    }
137
138    /// Get templates sorted by success rate.
139    pub fn get_templates_by_success_rate(&self) -> Vec<&ProofTemplate> {
140        let mut templates: Vec<&ProofTemplate> = self.templates.iter().collect();
141        templates.sort_by(|a, b| {
142            b.success_rate
143                .partial_cmp(&a.success_rate)
144                .unwrap_or(std::cmp::Ordering::Equal)
145        });
146        templates
147    }
148
149    /// Find a template by name.
150    pub fn find_template(&self, name: &str) -> Option<&ProofTemplate> {
151        self.templates.iter().find(|t| t.name == name)
152    }
153
154    /// Update success rate for a template.
155    pub fn update_success_rate(&mut self, name: &str, success_rate: f64) {
156        if let Some(template) = self.templates.iter_mut().find(|t| t.name == name) {
157            template.success_rate = success_rate.clamp(0.0, 1.0);
158        }
159    }
160
161    /// Clear all templates.
162    pub fn clear(&mut self) {
163        self.templates.clear();
164    }
165
166    // Helper: Extract candidate templates from a proof
167    fn extract_candidate_templates(&self, proof: &Proof) -> Vec<Vec<TemplateStep>> {
168        let mut candidates = Vec::new();
169        let nodes: Vec<ProofNodeId> = proof.nodes().iter().map(|n| n.id).collect();
170
171        // Try to extract subproofs of various sizes
172        for window_size in self.min_template_size..=self.max_template_size.min(nodes.len()) {
173            for window in nodes.windows(window_size) {
174                if let Some(template_steps) = self.extract_template_steps(proof, window) {
175                    candidates.push(template_steps);
176                }
177            }
178        }
179
180        candidates
181    }
182
183    // Helper: Extract template steps from a sequence of nodes
184    fn extract_template_steps(
185        &self,
186        proof: &Proof,
187        nodes: &[ProofNodeId],
188    ) -> Option<Vec<TemplateStep>> {
189        let mut steps = Vec::new();
190        let mut node_to_id = FxHashMap::default();
191
192        for (i, &node_id) in nodes.iter().enumerate() {
193            node_to_id.insert(node_id, i);
194
195            if let Some(node) = proof.get_node(node_id)
196                && let ProofStep::Inference { rule, premises, .. } = &node.step
197            {
198                // Map premises to template IDs
199                let premise_ids: Vec<usize> = premises
200                    .iter()
201                    .filter_map(|&p| node_to_id.get(&p).copied())
202                    .collect();
203
204                steps.push(TemplateStep {
205                    id: i,
206                    rule: rule.clone(),
207                    premise_ids,
208                    conclusion_pattern: self.abstract_conclusion(node.conclusion()),
209                });
210            }
211        }
212
213        if steps.len() >= self.min_template_size {
214            Some(steps)
215        } else {
216            None
217        }
218    }
219
220    // Helper: Abstract a conclusion by replacing specific values
221    fn abstract_conclusion(&self, conclusion: &str) -> String {
222        // Simple abstraction similar to pattern extraction
223        let mut abstracted = conclusion.to_string();
224
225        // Replace numbers (need to escape $ as $$)
226        let re_num = regex::Regex::new(r"\b\d+\b").expect("regex pattern is valid");
227        abstracted = re_num.replace_all(&abstracted, "$$N").to_string();
228
229        // Replace quoted strings
230        let re_str = regex::Regex::new(r#""[^"]*""#).expect("regex pattern is valid");
231        abstracted = re_str.replace_all(&abstracted, "$$S").to_string();
232
233        // Replace specific identifiers (lowercase starting)
234        let re_id = regex::Regex::new(r"\b[a-z][a-z0-9_]*\b").expect("regex pattern is valid");
235        abstracted = re_id.replace_all(&abstracted, "$$V").to_string();
236
237        abstracted
238    }
239
240    // Helper: Group similar templates
241    fn group_similar_templates<'a>(
242        &self,
243        candidates: &'a [Vec<TemplateStep>],
244    ) -> FxHashMap<String, Vec<&'a Vec<TemplateStep>>> {
245        let mut groups: FxHashMap<String, Vec<&Vec<TemplateStep>>> = FxHashMap::default();
246
247        for candidate in candidates {
248            let signature = self.compute_template_signature(candidate);
249            groups.entry(signature).or_default().push(candidate);
250        }
251
252        groups
253    }
254
255    // Helper: Compute a signature for a template
256    fn compute_template_signature(&self, steps: &[TemplateStep]) -> String {
257        steps
258            .iter()
259            .map(|s| format!("{}:{}", s.rule, s.conclusion_pattern))
260            .collect::<Vec<_>>()
261            .join("|")
262    }
263
264    // Helper: Create a template from a pattern
265    fn create_template(&self, pattern: &str, occurrences: usize) -> ProofTemplate {
266        // Parse the pattern to extract steps
267        let parts: Vec<&str> = pattern.split('|').collect();
268        let mut steps = Vec::new();
269        let mut parameters = FxHashSet::default();
270
271        for (i, part) in parts.iter().enumerate() {
272            if let Some((rule, conclusion_pattern)) = part.split_once(':') {
273                // Extract parameters from conclusion pattern
274                for capture in conclusion_pattern.split('$').skip(1) {
275                    if let Some(var) = capture.chars().next() {
276                        parameters.insert(format!("${}", var));
277                    }
278                }
279
280                steps.push(TemplateStep {
281                    id: i,
282                    rule: rule.to_string(),
283                    premise_ids: Vec::new(), // Simplified
284                    conclusion_pattern: conclusion_pattern.to_string(),
285                });
286            }
287        }
288
289        let mut params: Vec<String> = parameters.into_iter().collect();
290        params.sort();
291
292        ProofTemplate {
293            name: format!("template_{}", self.templates.len()),
294            steps,
295            parameters: params,
296            occurrences,
297            success_rate: 0.0, // Will be updated during use
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_template_identifier_new() {
308        let identifier = TemplateIdentifier::new();
309        assert_eq!(identifier.min_template_size, 3);
310        assert_eq!(identifier.max_template_size, 10);
311        assert_eq!(identifier.min_occurrences, 2);
312        assert!(identifier.templates.is_empty());
313    }
314
315    #[test]
316    fn test_template_identifier_with_settings() {
317        let identifier = TemplateIdentifier::new()
318            .with_min_size(5)
319            .with_max_size(15)
320            .with_min_occurrences(3);
321        assert_eq!(identifier.min_template_size, 5);
322        assert_eq!(identifier.max_template_size, 15);
323        assert_eq!(identifier.min_occurrences, 3);
324    }
325
326    #[test]
327    fn test_template_step() {
328        let step = TemplateStep {
329            id: 0,
330            rule: "resolution".to_string(),
331            premise_ids: vec![1, 2],
332            conclusion_pattern: "x = y".to_string(),
333        };
334        assert_eq!(step.id, 0);
335        assert_eq!(step.rule, "resolution");
336        assert_eq!(step.premise_ids.len(), 2);
337    }
338
339    #[test]
340    fn test_proof_template_display() {
341        let template = ProofTemplate {
342            name: "test_template".to_string(),
343            steps: vec![TemplateStep {
344                id: 0,
345                rule: "resolution".to_string(),
346                premise_ids: vec![],
347                conclusion_pattern: "$V = $V".to_string(),
348            }],
349            parameters: vec!["$V".to_string()],
350            occurrences: 5,
351            success_rate: 0.8,
352        };
353        let display = format!("{}", template);
354        assert!(display.contains("test_template"));
355        assert!(display.contains("80.0%"));
356    }
357
358    #[test]
359    fn test_abstract_conclusion() {
360        let identifier = TemplateIdentifier::new();
361        let abstracted = identifier.abstract_conclusion("x + 42 = y");
362        // The regex should work, but let's be more flexible in the test
363        assert!(abstracted.contains("$N") || abstracted.contains("42"));
364        assert!(abstracted.contains("$V") || abstracted.contains("x"));
365    }
366
367    #[test]
368    fn test_update_success_rate() {
369        let mut identifier = TemplateIdentifier::new();
370        identifier.templates.push(ProofTemplate {
371            name: "test".to_string(),
372            steps: vec![],
373            parameters: vec![],
374            occurrences: 1,
375            success_rate: 0.0,
376        });
377        identifier.update_success_rate("test", 0.75);
378        assert_eq!(identifier.templates[0].success_rate, 0.75);
379    }
380
381    #[test]
382    fn test_update_success_rate_clamp() {
383        let mut identifier = TemplateIdentifier::new();
384        identifier.templates.push(ProofTemplate {
385            name: "test".to_string(),
386            steps: vec![],
387            parameters: vec![],
388            occurrences: 1,
389            success_rate: 0.0,
390        });
391        identifier.update_success_rate("test", 1.5);
392        assert_eq!(identifier.templates[0].success_rate, 1.0);
393    }
394
395    #[test]
396    fn test_find_template() {
397        let mut identifier = TemplateIdentifier::new();
398        identifier.templates.push(ProofTemplate {
399            name: "test".to_string(),
400            steps: vec![],
401            parameters: vec![],
402            occurrences: 1,
403            success_rate: 0.0,
404        });
405        assert!(identifier.find_template("test").is_some());
406        assert!(identifier.find_template("nonexistent").is_none());
407    }
408
409    #[test]
410    fn test_clear_templates() {
411        let mut identifier = TemplateIdentifier::new();
412        identifier.templates.push(ProofTemplate {
413            name: "test".to_string(),
414            steps: vec![],
415            parameters: vec![],
416            occurrences: 1,
417            success_rate: 0.0,
418        });
419        identifier.clear();
420        assert!(identifier.templates.is_empty());
421    }
422
423    #[test]
424    fn test_get_templates_by_success_rate() {
425        let mut identifier = TemplateIdentifier::new();
426        identifier.templates.push(ProofTemplate {
427            name: "low".to_string(),
428            steps: vec![],
429            parameters: vec![],
430            occurrences: 1,
431            success_rate: 0.3,
432        });
433        identifier.templates.push(ProofTemplate {
434            name: "high".to_string(),
435            steps: vec![],
436            parameters: vec![],
437            occurrences: 1,
438            success_rate: 0.9,
439        });
440        let sorted = identifier.get_templates_by_success_rate();
441        assert_eq!(sorted[0].name, "high");
442        assert_eq!(sorted[1].name, "low");
443    }
444}