Skip to main content

slotted_egraphs/extract/
mod.rs

1use crate::*;
2
3mod cost;
4pub use cost::*;
5
6mod with_ord;
7pub use with_ord::*;
8
9use std::collections::BinaryHeap;
10
11/// An object used for quickly extracting terms (i.e. [RecExpr]s) using a given [CostFunction].
12///
13/// Creating an Extractor will setup an extraction-table which then allows you to extract terms from many e-classes efficiently.
14/// It is most useful when doing "bulk" extractions for many classes.
15pub struct Extractor<L: Language, CF: CostFunction<L>> {
16    pub(crate) map: HashMap<Id, WithOrdRev<L, CF::Cost>>,
17}
18
19impl<L: Language, CF: CostFunction<L>> Extractor<L, CF> {
20    pub fn new<N: Analysis<L>>(eg: &EGraph<L, N>, cost_fn: CF) -> Self {
21        if CHECKS {
22            eg.check();
23        }
24
25        // all the L in `map` and `queue` have to be
26        // - in "normal-form", i.e. calling lookup on them yields an identity AppliedId.
27        // - every internal slot needs to be refreshed.
28
29        // maps eclass id to their optimal RecExpr.
30        let mut map: HashMap<Id, WithOrdRev<L, CF::Cost>> = HashMap::default();
31        let mut queue: BinaryHeap<WithOrdRev<L, CF::Cost>> = BinaryHeap::new();
32
33        for id in eg.ids() {
34            for x in eg.enodes(id) {
35                if x.applied_id_occurrences().is_empty() {
36                    let x = eg.class_nf(&x);
37                    let c = cost_fn.cost(&x, |_| panic!());
38                    queue.push(WithOrdRev(x, c));
39                }
40            }
41        }
42
43        while let Some(WithOrdRev(enode, c)) = queue.pop() {
44            let i = eg.lookup(&enode).unwrap();
45            if map.contains_key(&i.id) {
46                continue;
47            }
48            map.insert(i.id, WithOrdRev(enode, c));
49
50            for x in eg.usages(i.id).clone() {
51                if x.applied_id_occurrences()
52                    .iter()
53                    .all(|i| map.contains_key(&i.id))
54                {
55                    if eg
56                        .lookup(&x)
57                        .map(|i| map.contains_key(&i.id))
58                        .unwrap_or(false)
59                    {
60                        continue;
61                    }
62                    let x = eg.class_nf(&x);
63                    let c = cost_fn.cost(&x, |i| map[&i].1.clone());
64                    queue.push(WithOrdRev(x, c));
65                }
66            }
67        }
68
69        Self { map }
70    }
71
72    pub fn extract<N: Analysis<L>>(&self, i: &AppliedId, eg: &EGraph<L, N>) -> RecExpr<L> {
73        let i = eg.find_applied_id(i);
74
75        let mut children = Vec::new();
76
77        // do I need to refresh some slots here?
78        let l = self.map[&i.id].0.apply_slotmap(&i.m);
79        for child in l.applied_id_occurrences() {
80            let n = self.extract(&child, eg);
81            children.push(n);
82        }
83
84        RecExpr { node: l, children }
85    }
86
87    pub fn get_best_cost<N: Analysis<L>>(&self, i: &AppliedId) -> CF::Cost {
88        self.map[&i.id].1.clone()
89    }
90}
91
92pub fn ast_size_extract<L: Language, N: Analysis<L>>(
93    i: &AppliedId,
94    eg: &EGraph<L, N>,
95) -> RecExpr<L> {
96    extract::<L, N, AstSize>(i, eg)
97}
98
99// `i` is not allowed to have free variables, hence prefer `Id` over `AppliedId`.
100pub fn extract<L: Language, N: Analysis<L>, CF: CostFunction<L> + Default>(
101    i: &AppliedId,
102    eg: &EGraph<L, N>,
103) -> RecExpr<L> {
104    let cost_fn = CF::default();
105    let extractor = Extractor::<L, CF>::new(eg, cost_fn);
106    let out = extractor.extract(&i, eg);
107    if CHECKS {
108        let i = eg.find_id(i.id);
109        let cost_fn = CF::default();
110        assert_eq!(cost_fn.cost_rec(&out), extractor.map[&i].1);
111    }
112    out
113}