1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
use crate::{
    egraph::{EClass, EGraph},
    expr::{Expr, Id, Language, RecExpr},
    util::HashMap,
};

use log::*;

pub type Cost = u64;

fn calculate_cost_rec<L: Language>(map: &mut HashMap<RecExpr<L>, Cost>, expr: &RecExpr<L>) -> Cost {
    if map.contains_key(expr) {
        return 1;
    }

    let child_cost_expr = expr.as_ref().map_children(|e| calculate_cost_rec(map, &e));
    let cost = L::cost(&child_cost_expr);

    map.insert(expr.clone(), cost);
    cost
}

pub fn calculate_cost<L: Language>(expr: &RecExpr<L>) -> Cost {
    let mut map = HashMap::default();
    let cost = calculate_cost_rec(&mut map, expr);
    trace!("Found cost to be {}\n  {}", cost, expr.to_sexp());
    cost
}

pub struct CostExpr<L: Language> {
    pub cost: Cost,
    pub expr: RecExpr<L>,
}

pub struct Extractor<'a, L: Language, M> {
    costs: HashMap<Id, CostExpr<L>>,
    egraph: &'a EGraph<L, M>,
}

impl<'a, L: Language, M> Extractor<'a, L, M> {
    pub fn new(egraph: &'a EGraph<L, M>) -> Self {
        // initialize costs with the maximum value
        let costs = HashMap::default();

        let mut extractor = Extractor { costs, egraph };
        extractor.find_costs();

        extractor
    }

    pub fn find_best(&self, eclass: Id) -> &CostExpr<L> {
        &self.costs[&eclass]
    }

    fn build_expr(&self, root: &Expr<L, Id>) -> Option<RecExpr<L>> {
        let expr = root
            .map_children_result(|id| {
                self.costs
                    .get(&id)
                    .map(|cost_expr| cost_expr.expr.clone())
                    .ok_or(())
            })
            .ok()?;
        Some(expr.into())
    }

    fn node_total_cost(&self, node: &Expr<L, Id>) -> Option<CostExpr<L>> {
        let expr = self.build_expr(node)?;
        let cost = calculate_cost(&expr);
        Some(CostExpr { cost, expr })
    }

    fn find_costs(&mut self) {
        let mut did_something = true;
        let mut loops = 0;
        while did_something {
            did_something = false;

            for class in self.egraph.classes() {
                did_something |= self.make_pass(class);
            }

            loops += 1;
        }

        info!("Took {} loops to find costs", loops);
    }

    fn make_pass(&mut self, class: &EClass<L, M>) -> bool {
        let new = class
            .iter()
            .filter_map(|n| self.node_total_cost(n))
            .min_by_key(|ce| ce.cost);

        let new = match new {
            Some(new) => new,
            None => return true,
        };

        if let Some(old) = self.costs.get(&class.id) {
            if new.cost < old.cost {
                self.costs.insert(class.id, new);
                true
            } else {
                false
            }
        } else {
            self.costs.insert(class.id, new);
            true
        }
    }
}