1use crate::*;
2use std::time::Duration;
3use std::time::Instant;
4
5pub struct Iteration<IterData> {
6 pub data: IterData,
8 pub num_nodes: usize,
10 pub finish_time: Option<Instant>,
11}
12pub trait IterationData<L, N>: Sized
13where
14 L: Language,
15 N: Analysis<L>,
16{
17 fn make<CustomErrorT>(runner: &Runner<L, N, Self, CustomErrorT>) -> Self
20 where
21 CustomErrorT: Clone;
22}
23
24impl<L, N> IterationData<L, N> for ()
25where
26 L: Language,
27 N: Analysis<L>,
28{
29 fn make<CustomErrorT>(_: &Runner<L, N, Self, CustomErrorT>) -> Self
30 where
31 CustomErrorT: Clone,
32 {
33 }
34}
35
36pub struct RunnerLimits {
37 iter_limit: usize,
38 node_limit: usize,
39 start_time: Option<Instant>,
40 time_limit: Duration,
41}
42pub type RunnerResult<T, CustomErrorT = String> = Result<T, StopReason<CustomErrorT>>;
44
45impl RunnerLimits {
46 fn check_limits<L, N, CustomErrorT>(
47 &self,
48 iteration: usize,
49 eg: &EGraph<L, N>,
50 ) -> RunnerResult<(), CustomErrorT>
51 where
52 L: Language,
53 N: Analysis<L>,
54 CustomErrorT: Clone,
55 {
56 let elapsed = self.start_time.unwrap().elapsed();
57 if iteration > self.iter_limit {
58 Err(StopReason::IterationLimit)
59 } else if eg.total_number_of_nodes() > self.node_limit {
60 Err(StopReason::NodeLimit)
61 } else if elapsed > self.time_limit {
62 Err(StopReason::TimeLimit)
63 } else {
64 Ok(())
65 }
66 }
67}
68
69pub struct Runner<L: Language, N: Analysis<L> = (), IterData = (), CustomErrorT = String>
70where
71 IterData: IterationData<L, N>,
72 CustomErrorT: Clone,
73{
74 pub egraph: EGraph<L, N>,
76 pub iterations: Vec<Iteration<IterData>>,
78 pub roots: Vec<AppliedId>,
81 pub stop_reason: Option<StopReason<CustomErrorT>>,
84
85 pub limits: RunnerLimits,
87 pub hooks: Vec<Box<dyn FnMut(&mut Self) -> Result<(), CustomErrorT> + 'static>>,
89}
90
91impl<L, N, IterData, CustomErrorT> Runner<L, N, IterData, CustomErrorT>
92where
93 L: Language,
94 N: Analysis<L>,
95 IterData: IterationData<L, N>,
96 CustomErrorT: Clone,
97{
98 pub fn new(n: N) -> Self {
99 Self {
100 egraph: EGraph::new(n),
101 iterations: vec![],
102 stop_reason: None,
103 limits: RunnerLimits {
104 iter_limit: 30,
105 node_limit: 10_000,
106 time_limit: Duration::from_secs(60),
107 start_time: None,
109 },
110 hooks: vec![],
111 roots: vec![],
112 }
113 }
114 pub fn with_expr(mut self, expr: &RecExpr<L>) -> Self {
115 let id = self.egraph.add_expr(expr.clone());
116 self.roots.push(id);
117 self
118 }
119 pub fn with_hook<F>(mut self, hook: F) -> Self
120 where
121 F: FnMut(&mut Self) -> Result<(), CustomErrorT> + 'static,
122 {
123 self.hooks.push(Box::new(hook));
124 self
125 }
126 pub fn with_egraph(mut self, egraph: EGraph<L, N>) -> Self {
127 self.egraph = egraph;
129 self
130 }
131 pub fn with_node_limit(mut self, node_limit: usize) -> Self {
132 self.limits.node_limit = node_limit;
133 self
134 }
135 pub fn with_iter_limit(mut self, iter_limit: usize) -> Self {
136 self.limits.iter_limit = iter_limit;
137 self
138 }
139 pub fn with_time_limit(mut self, time_limit: Duration) -> Self {
140 self.limits.time_limit = time_limit;
141 self
142 }
143
144 fn check_limits(&mut self) -> RunnerResult<(), CustomErrorT> {
145 self.limits
146 .check_limits(self.iterations.len(), &self.egraph)
147 }
148 pub fn run(&mut self, rewrites: &[Rewrite<L, N>]) -> Report<CustomErrorT> {
149 loop {
150 if let Some(_) = self.stop_reason {
151 break;
152 }
153 let iter = self.run_one(rewrites);
154 self.iterations.push(iter);
155 }
156 Report {
157 iterations: self.iterations.len(),
158 stop_reason: self.stop_reason.clone().unwrap(),
159 egraph_nodes: self.egraph.total_number_of_nodes(),
160 egraph_classes: self.egraph.classes.len(),
161 total_time: self
162 .iterations
163 .last()
164 .unwrap()
165 .finish_time
166 .unwrap()
167 .duration_since(self.limits.start_time.unwrap())
168 .as_secs_f64(),
169 }
170 }
171 fn run_one(&mut self, rewrites: &[Rewrite<L, N>]) -> Iteration<IterData> {
172 assert!(self.stop_reason.is_none());
173
174 self.limits.start_time.get_or_insert_with(Instant::now);
176 let mut hooks = std::mem::take(&mut self.hooks);
177
178 let mut result = Ok(());
179
180 let progress = apply_rewrites(&mut self.egraph, rewrites);
182 result = result
183 .and_then(|_| {
184 hooks
185 .iter_mut()
186 .try_for_each(|hook| hook(self).map_err(|err| StopReason::Other(err)))
187 })
188 .and_then(|_| self.check_limits());
189
190 if !progress {
191 result = result.and_then(|_| Err(StopReason::Saturated));
192 }
193
194 if let Err(stop_reason) = result {
195 self.stop_reason = Some(stop_reason);
196 }
197 self.hooks = hooks;
198 Iteration {
199 data: IterData::make(self),
200 num_nodes: self.egraph.total_number_of_nodes(),
201 finish_time: Some(Instant::now()),
202 }
203 }
204}
205
206impl<L, N, IterData, CustomErrorT> Default for Runner<L, N, IterData, CustomErrorT>
207where
208 L: Language,
209 N: Analysis<L> + Default,
210 IterData: IterationData<L, N>,
211 CustomErrorT: Clone,
212{
213 fn default() -> Self {
214 Runner::new(Default::default())
215 }
216}