use std::fmt::{self, Debug, Formatter};
use log::*;
use crate::*;
pub struct Runner<L: Language, N: Analysis<L>, IterData = ()> {
pub egraph: EGraph<L, N>,
pub iterations: Vec<Iteration<IterData>>,
pub roots: Vec<Id>,
pub stop_reason: Option<StopReason>,
#[allow(clippy::type_complexity)]
pub hooks: Vec<Box<dyn FnMut(&mut Self) -> Result<(), String>>>,
iter_limit: usize,
node_limit: usize,
time_limit: Duration,
start_time: Option<Instant>,
scheduler: Box<dyn RewriteScheduler<L, N>>,
}
impl<L, N> Default for Runner<L, N, ()>
where
L: Language,
N: Analysis<L> + Default,
{
fn default() -> Self {
Runner::new(N::default())
}
}
impl<L, N, IterData> Debug for Runner<L, N, IterData>
where
L: Language,
N: Analysis<L>,
IterData: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let Runner {
egraph,
iterations,
roots,
stop_reason,
hooks,
iter_limit,
node_limit,
time_limit,
start_time,
scheduler: _,
} = self;
f.debug_struct("Runner")
.field("egraph", egraph)
.field("iterations", iterations)
.field("roots", roots)
.field("stop_reason", stop_reason)
.field("hooks", &vec![format_args!("<dyn FnMut ..>"); hooks.len()])
.field("iter_limit", iter_limit)
.field("node_limit", node_limit)
.field("time_limit", time_limit)
.field("start_time", start_time)
.field("scheduler", &format_args!("<dyn RewriteScheduler ..>"))
.finish()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize))]
pub enum StopReason {
Saturated,
IterationLimit(usize),
NodeLimit(usize),
TimeLimit(f64),
Other(String),
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize))]
#[non_exhaustive]
#[allow(missing_docs)]
pub struct Report {
pub iterations: usize,
pub stop_reason: StopReason,
pub egraph_nodes: usize,
pub egraph_classes: usize,
pub memo_size: usize,
pub rebuilds: usize,
pub total_time: f64,
pub search_time: f64,
pub apply_time: f64,
pub rebuild_time: f64,
}
impl std::fmt::Display for Report {
#[rustfmt::skip]
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
writeln!(f, "Runner report")?;
writeln!(f, "=============")?;
writeln!(f, " Stop reason: {:?}", self.stop_reason)?;
writeln!(f, " Iterations: {}", self.iterations)?;
writeln!(f, " Egraph size: {} nodes, {} classes, {} memo", self.egraph_nodes, self.egraph_classes, self.memo_size)?;
writeln!(f, " Rebuilds: {}", self.rebuilds)?;
writeln!(f, " Total time: {}", self.total_time)?;
writeln!(f, " Search: ({:.2}) {}", self.search_time / self.total_time, self.search_time)?;
writeln!(f, " Apply: ({:.2}) {}", self.apply_time / self.total_time, self.apply_time)?;
writeln!(f, " Rebuild: ({:.2}) {}", self.rebuild_time / self.total_time, self.rebuild_time)?;
Ok(())
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize))]
#[non_exhaustive]
pub struct Iteration<IterData> {
pub egraph_nodes: usize,
pub egraph_classes: usize,
pub applied: IndexMap<Symbol, usize>,
pub hook_time: f64,
pub search_time: f64,
pub apply_time: f64,
pub rebuild_time: f64,
pub total_time: f64,
pub data: IterData,
pub n_rebuilds: usize,
pub stop_reason: Option<StopReason>,
}
type RunnerResult<T> = std::result::Result<T, StopReason>;
impl<L, N, IterData> Runner<L, N, IterData>
where
L: Language,
N: Analysis<L>,
IterData: IterationData<L, N>,
{
pub fn new(analysis: N) -> Self {
Self {
iter_limit: 30,
node_limit: 10_000,
time_limit: Duration::from_secs(5),
egraph: EGraph::new(analysis),
roots: vec![],
iterations: vec![],
stop_reason: None,
hooks: vec![],
start_time: None,
scheduler: Box::new(BackoffScheduler::default()),
}
}
pub fn with_iter_limit(self, iter_limit: usize) -> Self {
Self { iter_limit, ..self }
}
pub fn with_node_limit(self, node_limit: usize) -> Self {
Self { node_limit, ..self }
}
pub fn with_time_limit(self, time_limit: Duration) -> Self {
Self { time_limit, ..self }
}
pub fn with_hook<F>(mut self, hook: F) -> Self
where
F: FnMut(&mut Self) -> Result<(), String> + 'static,
{
self.hooks.push(Box::new(hook));
self
}
pub fn with_scheduler(self, scheduler: impl RewriteScheduler<L, N> + 'static) -> Self {
let scheduler = Box::new(scheduler);
Self { scheduler, ..self }
}
pub fn with_expr(mut self, expr: &RecExpr<L>) -> Self {
let id = self.egraph.add_expr(expr);
self.roots.push(id);
self
}
pub fn with_egraph(self, egraph: EGraph<L, N>) -> Self {
Self { egraph, ..self }
}
pub fn run<'a, R>(mut self, rules: R) -> Self
where
R: IntoIterator<Item = &'a Rewrite<L, N>>,
L: 'a,
N: 'a,
{
let rules: Vec<&Rewrite<L, N>> = rules.into_iter().collect();
check_rules(&rules);
self.egraph.rebuild();
loop {
let iter = self.run_one(&rules);
self.iterations.push(iter);
let stop_reason = self.iterations.last().unwrap().stop_reason.clone();
if let Some(stop_reason) = stop_reason.or_else(|| self.check_limits().err()) {
info!("Stopping: {:?}", stop_reason);
self.stop_reason = Some(stop_reason);
break;
}
}
assert!(!self.iterations.is_empty());
assert!(self.stop_reason.is_some());
self
}
pub fn with_explanations_enabled(mut self) -> Self {
self.egraph = self.egraph.with_explanations_enabled();
self
}
pub fn without_explanation_length_optimization(mut self) -> Self {
self.egraph = self.egraph.without_explanation_length_optimization();
self
}
pub fn with_explanation_length_optimization(mut self) -> Self {
self.egraph = self.egraph.with_explanation_length_optimization();
self
}
pub fn with_explanations_disabled(mut self) -> Self {
self.egraph = self.egraph.with_explanations_disabled();
self
}
pub fn explain_equivalence(&mut self, left: &RecExpr<L>, right: &RecExpr<L>) -> Explanation<L> {
self.egraph.explain_equivalence(left, right)
}
pub fn explain_existance(&mut self, expr: &RecExpr<L>) -> Explanation<L> {
self.egraph.explain_existance(expr)
}
pub fn explain_existance_pattern(
&mut self,
pattern: &PatternAst<L>,
subst: &Subst,
) -> Explanation<L> {
self.egraph.explain_existance_pattern(pattern, subst)
}
pub fn explain_matches(
&mut self,
left: &RecExpr<L>,
right: &PatternAst<L>,
subst: &Subst,
) -> Explanation<L> {
self.egraph.explain_matches(left, right, subst)
}
pub fn print_report(&self) {
println!("{}", self.report())
}
pub fn report(&self) -> Report {
Report {
stop_reason: self.stop_reason.clone().unwrap(),
iterations: self.iterations.len(),
egraph_nodes: self.egraph.total_number_of_nodes(),
egraph_classes: self.egraph.number_of_classes(),
memo_size: self.egraph.total_size(),
rebuilds: self.iterations.iter().map(|i| i.n_rebuilds).sum(),
search_time: self.iterations.iter().map(|i| i.search_time).sum(),
apply_time: self.iterations.iter().map(|i| i.apply_time).sum(),
rebuild_time: self.iterations.iter().map(|i| i.rebuild_time).sum(),
total_time: self.iterations.iter().map(|i| i.total_time).sum(),
}
}
fn run_one(&mut self, rules: &[&Rewrite<L, N>]) -> Iteration<IterData> {
assert!(self.stop_reason.is_none());
info!("\nIteration {}", self.iterations.len());
self.try_start();
let mut result = self.check_limits();
let egraph_nodes = self.egraph.total_size();
let egraph_classes = self.egraph.number_of_classes();
let hook_time = Instant::now();
let mut hooks = std::mem::take(&mut self.hooks);
result = result.and_then(|_| {
hooks
.iter_mut()
.try_for_each(|hook| hook(self).map_err(StopReason::Other))
});
self.hooks = hooks;
let hook_time = hook_time.elapsed().as_secs_f64();
let egraph_nodes_after_hooks = self.egraph.total_size();
let egraph_classes_after_hooks = self.egraph.number_of_classes();
let i = self.iterations.len();
trace!("EGraph {:?}", self.egraph.dump());
let start_time = Instant::now();
let mut matches = Vec::new();
let mut applied = IndexMap::default();
result = result.and_then(|_| {
rules.iter().try_for_each(|rw| {
let ms = self.scheduler.search_rewrite(i, &self.egraph, rw);
matches.push(ms);
self.check_limits()
})
});
let search_time = start_time.elapsed().as_secs_f64();
info!("Search time: {}", search_time);
let apply_time = Instant::now();
result = result.and_then(|_| {
rules.iter().zip(matches).try_for_each(|(rw, ms)| {
let total_matches: usize = ms.iter().map(|m| m.substs.len()).sum();
debug!("Applying {} {} times", rw.name, total_matches);
let actually_matched = self.scheduler.apply_rewrite(i, &mut self.egraph, rw, ms);
if actually_matched > 0 {
if let Some(count) = applied.get_mut(&rw.name) {
*count += actually_matched;
} else {
applied.insert(rw.name.to_owned(), actually_matched);
}
debug!("Applied {} {} times", rw.name, actually_matched);
}
self.check_limits()
})
});
let apply_time = apply_time.elapsed().as_secs_f64();
info!("Apply time: {}", apply_time);
let rebuild_time = Instant::now();
let n_rebuilds = self.egraph.rebuild();
if self.egraph.are_explanations_enabled() {
debug_assert!(self.egraph.check_each_explain(rules));
}
let rebuild_time = rebuild_time.elapsed().as_secs_f64();
info!("Rebuild time: {}", rebuild_time);
info!(
"Size: n={}, e={}",
self.egraph.total_size(),
self.egraph.number_of_classes()
);
let can_be_saturated = applied.is_empty()
&& self.scheduler.can_stop(i)
&& (egraph_nodes == egraph_nodes_after_hooks)
&& (egraph_classes == egraph_classes_after_hooks)
&& (egraph_nodes == self.egraph.total_size())
&& (egraph_classes == self.egraph.number_of_classes());
if can_be_saturated {
result = result.and(Err(StopReason::Saturated))
}
Iteration {
applied,
egraph_nodes,
egraph_classes,
hook_time,
search_time,
apply_time,
rebuild_time,
n_rebuilds,
data: IterData::make(self),
total_time: start_time.elapsed().as_secs_f64(),
stop_reason: result.err(),
}
}
fn try_start(&mut self) {
self.start_time.get_or_insert_with(Instant::now);
}
fn check_limits(&self) -> RunnerResult<()> {
let elapsed = self.start_time.unwrap().elapsed();
if elapsed > self.time_limit {
return Err(StopReason::TimeLimit(elapsed.as_secs_f64()));
}
let size = self.egraph.total_size();
if size > self.node_limit {
return Err(StopReason::NodeLimit(size));
}
if self.iterations.len() >= self.iter_limit {
return Err(StopReason::IterationLimit(self.iterations.len()));
}
Ok(())
}
}
fn check_rules<L, N>(rules: &[&Rewrite<L, N>]) {
let mut name_counts = IndexMap::default();
for rw in rules {
*name_counts.entry(rw.name).or_default() += 1
}
name_counts.retain(|_, count: &mut usize| *count > 1);
if !name_counts.is_empty() {
eprintln!("WARNING: Duplicated rule names may affect rule reporting and scheduling.");
log::warn!("Duplicated rule names may affect rule reporting and scheduling.");
for (name, &count) in name_counts.iter() {
assert!(count > 1);
eprintln!("Rule '{}' appears {} times", name, count);
log::warn!("Rule '{}' appears {} times", name, count);
}
}
}
#[allow(unused_variables)]
pub trait RewriteScheduler<L, N>
where
L: Language,
N: Analysis<L>,
{
fn can_stop(&mut self, iteration: usize) -> bool {
true
}
fn search_rewrite<'a>(
&mut self,
iteration: usize,
egraph: &EGraph<L, N>,
rewrite: &'a Rewrite<L, N>,
) -> Vec<SearchMatches<'a, L>> {
rewrite.search(egraph)
}
fn apply_rewrite(
&mut self,
iteration: usize,
egraph: &mut EGraph<L, N>,
rewrite: &Rewrite<L, N>,
matches: Vec<SearchMatches<L>>,
) -> usize {
rewrite.apply(egraph, &matches).len()
}
}
#[derive(Debug)]
pub struct SimpleScheduler;
impl<L, N> RewriteScheduler<L, N> for SimpleScheduler
where
L: Language,
N: Analysis<L>,
{
}
#[derive(Debug)]
pub struct BackoffScheduler {
default_match_limit: usize,
default_ban_length: usize,
stats: IndexMap<Symbol, RuleStats>,
}
#[derive(Debug)]
struct RuleStats {
times_applied: usize,
banned_until: usize,
times_banned: usize,
match_limit: usize,
ban_length: usize,
}
impl BackoffScheduler {
pub fn with_initial_match_limit(mut self, limit: usize) -> Self {
self.default_match_limit = limit;
self
}
pub fn with_ban_length(mut self, ban_length: usize) -> Self {
self.default_ban_length = ban_length;
self
}
fn rule_stats(&mut self, name: Symbol) -> &mut RuleStats {
if self.stats.contains_key(&name) {
&mut self.stats[&name]
} else {
self.stats.entry(name).or_insert(RuleStats {
times_applied: 0,
banned_until: 0,
times_banned: 0,
match_limit: self.default_match_limit,
ban_length: self.default_ban_length,
})
}
}
pub fn do_not_ban(mut self, name: impl Into<Symbol>) -> Self {
self.rule_stats(name.into()).match_limit = usize::MAX;
self
}
pub fn rule_match_limit(mut self, name: impl Into<Symbol>, limit: usize) -> Self {
self.rule_stats(name.into()).match_limit = limit;
self
}
pub fn rule_ban_length(mut self, name: impl Into<Symbol>, length: usize) -> Self {
self.rule_stats(name.into()).ban_length = length;
self
}
}
impl Default for BackoffScheduler {
fn default() -> Self {
Self {
stats: Default::default(),
default_match_limit: 1_000,
default_ban_length: 5,
}
}
}
impl<L, N> RewriteScheduler<L, N> for BackoffScheduler
where
L: Language,
N: Analysis<L>,
{
fn can_stop(&mut self, iteration: usize) -> bool {
let n_stats = self.stats.len();
let mut banned: Vec<_> = self
.stats
.iter_mut()
.filter(|(_, s)| s.banned_until > iteration)
.collect();
if banned.is_empty() {
true
} else {
let min_ban = banned
.iter()
.map(|(_, s)| s.banned_until)
.min()
.expect("banned cannot be empty here");
assert!(min_ban >= iteration);
let delta = min_ban - iteration;
let mut unbanned = vec![];
for (name, s) in &mut banned {
s.banned_until -= delta;
if s.banned_until == iteration {
unbanned.push(name.as_str());
}
}
assert!(!unbanned.is_empty());
info!(
"Banned {}/{}, fast-forwarded by {} to unban {}",
banned.len(),
n_stats,
delta,
unbanned.join(", "),
);
false
}
}
fn search_rewrite<'a>(
&mut self,
iteration: usize,
egraph: &EGraph<L, N>,
rewrite: &'a Rewrite<L, N>,
) -> Vec<SearchMatches<'a, L>> {
let stats = self.rule_stats(rewrite.name);
if iteration < stats.banned_until {
debug!(
"Skipping {} ({}-{}), banned until {}...",
rewrite.name, stats.times_applied, stats.times_banned, stats.banned_until,
);
return vec![];
}
let threshold = stats.match_limit << stats.times_banned;
let matches = rewrite.search_with_limit(egraph, threshold + 1);
let total_len: usize = matches.iter().map(|m| m.substs.len()).sum();
if total_len > threshold {
let ban_length = stats.ban_length << stats.times_banned;
stats.times_banned += 1;
stats.banned_until = iteration + ban_length;
info!(
"Banning {} ({}-{}) for {} iters: {} < {}",
rewrite.name,
stats.times_applied,
stats.times_banned,
ban_length,
threshold,
total_len,
);
vec![]
} else {
stats.times_applied += 1;
matches
}
}
}
pub trait IterationData<L, N>: Sized
where
L: Language,
N: Analysis<L>,
{
fn make(runner: &Runner<L, N, Self>) -> Self;
}
impl<L, N> IterationData<L, N> for ()
where
L: Language,
N: Analysis<L>,
{
fn make(_: &Runner<L, N, Self>) -> Self {}
}