1use std::sync::Arc;
2use std::sync::Mutex;
3use std::time::Duration;
4
5use core_relations::{ExecutionState, ExternalFunction, Value};
6use egglog_bridge::{
7 ColumnTy, DefaultVal, FunctionConfig, FunctionId, MergeFn, RuleId, TableAction,
8};
9use numeric_id::define_id;
10
11use crate::{ast::ResolvedVar, core::GenericAtomTerm, core::ResolvedCoreRule, util::IndexMap, *};
12
13pub trait Scheduler: dyn_clone::DynClone + Send + Sync {
18 fn can_stop(&mut self, rules: &[&str], ruleset: &str) -> bool {
24 let _ = (rules, ruleset);
25 true
26 }
27
28 fn filter_matches(&mut self, rule: &str, ruleset: &str, matches: &mut Matches) -> bool;
33}
34
35dyn_clone::clone_trait_object!(Scheduler);
36
37pub struct Matches {
40 matches: Vec<Value>,
41 chosen: Vec<usize>,
42 vars: Vec<ResolvedVar>,
43 all_chosen: bool,
44}
45
46pub struct Match<'a> {
49 values: &'a [Value],
50 vars: &'a [ResolvedVar],
51}
52
53impl Match<'_> {
54 pub fn get_value(&self, var: &str) -> Value {
56 let idx = self.vars.iter().position(|v| v.name == var).unwrap();
57 self.values[idx]
58 }
59}
60
61impl Matches {
62 fn new(matches: Vec<Value>, vars: Vec<ResolvedVar>) -> Self {
63 let total_len = matches.len();
64 let tuple_len = vars.len();
65 assert!(total_len % tuple_len == 0);
66 Self {
67 matches,
68 vars,
69 chosen: Vec::new(),
70 all_chosen: false,
71 }
72 }
73
74 pub fn match_size(&self) -> usize {
76 self.matches.len() / self.vars.len()
77 }
78
79 pub fn tuple_len(&self) -> usize {
81 self.vars.len()
82 }
83
84 pub fn get_match(&self, idx: usize) -> Match<'_> {
86 Match {
87 values: &self.matches[idx * self.tuple_len()..(idx + 1) * self.tuple_len()],
88 vars: &self.vars,
89 }
90 }
91
92 pub fn choose(&mut self, idx: usize) {
94 self.chosen.push(idx);
95 }
96
97 pub fn choose_all(&mut self) {
101 self.all_chosen = true;
102 }
103
104 fn instantiate(
106 mut self,
107 state: &mut ExecutionState<'_>,
108 mut table_action: TableAction,
109 ) -> Vec<Value> {
110 let tuple_len = self.tuple_len();
111 let unit = state.base_values().get(());
112
113 if self.all_chosen {
114 for row in self.matches.chunks(tuple_len) {
115 table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
116 }
117 vec![]
118 } else {
119 for idx in self.chosen.iter() {
120 let row = &self.matches[idx * tuple_len..(idx + 1) * tuple_len];
121 table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
122 }
123
124 self.chosen.sort_unstable();
126 self.chosen.dedup();
127 let mut p = self.match_size();
128 for c in self.chosen.into_iter().rev() {
129 p -= 1;
132 if c != p {
133 let idx_c = c * tuple_len;
134 let idx_p = p * tuple_len;
135 for i in 0..tuple_len {
136 self.matches.swap(idx_c + i, idx_p + i);
137 }
138 }
139 }
140 self.matches.truncate(p * tuple_len);
141
142 self.matches
143 }
144 }
145}
146
147define_id!(
148 pub SchedulerId, u32,
149 "A unique identifier for a scheduler in the EGraph."
150);
151
152impl EGraph {
153 pub fn add_scheduler(&mut self, scheduler: Box<dyn Scheduler>) -> SchedulerId {
155 self.schedulers.push(SchedulerRecord {
156 scheduler,
157 rule_info: Default::default(),
158 })
159 }
160
161 pub fn remove_scheduler(&mut self, scheduler_id: SchedulerId) -> Option<Box<dyn Scheduler>> {
163 self.schedulers.take(scheduler_id).map(|r| r.scheduler)
164 }
165
166 pub fn step_rules_with_scheduler(
168 &mut self,
169 scheduler_id: SchedulerId,
170 ruleset: &str,
171 ) -> Result<RunReport, Error> {
172 fn collect_rules<'a>(
173 ruleset: &str,
174 rulesets: &'a IndexMap<String, Ruleset>,
175 ids: &mut Vec<(String, &'a ResolvedCoreRule)>,
176 ) {
177 match &rulesets[ruleset] {
178 Ruleset::Rules(rules) => {
179 for (rule_name, (core_rule, _)) in rules.iter() {
180 ids.push((rule_name.clone(), core_rule));
181 }
182 }
183 Ruleset::Combined(sub_rulesets) => {
184 for sub_ruleset in sub_rulesets {
185 collect_rules(sub_ruleset, rulesets, ids);
186 }
187 }
188 }
189 }
190
191 let mut rules = Vec::new();
192 let rulesets = std::mem::take(&mut self.rulesets);
193 collect_rules(ruleset, &rulesets, &mut rules);
194 let mut schedulers = std::mem::take(&mut self.schedulers);
195
196 let record = &mut schedulers[scheduler_id];
198 rules.iter().for_each(|(id, rule)| {
199 record
200 .rule_info
201 .entry((*id).to_owned())
202 .or_insert_with(|| SchedulerRuleInfo::new(self, rule, id));
203 });
204
205 let query_rules = rules
207 .iter()
208 .filter_map(|(rule_id, _rule)| {
209 let rule_info = record.rule_info.get(rule_id).unwrap();
210
211 if rule_info.should_seek {
212 Some(rule_info.query_rule)
213 } else {
214 None
215 }
216 })
217 .collect::<Vec<_>>();
218
219 let query_iter_report = self
220 .backend
221 .run_rules(&query_rules)
222 .map_err(|e| Error::BackendError(e.to_string()))?;
223
224 self.backend.with_execution_state(|state| {
226 for (rule_id, _rule) in rules.iter() {
227 let rule_info = record.rule_info.get_mut(rule_id).unwrap();
228
229 let matches: Vec<Value> =
230 std::mem::take(rule_info.matches.lock().unwrap().as_mut());
231 let mut matches = Matches::new(matches, rule_info.free_vars.clone());
232 rule_info.should_seek =
233 record
234 .scheduler
235 .filter_matches(rule_id, ruleset, &mut matches);
236 let table_action = TableAction::new(&self.backend, rule_info.decided);
237 *rule_info.matches.lock().unwrap() = matches.instantiate(state, table_action);
238 }
239 });
240 self.backend.flush_updates();
241
242 let action_rules = rules
244 .iter()
245 .map(|(rule_id, _rule)| {
246 let rule_info = record.rule_info.get(rule_id).unwrap();
247 rule_info.action_rule
248 })
249 .collect::<Vec<_>>();
250 let action_iter_report = self
251 .backend
252 .run_rules(&action_rules)
253 .map_err(|e| Error::BackendError(e.to_string()))?;
254
255 let per_ruleset = |x| [(ruleset.to_owned(), x)].into_iter().collect();
257 let mut report = RunReport::default();
258 report.updated = action_iter_report.changed || {
259 let rule_ids = rules.iter().map(|(id, _)| id.as_str()).collect::<Vec<_>>();
260 !record.scheduler.can_stop(&rule_ids, ruleset)
261 };
262
263 report.search_and_apply_time_per_ruleset = per_ruleset(
264 query_iter_report.search_and_apply_time + action_iter_report.search_and_apply_time,
265 );
266 report.merge_time_per_ruleset =
267 per_ruleset(query_iter_report.merge_time + action_iter_report.merge_time);
268 report.rebuild_time_per_ruleset =
269 per_ruleset(query_iter_report.rebuild_time + action_iter_report.rebuild_time);
270
271 report.search_and_apply_time_per_rule = {
272 let mut map = HashMap::default();
273 for (rule, report) in query_iter_report.rule_reports.iter() {
274 *map.entry(rule.as_str().into())
275 .or_insert_with(|| Duration::from_nanos(0)) += report.search_and_apply_time;
276 }
277 for (rule, report) in action_iter_report.rule_reports.iter() {
278 *map.entry(rule.as_str().into())
279 .or_insert_with(|| Duration::from_nanos(0)) += report.search_and_apply_time;
280 }
281 map
282 };
283 report.num_matches_per_rule = action_iter_report
284 .rule_reports
285 .iter()
286 .map(|(rule, report)| (rule.as_str().into(), report.num_matches))
287 .collect();
288
289 self.rulesets = rulesets;
290 self.schedulers = schedulers;
291
292 Ok(report)
293 }
294}
295
296#[derive(Clone)]
297pub(crate) struct SchedulerRecord {
298 scheduler: Box<dyn Scheduler>,
299 rule_info: HashMap<String, SchedulerRuleInfo>,
300}
301
302#[derive(Clone)]
307struct SchedulerRuleInfo {
308 matches: Arc<Mutex<Vec<Value>>>,
309 should_seek: bool,
310 decided: FunctionId,
311 query_rule: RuleId,
312 action_rule: RuleId,
313 free_vars: Vec<ResolvedVar>,
314}
315
316struct CollectMatches {
317 matches: Arc<Mutex<Vec<Value>>>,
318}
319
320impl Clone for CollectMatches {
321 fn clone(&self) -> Self {
322 Self {
323 matches: Arc::new(Mutex::new(self.matches.lock().unwrap().clone())),
324 }
325 }
326}
327
328impl CollectMatches {
329 fn new(matches: Arc<Mutex<Vec<Value>>>) -> Self {
330 Self { matches }
331 }
332}
333
334impl ExternalFunction for CollectMatches {
335 fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
336 self.matches.lock().unwrap().extend(args.iter().copied());
337 Some(state.base_values().get(()))
338 }
339}
340
341impl SchedulerRuleInfo {
342 fn new(egraph: &mut EGraph, rule: &ResolvedCoreRule, name: &str) -> SchedulerRuleInfo {
343 let free_vars = rule.head.get_free_vars().into_iter().collect::<Vec<_>>();
344 let unit_type = egraph.backend.base_values().get_ty::<()>();
345 let unit = egraph.backend.base_values().get(());
346 let unit_entry = egraph.backend.base_value_constant(());
347
348 let matches = Arc::new(Mutex::new(Vec::new()));
349 let collect_matches = egraph
350 .backend
351 .register_external_func(CollectMatches::new(matches.clone()));
352 let schema = free_vars
353 .iter()
354 .map(|v| v.sort.column_ty(&egraph.backend))
355 .chain(std::iter::once(ColumnTy::Base(unit_type)))
356 .collect();
357 let decided = egraph.backend.add_table(FunctionConfig {
358 schema,
359 default: DefaultVal::Const(unit),
360 merge: MergeFn::AssertEq,
361 name: "backend".to_string(),
362 can_subsume: false,
363 });
364
365 let mut qrule_builder = BackendRule::new(
367 egraph.backend.new_rule(name, true),
368 &egraph.functions,
369 &egraph.type_info,
370 );
371 qrule_builder.query(&rule.body, true);
372 let entries = free_vars
373 .iter()
374 .map(|fv| qrule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
375 .collect::<Vec<_>>();
376 let _var = qrule_builder.rb.call_external_func(
377 collect_matches,
378 &entries,
379 ColumnTy::Base(unit_type),
380 || "collect_matches".to_string(),
381 );
382 let qrule_id = qrule_builder.build();
383
384 let mut arule_builder = BackendRule::new(
386 egraph.backend.new_rule(name, false),
387 &egraph.functions,
388 &egraph.type_info,
389 );
390 let mut entries = free_vars
391 .iter()
392 .map(|fv| arule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
393 .collect::<Vec<_>>();
394 entries.push(unit_entry);
395 arule_builder
396 .rb
397 .query_table(decided, &entries, None)
398 .unwrap();
399 arule_builder.actions(&rule.head).unwrap();
400 entries.pop();
402 arule_builder.rb.remove(decided, &entries);
403 let arule_id = arule_builder.build();
404
405 SchedulerRuleInfo {
406 free_vars,
407 query_rule: qrule_id,
408 action_rule: arule_id,
409 matches,
410 decided,
411 should_seek: true,
412 }
413 }
414}
415
416#[cfg(test)]
417mod test {
418 use super::*;
419
420 #[derive(Clone)]
421 struct FirstNScheduler {
422 n: usize,
423 }
424
425 impl Scheduler for FirstNScheduler {
426 fn filter_matches(&mut self, _rule: &str, _ruleset: &str, matches: &mut Matches) -> bool {
427 if matches.match_size() <= self.n {
428 matches.choose_all();
429 } else {
430 for i in 0..self.n {
431 matches.choose(i);
432 }
433 }
434 matches.match_size() < self.n * 2
435 }
436 }
437
438 #[test]
439 fn test_first_n_scheduler() {
440 let mut egraph = EGraph::default();
441 let scheduler = FirstNScheduler { n: 10 };
442 let scheduler_id = egraph.add_scheduler(Box::new(scheduler));
443 let input = r#"
444 (relation R (i64))
445 (R 0)
446 (rule ((R x) (< x 100)) ((R (+ x 1))))
447 (run-schedule (saturate (run)))
448
449 (ruleset test)
450 (relation S (i64))
451 (rule ((R x)) ((S x)) :ruleset test :name "test-rule")
452 "#;
453 egraph.parse_and_run_program(None, input).unwrap();
454 assert_eq!(egraph.get_size("R"), 101);
455 let mut iter = 0;
456 loop {
457 let report = egraph
458 .step_rules_with_scheduler(scheduler_id, "test")
459 .unwrap();
460 let table_size = egraph.get_size("S");
461 iter += 1;
462 assert_eq!(table_size, std::cmp::min(iter * 10, 101));
463
464 let expected_matches = if iter <= 10 { 10 } else { 12 - iter };
465 assert_eq!(
466 report.num_matches_per_rule.iter().collect::<Vec<_>>(),
467 [(&"test-rule".to_owned(), &expected_matches)]
468 );
469
470 assert!(
472 report
473 .search_and_apply_time_per_rule
474 .keys()
475 .all(|k| k.as_str().starts_with("test-rule"))
476 );
477 assert_eq!(
478 report.merge_time_per_ruleset.keys().collect::<Vec<_>>(),
479 ["test"]
480 );
481 assert_eq!(
482 report
483 .search_and_apply_time_per_ruleset
484 .keys()
485 .collect::<Vec<_>>(),
486 ["test"]
487 );
488
489 if !report.updated {
490 break;
491 }
492 }
493
494 assert_eq!(iter, 12);
495 }
496}