Skip to main content

slotted_egraphs/run/
runner.rs

1use crate::*;
2use std::time::Duration;
3use std::time::Instant;
4
5pub struct Iteration<IterData> {
6    /// The user provided annotation for this iteration
7    pub data: IterData,
8    // TODO: add more data things
9    pub num_nodes: usize,
10    pub finish_time: Option<Instant>,
11}
12pub trait IterationData<L, N>: Sized
13where
14    L: Language,
15    N: Analysis<L>,
16{
17    /// Given the current [`Runner`], make the
18    /// data to be put in this [`Iteration`].
19    fn make<CustomErrorT>(runner: &Runner<L, N, Self, CustomErrorT>) -> Self
20    where
21        CustomErrorT: Clone;
22}
23
24impl<L, N> IterationData<L, N> for ()
25where
26    L: Language,
27    N: Analysis<L>,
28{
29    fn make<CustomErrorT>(_: &Runner<L, N, Self, CustomErrorT>) -> Self
30    where
31        CustomErrorT: Clone,
32    {
33    }
34}
35
36pub struct RunnerLimits {
37    iter_limit: usize,
38    node_limit: usize,
39    start_time: Option<Instant>,
40    time_limit: Duration,
41}
42/// Type alias for the result of a [`Runner`].
43pub type RunnerResult<T, CustomErrorT = String> = Result<T, StopReason<CustomErrorT>>;
44
45impl RunnerLimits {
46    fn check_limits<L, N, CustomErrorT>(
47        &self,
48        iteration: usize,
49        eg: &EGraph<L, N>,
50    ) -> RunnerResult<(), CustomErrorT>
51    where
52        L: Language,
53        N: Analysis<L>,
54        CustomErrorT: Clone,
55    {
56        let elapsed = self.start_time.unwrap().elapsed();
57        if iteration > self.iter_limit {
58            Err(StopReason::IterationLimit)
59        } else if eg.total_number_of_nodes() > self.node_limit {
60            Err(StopReason::NodeLimit)
61        } else if elapsed > self.time_limit {
62            Err(StopReason::TimeLimit)
63        } else {
64            Ok(())
65        }
66    }
67}
68
69pub struct Runner<L: Language, N: Analysis<L> = (), IterData = (), CustomErrorT = String>
70where
71    IterData: IterationData<L, N>,
72    CustomErrorT: Clone,
73{
74    /// The [`EGraph`] used.
75    pub egraph: EGraph<L, N>,
76    /// Data accumulated over each [`Iteration`].
77    pub iterations: Vec<Iteration<IterData>>,
78    /// The roots of expressions added by the
79    /// [`with_expr`](Runner::with_expr()) method, in insertion order.
80    pub roots: Vec<AppliedId>,
81    /// Why the `Runner` stopped. This will be `None` if it hasn't
82    /// stopped yet.
83    pub stop_reason: Option<StopReason<CustomErrorT>>,
84
85    // Initial expressions in the EGraph
86    pub limits: RunnerLimits,
87    /// hooks
88    pub hooks: Vec<Box<dyn FnMut(&mut Self) -> Result<(), CustomErrorT> + 'static>>,
89}
90
91impl<L, N, IterData, CustomErrorT> Runner<L, N, IterData, CustomErrorT>
92where
93    L: Language,
94    N: Analysis<L>,
95    IterData: IterationData<L, N>,
96    CustomErrorT: Clone,
97{
98    pub fn new(n: N) -> Self {
99        Self {
100            egraph: EGraph::new(n),
101            iterations: vec![],
102            stop_reason: None,
103            limits: RunnerLimits {
104                iter_limit: 30,
105                node_limit: 10_000,
106                time_limit: Duration::from_secs(60),
107                // The start_time is initialized when the Runner is ran
108                start_time: None,
109            },
110            hooks: vec![],
111            roots: vec![],
112        }
113    }
114    pub fn with_expr(mut self, expr: &RecExpr<L>) -> Self {
115        let id = self.egraph.add_expr(expr.clone());
116        self.roots.push(id);
117        self
118    }
119    pub fn with_hook<F>(mut self, hook: F) -> Self
120    where
121        F: FnMut(&mut Self) -> Result<(), CustomErrorT> + 'static,
122    {
123        self.hooks.push(Box::new(hook));
124        self
125    }
126    pub fn with_egraph(mut self, egraph: EGraph<L, N>) -> Self {
127        // You should probably not use this if you use `with_expr` as well
128        self.egraph = egraph;
129        self
130    }
131    pub fn with_node_limit(mut self, node_limit: usize) -> Self {
132        self.limits.node_limit = node_limit;
133        self
134    }
135    pub fn with_iter_limit(mut self, iter_limit: usize) -> Self {
136        self.limits.iter_limit = iter_limit;
137        self
138    }
139    pub fn with_time_limit(mut self, time_limit: Duration) -> Self {
140        self.limits.time_limit = time_limit;
141        self
142    }
143
144    fn check_limits(&mut self) -> RunnerResult<(), CustomErrorT> {
145        self.limits
146            .check_limits(self.iterations.len(), &self.egraph)
147    }
148    pub fn run(&mut self, rewrites: &[Rewrite<L, N>]) -> Report<CustomErrorT> {
149        loop {
150            if let Some(_) = self.stop_reason {
151                break;
152            }
153            let iter = self.run_one(rewrites);
154            self.iterations.push(iter);
155        }
156        Report {
157            iterations: self.iterations.len(),
158            stop_reason: self.stop_reason.clone().unwrap(),
159            egraph_nodes: self.egraph.total_number_of_nodes(),
160            egraph_classes: self.egraph.classes.len(),
161            total_time: self
162                .iterations
163                .last()
164                .unwrap()
165                .finish_time
166                .unwrap()
167                .duration_since(self.limits.start_time.unwrap())
168                .as_secs_f64(),
169        }
170    }
171    fn run_one(&mut self, rewrites: &[Rewrite<L, N>]) -> Iteration<IterData> {
172        assert!(self.stop_reason.is_none());
173
174        // if the runner has not started, start the timer
175        self.limits.start_time.get_or_insert_with(Instant::now);
176        let mut hooks = std::mem::take(&mut self.hooks);
177
178        let mut result = Ok(());
179
180        // Apply rewrites, then check hooks, then check limits, then check if saturated.
181        let progress = apply_rewrites(&mut self.egraph, rewrites);
182        result = result
183            .and_then(|_| {
184                hooks
185                    .iter_mut()
186                    .try_for_each(|hook| hook(self).map_err(|err| StopReason::Other(err)))
187            })
188            .and_then(|_| self.check_limits());
189
190        if !progress {
191            result = result.and_then(|_| Err(StopReason::Saturated));
192        }
193
194        if let Err(stop_reason) = result {
195            self.stop_reason = Some(stop_reason);
196        }
197        self.hooks = hooks;
198        Iteration {
199            data: IterData::make(self),
200            num_nodes: self.egraph.total_number_of_nodes(),
201            finish_time: Some(Instant::now()),
202        }
203    }
204}
205
206impl<L, N, IterData, CustomErrorT> Default for Runner<L, N, IterData, CustomErrorT>
207where
208    L: Language,
209    N: Analysis<L> + Default,
210    IterData: IterationData<L, N>,
211    CustomErrorT: Clone,
212{
213    fn default() -> Self {
214        Runner::new(Default::default())
215    }
216}