slotted_egraphs/extract/
cost.rs1use crate::*;
2
3use std::marker::PhantomData;
4
5pub trait CostFunction<L: Language> {
9 type Cost: Ord + Clone + Debug;
10 fn cost<C>(&self, enode: &L, costs: C) -> Self::Cost
11 where
12 C: Fn(Id) -> Self::Cost;
13
14 fn cost_rec(&self, expr: &RecExpr<L>) -> Self::Cost {
15 let child_costs: Vec<Self::Cost> = expr.children.iter().map(|x| self.cost_rec(x)).collect();
16 let c = |i: Id| child_costs[i.0].clone();
17 let mut node = expr.node.clone();
18 for (i, x) in node.applied_id_occurrences_mut().iter_mut().enumerate() {
19 **x = AppliedId::new(Id(i), SlotMap::new());
20 }
21 self.cost(&node, c)
22 }
23}
24
25#[derive(Default)]
27pub struct AstSize;
28
29impl<L: Language> CostFunction<L> for AstSize {
30 type Cost = u64;
31
32 fn cost<C>(&self, enode: &L, costs: C) -> u64
33 where
34 C: Fn(Id) -> u64,
35 {
36 let mut s: u64 = 1;
37 for x in enode.applied_id_occurrences() {
38 s = s.saturating_add(costs(x.id));
39 }
40 s
41 }
42}