Skip to main content

slotted_egraphs/rewrite/
mod.rs

1use crate::*;
2use std::any::Any;
3
4mod ematch;
5pub use ematch::*;
6
7mod pattern;
8pub use pattern::*;
9
10mod subst_method;
11pub use subst_method::*;
12
13/// An equational rewrite rule.
14pub struct Rewrite<L: Language, N: Analysis<L> = ()> {
15    pub(crate) searcher: Box<dyn Fn(&EGraph<L, N>) -> Box<dyn Any>>,
16    pub(crate) applier: Box<dyn Fn(Box<dyn Any>, &mut EGraph<L, N>)>,
17}
18
19/// Use this type when you want to build your own [Rewrite].
20///
21/// The type parameter `T` can be anything you want, as long as the `searcher` creates it, and the `applier` consumes it.
22///
23/// In most cases, `T` is a [Subst].
24pub struct RewriteT<L: Language, N: Analysis<L> = (), T: Any = ()> {
25    pub searcher: Box<dyn Fn(&EGraph<L, N>) -> T>,
26    pub applier: Box<dyn Fn(T, &mut EGraph<L, N>)>,
27}
28
29impl<L: Language + 'static, N: Analysis<L> + 'static, T: 'static> RewriteT<L, N, T> {
30    /// Use this function to convert it to an actual [Rewrite].
31    pub fn into(self) -> Rewrite<L, N> {
32        let searcher = self.searcher;
33        let applier = self.applier;
34        Rewrite {
35            searcher: Box::new(move |eg| Box::new((*searcher)(eg))),
36            applier: Box::new(move |t, eg| (*applier)(any_to_t(t), eg)),
37        }
38    }
39}
40
41pub fn any_to_t<T: Any>(t: Box<dyn Any>) -> T {
42    *t.downcast().unwrap()
43}
44
45/// Applies each given rewrite rule to the E-Graph once.
46/// Returns an indicator for whether the e-graph changed as a result.
47pub fn apply_rewrites<L: Language, N: Analysis<L>>(
48    eg: &mut EGraph<L, N>,
49    rewrites: &[Rewrite<L, N>],
50) -> bool {
51    let prog = eg.progress();
52
53    let ts: Vec<Box<dyn Any>> = rewrites.iter().map(|rw| (*rw.searcher)(eg)).collect();
54    for (rw, t) in rewrites.iter().zip(ts.into_iter()) {
55        (*rw.applier)(t, eg);
56    }
57
58    prog != eg.progress()
59}
60
61impl<L: Language + 'static, N: Analysis<L> + 'static> Rewrite<L, N> {
62    /// Create a rewrite rule by specifing a left- and right-hand side of your equation.
63    pub fn new(rule: &str, a: &str, b: &str) -> Self {
64        Self::new_if(rule, a, b, |_, _| true)
65    }
66
67    /// Create a conditional rewrite rule.
68    pub fn new_if(
69        rule: &str,
70        a: &str,
71        b: &str,
72        cond: impl Fn(&Subst, &EGraph<L, N>) -> bool + 'static,
73    ) -> Self {
74        let a = Pattern::parse(a).unwrap();
75        let b = Pattern::parse(b).unwrap();
76        let rule = rule.to_string();
77        let a2 = a.clone();
78        RewriteT {
79            searcher: Box::new(move |eg| ematch_all(eg, &a)),
80            applier: Box::new(move |substs, eg| {
81                Self::apply_substs_cond(substs, &cond, &a2, &b, &rule, eg)
82            }),
83        }
84        .into()
85    }
86
87    fn apply_substs_cond(
88        substs: Vec<Subst>,
89        cond: &(impl Fn(&Subst, &EGraph<L, N>) -> bool + 'static),
90        a: &Pattern<L>,
91        b: &Pattern<L>,
92        rule: &str,
93        eg: &mut EGraph<L, N>,
94    ) {
95        for subst in substs {
96            if cond(&subst, eg) {
97                eg.union_instantiations(a, b, &subst, Some(rule.to_string()));
98            }
99        }
100    }
101}
102
103#[derive(PartialEq, Eq)]
104/// A Progress Measure to check saturation of an e-graph with.
105pub struct ProgressMeasure {
106    /// How many classes that were allocated in this e-graph. This measure is strictly growing.
107    pub number_of_classes: usize,
108
109    /// How many classes are still "live". If "number_of_classes" isn't changed, this can only decrease (by union).
110    pub number_of_live_classes: usize,
111
112    /// How many parameter-slots are still in the e-classes. If number_of_classes & number_of_live_classes isn't changed, this can only decrease (by proving a redundancy by union).
113    pub sum_of_slots: usize,
114
115    /// How many symmetries the egraphs knows. If number_of_classes & number_of_live_classes & sum_of_slots isn't changed, this can only increase (by proving a symmetry by union).
116    pub sum_of_symmetries: usize,
117}
118
119impl<L: Language, N: Analysis<L>> EGraph<L, N> {
120    /// Computes the [ProgressMeasure] of this E-Graph.
121    pub fn progress(&self) -> ProgressMeasure {
122        let ids = self.ids();
123        ProgressMeasure {
124            number_of_classes: self.classes.len(),
125            number_of_live_classes: ids.len(),
126            sum_of_symmetries: ids.iter().map(|x| self.classes[x].group.count()).sum(),
127            sum_of_slots: ids.iter().map(|x| self.slots(*x).len()).sum(),
128        }
129    }
130}
131
132#[macro_export]
133macro_rules! rw {
134    ($name:expr; $lhs:expr => $rhs:expr) => {
135        Rewrite::new($name, $lhs, $rhs)
136    };
137
138    ($name:expr; $lhs:expr => $rhs:expr, if !$cond:expr) => {
139        Rewrite::new_if($name, $lhs, $rhs, not($cond))
140    };
141
142    ($name:expr; $lhs:expr => $rhs:expr, if $cond:expr) => {
143        Rewrite::new_if($name, $lhs, $rhs, $cond)
144    };
145}
146
147pub trait Cond<L, N>: Fn(&Subst, &EGraph<L, N>) -> bool + 'static {}
148impl<T, L: Language, N: Analysis<L>> Cond<L, N> for T where
149    T: Fn(&Subst, &EGraph<L, N>) -> bool + 'static
150{
151}
152
153pub fn slot_free_in<L: Language, N: Analysis<L>>(slot: &str, var: &str) -> impl Cond<L, N> {
154    let s: Slot = Slot::named(slot);
155    let var = var.to_string();
156    move |subst, _| !subst[&*var].slots().contains(&s)
157}
158
159pub fn or<L: Language, N: Analysis<L>>(x: impl Cond<L, N>, y: impl Cond<L, N>) -> impl Cond<L, N> {
160    move |subst, eg| x(subst, eg) || y(subst, eg)
161}
162
163pub fn and<L: Language, N: Analysis<L>>(x: impl Cond<L, N>, y: impl Cond<L, N>) -> impl Cond<L, N> {
164    move |subst, eg| x(subst, eg) && y(subst, eg)
165}
166
167pub fn not<L: Language, N: Analysis<L>>(x: impl Cond<L, N>) -> impl Cond<L, N> {
168    move |subst, eg| !x(subst, eg)
169}