1use crate::*;
2use std::any::Any;
3
4mod ematch;
5pub use ematch::*;
6
7mod pattern;
8pub use pattern::*;
9
10mod subst_method;
11pub use subst_method::*;
12
13pub struct Rewrite<L: Language, N: Analysis<L> = ()> {
15 pub(crate) searcher: Box<dyn Fn(&EGraph<L, N>) -> Box<dyn Any>>,
16 pub(crate) applier: Box<dyn Fn(Box<dyn Any>, &mut EGraph<L, N>)>,
17}
18
19pub struct RewriteT<L: Language, N: Analysis<L> = (), T: Any = ()> {
25 pub searcher: Box<dyn Fn(&EGraph<L, N>) -> T>,
26 pub applier: Box<dyn Fn(T, &mut EGraph<L, N>)>,
27}
28
29impl<L: Language + 'static, N: Analysis<L> + 'static, T: 'static> RewriteT<L, N, T> {
30 pub fn into(self) -> Rewrite<L, N> {
32 let searcher = self.searcher;
33 let applier = self.applier;
34 Rewrite {
35 searcher: Box::new(move |eg| Box::new((*searcher)(eg))),
36 applier: Box::new(move |t, eg| (*applier)(any_to_t(t), eg)),
37 }
38 }
39}
40
41pub fn any_to_t<T: Any>(t: Box<dyn Any>) -> T {
42 *t.downcast().unwrap()
43}
44
45pub fn apply_rewrites<L: Language, N: Analysis<L>>(
48 eg: &mut EGraph<L, N>,
49 rewrites: &[Rewrite<L, N>],
50) -> bool {
51 let prog = eg.progress();
52
53 let ts: Vec<Box<dyn Any>> = rewrites.iter().map(|rw| (*rw.searcher)(eg)).collect();
54 for (rw, t) in rewrites.iter().zip(ts.into_iter()) {
55 (*rw.applier)(t, eg);
56 }
57
58 prog != eg.progress()
59}
60
61impl<L: Language + 'static, N: Analysis<L> + 'static> Rewrite<L, N> {
62 pub fn new(rule: &str, a: &str, b: &str) -> Self {
64 Self::new_if(rule, a, b, |_, _| true)
65 }
66
67 pub fn new_if(
69 rule: &str,
70 a: &str,
71 b: &str,
72 cond: impl Fn(&Subst, &EGraph<L, N>) -> bool + 'static,
73 ) -> Self {
74 let a = Pattern::parse(a).unwrap();
75 let b = Pattern::parse(b).unwrap();
76 let rule = rule.to_string();
77 let a2 = a.clone();
78 RewriteT {
79 searcher: Box::new(move |eg| ematch_all(eg, &a)),
80 applier: Box::new(move |substs, eg| {
81 Self::apply_substs_cond(substs, &cond, &a2, &b, &rule, eg)
82 }),
83 }
84 .into()
85 }
86
87 fn apply_substs_cond(
88 substs: Vec<Subst>,
89 cond: &(impl Fn(&Subst, &EGraph<L, N>) -> bool + 'static),
90 a: &Pattern<L>,
91 b: &Pattern<L>,
92 rule: &str,
93 eg: &mut EGraph<L, N>,
94 ) {
95 for subst in substs {
96 if cond(&subst, eg) {
97 eg.union_instantiations(a, b, &subst, Some(rule.to_string()));
98 }
99 }
100 }
101}
102
103#[derive(PartialEq, Eq)]
104pub struct ProgressMeasure {
106 pub number_of_classes: usize,
108
109 pub number_of_live_classes: usize,
111
112 pub sum_of_slots: usize,
114
115 pub sum_of_symmetries: usize,
117}
118
119impl<L: Language, N: Analysis<L>> EGraph<L, N> {
120 pub fn progress(&self) -> ProgressMeasure {
122 let ids = self.ids();
123 ProgressMeasure {
124 number_of_classes: self.classes.len(),
125 number_of_live_classes: ids.len(),
126 sum_of_symmetries: ids.iter().map(|x| self.classes[x].group.count()).sum(),
127 sum_of_slots: ids.iter().map(|x| self.slots(*x).len()).sum(),
128 }
129 }
130}
131
132#[macro_export]
133macro_rules! rw {
134 ($name:expr; $lhs:expr => $rhs:expr) => {
135 Rewrite::new($name, $lhs, $rhs)
136 };
137
138 ($name:expr; $lhs:expr => $rhs:expr, if !$cond:expr) => {
139 Rewrite::new_if($name, $lhs, $rhs, not($cond))
140 };
141
142 ($name:expr; $lhs:expr => $rhs:expr, if $cond:expr) => {
143 Rewrite::new_if($name, $lhs, $rhs, $cond)
144 };
145}
146
147pub trait Cond<L, N>: Fn(&Subst, &EGraph<L, N>) -> bool + 'static {}
148impl<T, L: Language, N: Analysis<L>> Cond<L, N> for T where
149 T: Fn(&Subst, &EGraph<L, N>) -> bool + 'static
150{
151}
152
153pub fn slot_free_in<L: Language, N: Analysis<L>>(slot: &str, var: &str) -> impl Cond<L, N> {
154 let s: Slot = Slot::named(slot);
155 let var = var.to_string();
156 move |subst, _| !subst[&*var].slots().contains(&s)
157}
158
159pub fn or<L: Language, N: Analysis<L>>(x: impl Cond<L, N>, y: impl Cond<L, N>) -> impl Cond<L, N> {
160 move |subst, eg| x(subst, eg) || y(subst, eg)
161}
162
163pub fn and<L: Language, N: Analysis<L>>(x: impl Cond<L, N>, y: impl Cond<L, N>) -> impl Cond<L, N> {
164 move |subst, eg| x(subst, eg) && y(subst, eg)
165}
166
167pub fn not<L: Language, N: Analysis<L>>(x: impl Cond<L, N>) -> impl Cond<L, N> {
168 move |subst, eg| !x(subst, eg)
169}