Skip to main content

slotted_egraphs/extract/
cost.rs

1use crate::*;
2
3use std::marker::PhantomData;
4
5/// A cost function to guide extraction.
6///
7/// If you want to use your e-graph analysis in your cost function, then your cost function should hold a reference to the e-graph.
8pub trait CostFunction<L: Language> {
9    type Cost: Ord + Clone + Debug;
10    fn cost<C>(&self, enode: &L, costs: C) -> Self::Cost
11    where
12        C: Fn(Id) -> Self::Cost;
13
14    fn cost_rec(&self, expr: &RecExpr<L>) -> Self::Cost {
15        let child_costs: Vec<Self::Cost> = expr.children.iter().map(|x| self.cost_rec(x)).collect();
16        let c = |i: Id| child_costs[i.0].clone();
17        let mut node = expr.node.clone();
18        for (i, x) in node.applied_id_occurrences_mut().iter_mut().enumerate() {
19            **x = AppliedId::new(Id(i), SlotMap::new());
20        }
21        self.cost(&node, c)
22    }
23}
24
25/// The 'default' [CostFunction]. It measures the size of the abstract syntax tree of the corresponding term.
26#[derive(Default)]
27pub struct AstSize;
28
29impl<L: Language> CostFunction<L> for AstSize {
30    type Cost = u64;
31
32    fn cost<C>(&self, enode: &L, costs: C) -> u64
33    where
34        C: Fn(Id) -> u64,
35    {
36        let mut s: u64 = 1;
37        for x in enode.applied_id_occurrences() {
38            s = s.saturating_add(costs(x.id));
39        }
40        s
41    }
42}