egg/
lp_extract.rs

1use coin_cbc::{Col, Model, Sense};
2
3use crate::*;
4
5/// A cost function to be used by an [`LpExtractor`].
6#[cfg_attr(docsrs, doc(cfg(feature = "lp")))]
7pub trait LpCostFunction<L: Language, N: Analysis<L>> {
8    /// Returns the cost of the given e-node.
9    ///
10    /// This function may look at other parts of the e-graph to compute the cost
11    /// of the given e-node.
12    fn node_cost(&mut self, egraph: &EGraph<L, N>, eclass: Id, enode: &L) -> f64;
13}
14
15#[cfg_attr(docsrs, doc(cfg(feature = "lp")))]
16impl<L: Language, N: Analysis<L>> LpCostFunction<L, N> for AstSize {
17    fn node_cost(&mut self, _egraph: &EGraph<L, N>, _eclass: Id, _enode: &L) -> f64 {
18        1.0
19    }
20}
21
22/// A structure to perform extraction using integer linear programming.
23/// This uses the [`cbc`](https://projects.coin-or.org/Cbc) solver.
24/// You must have it installed on your machine to use this feature.
25/// You can install it using:
26///
27/// | OS               | Command                                  |
28/// |------------------|------------------------------------------|
29/// | Fedora / Red Hat | `sudo dnf install coin-or-Cbc-devel`     |
30/// | Ubuntu / Debian  | `sudo apt-get install coinor-libcbc-dev` |
31/// | macOS            | `brew install cbc`                       |
32///
33/// On macOS, you might also need the following in your `.zshrc` file:
34/// `export LIBRARY_PATH=$LIBRARY_PATH:$(brew --prefix)/lib`
35///
36/// # Example
37/// ```
38/// use egg::*;
39/// let mut egraph = EGraph::<SymbolLang, ()>::default();
40///
41/// let f = egraph.add_expr(&"(f x x x)".parse().unwrap());
42/// let g = egraph.add_expr(&"(g (g x))".parse().unwrap());
43/// egraph.union(f, g);
44/// egraph.rebuild();
45///
46/// let best = Extractor::new(&egraph, AstSize).find_best(f).1;
47/// let lp_best = LpExtractor::new(&egraph, AstSize).solve(f);
48///
49/// // In regular extraction, cost is measures on the tree.
50/// assert_eq!(best.to_string(), "(g (g x))");
51///
52/// // Using ILP only counts common sub-expressions once,
53/// // so it can lead to a smaller DAG expression.
54/// assert_eq!(lp_best.to_string(), "(f x x x)");
55/// assert_eq!(lp_best.len(), 2);
56/// ```
57#[cfg_attr(docsrs, doc(cfg(feature = "lp")))]
58pub struct LpExtractor<'a, L: Language, N: Analysis<L>> {
59    egraph: &'a EGraph<L, N>,
60    model: Model,
61    vars: HashMap<Id, ClassVars>,
62}
63
64struct ClassVars {
65    active: Col,
66    order: Col,
67    nodes: Vec<Col>,
68}
69
70impl<'a, L, N> LpExtractor<'a, L, N>
71where
72    L: Language,
73    N: Analysis<L>,
74{
75    /// Create an [`LpExtractor`] using costs from the given [`LpCostFunction`].
76    /// See those docs for details.
77    pub fn new<CF>(egraph: &'a EGraph<L, N>, mut cost_function: CF) -> Self
78    where
79        CF: LpCostFunction<L, N>,
80    {
81        let max_order = egraph.total_number_of_nodes() as f64 * 10.0;
82
83        let mut model = Model::default();
84
85        let vars: HashMap<Id, ClassVars> = egraph
86            .classes()
87            .map(|class| {
88                let cvars = ClassVars {
89                    active: model.add_binary(),
90                    order: model.add_col(),
91                    nodes: class.nodes.iter().map(|_| model.add_binary()).collect(),
92                };
93                model.set_col_upper(cvars.order, max_order);
94                (class.id, cvars)
95            })
96            .collect();
97
98        let mut cycles: HashSet<(Id, usize)> = Default::default();
99        find_cycles(egraph, |id, i| {
100            cycles.insert((id, i));
101        });
102
103        for (&id, class) in &vars {
104            // class active == some node active
105            // sum(for node_active in class) == class_active
106            let row = model.add_row();
107            model.set_row_equal(row, 0.0);
108            model.set_weight(row, class.active, -1.0);
109            for &node_active in &class.nodes {
110                model.set_weight(row, node_active, 1.0);
111            }
112
113            for (i, (node, &node_active)) in egraph[id].iter().zip(&class.nodes).enumerate() {
114                if cycles.contains(&(id, i)) {
115                    model.set_col_upper(node_active, 0.0);
116                    model.set_col_lower(node_active, 0.0);
117                    continue;
118                }
119
120                for child in node.children() {
121                    let child_active = vars[child].active;
122                    // node active implies child active, encoded as:
123                    //   node_active <= child_active
124                    //   node_active - child_active <= 0
125                    let row = model.add_row();
126                    model.set_row_upper(row, 0.0);
127                    model.set_weight(row, node_active, 1.0);
128                    model.set_weight(row, child_active, -1.0);
129                }
130            }
131        }
132
133        model.set_obj_sense(Sense::Minimize);
134        for class in egraph.classes() {
135            for (node, &node_active) in class.iter().zip(&vars[&class.id].nodes) {
136                model.set_obj_coeff(node_active, cost_function.node_cost(egraph, class.id, node));
137            }
138        }
139
140        dbg!(max_order);
141
142        Self {
143            egraph,
144            model,
145            vars,
146        }
147    }
148
149    /// Set the cbc timeout in seconds.
150    pub fn timeout(&mut self, seconds: f64) -> &mut Self {
151        self.model.set_parameter("seconds", &seconds.to_string());
152        self
153    }
154
155    /// Extract a single rooted term.
156    ///
157    /// This is just a shortcut for [`LpExtractor::solve_multiple`].
158    pub fn solve(&mut self, root: Id) -> RecExpr<L> {
159        self.solve_multiple(&[root]).0
160    }
161
162    /// Extract (potentially multiple) roots
163    pub fn solve_multiple(&mut self, roots: &[Id]) -> (RecExpr<L>, Vec<Id>) {
164        let egraph = self.egraph;
165
166        for class in self.vars.values() {
167            self.model.set_binary(class.active);
168        }
169
170        for root in roots {
171            let root = &egraph.find(*root);
172            self.model.set_col_lower(self.vars[root].active, 1.0);
173        }
174
175        let solution = self.model.solve();
176        log::info!(
177            "CBC status {:?}, {:?}",
178            solution.raw().status(),
179            solution.raw().secondary_status()
180        );
181
182        let mut todo: Vec<Id> = roots.iter().map(|id| self.egraph.find(*id)).collect();
183        let mut expr = RecExpr::default();
184        // converts e-class ids to e-node ids
185        let mut ids: HashMap<Id, Id> = HashMap::default();
186
187        while let Some(&id) = todo.last() {
188            if ids.contains_key(&id) {
189                todo.pop();
190                continue;
191            }
192            let v = &self.vars[&id];
193            assert!(solution.col(v.active) > 0.0);
194            let node_idx = v.nodes.iter().position(|&n| solution.col(n) > 0.0).unwrap();
195            let node = &self.egraph[id].nodes[node_idx];
196            if node.all(|child| ids.contains_key(&child)) {
197                let new_id = expr.add(node.clone().map_children(|i| ids[&self.egraph.find(i)]));
198                ids.insert(id, new_id);
199                todo.pop();
200            } else {
201                todo.extend_from_slice(node.children())
202            }
203        }
204
205        let root_idxs = roots.iter().map(|root| ids[root]).collect();
206
207        assert!(expr.is_dag(), "LpExtract found a cyclic term!: {:?}", expr);
208        (expr, root_idxs)
209    }
210}
211
212fn find_cycles<L, N>(egraph: &EGraph<L, N>, mut f: impl FnMut(Id, usize))
213where
214    L: Language,
215    N: Analysis<L>,
216{
217    enum Color {
218        White,
219        Gray,
220        Black,
221    }
222    type Enter = bool;
223
224    let mut color: HashMap<Id, Color> = egraph.classes().map(|c| (c.id, Color::White)).collect();
225    let mut stack: Vec<(Enter, Id)> = egraph.classes().map(|c| (true, c.id)).collect();
226
227    while let Some((enter, id)) = stack.pop() {
228        if enter {
229            *color.get_mut(&id).unwrap() = Color::Gray;
230            stack.push((false, id));
231            for (i, node) in egraph[id].iter().enumerate() {
232                for child in node.children() {
233                    match &color[child] {
234                        Color::White => stack.push((true, *child)),
235                        Color::Gray => f(id, i),
236                        Color::Black => (),
237                    }
238                }
239            }
240        } else {
241            *color.get_mut(&id).unwrap() = Color::Black;
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use crate::{SymbolLang as S, *};
249
250    #[test]
251    fn simple_lp_extract_two() {
252        let mut egraph = EGraph::<S, ()>::default();
253        let a = egraph.add(S::leaf("a"));
254        let plus = egraph.add(S::new("+", vec![a, a]));
255        let f = egraph.add(S::new("f", vec![plus]));
256        let g = egraph.add(S::new("g", vec![plus]));
257
258        let mut ext = LpExtractor::new(&egraph, AstSize);
259        ext.timeout(10.0); // way too much time
260        let (exp, ids) = ext.solve_multiple(&[f, g]);
261        println!("{:?}", exp);
262        println!("{}", exp);
263        assert_eq!(exp.len(), 4);
264        assert_eq!(ids.len(), 2);
265    }
266}