Trait egg::CostFunction

source ·
pub trait CostFunction<L: Language> {
    type Cost: PartialOrd + Debug + Clone;

    // Required method
    fn cost<C>(&mut self, enode: &L, costs: C) -> Self::Cost
       where C: FnMut(Id) -> Self::Cost;

    // Provided method
    fn cost_rec(&mut self, expr: &RecExpr<L>) -> Self::Cost { ... }
}
Expand description

A cost function that can be used by an Extractor.

To extract an expression from an EGraph, the Extractor requires a cost function to performs its greedy search. egg provides the simple AstSize and AstDepth cost functions.

The example below illustrates a silly but realistic example of implementing a cost function that is essentially AST size weighted by the operator:

struct SillyCostFn;
impl CostFunction<SymbolLang> for SillyCostFn {
    type Cost = f64;
    fn cost<C>(&mut self, enode: &SymbolLang, mut costs: C) -> Self::Cost
    where
        C: FnMut(Id) -> Self::Cost
    {
        let op_cost = match enode.op.as_str() {
            "foo" => 100.0,
            "bar" => 0.7,
            _ => 1.0
        };
        enode.fold(op_cost, |sum, id| sum + costs(id))
    }
}

let e: RecExpr<SymbolLang> = "(do_it foo bar baz)".parse().unwrap();
assert_eq!(SillyCostFn.cost_rec(&e), 102.7);
assert_eq!(AstSize.cost_rec(&e), 4);
assert_eq!(AstDepth.cost_rec(&e), 2);

If you’d like to access the Analysis data or anything else in the e-graph, you can put a reference to the e-graph in your CostFunction:

struct EGraphCostFn<'a> {
    egraph: &'a EGraph<SymbolLang, MyAnalysis>,
}

impl<'a> CostFunction<SymbolLang> for EGraphCostFn<'a> {
    type Cost = usize;
    fn cost<C>(&mut self, enode: &SymbolLang, mut costs: C) -> Self::Cost
    where
        C: FnMut(Id) -> Self::Cost
    {
        // use self.egraph however you want here
        println!("the egraph has {} classes", self.egraph.number_of_classes());
        return 1
    }
}

let mut egraph = EGraph::<SymbolLang, MyAnalysis>::default();
let id = egraph.add_expr(&"(foo bar)".parse().unwrap());
let cost_func = EGraphCostFn { egraph: &egraph };
let mut extractor = Extractor::new(&egraph, cost_func);
let _ = extractor.find_best(id);

Note that a particular e-class might occur in an expression multiple times. This means that pathological, but nevertheless realistic cases might overflow usize if you implement a cost function like AstSize, even if the actual RecExpr fits compactly in memory. You might want to use saturating_add to ensure your cost function is still monotonic in this situation.

Required Associated Types§

source

type Cost: PartialOrd + Debug + Clone

The Cost type. It only requires PartialOrd so you can use floating point types, but failed comparisons (NaNs) will result in a panic.

Required Methods§

source

fn cost<C>(&mut self, enode: &L, costs: C) -> Self::Costwhere C: FnMut(Id) -> Self::Cost,

Calculates the cost of an enode whose children are Costs.

For this to work properly, your cost function should be monotonic, i.e. cost should return a Cost greater than any of the child costs of the given enode.

Provided Methods§

source

fn cost_rec(&mut self, expr: &RecExpr<L>) -> Self::Cost

Calculates the total cost of a RecExpr.

As provided, this just recursively calls cost all the way down the RecExpr.

Implementors§