Skip to main content

chomsky_rule_engine/
lib.rs

1#![warn(missing_docs)]
2
3use chomsky_uir::Analysis;
4use chomsky_uir::egraph::EGraph;
5use chomsky_uir::intent::IKun;
6
7pub trait RewriteRule<A: Analysis<IKun>>: Send + Sync {
8    fn name(&self) -> &str;
9    fn apply(&self, egraph: &EGraph<IKun, A>);
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum RuleCategory {
14    Algebraic,
15    Architectural,
16    Aggressive,
17    Concretization,
18}
19
20pub struct RewriteRegistry<A: Analysis<IKun>> {
21    rules: Vec<(RuleCategory, Box<dyn RewriteRule<A>>)>,
22}
23
24impl<A: Analysis<IKun>> RewriteRegistry<A> {
25    pub fn new() -> Self {
26        Self { rules: Vec::new() }
27    }
28
29    pub fn register(&mut self, category: RuleCategory, rule: Box<dyn RewriteRule<A>>) {
30        self.rules.push((category, rule));
31    }
32
33    pub fn get_rules(
34        &self,
35        category: RuleCategory,
36    ) -> impl Iterator<Item = &Box<dyn RewriteRule<A>>> {
37        self.rules
38            .iter()
39            .filter(move |(cat, _)| *cat == category)
40            .map(|(_, rule)| rule)
41    }
42
43    pub fn all_rules(&self) -> impl Iterator<Item = &Box<dyn RewriteRule<A>>> {
44        self.rules.iter().map(|(_, rule)| rule)
45    }
46
47    pub fn into_rules(self) -> Vec<(RuleCategory, Box<dyn RewriteRule<A>>)> {
48        self.rules
49    }
50}
51
52pub struct SaturationScheduler {
53    pub fuel: usize,
54    pub timeout: std::time::Duration,
55}
56
57impl Default for SaturationScheduler {
58    fn default() -> Self {
59        Self {
60            fuel: 10,
61            timeout: std::time::Duration::from_secs(5),
62        }
63    }
64}
65
66impl SaturationScheduler {
67    pub fn run<A: Analysis<IKun>>(&self, egraph: &EGraph<IKun, A>, registry: &RewriteRegistry<A>) {
68        let start_time = std::time::Instant::now();
69
70        for i in 0..self.fuel {
71            if start_time.elapsed() > self.timeout {
72                println!("Saturation timeout reached at iteration {}", i);
73                break;
74            }
75
76            let prev_nodes_count = egraph.memo.len();
77            let prev_classes_count = egraph.classes.len();
78
79            // Apply all rules
80            for rule in registry.all_rules() {
81                rule.apply(egraph);
82            }
83
84            egraph.rebuild();
85
86            let current_nodes_count = egraph.memo.len();
87            let current_classes_count = egraph.classes.len();
88
89            if current_nodes_count == prev_nodes_count
90                && current_classes_count == prev_classes_count
91            {
92                println!("Saturation reached at iteration {}", i);
93                break;
94            }
95        }
96    }
97}
98
99pub fn ikun_registry<A: Analysis<IKun> + 'static>() -> RewriteRegistry<A> {
100    RewriteRegistry::new()
101}
102
103pub fn default_rules<A: Analysis<IKun> + 'static>() -> Vec<Box<dyn RewriteRule<A>>> {
104    vec![]
105}