use crate::{UOp, UOpKey};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::pattern::{Matcher, RewriteResult};
const REWRITE_STACK_LIMIT: usize = 500_000;
#[derive(Clone)]
struct Entry {
n: Arc<UOp>,
stage: u8,
new_n: Arc<UOp>,
}
struct RewriteEngine<'a, PM, BPM, C>
where
PM: Matcher<C>,
BPM: Matcher<C>,
{
pm: Option<&'a PM>,
bpm: Option<&'a BPM>,
ctx: &'a mut C,
replace: HashMap<UOpKey, Arc<UOp>>,
bpm_cache: HashMap<UOpKey, Option<Arc<UOp>>>,
}
impl<'a, PM, BPM, C> RewriteEngine<'a, PM, BPM, C>
where
PM: Matcher<C>,
BPM: Matcher<C>,
{
fn new(pm: Option<&'a PM>, bpm: Option<&'a BPM>, ctx: &'a mut C) -> Self {
Self { pm, bpm, ctx, replace: HashMap::new(), bpm_cache: HashMap::new() }
}
#[inline]
fn pm_rewrite(&mut self, x: &Arc<UOp>) -> Option<Arc<UOp>> {
let pm = self.pm.as_ref()?;
match pm.rewrite(x, self.ctx) {
RewriteResult::Rewritten(new_node) => {
debug_assert!(
!Arc::ptr_eq(&new_node, x),
"PM pattern returned Rewritten but produced the same node (id={}). \
This causes infinite loops. Return NoMatch instead.\nOp: {:?}",
x.id,
x.op().as_ref(),
);
Some(new_node)
}
RewriteResult::Gate(_) | RewriteResult::NoMatch => None,
}
}
#[inline]
fn cached_bpm_rewrite(&mut self, x: &Arc<UOp>) -> Result<Option<Arc<UOp>>, Arc<UOp>> {
let key = UOpKey(x.clone());
if let Some(cached) = self.bpm_cache.get(&key) {
return match cached {
Some(node) => Ok(Some(node.clone())),
None => Ok(None),
};
}
let bpm = self.bpm.as_ref().unwrap();
match bpm.rewrite(x, self.ctx) {
RewriteResult::Rewritten(new_node) => {
debug_assert!(
!Arc::ptr_eq(&new_node, x),
"BPM pattern returned Rewritten but produced the same node (id={}). \
This causes infinite loops. Return NoMatch instead.\nOp: {:?}",
x.id,
x.op().as_ref(),
);
self.bpm_cache.insert(key, Some(new_node.clone()));
Ok(Some(new_node))
}
RewriteResult::Gate(gate_node) => Err(gate_node),
RewriteResult::NoMatch => {
self.bpm_cache.insert(key, None);
Ok(None)
}
}
}
#[inline]
fn record_replace(&mut self, original: &Arc<UOp>, result: Arc<UOp>) {
if !Arc::ptr_eq(original, &result) {
use crate::provenance::{PROVENANCE_TRACKER, PassName};
PROVENANCE_TRACKER.with(|tracker| {
tracker.borrow_mut().record_transform(result.id, original.id, PassName::RewritePattern);
});
}
self.replace.insert(UOpKey(original.clone()), result);
}
#[allow(clippy::mutable_key_type)]
fn rewrite(&mut self, root: Arc<UOp>) -> Arc<UOp> {
let mut stack: Vec<Entry> = vec![Entry { n: root.clone(), stage: 0, new_n: root.clone() }];
let mut on_stack: HashSet<UOpKey> = HashSet::new();
on_stack.insert(UOpKey(root.clone()));
let mut waitlist: HashMap<UOpKey, Vec<Entry>> = HashMap::new();
while let Some(Entry { n, stage, new_n }) = stack.pop() {
if stack.len() > REWRITE_STACK_LIMIT {
panic!(
"infinite loop in graph_rewrite (stack too big: {}). results cached: {}",
stack.len(),
self.replace.len(),
);
}
let n_key = UOpKey(n.clone());
if self.replace.contains_key(&n_key) {
continue;
}
if stage == 0 {
let mut working = new_n;
if self.bpm.is_some() {
let mut seen: HashSet<UOpKey> = HashSet::new();
let mut gated = false;
loop {
let working_key = UOpKey(working.clone());
if seen.contains(&working_key) {
panic!(
"infinite loop in fixed_point_rewrite: node {:?} (id={}) seen twice",
working.op().as_ref(),
working.id
);
}
seen.insert(working_key);
match self.cached_bpm_rewrite(&working) {
Ok(Some(rewritten)) => {
working = rewritten;
}
Ok(None) => break,
Err(gate_node) => {
self.record_replace(&n, gate_node);
if let Some(entries) = waitlist.remove(&n_key) {
stack.extend(entries);
}
gated = true;
break;
}
}
}
if gated {
continue;
}
}
stack.push(Entry { n: n.clone(), stage: 1, new_n: working.clone() });
let sources = working.op().sources();
for child in sources.iter().rev() {
let child_key = UOpKey(child.clone());
if on_stack.contains(&child_key) {
continue;
}
stack.push(Entry { n: child.clone(), stage: 0, new_n: child.clone() });
on_stack.insert(child_key);
}
} else if stage == 1 {
let sources = new_n.op().sources();
let mut tmp: Vec<Arc<UOp>> = Vec::with_capacity(sources.len());
let mut waiting = false;
for src in &sources {
let src_key = UOpKey(src.clone());
if let Some(rx) = self.replace.get(&src_key) {
tmp.push(rx.clone());
} else {
waitlist.entry(src_key).or_default().push(Entry {
n: n.clone(),
stage: 1,
new_n: new_n.clone(),
});
waiting = true;
break;
}
}
if waiting {
continue;
}
let sources_changed = tmp.iter().zip(sources.iter()).any(|(a, b)| !Arc::ptr_eq(a, b));
let node = if sources_changed {
let reconstructed = new_n.with_sources(tmp);
if Arc::ptr_eq(&reconstructed, &new_n) { new_n.clone() } else { reconstructed }
} else {
new_n.clone()
};
if Arc::ptr_eq(&node, &new_n) {
if let Some(new_src_n) = self.pm_rewrite(&new_n) {
stack.push(Entry { n: n.clone(), stage: 2, new_n: new_src_n.clone() });
stack.push(Entry { n: new_src_n.clone(), stage: 0, new_n: new_src_n });
} else {
self.record_replace(&n, new_n);
if let Some(entries) = waitlist.remove(&n_key) {
stack.extend(entries);
}
}
} else {
stack.push(Entry { n: n.clone(), stage: 2, new_n: node.clone() });
stack.push(Entry { n: node.clone(), stage: 0, new_n: node });
}
} else {
let new_n_key = UOpKey(new_n.clone());
if let Some(replaced_new_n) = self.replace.get(&new_n_key).cloned() {
self.record_replace(&n, replaced_new_n);
if let Some(entries) = waitlist.remove(&n_key) {
stack.extend(entries);
}
} else {
waitlist.entry(new_n_key).or_default().push(Entry { n, stage: 2, new_n });
}
}
}
self.replace.get(&UOpKey(root.clone())).cloned().unwrap_or(root)
}
}
pub struct NoMatcher;
impl<C> Matcher<C> for NoMatcher {
fn rewrite(&self, _node: &Arc<UOp>, _ctx: &mut C) -> RewriteResult {
RewriteResult::NoMatch
}
}
pub fn graph_rewrite<M: Matcher<C>, C>(matcher: &M, root: Arc<UOp>, ctx: &mut C) -> Arc<UOp> {
RewriteEngine::new(Some(matcher), None::<&NoMatcher>, ctx).rewrite(root)
}
pub fn graph_rewrite_bottom_up<M: Matcher<C>, C>(matcher: &M, root: Arc<UOp>, ctx: &mut C) -> Arc<UOp> {
RewriteEngine::new(None::<&NoMatcher>, Some(matcher), ctx).rewrite(root)
}
pub fn graph_rewrite_with_bpm<PM, BPM, C>(pm: &PM, bpm: &BPM, root: Arc<UOp>, ctx: &mut C) -> Arc<UOp>
where
PM: Matcher<C>,
BPM: Matcher<C>,
{
RewriteEngine::new(Some(pm), Some(bpm), ctx).rewrite(root)
}