use std::cmp::Ordering;
use std::fmt::Debug;
use crate::util::{hashmap_with_capacity, HashMap};
use crate::{Analysis, EClass, EGraph, Id, Language, RecExpr};
#[derive(Debug)]
pub struct Extractor<'a, CF: CostFunction<L>, L: Language, N: Analysis<L>> {
cost_function: CF,
costs: HashMap<Id, (CF::Cost, L)>,
egraph: &'a EGraph<L, N>,
}
pub trait CostFunction<L: Language> {
type Cost: PartialOrd + Debug + Clone;
fn cost<C>(&mut self, enode: &L, costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost;
fn cost_rec(&mut self, expr: &RecExpr<L>) -> Self::Cost {
let mut costs = hashmap_with_capacity::<Id, Self::Cost>(expr.len());
for (i, node) in expr.items() {
let cost = self.cost(node, |i| costs[&i].clone());
costs.insert(i, cost);
}
let root = expr.root();
costs[&root].clone()
}
}
#[derive(Debug)]
pub struct AstSize;
impl<L: Language> CostFunction<L> for AstSize {
type Cost = usize;
fn cost<C>(&mut self, enode: &L, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
enode.fold(1, |sum, id| sum.saturating_add(costs(id)))
}
}
#[derive(Debug)]
pub struct AstDepth;
impl<L: Language> CostFunction<L> for AstDepth {
type Cost = usize;
fn cost<C>(&mut self, enode: &L, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
1 + enode.fold(0, |max, id| max.max(costs(id)))
}
}
fn cmp<T: PartialOrd>(a: &Option<T>, b: &Option<T>) -> Ordering {
match (a, b) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Greater,
(Some(_), None) => Ordering::Less,
(Some(a), Some(b)) => a.partial_cmp(b).unwrap(),
}
}
impl<'a, CF, L, N> Extractor<'a, CF, L, N>
where
CF: CostFunction<L>,
L: Language,
N: Analysis<L>,
{
pub fn new(egraph: &'a EGraph<L, N>, cost_function: CF) -> Self {
let costs = HashMap::default();
let mut extractor = Extractor {
costs,
egraph,
cost_function,
};
extractor.find_costs();
extractor
}
pub fn find_best(&self, eclass: Id) -> (CF::Cost, RecExpr<L>) {
let (cost, root) = self.costs[&self.egraph.find(eclass)].clone();
let expr = root.build_recexpr(|id| self.find_best_node(id).clone());
(cost, expr)
}
pub fn find_best_node(&self, eclass: Id) -> &L {
&self.costs[&self.egraph.find(eclass)].1
}
pub fn find_best_cost(&self, eclass: Id) -> CF::Cost {
let (cost, _) = &self.costs[&self.egraph.find(eclass)];
cost.clone()
}
fn node_total_cost(&mut self, node: &L) -> Option<CF::Cost> {
let eg = &self.egraph;
let has_cost = |id| self.costs.contains_key(&eg.find(id));
if node.all(has_cost) {
let costs = &self.costs;
let cost_f = |id| costs[&eg.find(id)].0.clone();
Some(self.cost_function.cost(node, cost_f))
} else {
None
}
}
fn find_costs(&mut self) {
let mut did_something = true;
while did_something {
did_something = false;
for class in self.egraph.classes() {
let pass = self.make_pass(class);
match (self.costs.get(&class.id), pass) {
(None, Some(new)) => {
self.costs.insert(class.id, new);
did_something = true;
}
(Some(old), Some(new)) if new.0 < old.0 => {
self.costs.insert(class.id, new);
did_something = true;
}
_ => (),
}
}
}
for class in self.egraph.classes() {
if !self.costs.contains_key(&class.id) {
log::warn!(
"Failed to compute cost for eclass {}: {:?}",
class.id,
class.nodes
)
}
}
}
fn make_pass(&mut self, eclass: &EClass<L, N::Data>) -> Option<(CF::Cost, L)> {
let (cost, node) = eclass
.iter()
.map(|n| (self.node_total_cost(n), n))
.min_by(|a, b| cmp(&a.0, &b.0))
.unwrap_or_else(|| panic!("Can't extract, eclass is empty: {:#?}", eclass));
cost.map(|c| (c, node.clone()))
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn ast_size_overflow() {
let rules: &[Rewrite<SymbolLang, ()>] =
&[rewrite!("explode"; "(meow ?a)" => "(meow (meow ?a ?a))")];
let start = "(meow 42)".parse().unwrap();
let runner = Runner::default()
.with_iter_limit(100)
.with_expr(&start)
.run(rules);
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, best_expr) = extractor.find_best(runner.roots[0]);
assert_eq!(best_expr, start);
}
}