use std::fmt;
use indexmap::IndexMap;
use instant::Instant;
use log::*;
use crate::{EGraph, Id, Language, Metadata, RecExpr, Rewrite, SearchMatches};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize))]
#[non_exhaustive]
pub struct Iteration {
pub egraph_nodes: usize,
pub egraph_classes: usize,
pub applied: IndexMap<String, usize>,
pub search_time: f64,
pub apply_time: f64,
pub rebuild_time: f64,
}
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "serde-1",
derive(serde::Serialize),
serde(bound(serialize = "
L: Language + std::fmt::Display,
E: serde::Serialize
"))
)]
#[non_exhaustive]
pub struct RunReport<L, E> {
pub initial_expr: RecExpr<L>,
pub initial_expr_eclass: Id,
pub iterations: Vec<Iteration>,
pub rules_time: f64,
pub stop_reason: E,
}
pub trait Runner<L, M>
where
L: Language,
M: Metadata<L>,
{
type Error: fmt::Debug;
fn pre_step(&mut self, _egraph: &mut EGraph<L, M>) -> Result<(), Self::Error> {
Ok(())
}
fn post_step(
&mut self,
_iteration: &Iteration,
_egraph: &mut EGraph<L, M>,
) -> Result<(), Self::Error> {
Ok(())
}
fn during_step(&mut self, _egraph: &EGraph<L, M>) -> Result<(), Self::Error> {
Ok(())
}
fn search_rewrite(
&mut self,
egraph: &mut EGraph<L, M>,
rewrite: &Rewrite<L, M>,
) -> Vec<SearchMatches> {
rewrite.search(egraph)
}
fn apply_rewrite(
&mut self,
egraph: &mut EGraph<L, M>,
rewrite: &Rewrite<L, M>,
matches: Vec<SearchMatches>,
) -> usize {
rewrite.apply(egraph, &matches).len()
}
fn step(
&mut self,
egraph: &mut EGraph<L, M>,
rules: &[Rewrite<L, M>],
) -> Result<Iteration, Self::Error> {
let egraph_nodes = egraph.total_size();
let egraph_classes = egraph.number_of_classes();
trace!("EGraph {:?}", egraph.dump());
let search_time = Instant::now();
let mut matches = Vec::new();
for rule in rules.iter() {
let ms = self.search_rewrite(egraph, rule);
matches.push(ms);
self.during_step(egraph)?
}
let search_time = search_time.elapsed().as_secs_f64();
info!("Search time: {}", search_time);
let apply_time = Instant::now();
let mut applied = IndexMap::new();
for (rw, ms) in rules.iter().zip(matches) {
let total_matches: usize = ms.iter().map(|m| m.mappings.len()).sum();
if total_matches == 0 {
continue;
}
debug!("Applying {} {} times", rw.name(), total_matches);
let actually_matched = self.apply_rewrite(egraph, rw, ms);
if actually_matched > 0 {
if let Some(count) = applied.get_mut(rw.name()) {
*count += 1;
} else {
applied.insert(rw.name().to_owned(), 1);
}
debug!("Applied {} {} times", rw.name(), actually_matched);
}
self.during_step(egraph)?
}
let apply_time = apply_time.elapsed().as_secs_f64();
info!("Apply time: {}", apply_time);
let rebuild_time = Instant::now();
egraph.rebuild();
let rebuild_time = rebuild_time.elapsed().as_secs_f64();
info!("Rebuild time: {}", rebuild_time);
info!(
"Size: n={}, e={}",
egraph.total_size(),
egraph.number_of_classes()
);
trace!("Running post_step...");
Ok(Iteration {
applied,
egraph_nodes,
egraph_classes,
search_time,
apply_time,
rebuild_time,
})
}
fn run(
&mut self,
egraph: &mut EGraph<L, M>,
rules: &[Rewrite<L, M>],
) -> (Vec<Iteration>, Self::Error) {
let mut iterations = vec![];
let mut fn_loop = || -> Result<(), Self::Error> {
loop {
trace!("Running pre_step...");
self.pre_step(egraph)?;
trace!("Running step...");
iterations.push(self.step(egraph, rules)?);
trace!("Running post_step...");
self.post_step(iterations.last().unwrap(), egraph)?;
}
};
let stop_reason = fn_loop().unwrap_err();
info!("Stopping {:?}", stop_reason);
(iterations, stop_reason)
}
fn run_expr(
&mut self,
initial_expr: RecExpr<L>,
rules: &[Rewrite<L, M>],
) -> (EGraph<L, M>, RunReport<L, Self::Error>) {
let (mut egraph, initial_expr_eclass) = EGraph::from_expr(&initial_expr);
let rules_time = Instant::now();
let (iterations, stop_reason) = self.run(&mut egraph, rules);
let rules_time = rules_time.elapsed().as_secs_f64();
let report = RunReport {
iterations,
rules_time,
stop_reason,
initial_expr,
initial_expr_eclass,
};
(egraph, report)
}
}
pub struct SimpleRunner {
iter_limit: usize,
node_limit: usize,
i: usize,
stats: IndexMap<String, RuleStats>,
initial_match_limit: usize,
ban_length: usize,
}
struct RuleStats {
times_applied: usize,
banned_until: usize,
times_banned: usize,
}
impl Default for SimpleRunner {
fn default() -> Self {
Self {
iter_limit: 30,
node_limit: 10_000,
i: 0,
stats: Default::default(),
initial_match_limit: 1_000,
ban_length: 5,
}
}
}
impl SimpleRunner {
pub fn with_iter_limit(self, iter_limit: usize) -> Self {
Self { iter_limit, ..self }
}
pub fn with_node_limit(self, node_limit: usize) -> Self {
Self { node_limit, ..self }
}
pub fn with_initial_match_limit(self, initial_match_limit: usize) -> Self {
Self {
initial_match_limit,
..self
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize))]
pub enum SimpleRunnerError {
Saturated,
IterationLimit(usize),
NodeLimit(usize),
}
impl<L, M> Runner<L, M> for SimpleRunner
where
L: Language,
M: Metadata<L>,
{
type Error = SimpleRunnerError;
fn pre_step(&mut self, egraph: &mut EGraph<L, M>) -> Result<(), Self::Error> {
info!(
"\n\nIteration {}, n={}, e={}",
self.i,
egraph.total_size(),
egraph.number_of_classes()
);
if self.i >= self.iter_limit {
Err(SimpleRunnerError::IterationLimit(self.i))
} else {
Ok(())
}
}
fn during_step(&mut self, egraph: &EGraph<L, M>) -> Result<(), Self::Error> {
let size = egraph.total_size();
if size > self.node_limit {
Err(SimpleRunnerError::NodeLimit(size))
} else {
Ok(())
}
}
fn post_step(
&mut self,
iteration: &Iteration,
_egraph: &mut EGraph<L, M>,
) -> Result<(), Self::Error> {
let is_banned = |s: &RuleStats| s.banned_until > self.i;
let any_bans = self.stats.values().any(is_banned);
self.i += 1;
if !any_bans && iteration.applied.is_empty() {
Err(SimpleRunnerError::Saturated)
} else {
Ok(())
}
}
fn search_rewrite(
&mut self,
egraph: &mut EGraph<L, M>,
rewrite: &Rewrite<L, M>,
) -> Vec<SearchMatches> {
if let Some(limit) = self.stats.get_mut(rewrite.name()) {
if self.i < limit.banned_until {
debug!(
"Skipping {} ({}-{}), banned until {}...",
rewrite.name(),
limit.times_applied,
limit.times_banned,
limit.banned_until,
);
return vec![];
}
let matches = rewrite.search(egraph);
let total_len: usize = matches.iter().map(|m| m.mappings.len()).sum();
let threshold = self.initial_match_limit << limit.times_banned;
if total_len > threshold {
let ban_length = self.ban_length << limit.times_banned;
limit.times_banned += 1;
limit.banned_until = self.i + ban_length;
info!(
"Banning {} ({}-{}) for {} iters: {} < {}",
rewrite.name(),
limit.times_applied,
limit.times_banned,
ban_length,
threshold,
total_len,
);
vec![]
} else {
limit.times_applied += 1;
matches
}
} else {
self.stats.insert(
rewrite.name().into(),
RuleStats {
times_applied: 0,
banned_until: 0,
times_banned: 0,
},
);
rewrite.search(egraph)
}
}
}