chomsky_rule_engine/
lib.rs1#![warn(missing_docs)]
2
3use chomsky_uir::Analysis;
4use chomsky_uir::egraph::EGraph;
5use chomsky_uir::intent::IKun;
6
7pub trait RewriteRule<A: Analysis<IKun>>: Send + Sync {
8 fn name(&self) -> &str;
9 fn apply(&self, egraph: &EGraph<IKun, A>);
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum RuleCategory {
14 Algebraic,
15 Architectural,
16 Aggressive,
17 Concretization,
18}
19
20pub struct RewriteRegistry<A: Analysis<IKun>> {
21 rules: Vec<(RuleCategory, Box<dyn RewriteRule<A>>)>,
22}
23
24impl<A: Analysis<IKun>> RewriteRegistry<A> {
25 pub fn new() -> Self {
26 Self { rules: Vec::new() }
27 }
28
29 pub fn register(&mut self, category: RuleCategory, rule: Box<dyn RewriteRule<A>>) {
30 self.rules.push((category, rule));
31 }
32
33 pub fn get_rules(
34 &self,
35 category: RuleCategory,
36 ) -> impl Iterator<Item = &Box<dyn RewriteRule<A>>> {
37 self.rules
38 .iter()
39 .filter(move |(cat, _)| *cat == category)
40 .map(|(_, rule)| rule)
41 }
42
43 pub fn all_rules(&self) -> impl Iterator<Item = &Box<dyn RewriteRule<A>>> {
44 self.rules.iter().map(|(_, rule)| rule)
45 }
46
47 pub fn into_rules(self) -> Vec<(RuleCategory, Box<dyn RewriteRule<A>>)> {
48 self.rules
49 }
50}
51
52pub struct SaturationScheduler {
53 pub fuel: usize,
54 pub timeout: std::time::Duration,
55}
56
57impl Default for SaturationScheduler {
58 fn default() -> Self {
59 Self {
60 fuel: 10,
61 timeout: std::time::Duration::from_secs(5),
62 }
63 }
64}
65
66impl SaturationScheduler {
67 pub fn run<A: Analysis<IKun>>(&self, egraph: &EGraph<IKun, A>, registry: &RewriteRegistry<A>) {
68 let start_time = std::time::Instant::now();
69
70 for i in 0..self.fuel {
71 if start_time.elapsed() > self.timeout {
72 println!("Saturation timeout reached at iteration {}", i);
73 break;
74 }
75
76 let prev_nodes_count = egraph.memo.len();
77 let prev_classes_count = egraph.classes.len();
78
79 for rule in registry.all_rules() {
81 rule.apply(egraph);
82 }
83
84 egraph.rebuild();
85
86 let current_nodes_count = egraph.memo.len();
87 let current_classes_count = egraph.classes.len();
88
89 if current_nodes_count == prev_nodes_count
90 && current_classes_count == prev_classes_count
91 {
92 println!("Saturation reached at iteration {}", i);
93 break;
94 }
95 }
96 }
97}
98
99pub fn ikun_registry<A: Analysis<IKun> + 'static>() -> RewriteRegistry<A> {
100 RewriteRegistry::new()
101}
102
103pub fn default_rules<A: Analysis<IKun> + 'static>() -> Vec<Box<dyn RewriteRule<A>>> {
104 vec![]
105}