egglog/
scheduler.rs

1use std::sync::Arc;
2use std::sync::Mutex;
3use std::time::Duration;
4
5use core_relations::{ExecutionState, ExternalFunction, Value};
6use egglog_bridge::{
7    ColumnTy, DefaultVal, FunctionConfig, FunctionId, MergeFn, RuleId, TableAction,
8};
9use numeric_id::define_id;
10
11use crate::{ast::ResolvedVar, core::GenericAtomTerm, core::ResolvedCoreRule, util::IndexMap, *};
12
13/// A scheduler decides which matches to be applied for a rule.
14///
15/// The matches that are not chosen in this iteration will be delayed
16/// to the next iteration.
17pub trait Scheduler: dyn_clone::DynClone + Send + Sync {
18    /// Whether or not the rules can be considered as saturated (i.e.,
19    /// `run_report.updated == false`).
20    ///
21    /// This is only called when the runner is otherwise saturated.
22    /// Default implementation just returns `true`.
23    fn can_stop(&mut self, rules: &[&str], ruleset: &str) -> bool {
24        let _ = (rules, ruleset);
25        true
26    }
27
28    /// Filter the matches for a rule.
29    ///
30    /// Return `true` if the scheduler's next run of the rule should feed
31    /// `filter_matches` with a new iteration of matches.
32    fn filter_matches(&mut self, rule: &str, ruleset: &str, matches: &mut Matches) -> bool;
33}
34
35dyn_clone::clone_trait_object!(Scheduler);
36
37/// A collection of matches produced by a rule.
38/// The user can choose which matches to be fired.
39pub struct Matches {
40    matches: Vec<Value>,
41    chosen: Vec<usize>,
42    vars: Vec<ResolvedVar>,
43    all_chosen: bool,
44}
45
46/// A match is a tuple of values corresponding to the variables in a rule.
47/// It allows you to retrieve the value corresponding to a variable in the match.
48pub struct Match<'a> {
49    values: &'a [Value],
50    vars: &'a [ResolvedVar],
51}
52
53impl Match<'_> {
54    /// Get the value corresponding a variable in this match.
55    pub fn get_value(&self, var: &str) -> Value {
56        let idx = self.vars.iter().position(|v| v.name == var).unwrap();
57        self.values[idx]
58    }
59}
60
61impl Matches {
62    fn new(matches: Vec<Value>, vars: Vec<ResolvedVar>) -> Self {
63        let total_len = matches.len();
64        let tuple_len = vars.len();
65        assert!(total_len % tuple_len == 0);
66        Self {
67            matches,
68            vars,
69            chosen: Vec::new(),
70            all_chosen: false,
71        }
72    }
73
74    /// The number of matches in total.
75    pub fn match_size(&self) -> usize {
76        self.matches.len() / self.vars.len()
77    }
78
79    /// The length of a tuple.
80    pub fn tuple_len(&self) -> usize {
81        self.vars.len()
82    }
83
84    /// Get `idx`-th match.
85    pub fn get_match(&self, idx: usize) -> Match<'_> {
86        Match {
87            values: &self.matches[idx * self.tuple_len()..(idx + 1) * self.tuple_len()],
88            vars: &self.vars,
89        }
90    }
91
92    /// Pick the match at `idx` to be fired.
93    pub fn choose(&mut self, idx: usize) {
94        self.chosen.push(idx);
95    }
96
97    /// Pick all matches to be fired.
98    ///
99    /// This is more efficient than calling `choose` for each match.
100    pub fn choose_all(&mut self) {
101        self.all_chosen = true;
102    }
103
104    /// Apply the chosen matches and return the residual matches.
105    fn instantiate(
106        mut self,
107        state: &mut ExecutionState<'_>,
108        mut table_action: TableAction,
109    ) -> Vec<Value> {
110        let tuple_len = self.tuple_len();
111        let unit = state.base_values().get(());
112
113        if self.all_chosen {
114            for row in self.matches.chunks(tuple_len) {
115                table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
116            }
117            vec![]
118        } else {
119            for idx in self.chosen.iter() {
120                let row = &self.matches[idx * tuple_len..(idx + 1) * tuple_len];
121                table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
122            }
123
124            // swap remove the chosen matches
125            self.chosen.sort_unstable();
126            self.chosen.dedup();
127            let mut p = self.match_size();
128            for c in self.chosen.into_iter().rev() {
129                // It's important to decrement `p` first, because otherwise it might underflow when
130                // matches are exhausted.
131                p -= 1;
132                if c != p {
133                    let idx_c = c * tuple_len;
134                    let idx_p = p * tuple_len;
135                    for i in 0..tuple_len {
136                        self.matches.swap(idx_c + i, idx_p + i);
137                    }
138                }
139            }
140            self.matches.truncate(p * tuple_len);
141
142            self.matches
143        }
144    }
145}
146
147define_id!(
148    pub SchedulerId, u32,
149    "A unique identifier for a scheduler in the EGraph."
150);
151
152impl EGraph {
153    /// Register a new scheduler and return its id.
154    pub fn add_scheduler(&mut self, scheduler: Box<dyn Scheduler>) -> SchedulerId {
155        self.schedulers.push(SchedulerRecord {
156            scheduler,
157            rule_info: Default::default(),
158        })
159    }
160
161    /// Removes a scheduler
162    pub fn remove_scheduler(&mut self, scheduler_id: SchedulerId) -> Option<Box<dyn Scheduler>> {
163        self.schedulers.take(scheduler_id).map(|r| r.scheduler)
164    }
165
166    /// Runs a ruleset for one iteration using the given ruleset
167    pub fn step_rules_with_scheduler(
168        &mut self,
169        scheduler_id: SchedulerId,
170        ruleset: &str,
171    ) -> Result<RunReport, Error> {
172        fn collect_rules<'a>(
173            ruleset: &str,
174            rulesets: &'a IndexMap<String, Ruleset>,
175            ids: &mut Vec<(String, &'a ResolvedCoreRule)>,
176        ) {
177            match &rulesets[ruleset] {
178                Ruleset::Rules(rules) => {
179                    for (rule_name, (core_rule, _)) in rules.iter() {
180                        ids.push((rule_name.clone(), core_rule));
181                    }
182                }
183                Ruleset::Combined(sub_rulesets) => {
184                    for sub_ruleset in sub_rulesets {
185                        collect_rules(sub_ruleset, rulesets, ids);
186                    }
187                }
188            }
189        }
190
191        let mut rules = Vec::new();
192        let rulesets = std::mem::take(&mut self.rulesets);
193        collect_rules(ruleset, &rulesets, &mut rules);
194        let mut schedulers = std::mem::take(&mut self.schedulers);
195
196        // Step 1: build all the query/action rules and worklist if have not already
197        let record = &mut schedulers[scheduler_id];
198        rules.iter().for_each(|(id, rule)| {
199            record
200                .rule_info
201                .entry((*id).to_owned())
202                .or_insert_with(|| SchedulerRuleInfo::new(self, rule, id));
203        });
204
205        // Step 2: run all the queries for one iteration
206        let query_rules = rules
207            .iter()
208            .filter_map(|(rule_id, _rule)| {
209                let rule_info = record.rule_info.get(rule_id).unwrap();
210
211                if rule_info.should_seek {
212                    Some(rule_info.query_rule)
213                } else {
214                    None
215                }
216            })
217            .collect::<Vec<_>>();
218
219        let query_iter_report = self
220            .backend
221            .run_rules(&query_rules)
222            .map_err(|e| Error::BackendError(e.to_string()))?;
223
224        // Step 3: let the scheduler decide which matches need to be kept
225        self.backend.with_execution_state(|state| {
226            for (rule_id, _rule) in rules.iter() {
227                let rule_info = record.rule_info.get_mut(rule_id).unwrap();
228
229                let matches: Vec<Value> =
230                    std::mem::take(rule_info.matches.lock().unwrap().as_mut());
231                let mut matches = Matches::new(matches, rule_info.free_vars.clone());
232                rule_info.should_seek =
233                    record
234                        .scheduler
235                        .filter_matches(rule_id, ruleset, &mut matches);
236                let table_action = TableAction::new(&self.backend, rule_info.decided);
237                *rule_info.matches.lock().unwrap() = matches.instantiate(state, table_action);
238            }
239        });
240        self.backend.flush_updates();
241
242        // Step 4: run the action rules
243        let action_rules = rules
244            .iter()
245            .map(|(rule_id, _rule)| {
246                let rule_info = record.rule_info.get(rule_id).unwrap();
247                rule_info.action_rule
248            })
249            .collect::<Vec<_>>();
250        let action_iter_report = self
251            .backend
252            .run_rules(&action_rules)
253            .map_err(|e| Error::BackendError(e.to_string()))?;
254
255        // Step 5: combine the reports
256        let per_ruleset = |x| [(ruleset.to_owned(), x)].into_iter().collect();
257        let mut report = RunReport::default();
258        report.updated = action_iter_report.changed || {
259            let rule_ids = rules.iter().map(|(id, _)| id.as_str()).collect::<Vec<_>>();
260            !record.scheduler.can_stop(&rule_ids, ruleset)
261        };
262
263        report.search_and_apply_time_per_ruleset = per_ruleset(
264            query_iter_report.search_and_apply_time + action_iter_report.search_and_apply_time,
265        );
266        report.merge_time_per_ruleset =
267            per_ruleset(query_iter_report.merge_time + action_iter_report.merge_time);
268        report.rebuild_time_per_ruleset =
269            per_ruleset(query_iter_report.rebuild_time + action_iter_report.rebuild_time);
270
271        report.search_and_apply_time_per_rule = {
272            let mut map = HashMap::default();
273            for (rule, report) in query_iter_report.rule_reports.iter() {
274                *map.entry(rule.as_str().into())
275                    .or_insert_with(|| Duration::from_nanos(0)) += report.search_and_apply_time;
276            }
277            for (rule, report) in action_iter_report.rule_reports.iter() {
278                *map.entry(rule.as_str().into())
279                    .or_insert_with(|| Duration::from_nanos(0)) += report.search_and_apply_time;
280            }
281            map
282        };
283        report.num_matches_per_rule = action_iter_report
284            .rule_reports
285            .iter()
286            .map(|(rule, report)| (rule.as_str().into(), report.num_matches))
287            .collect();
288
289        self.rulesets = rulesets;
290        self.schedulers = schedulers;
291
292        Ok(report)
293    }
294}
295
296#[derive(Clone)]
297pub(crate) struct SchedulerRecord {
298    scheduler: Box<dyn Scheduler>,
299    rule_info: HashMap<String, SchedulerRuleInfo>,
300}
301
302/// To enable scheduling without modifying the backend,
303/// we split a rule (rule query action) into a worklist relation
304/// two rules (rule query (worklist vars false)) and
305/// (rule (worklist vars false) (action ... (delete (worklist vars false))))
306#[derive(Clone)]
307struct SchedulerRuleInfo {
308    matches: Arc<Mutex<Vec<Value>>>,
309    should_seek: bool,
310    decided: FunctionId,
311    query_rule: RuleId,
312    action_rule: RuleId,
313    free_vars: Vec<ResolvedVar>,
314}
315
316struct CollectMatches {
317    matches: Arc<Mutex<Vec<Value>>>,
318}
319
320impl Clone for CollectMatches {
321    fn clone(&self) -> Self {
322        Self {
323            matches: Arc::new(Mutex::new(self.matches.lock().unwrap().clone())),
324        }
325    }
326}
327
328impl CollectMatches {
329    fn new(matches: Arc<Mutex<Vec<Value>>>) -> Self {
330        Self { matches }
331    }
332}
333
334impl ExternalFunction for CollectMatches {
335    fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
336        self.matches.lock().unwrap().extend(args.iter().copied());
337        Some(state.base_values().get(()))
338    }
339}
340
341impl SchedulerRuleInfo {
342    fn new(egraph: &mut EGraph, rule: &ResolvedCoreRule, name: &str) -> SchedulerRuleInfo {
343        let free_vars = rule.head.get_free_vars().into_iter().collect::<Vec<_>>();
344        let unit_type = egraph.backend.base_values().get_ty::<()>();
345        let unit = egraph.backend.base_values().get(());
346        let unit_entry = egraph.backend.base_value_constant(());
347
348        let matches = Arc::new(Mutex::new(Vec::new()));
349        let collect_matches = egraph
350            .backend
351            .register_external_func(CollectMatches::new(matches.clone()));
352        let schema = free_vars
353            .iter()
354            .map(|v| v.sort.column_ty(&egraph.backend))
355            .chain(std::iter::once(ColumnTy::Base(unit_type)))
356            .collect();
357        let decided = egraph.backend.add_table(FunctionConfig {
358            schema,
359            default: DefaultVal::Const(unit),
360            merge: MergeFn::AssertEq,
361            name: "backend".to_string(),
362            can_subsume: false,
363        });
364
365        // Step 1: build the query rule
366        let mut qrule_builder = BackendRule::new(
367            egraph.backend.new_rule(name, true),
368            &egraph.functions,
369            &egraph.type_info,
370        );
371        qrule_builder.query(&rule.body, true);
372        let entries = free_vars
373            .iter()
374            .map(|fv| qrule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
375            .collect::<Vec<_>>();
376        let _var = qrule_builder.rb.call_external_func(
377            collect_matches,
378            &entries,
379            ColumnTy::Base(unit_type),
380            || "collect_matches".to_string(),
381        );
382        let qrule_id = qrule_builder.build();
383
384        // Step 2: build the action rule
385        let mut arule_builder = BackendRule::new(
386            egraph.backend.new_rule(name, false),
387            &egraph.functions,
388            &egraph.type_info,
389        );
390        let mut entries = free_vars
391            .iter()
392            .map(|fv| arule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
393            .collect::<Vec<_>>();
394        entries.push(unit_entry);
395        arule_builder
396            .rb
397            .query_table(decided, &entries, None)
398            .unwrap();
399        arule_builder.actions(&rule.head).unwrap();
400        // Remove the entry as it's now done
401        entries.pop();
402        arule_builder.rb.remove(decided, &entries);
403        let arule_id = arule_builder.build();
404
405        SchedulerRuleInfo {
406            free_vars,
407            query_rule: qrule_id,
408            action_rule: arule_id,
409            matches,
410            decided,
411            should_seek: true,
412        }
413    }
414}
415
416#[cfg(test)]
417mod test {
418    use super::*;
419
420    #[derive(Clone)]
421    struct FirstNScheduler {
422        n: usize,
423    }
424
425    impl Scheduler for FirstNScheduler {
426        fn filter_matches(&mut self, _rule: &str, _ruleset: &str, matches: &mut Matches) -> bool {
427            if matches.match_size() <= self.n {
428                matches.choose_all();
429            } else {
430                for i in 0..self.n {
431                    matches.choose(i);
432                }
433            }
434            matches.match_size() < self.n * 2
435        }
436    }
437
438    #[test]
439    fn test_first_n_scheduler() {
440        let mut egraph = EGraph::default();
441        let scheduler = FirstNScheduler { n: 10 };
442        let scheduler_id = egraph.add_scheduler(Box::new(scheduler));
443        let input = r#"
444        (relation R (i64))
445        (R 0)
446        (rule ((R x) (< x 100)) ((R (+ x 1))))
447        (run-schedule (saturate (run)))
448
449        (ruleset test)
450        (relation S (i64))
451        (rule ((R x)) ((S x)) :ruleset test :name "test-rule")
452        "#;
453        egraph.parse_and_run_program(None, input).unwrap();
454        assert_eq!(egraph.get_size("R"), 101);
455        let mut iter = 0;
456        loop {
457            let report = egraph
458                .step_rules_with_scheduler(scheduler_id, "test")
459                .unwrap();
460            let table_size = egraph.get_size("S");
461            iter += 1;
462            assert_eq!(table_size, std::cmp::min(iter * 10, 101));
463
464            let expected_matches = if iter <= 10 { 10 } else { 12 - iter };
465            assert_eq!(
466                report.num_matches_per_rule.iter().collect::<Vec<_>>(),
467                [(&"test-rule".to_owned(), &expected_matches)]
468            );
469
470            // Because of semi-naive, the exact rules that are run are more than just `test-rule`
471            assert!(
472                report
473                    .search_and_apply_time_per_rule
474                    .keys()
475                    .all(|k| k.as_str().starts_with("test-rule"))
476            );
477            assert_eq!(
478                report.merge_time_per_ruleset.keys().collect::<Vec<_>>(),
479                ["test"]
480            );
481            assert_eq!(
482                report
483                    .search_and_apply_time_per_ruleset
484                    .keys()
485                    .collect::<Vec<_>>(),
486                ["test"]
487            );
488
489            if !report.updated {
490                break;
491            }
492        }
493
494        assert_eq!(iter, 12);
495    }
496}