1use std::{borrow::Cow, fs::OpenOptions, io::Write, marker::PhantomData, path::Path};
2
3use crate::Recorder;
4
5pub struct LoggingScheduler<'a, S, L, N> {
6 identifier: String,
7 out_file: Box<dyn Write + 'a>,
8 inner_scheduler: S,
9 logging: bool,
10 phantom: PhantomData<(L, N)>,
11 recorders: Vec<Box<dyn Recorder<L, N> + 'a>>,
12}
13
14pub fn write_headers(path: impl AsRef<Path>) {
15 let mut file = OpenOptions::new()
16 .truncate(true)
17 .write(true)
18 .create(true)
19 .open(path)
20 .unwrap();
21 writeln!(&mut file, "id,iteration,rule_name,rule,when,name,value").unwrap()
22}
23
24impl<'a, S, L, N> LoggingScheduler<'a, S, L, N> {
25 pub fn new(scheduler: S) -> Self {
26 LoggingScheduler {
27 identifier: "default".to_string(),
28 out_file: Box::new(std::io::stdout()),
29 inner_scheduler: scheduler,
30 logging: false,
31 phantom: PhantomData,
32 recorders: vec![],
33 }
34 }
35
36 pub fn with_identifier(mut self, id: impl ToString) -> Self {
37 self.identifier = id.to_string();
38 self
39 }
40
41 pub fn with_out_file<W: Write + 'a>(mut self, out_file: W) -> Self {
42 self.out_file = Box::new(out_file);
43 self
44 }
45
46 pub fn with_logging_enabled(mut self, enabled: bool) -> Self {
47 self.logging = enabled;
48 self
49 }
50
51 pub fn with_recorder<D: Recorder<L, N> + 'a>(mut self, datum: D) -> Self
52 where
53 L: egg::Language,
54 N: egg::Analysis<L>,
55 {
56 self.recorders.push(Box::new(datum));
57 self
58 }
59
60 pub fn identifier(&mut self, id: impl ToString) -> &mut Self {
61 self.identifier = id.to_string();
62 self
63 }
64
65 pub fn out_file(&mut self, out_file: impl Write + 'a) -> &mut Self {
66 self.out_file = Box::new(out_file);
67 self
68 }
69
70 pub fn logging_enabled(&mut self, enabled: bool) -> &mut Self {
71 self.logging = enabled;
72 self
73 }
74
75 pub fn record<D: Recorder<L, N> + 'a>(&mut self, datum: D) -> &mut Self
76 where
77 L: egg::Language,
78 N: egg::Analysis<L>,
79 {
80 self.recorders.push(Box::new(datum));
81 self
82 }
83
84 fn write(
85 &mut self,
86 iteration: usize,
87 rule: &egg::Rewrite<L, N>,
88 typ: &str,
89 id: Cow<'static, str>,
90 datum: String,
91 ) where
92 L: egg::Language + std::fmt::Display,
93 N: egg::Analysis<L>,
94 {
95 writeln!(
96 &mut self.out_file,
97 "{},{},{},{},{},{},{}",
98 self.identifier,
99 iteration,
100 rule.name,
101 rewrite_str(rule),
102 typ,
103 id,
104 datum
105 )
106 .unwrap();
107 }
108}
109
110impl<'a, S, L, N> From<S> for LoggingScheduler<'a, S, L, N>
111where
112 S: egg::RewriteScheduler<L, N>,
113 L: egg::Language,
114 N: egg::Analysis<L>,
115{
116 fn from(value: S) -> Self {
117 LoggingScheduler::new(value)
118 }
119}
120
121fn rewrite_str<L, N>(rewrite: &egg::Rewrite<L, N>) -> String
122where
123 L: egg::Language + std::fmt::Display,
124 N: egg::Analysis<L>,
125{
126 if let (Some(searcher), Some(applier)) = (
127 rewrite.searcher.get_pattern_ast(),
128 rewrite.applier.get_pattern_ast(),
129 ) {
130 format!("{searcher} => {applier}")
131 } else {
132 format!("name_{}", rewrite.name)
133 }
134}
135
136impl<'a, S, L, N> egg::RewriteScheduler<L, N> for LoggingScheduler<'a, S, L, N>
137where
138 S: egg::RewriteScheduler<L, N>,
139 L: egg::Language + std::fmt::Display,
140 N: egg::Analysis<L>,
141{
142 fn can_stop(&mut self, iteration: usize) -> bool {
143 if !self.logging {
145 return self.inner_scheduler.can_stop(iteration);
146 }
147
148 self.inner_scheduler.can_stop(iteration)
149 }
150
151 fn search_rewrite<'s>(
152 &mut self,
153 iteration: usize,
154 egraph: &egg::EGraph<L, N>,
155 rewrite: &'s egg::Rewrite<L, N>,
156 ) -> Vec<egg::SearchMatches<'s, L>> {
157 if !self.logging {
159 return self
160 .inner_scheduler
161 .search_rewrite(iteration, egraph, rewrite);
162 }
163
164 self.recorders
165 .iter()
166 .map(|recorder| {
167 (
168 recorder.identifier(),
169 recorder.record_before_search(iteration, egraph, rewrite),
170 )
171 })
172 .collect::<Vec<_>>()
173 .into_iter()
174 .for_each(|(id, datum)| {
175 if let Some(datum) = datum {
176 self.write(iteration, rewrite, "before_search", id, datum);
177 }
178 });
179
180 let matches = self
181 .inner_scheduler
182 .search_rewrite(iteration, egraph, rewrite);
183
184 self.recorders
185 .iter()
186 .map(|recorder| {
187 (
188 recorder.identifier(),
189 recorder.record_after_search(iteration, egraph, rewrite, &matches),
190 )
191 })
192 .collect::<Vec<_>>()
193 .into_iter()
194 .for_each(|(id, datum)| {
195 if let Some(datum) = datum {
196 self.write(iteration, rewrite, "before_search", id, datum);
197 }
198 });
199
200 matches
201 }
202
203 fn apply_rewrite(
204 &mut self,
205 iteration: usize,
206 egraph: &mut egg::EGraph<L, N>,
207 rewrite: &egg::Rewrite<L, N>,
208 matches: Vec<egg::SearchMatches<L>>,
209 ) -> usize {
210 if !self.logging {
212 return self
213 .inner_scheduler
214 .apply_rewrite(iteration, egraph, rewrite, matches);
215 }
216
217 self.recorders
218 .iter()
219 .map(|recorder| {
220 (
221 recorder.identifier(),
222 recorder.record_before_rewrite(iteration, egraph, rewrite, &matches),
223 )
224 })
225 .collect::<Vec<_>>()
226 .into_iter()
227 .for_each(|(id, datum)| {
228 if let Some(datum) = datum {
229 self.write(iteration, rewrite, "before_rewrite", id, datum);
230 }
231 });
232
233 let n_matches = self
234 .inner_scheduler
235 .apply_rewrite(iteration, egraph, rewrite, matches);
236
237 self.recorders
238 .iter()
239 .map(|recorder| {
240 (
241 recorder.identifier(),
242 recorder.record_after_rewrite(iteration, egraph, rewrite, n_matches),
243 )
244 })
245 .collect::<Vec<_>>()
246 .into_iter()
247 .for_each(|(id, datum)| {
248 if let Some(datum) = datum {
249 self.write(iteration, rewrite, "after_rewrite", id, datum);
250 }
251 });
252
253 n_matches
254 }
255}