use std::collections::VecDeque;
use rustc_hash::FxHashSet;
use crate::{
context::{Context, Ptr},
graph::walkers::{IRNode, WALKCONFIG_PREORDER_FORWARD, uninterruptible::immutable::walk_op},
irbuild::{
inserter::{Inserter, OpInsertionPoint},
listener::{Recorder, RecorderEvent},
rewriter::IRRewriter,
},
operation::Operation,
result::Result,
};
pub type MatchRewriter = IRRewriter<Recorder>;
pub trait MatchRewrite {
fn r#match(&mut self, ctx: &Context, op: Ptr<Operation>) -> bool;
fn rewrite(
&mut self,
ctx: &mut Context,
rewriter: &mut MatchRewriter,
op: Ptr<Operation>,
) -> Result<()>;
}
pub fn apply_match_rewrite<M: MatchRewrite>(
ctx: &mut Context,
mut match_rewrite: M,
op: Ptr<Operation>,
) -> Result<()> {
let mut to_rewrite = VecDeque::new();
struct WalkerState<'a, M> {
match_rewrite: &'a mut M,
to_rewrite: &'a mut VecDeque<Ptr<Operation>>,
}
let mut state = WalkerState {
match_rewrite: &mut match_rewrite,
to_rewrite: &mut to_rewrite,
};
fn walker_callback<M: MatchRewrite>(ctx: &Context, state: &mut WalkerState<M>, node: IRNode) {
if let IRNode::Operation(op) = node
&& state.match_rewrite.r#match(ctx, op)
{
state.to_rewrite.push_back(op);
}
}
walk_op(
ctx,
&mut state,
&WALKCONFIG_PREORDER_FORWARD,
op,
walker_callback,
);
let mut erased = FxHashSet::<Ptr<Operation>>::default();
let mut rewriter = MatchRewriter::default();
rewriter.set_listener(Recorder::default());
while !to_rewrite.is_empty() {
let op = to_rewrite.pop_front().unwrap();
if erased.contains(&op) {
continue;
}
rewriter.set_insertion_point(OpInsertionPoint::BeforeOperation(op));
match_rewrite.rewrite(ctx, &mut rewriter, op)?;
let listener = rewriter.get_listener_mut();
for event in &listener.events {
if let RecorderEvent::ErasedOperation(erased_op) = event {
erased.insert(*erased_op);
}
}
for event in &listener.events {
match event {
RecorderEvent::ErasedOperation(_) => {
}
RecorderEvent::InsertedOperation(new_op) => {
if !erased.contains(new_op) && match_rewrite.r#match(ctx, *new_op) {
to_rewrite.push_back(*new_op);
}
}
RecorderEvent::ReplacedValueUses { .. } => {
}
RecorderEvent::InsertedBlock(_) => {
}
RecorderEvent::ErasedBlock(_) => {
}
RecorderEvent::ErasedRegion(_) => {
}
RecorderEvent::ValueTypeChanged { .. } => {
}
RecorderEvent::UnlinkedOperation(_op, _prev_position) => {
}
RecorderEvent::UnlinkedBlock(_block, _prev_position) => {
}
}
}
listener.clear();
}
Ok(())
}