egg_stats/
logging_scheduler.rs

1use std::{borrow::Cow, fs::OpenOptions, io::Write, marker::PhantomData, path::Path};
2
3use crate::Recorder;
4
5pub struct LoggingScheduler<'a, S, L, N> {
6    identifier: String,
7    out_file: Box<dyn Write + 'a>,
8    inner_scheduler: S,
9    logging: bool,
10    phantom: PhantomData<(L, N)>,
11    recorders: Vec<Box<dyn Recorder<L, N> + 'a>>,
12}
13
14pub fn write_headers(path: impl AsRef<Path>) {
15    let mut file = OpenOptions::new()
16        .truncate(true)
17        .write(true)
18        .create(true)
19        .open(path)
20        .unwrap();
21    writeln!(&mut file, "id,iteration,rule_name,rule,when,name,value").unwrap()
22}
23
24impl<'a, S, L, N> LoggingScheduler<'a, S, L, N> {
25    pub fn new(scheduler: S) -> Self {
26        LoggingScheduler {
27            identifier: "default".to_string(),
28            out_file: Box::new(std::io::stdout()),
29            inner_scheduler: scheduler,
30            logging: false,
31            phantom: PhantomData,
32            recorders: vec![],
33        }
34    }
35
36    pub fn with_identifier(mut self, id: impl ToString) -> Self {
37        self.identifier = id.to_string();
38        self
39    }
40
41    pub fn with_out_file<W: Write + 'a>(mut self, out_file: W) -> Self {
42        self.out_file = Box::new(out_file);
43        self
44    }
45
46    pub fn with_logging_enabled(mut self, enabled: bool) -> Self {
47        self.logging = enabled;
48        self
49    }
50
51    pub fn with_recorder<D: Recorder<L, N> + 'a>(mut self, datum: D) -> Self
52    where
53        L: egg::Language,
54        N: egg::Analysis<L>,
55    {
56        self.recorders.push(Box::new(datum));
57        self
58    }
59
60    pub fn identifier(&mut self, id: impl ToString) -> &mut Self {
61        self.identifier = id.to_string();
62        self
63    }
64
65    pub fn out_file(&mut self, out_file: impl Write + 'a) -> &mut Self {
66        self.out_file = Box::new(out_file);
67        self
68    }
69
70    pub fn logging_enabled(&mut self, enabled: bool) -> &mut Self {
71        self.logging = enabled;
72        self
73    }
74
75    pub fn record<D: Recorder<L, N> + 'a>(&mut self, datum: D) -> &mut Self
76    where
77        L: egg::Language,
78        N: egg::Analysis<L>,
79    {
80        self.recorders.push(Box::new(datum));
81        self
82    }
83
84    fn write(
85        &mut self,
86        iteration: usize,
87        rule: &egg::Rewrite<L, N>,
88        typ: &str,
89        id: Cow<'static, str>,
90        datum: String,
91    ) where
92        L: egg::Language + std::fmt::Display,
93        N: egg::Analysis<L>,
94    {
95        writeln!(
96            &mut self.out_file,
97            "{},{},{},{},{},{},{}",
98            self.identifier,
99            iteration,
100            rule.name,
101            rewrite_str(rule),
102            typ,
103            id,
104            datum
105        )
106        .unwrap();
107    }
108}
109
110impl<'a, S, L, N> From<S> for LoggingScheduler<'a, S, L, N>
111where
112    S: egg::RewriteScheduler<L, N>,
113    L: egg::Language,
114    N: egg::Analysis<L>,
115{
116    fn from(value: S) -> Self {
117        LoggingScheduler::new(value)
118    }
119}
120
121fn rewrite_str<L, N>(rewrite: &egg::Rewrite<L, N>) -> String
122where
123    L: egg::Language + std::fmt::Display,
124    N: egg::Analysis<L>,
125{
126    if let (Some(searcher), Some(applier)) = (
127        rewrite.searcher.get_pattern_ast(),
128        rewrite.applier.get_pattern_ast(),
129    ) {
130        format!("{searcher} => {applier}")
131    } else {
132        format!("name_{}", rewrite.name)
133    }
134}
135
136impl<'a, S, L, N> egg::RewriteScheduler<L, N> for LoggingScheduler<'a, S, L, N>
137where
138    S: egg::RewriteScheduler<L, N>,
139    L: egg::Language + std::fmt::Display,
140    N: egg::Analysis<L>,
141{
142    fn can_stop(&mut self, iteration: usize) -> bool {
143        // if disabled, just call underlying scheduler
144        if !self.logging {
145            return self.inner_scheduler.can_stop(iteration);
146        }
147
148        self.inner_scheduler.can_stop(iteration)
149    }
150
151    fn search_rewrite<'s>(
152        &mut self,
153        iteration: usize,
154        egraph: &egg::EGraph<L, N>,
155        rewrite: &'s egg::Rewrite<L, N>,
156    ) -> Vec<egg::SearchMatches<'s, L>> {
157        // if disabled, just call underlying scheduler
158        if !self.logging {
159            return self
160                .inner_scheduler
161                .search_rewrite(iteration, egraph, rewrite);
162        }
163
164        self.recorders
165            .iter()
166            .map(|recorder| {
167                (
168                    recorder.identifier(),
169                    recorder.record_before_search(iteration, egraph, rewrite),
170                )
171            })
172            .collect::<Vec<_>>()
173            .into_iter()
174            .for_each(|(id, datum)| {
175                if let Some(datum) = datum {
176                    self.write(iteration, rewrite, "before_search", id, datum);
177                }
178            });
179
180        let matches = self
181            .inner_scheduler
182            .search_rewrite(iteration, egraph, rewrite);
183
184        self.recorders
185            .iter()
186            .map(|recorder| {
187                (
188                    recorder.identifier(),
189                    recorder.record_after_search(iteration, egraph, rewrite, &matches),
190                )
191            })
192            .collect::<Vec<_>>()
193            .into_iter()
194            .for_each(|(id, datum)| {
195                if let Some(datum) = datum {
196                    self.write(iteration, rewrite, "before_search", id, datum);
197                }
198            });
199
200        matches
201    }
202
203    fn apply_rewrite(
204        &mut self,
205        iteration: usize,
206        egraph: &mut egg::EGraph<L, N>,
207        rewrite: &egg::Rewrite<L, N>,
208        matches: Vec<egg::SearchMatches<L>>,
209    ) -> usize {
210        // if disabled, just call underlying scheduler
211        if !self.logging {
212            return self
213                .inner_scheduler
214                .apply_rewrite(iteration, egraph, rewrite, matches);
215        }
216
217        self.recorders
218            .iter()
219            .map(|recorder| {
220                (
221                    recorder.identifier(),
222                    recorder.record_before_rewrite(iteration, egraph, rewrite, &matches),
223                )
224            })
225            .collect::<Vec<_>>()
226            .into_iter()
227            .for_each(|(id, datum)| {
228                if let Some(datum) = datum {
229                    self.write(iteration, rewrite, "before_rewrite", id, datum);
230                }
231            });
232
233        let n_matches = self
234            .inner_scheduler
235            .apply_rewrite(iteration, egraph, rewrite, matches);
236
237        self.recorders
238            .iter()
239            .map(|recorder| {
240                (
241                    recorder.identifier(),
242                    recorder.record_after_rewrite(iteration, egraph, rewrite, n_matches),
243                )
244            })
245            .collect::<Vec<_>>()
246            .into_iter()
247            .for_each(|(id, datum)| {
248                if let Some(datum) = datum {
249                    self.write(iteration, rewrite, "after_rewrite", id, datum);
250                }
251            });
252
253        n_matches
254    }
255}