1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use crate::{
egraph::{EClass, EGraph},
expr::{Expr, Id, Language, RecExpr},
util::HashMap,
};
use log::*;
pub type Cost = u64;
fn calculate_cost_rec<L: Language>(map: &mut HashMap<RecExpr<L>, Cost>, expr: &RecExpr<L>) -> Cost {
if map.contains_key(expr) {
return 1;
}
let child_cost_expr = expr.as_ref().map_children(|e| calculate_cost_rec(map, &e));
let cost = L::cost(&child_cost_expr);
map.insert(expr.clone(), cost);
cost
}
pub fn calculate_cost<L: Language>(expr: &RecExpr<L>) -> Cost {
let mut map = HashMap::default();
let cost = calculate_cost_rec(&mut map, expr);
trace!("Found cost to be {}\n {}", cost, expr.to_sexp());
cost
}
pub struct CostExpr<L: Language> {
pub cost: Cost,
pub expr: RecExpr<L>,
}
pub struct Extractor<'a, L: Language, M> {
costs: HashMap<Id, CostExpr<L>>,
egraph: &'a EGraph<L, M>,
}
impl<'a, L: Language, M> Extractor<'a, L, M> {
pub fn new(egraph: &'a EGraph<L, M>) -> Self {
let costs = HashMap::default();
let mut extractor = Extractor { costs, egraph };
extractor.find_costs();
extractor
}
pub fn find_best(&self, eclass: Id) -> &CostExpr<L> {
&self.costs[&eclass]
}
fn build_expr(&self, root: &Expr<L, Id>) -> Option<RecExpr<L>> {
let expr = root
.map_children_result(|id| {
self.costs
.get(&id)
.map(|cost_expr| cost_expr.expr.clone())
.ok_or(())
})
.ok()?;
Some(expr.into())
}
fn node_total_cost(&self, node: &Expr<L, Id>) -> Option<CostExpr<L>> {
let expr = self.build_expr(node)?;
let cost = calculate_cost(&expr);
Some(CostExpr { cost, expr })
}
fn find_costs(&mut self) {
let mut did_something = true;
let mut loops = 0;
while did_something {
did_something = false;
for class in self.egraph.classes() {
did_something |= self.make_pass(class);
}
loops += 1;
}
info!("Took {} loops to find costs", loops);
}
fn make_pass(&mut self, class: &EClass<L, M>) -> bool {
let new = class
.iter()
.filter_map(|n| self.node_total_cost(n))
.min_by_key(|ce| ce.cost);
let new = match new {
Some(new) => new,
None => return true,
};
if let Some(old) = self.costs.get(&class.id) {
if new.cost < old.cost {
self.costs.insert(class.id, new);
true
} else {
false
}
} else {
self.costs.insert(class.id, new);
true
}
}
}