1use coin_cbc::{Col, Model, Sense};
2
3use crate::*;
4
5#[cfg_attr(docsrs, doc(cfg(feature = "lp")))]
7pub trait LpCostFunction<L: Language, N: Analysis<L>> {
8 fn node_cost(&mut self, egraph: &EGraph<L, N>, eclass: Id, enode: &L) -> f64;
13}
14
15#[cfg_attr(docsrs, doc(cfg(feature = "lp")))]
16impl<L: Language, N: Analysis<L>> LpCostFunction<L, N> for AstSize {
17 fn node_cost(&mut self, _egraph: &EGraph<L, N>, _eclass: Id, _enode: &L) -> f64 {
18 1.0
19 }
20}
21
22#[cfg_attr(docsrs, doc(cfg(feature = "lp")))]
58pub struct LpExtractor<'a, L: Language, N: Analysis<L>> {
59 egraph: &'a EGraph<L, N>,
60 model: Model,
61 vars: HashMap<Id, ClassVars>,
62}
63
64struct ClassVars {
65 active: Col,
66 order: Col,
67 nodes: Vec<Col>,
68}
69
70impl<'a, L, N> LpExtractor<'a, L, N>
71where
72 L: Language,
73 N: Analysis<L>,
74{
75 pub fn new<CF>(egraph: &'a EGraph<L, N>, mut cost_function: CF) -> Self
78 where
79 CF: LpCostFunction<L, N>,
80 {
81 let max_order = egraph.total_number_of_nodes() as f64 * 10.0;
82
83 let mut model = Model::default();
84
85 let vars: HashMap<Id, ClassVars> = egraph
86 .classes()
87 .map(|class| {
88 let cvars = ClassVars {
89 active: model.add_binary(),
90 order: model.add_col(),
91 nodes: class.nodes.iter().map(|_| model.add_binary()).collect(),
92 };
93 model.set_col_upper(cvars.order, max_order);
94 (class.id, cvars)
95 })
96 .collect();
97
98 let mut cycles: HashSet<(Id, usize)> = Default::default();
99 find_cycles(egraph, |id, i| {
100 cycles.insert((id, i));
101 });
102
103 for (&id, class) in &vars {
104 let row = model.add_row();
107 model.set_row_equal(row, 0.0);
108 model.set_weight(row, class.active, -1.0);
109 for &node_active in &class.nodes {
110 model.set_weight(row, node_active, 1.0);
111 }
112
113 for (i, (node, &node_active)) in egraph[id].iter().zip(&class.nodes).enumerate() {
114 if cycles.contains(&(id, i)) {
115 model.set_col_upper(node_active, 0.0);
116 model.set_col_lower(node_active, 0.0);
117 continue;
118 }
119
120 for child in node.children() {
121 let child_active = vars[child].active;
122 let row = model.add_row();
126 model.set_row_upper(row, 0.0);
127 model.set_weight(row, node_active, 1.0);
128 model.set_weight(row, child_active, -1.0);
129 }
130 }
131 }
132
133 model.set_obj_sense(Sense::Minimize);
134 for class in egraph.classes() {
135 for (node, &node_active) in class.iter().zip(&vars[&class.id].nodes) {
136 model.set_obj_coeff(node_active, cost_function.node_cost(egraph, class.id, node));
137 }
138 }
139
140 dbg!(max_order);
141
142 Self {
143 egraph,
144 model,
145 vars,
146 }
147 }
148
149 pub fn timeout(&mut self, seconds: f64) -> &mut Self {
151 self.model.set_parameter("seconds", &seconds.to_string());
152 self
153 }
154
155 pub fn solve(&mut self, root: Id) -> RecExpr<L> {
159 self.solve_multiple(&[root]).0
160 }
161
162 pub fn solve_multiple(&mut self, roots: &[Id]) -> (RecExpr<L>, Vec<Id>) {
164 let egraph = self.egraph;
165
166 for class in self.vars.values() {
167 self.model.set_binary(class.active);
168 }
169
170 for root in roots {
171 let root = &egraph.find(*root);
172 self.model.set_col_lower(self.vars[root].active, 1.0);
173 }
174
175 let solution = self.model.solve();
176 log::info!(
177 "CBC status {:?}, {:?}",
178 solution.raw().status(),
179 solution.raw().secondary_status()
180 );
181
182 let mut todo: Vec<Id> = roots.iter().map(|id| self.egraph.find(*id)).collect();
183 let mut expr = RecExpr::default();
184 let mut ids: HashMap<Id, Id> = HashMap::default();
186
187 while let Some(&id) = todo.last() {
188 if ids.contains_key(&id) {
189 todo.pop();
190 continue;
191 }
192 let v = &self.vars[&id];
193 assert!(solution.col(v.active) > 0.0);
194 let node_idx = v.nodes.iter().position(|&n| solution.col(n) > 0.0).unwrap();
195 let node = &self.egraph[id].nodes[node_idx];
196 if node.all(|child| ids.contains_key(&child)) {
197 let new_id = expr.add(node.clone().map_children(|i| ids[&self.egraph.find(i)]));
198 ids.insert(id, new_id);
199 todo.pop();
200 } else {
201 todo.extend_from_slice(node.children())
202 }
203 }
204
205 let root_idxs = roots.iter().map(|root| ids[root]).collect();
206
207 assert!(expr.is_dag(), "LpExtract found a cyclic term!: {:?}", expr);
208 (expr, root_idxs)
209 }
210}
211
212fn find_cycles<L, N>(egraph: &EGraph<L, N>, mut f: impl FnMut(Id, usize))
213where
214 L: Language,
215 N: Analysis<L>,
216{
217 enum Color {
218 White,
219 Gray,
220 Black,
221 }
222 type Enter = bool;
223
224 let mut color: HashMap<Id, Color> = egraph.classes().map(|c| (c.id, Color::White)).collect();
225 let mut stack: Vec<(Enter, Id)> = egraph.classes().map(|c| (true, c.id)).collect();
226
227 while let Some((enter, id)) = stack.pop() {
228 if enter {
229 *color.get_mut(&id).unwrap() = Color::Gray;
230 stack.push((false, id));
231 for (i, node) in egraph[id].iter().enumerate() {
232 for child in node.children() {
233 match &color[child] {
234 Color::White => stack.push((true, *child)),
235 Color::Gray => f(id, i),
236 Color::Black => (),
237 }
238 }
239 }
240 } else {
241 *color.get_mut(&id).unwrap() = Color::Black;
242 }
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use crate::{SymbolLang as S, *};
249
250 #[test]
251 fn simple_lp_extract_two() {
252 let mut egraph = EGraph::<S, ()>::default();
253 let a = egraph.add(S::leaf("a"));
254 let plus = egraph.add(S::new("+", vec![a, a]));
255 let f = egraph.add(S::new("f", vec![plus]));
256 let g = egraph.add(S::new("g", vec![plus]));
257
258 let mut ext = LpExtractor::new(&egraph, AstSize);
259 ext.timeout(10.0); let (exp, ids) = ext.solve_multiple(&[f, g]);
261 println!("{:?}", exp);
262 println!("{}", exp);
263 assert_eq!(exp.len(), 4);
264 assert_eq!(ids.len(), 2);
265 }
266}