egg/extract.rs
1use std::cmp::Ordering;
2use std::fmt::Debug;
3
4use crate::util::{hashmap_with_capacity, HashMap};
5use crate::{Analysis, EClass, EGraph, Id, Language, RecExpr};
6
7/** Extracting a single [`RecExpr`] from an [`EGraph`].
8
9```
10use egg::*;
11
12define_language! {
13 enum SimpleLanguage {
14 Num(i32),
15 "+" = Add([Id; 2]),
16 "*" = Mul([Id; 2]),
17 }
18}
19
20let rules: &[Rewrite<SimpleLanguage, ()>] = &[
21 rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
22 rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"),
23
24 rewrite!("add-0"; "(+ ?a 0)" => "?a"),
25 rewrite!("mul-0"; "(* ?a 0)" => "0"),
26 rewrite!("mul-1"; "(* ?a 1)" => "?a"),
27];
28
29let start = "(+ 0 (* 1 10))".parse().unwrap();
30let runner = Runner::default().with_expr(&start).run(rules);
31let (egraph, root) = (runner.egraph, runner.roots[0]);
32
33let mut extractor = Extractor::new(&egraph, AstSize);
34let (best_cost, best) = extractor.find_best(root);
35assert_eq!(best_cost, 1);
36assert_eq!(best, "10".parse().unwrap());
37```
38
39**/
40#[derive(Debug)]
41pub struct Extractor<'a, CF: CostFunction<L>, L: Language, N: Analysis<L>> {
42 cost_function: CF,
43 costs: HashMap<Id, (CF::Cost, L)>,
44 egraph: &'a EGraph<L, N>,
45}
46
47/** A cost function that can be used by an [`Extractor`].
48
49To extract an expression from an [`EGraph`], the [`Extractor`]
50requires a cost function to performs its greedy search.
51`egg` provides the simple [`AstSize`] and [`AstDepth`] cost functions.
52
53The example below illustrates a silly but realistic example of
54implementing a cost function that is essentially AST size weighted by
55the operator:
56```
57# use egg::*;
58struct SillyCostFn;
59impl CostFunction<SymbolLang> for SillyCostFn {
60 type Cost = f64;
61 fn cost<C>(&mut self, enode: &SymbolLang, mut costs: C) -> Self::Cost
62 where
63 C: FnMut(Id) -> Self::Cost
64 {
65 let op_cost = match enode.op.as_str() {
66 "foo" => 100.0,
67 "bar" => 0.7,
68 _ => 1.0
69 };
70 enode.fold(op_cost, |sum, id| sum + costs(id))
71 }
72}
73
74let e: RecExpr<SymbolLang> = "(do_it foo bar baz)".parse().unwrap();
75assert_eq!(SillyCostFn.cost_rec(&e), 102.7);
76assert_eq!(AstSize.cost_rec(&e), 4);
77assert_eq!(AstDepth.cost_rec(&e), 2);
78```
79
80If you'd like to access the [`Analysis`] data or anything else in the e-graph,
81you can put a reference to the e-graph in your [`CostFunction`]:
82
83```
84# use egg::*;
85# type MyAnalysis = ();
86struct EGraphCostFn<'a> {
87 egraph: &'a EGraph<SymbolLang, MyAnalysis>,
88}
89
90impl<'a> CostFunction<SymbolLang> for EGraphCostFn<'a> {
91 type Cost = usize;
92 fn cost<C>(&mut self, enode: &SymbolLang, mut costs: C) -> Self::Cost
93 where
94 C: FnMut(Id) -> Self::Cost
95 {
96 // use self.egraph however you want here
97 println!("the egraph has {} classes", self.egraph.number_of_classes());
98 return 1
99 }
100}
101
102let mut egraph = EGraph::<SymbolLang, MyAnalysis>::default();
103let id = egraph.add_expr(&"(foo bar)".parse().unwrap());
104let cost_func = EGraphCostFn { egraph: &egraph };
105let mut extractor = Extractor::new(&egraph, cost_func);
106let _ = extractor.find_best(id);
107```
108
109Note that a particular e-class might occur in an expression multiple times.
110This means that pathological, but nevertheless realistic cases
111might overflow `usize` if you implement a cost function like [`AstSize`],
112even if the actual [`RecExpr`] fits compactly in memory.
113You might want to use [`saturating_add`](u64::saturating_add) to
114ensure your cost function is still monotonic in this situation.
115**/
116pub trait CostFunction<L: Language> {
117 /// The `Cost` type. It only requires `PartialOrd` so you can use
118 /// floating point types, but failed comparisons (`NaN`s) will
119 /// result in a panic.
120 type Cost: PartialOrd + Debug + Clone;
121
122 /// Calculates the cost of an enode whose children are `Cost`s.
123 ///
124 /// For this to work properly, your cost function should be
125 /// _monotonic_, i.e. `cost` should return a `Cost` greater than
126 /// any of the child costs of the given enode.
127 fn cost<C>(&mut self, enode: &L, costs: C) -> Self::Cost
128 where
129 C: FnMut(Id) -> Self::Cost;
130
131 /// Calculates the total cost of a [`RecExpr`].
132 ///
133 /// As provided, this just recursively calls `cost` all the way
134 /// down the [`RecExpr`].
135 ///
136 fn cost_rec(&mut self, expr: &RecExpr<L>) -> Self::Cost {
137 let mut costs = hashmap_with_capacity::<Id, Self::Cost>(expr.len());
138 for (i, node) in expr.items() {
139 let cost = self.cost(node, |i| costs[&i].clone());
140 costs.insert(i, cost);
141 }
142 let root = expr.root();
143 costs[&root].clone()
144 }
145}
146
147/** A simple [`CostFunction`] that counts total AST size.
148
149```
150# use egg::*;
151let e: RecExpr<SymbolLang> = "(do_it foo bar baz)".parse().unwrap();
152assert_eq!(AstSize.cost_rec(&e), 4);
153```
154
155**/
156#[derive(Debug)]
157pub struct AstSize;
158impl<L: Language> CostFunction<L> for AstSize {
159 type Cost = usize;
160 fn cost<C>(&mut self, enode: &L, mut costs: C) -> Self::Cost
161 where
162 C: FnMut(Id) -> Self::Cost,
163 {
164 enode.fold(1, |sum, id| sum.saturating_add(costs(id)))
165 }
166}
167
168/** A simple [`CostFunction`] that counts maximum AST depth.
169
170```
171# use egg::*;
172let e: RecExpr<SymbolLang> = "(do_it foo bar baz)".parse().unwrap();
173assert_eq!(AstDepth.cost_rec(&e), 2);
174```
175
176**/
177#[derive(Debug)]
178pub struct AstDepth;
179impl<L: Language> CostFunction<L> for AstDepth {
180 type Cost = usize;
181 fn cost<C>(&mut self, enode: &L, mut costs: C) -> Self::Cost
182 where
183 C: FnMut(Id) -> Self::Cost,
184 {
185 1 + enode.fold(0, |max, id| max.max(costs(id)))
186 }
187}
188
189fn cmp<T: PartialOrd>(a: &Option<T>, b: &Option<T>) -> Ordering {
190 // None is high
191 match (a, b) {
192 (None, None) => Ordering::Equal,
193 (None, Some(_)) => Ordering::Greater,
194 (Some(_), None) => Ordering::Less,
195 (Some(a), Some(b)) => a.partial_cmp(b).unwrap(),
196 }
197}
198
199impl<'a, CF, L, N> Extractor<'a, CF, L, N>
200where
201 CF: CostFunction<L>,
202 L: Language,
203 N: Analysis<L>,
204{
205 /// Create a new `Extractor` given an `EGraph` and a
206 /// `CostFunction`.
207 ///
208 /// The extraction does all the work on creation, so this function
209 /// performs the greedy search for cheapest representative of each
210 /// eclass.
211 pub fn new(egraph: &'a EGraph<L, N>, cost_function: CF) -> Self {
212 let costs = HashMap::default();
213 let mut extractor = Extractor {
214 costs,
215 egraph,
216 cost_function,
217 };
218 extractor.find_costs();
219
220 extractor
221 }
222
223 /// Find the cheapest (lowest cost) represented `RecExpr` in the
224 /// given eclass.
225 pub fn find_best(&self, eclass: Id) -> (CF::Cost, RecExpr<L>) {
226 let (cost, root) = self.costs[&self.egraph.find(eclass)].clone();
227 let expr = root.build_recexpr(|id| self.find_best_node(id).clone());
228 (cost, expr)
229 }
230
231 /// Find the cheapest e-node in the given e-class.
232 pub fn find_best_node(&self, eclass: Id) -> &L {
233 &self.costs[&self.egraph.find(eclass)].1
234 }
235
236 /// Find the cost of the term that would be extracted from this e-class.
237 pub fn find_best_cost(&self, eclass: Id) -> CF::Cost {
238 let (cost, _) = &self.costs[&self.egraph.find(eclass)];
239 cost.clone()
240 }
241
242 fn node_total_cost(&mut self, node: &L) -> Option<CF::Cost> {
243 let eg = &self.egraph;
244 let has_cost = |id| self.costs.contains_key(&eg.find(id));
245 if node.all(has_cost) {
246 let costs = &self.costs;
247 let cost_f = |id| costs[&eg.find(id)].0.clone();
248 Some(self.cost_function.cost(node, cost_f))
249 } else {
250 None
251 }
252 }
253
254 fn find_costs(&mut self) {
255 let mut did_something = true;
256 while did_something {
257 did_something = false;
258
259 for class in self.egraph.classes() {
260 let pass = self.make_pass(class);
261 match (self.costs.get(&class.id), pass) {
262 (None, Some(new)) => {
263 self.costs.insert(class.id, new);
264 did_something = true;
265 }
266 (Some(old), Some(new)) if new.0 < old.0 => {
267 self.costs.insert(class.id, new);
268 did_something = true;
269 }
270 _ => (),
271 }
272 }
273 }
274
275 for class in self.egraph.classes() {
276 if !self.costs.contains_key(&class.id) {
277 log::warn!(
278 "Failed to compute cost for eclass {}: {:?}",
279 class.id,
280 class.nodes
281 )
282 }
283 }
284 }
285
286 fn make_pass(&mut self, eclass: &EClass<L, N::Data>) -> Option<(CF::Cost, L)> {
287 let (cost, node) = eclass
288 .iter()
289 .map(|n| (self.node_total_cost(n), n))
290 .min_by(|a, b| cmp(&a.0, &b.0))
291 .unwrap_or_else(|| panic!("Can't extract, eclass is empty: {:#?}", eclass));
292 cost.map(|c| (c, node.clone()))
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use crate::*;
299
300 #[test]
301 fn ast_size_overflow() {
302 let rules: &[Rewrite<SymbolLang, ()>] =
303 &[rewrite!("explode"; "(meow ?a)" => "(meow (meow ?a ?a))")];
304
305 let start = "(meow 42)".parse().unwrap();
306 let runner = Runner::default()
307 .with_iter_limit(100)
308 .with_expr(&start)
309 .run(rules);
310
311 let extractor = Extractor::new(&runner.egraph, AstSize);
312 let (_, best_expr) = extractor.find_best(runner.roots[0]);
313 assert_eq!(best_expr, start);
314 }
315}