use std::sync::Arc;
use std::sync::Mutex;
use core_relations::{ExecutionState, ExternalFunction, Value};
use egglog_bridge::{
ColumnTy, DefaultVal, FunctionConfig, FunctionId, MergeFn, RuleId, TableAction,
};
use egglog_reports::RunReport;
use numeric_id::define_id;
use crate::{ast::ResolvedVar, core::GenericAtomTerm, core::ResolvedCoreRule, util::IndexMap, *};
pub trait Scheduler: dyn_clone::DynClone + Send + Sync {
fn can_stop(&mut self, rules: &[&str], ruleset: &str) -> bool {
let _ = (rules, ruleset);
true
}
fn filter_matches(&mut self, rule: &str, ruleset: &str, matches: &mut Matches) -> bool;
}
dyn_clone::clone_trait_object!(Scheduler);
pub struct Matches {
matches: Vec<Value>,
chosen: Vec<usize>,
vars: Vec<ResolvedVar>,
all_chosen: bool,
}
pub struct Match<'a> {
values: &'a [Value],
vars: &'a [ResolvedVar],
}
impl Match<'_> {
pub fn get_value(&self, var: &str) -> Value {
let idx = self.vars.iter().position(|v| v.name == var).unwrap();
self.values[idx]
}
}
impl Matches {
fn new(matches: Vec<Value>, vars: Vec<ResolvedVar>) -> Self {
let total_len = matches.len();
let tuple_len = vars.len();
assert!(total_len % tuple_len == 0);
Self {
matches,
vars,
chosen: Vec::new(),
all_chosen: false,
}
}
pub fn match_size(&self) -> usize {
self.matches.len() / self.vars.len()
}
pub fn tuple_len(&self) -> usize {
self.vars.len()
}
pub fn get_match(&self, idx: usize) -> Match<'_> {
Match {
values: &self.matches[idx * self.tuple_len()..(idx + 1) * self.tuple_len()],
vars: &self.vars,
}
}
pub fn choose(&mut self, idx: usize) {
self.chosen.push(idx);
}
pub fn choose_all(&mut self) {
self.all_chosen = true;
}
fn instantiate(
mut self,
state: &mut ExecutionState<'_>,
mut table_action: TableAction,
) -> Vec<Value> {
let tuple_len = self.tuple_len();
let unit = state.base_values().get(());
if self.all_chosen {
for row in self.matches.chunks(tuple_len) {
table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
}
vec![]
} else {
for idx in self.chosen.iter() {
let row = &self.matches[idx * tuple_len..(idx + 1) * tuple_len];
table_action.insert(state, row.iter().cloned().chain(std::iter::once(unit)));
}
self.chosen.sort_unstable();
self.chosen.dedup();
let mut p = self.match_size();
for c in self.chosen.into_iter().rev() {
p -= 1;
if c != p {
let idx_c = c * tuple_len;
let idx_p = p * tuple_len;
for i in 0..tuple_len {
self.matches.swap(idx_c + i, idx_p + i);
}
}
}
self.matches.truncate(p * tuple_len);
self.matches
}
}
}
define_id!(
pub SchedulerId, u32,
"A unique identifier for a scheduler in the EGraph."
);
impl EGraph {
pub fn add_scheduler(&mut self, scheduler: Box<dyn Scheduler>) -> SchedulerId {
self.schedulers.push(SchedulerRecord {
scheduler,
rule_info: Default::default(),
})
}
pub fn remove_scheduler(&mut self, scheduler_id: SchedulerId) -> Option<Box<dyn Scheduler>> {
self.schedulers.take(scheduler_id).map(|r| r.scheduler)
}
pub fn step_rules_with_scheduler(
&mut self,
scheduler_id: SchedulerId,
ruleset: &str,
) -> Result<RunReport, Error> {
fn collect_rules<'a>(
ruleset: &str,
rulesets: &'a IndexMap<String, Ruleset>,
ids: &mut Vec<(String, &'a ResolvedCoreRule)>,
) {
match &rulesets[ruleset] {
Ruleset::Rules(rules) => {
for (rule_name, (core_rule, _)) in rules.iter() {
ids.push((rule_name.clone(), core_rule));
}
}
Ruleset::Combined(sub_rulesets) => {
for sub_ruleset in sub_rulesets {
collect_rules(sub_ruleset, rulesets, ids);
}
}
}
}
let mut rules = Vec::new();
let rulesets = std::mem::take(&mut self.rulesets);
collect_rules(ruleset, &rulesets, &mut rules);
let mut schedulers = std::mem::take(&mut self.schedulers);
let record = &mut schedulers[scheduler_id];
rules.iter().for_each(|(id, rule)| {
record
.rule_info
.entry((*id).to_owned())
.or_insert_with(|| SchedulerRuleInfo::new(self, rule, id));
});
let query_rules = rules
.iter()
.filter_map(|(rule_id, _rule)| {
let rule_info = record.rule_info.get(rule_id).unwrap();
if rule_info.should_seek {
Some(rule_info.query_rule)
} else {
None
}
})
.collect::<Vec<_>>();
let query_iter_report = self
.backend
.run_rules(&query_rules)
.map_err(|e| Error::BackendError(e.to_string()))?;
self.backend.with_execution_state(|state| {
for (rule_id, _rule) in rules.iter() {
let rule_info = record.rule_info.get_mut(rule_id).unwrap();
let matches: Vec<Value> =
std::mem::take(rule_info.matches.lock().unwrap().as_mut());
let mut matches = Matches::new(matches, rule_info.free_vars.clone());
rule_info.should_seek =
record
.scheduler
.filter_matches(rule_id, ruleset, &mut matches);
let table_action = TableAction::new(&self.backend, rule_info.decided);
*rule_info.matches.lock().unwrap() = matches.instantiate(state, table_action);
}
});
self.backend.flush_updates();
let action_rules = rules
.iter()
.map(|(rule_id, _rule)| {
let rule_info = record.rule_info.get(rule_id).unwrap();
rule_info.action_rule
})
.collect::<Vec<_>>();
let action_iter_report = self
.backend
.run_rules(&action_rules)
.map_err(|e| Error::BackendError(e.to_string()))?;
let mut query_report = RunReport::singleton(ruleset, query_iter_report);
let mut action_report = RunReport::singleton(ruleset, action_iter_report);
query_report.updated = false;
query_report.num_matches_per_rule.clear();
action_report.updated = action_report.updated || {
let rule_ids = rules.iter().map(|(id, _)| id.as_str()).collect::<Vec<_>>();
!record.scheduler.can_stop(&rule_ids, ruleset)
};
query_report.union(action_report);
self.rulesets = rulesets;
self.schedulers = schedulers;
Ok(query_report)
}
}
#[derive(Clone)]
pub(crate) struct SchedulerRecord {
scheduler: Box<dyn Scheduler>,
rule_info: HashMap<String, SchedulerRuleInfo>,
}
#[derive(Clone)]
struct SchedulerRuleInfo {
matches: Arc<Mutex<Vec<Value>>>,
should_seek: bool,
decided: FunctionId,
query_rule: RuleId,
action_rule: RuleId,
free_vars: Vec<ResolvedVar>,
}
struct CollectMatches {
matches: Arc<Mutex<Vec<Value>>>,
}
impl Clone for CollectMatches {
fn clone(&self) -> Self {
Self {
matches: Arc::new(Mutex::new(self.matches.lock().unwrap().clone())),
}
}
}
impl CollectMatches {
fn new(matches: Arc<Mutex<Vec<Value>>>) -> Self {
Self { matches }
}
}
impl ExternalFunction for CollectMatches {
fn invoke(&self, state: &mut core_relations::ExecutionState, args: &[Value]) -> Option<Value> {
self.matches.lock().unwrap().extend(args.iter().copied());
Some(state.base_values().get(()))
}
}
impl SchedulerRuleInfo {
fn new(egraph: &mut EGraph, rule: &ResolvedCoreRule, name: &str) -> SchedulerRuleInfo {
let free_vars = rule.head.get_free_vars().into_iter().collect::<Vec<_>>();
let unit_type = egraph.backend.base_values().get_ty::<()>();
let unit = egraph.backend.base_values().get(());
let unit_entry = egraph.backend.base_value_constant(());
let matches = Arc::new(Mutex::new(Vec::new()));
let collect_matches = egraph
.backend
.register_external_func(Box::new(CollectMatches::new(matches.clone())));
let schema = free_vars
.iter()
.map(|v| v.sort.column_ty(&egraph.backend))
.chain(std::iter::once(ColumnTy::Base(unit_type)))
.collect();
let decided = egraph.backend.add_table(FunctionConfig {
schema,
default: DefaultVal::Const(unit),
merge: MergeFn::AssertEq,
name: "backend".to_string(),
can_subsume: false,
});
let mut qrule_builder = BackendRule::new(
egraph.backend.new_rule(name, true),
&egraph.functions,
&egraph.type_info,
);
qrule_builder.query(&rule.body, true);
let entries = free_vars
.iter()
.map(|fv| qrule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
.collect::<Vec<_>>();
let _var = qrule_builder.rb.call_external_func(
collect_matches,
&entries,
ColumnTy::Base(unit_type),
|| "collect_matches".to_string(),
);
let qrule_id = qrule_builder.build();
let mut arule_builder = BackendRule::new(
egraph.backend.new_rule(name, false),
&egraph.functions,
&egraph.type_info,
);
let mut entries = free_vars
.iter()
.map(|fv| arule_builder.entry(&GenericAtomTerm::Var(span!(), fv.clone())))
.collect::<Vec<_>>();
entries.push(unit_entry);
arule_builder
.rb
.query_table(decided, &entries, None)
.unwrap();
arule_builder.actions(&rule.head).unwrap();
entries.pop();
arule_builder.rb.remove(decided, &entries);
let arule_id = arule_builder.build();
SchedulerRuleInfo {
free_vars,
query_rule: qrule_id,
action_rule: arule_id,
matches,
decided,
should_seek: true,
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[derive(Clone)]
struct FirstNScheduler {
n: usize,
}
impl Scheduler for FirstNScheduler {
fn filter_matches(&mut self, _rule: &str, _ruleset: &str, matches: &mut Matches) -> bool {
if matches.match_size() <= self.n {
matches.choose_all();
} else {
for i in 0..self.n {
matches.choose(i);
}
}
matches.match_size() < self.n * 2
}
}
#[test]
fn test_first_n_scheduler() {
let mut egraph = EGraph::default();
let scheduler = FirstNScheduler { n: 10 };
let scheduler_id = egraph.add_scheduler(Box::new(scheduler));
let input = r#"
(relation R (i64))
(R 0)
(rule ((R x) (< x 100)) ((R (+ x 1))))
(run-schedule (saturate (run)))
(ruleset test)
(relation S (i64))
(rule ((R x)) ((S x)) :ruleset test :name "test-rule")
"#;
egraph.parse_and_run_program(None, input).unwrap();
assert_eq!(egraph.get_size("R"), 101);
let mut iter = 0;
loop {
let report = egraph
.step_rules_with_scheduler(scheduler_id, "test")
.unwrap();
let table_size = egraph.get_size("S");
iter += 1;
assert_eq!(table_size, std::cmp::min(iter * 10, 101));
let expected_matches = if iter <= 10 { 10 } else { 12 - iter };
assert_eq!(
report.num_matches_per_rule.iter().collect::<Vec<_>>(),
[(&"test-rule".into(), &expected_matches)]
);
assert!(
report
.search_and_apply_time_per_rule
.keys()
.all(|k| k.starts_with("test-rule"))
);
assert_eq!(
report.merge_time_per_ruleset.keys().collect::<Vec<_>>(),
[&"test".into()]
);
assert_eq!(
report
.search_and_apply_time_per_ruleset
.keys()
.collect::<Vec<_>>(),
[&"test".into()]
);
if !report.updated {
break;
}
}
assert_eq!(iter, 12);
}
}