use alloc::{boxed::Box, format, rc::Rc};
use core::ops::{Deref, DerefMut};
use midenc_session::diagnostics::PrintDiagnostic;
use smallvec::SmallVec;
use crate::{
BlockRef, Builder, Context, InsertionGuard, Listener, ListenerType, OpBuilder, OpOperandImpl,
OperationRef, PostOrderBlockIter, ProgramPoint, RegionRef, Report, SourceSpan, Usable,
ValueRef,
formatter::{DisplayOptional, DisplayValues},
patterns::Pattern,
};
pub trait Rewriter: Builder + RewriterListener {
fn has_listener(&self) -> bool;
fn replace_op_with_values(&mut self, op: OperationRef, values: &[Option<ValueRef>]) {
assert_eq!(op.borrow().num_results(), values.len());
self.replace_all_op_uses_with_values(op, values);
self.erase_op(op);
}
fn replace_op(&mut self, op: OperationRef, new_op: OperationRef) {
assert_eq!(op.borrow().num_results(), new_op.borrow().num_results());
self.replace_all_op_uses_with(op, new_op);
self.erase_op(op);
}
fn erase_op(&mut self, mut op: OperationRef) {
assert!(!op.borrow().is_used(), "expected op to have no uses");
if !self.has_listener() {
op.borrow_mut().erase();
return;
}
fn erase_single_op<R: ?Sized + RewriterListener>(
mut operation: OperationRef,
rewrite_listener: &mut R,
) {
let op = operation.borrow();
if cfg!(debug_assertions) {
assert!(op.regions().iter().all(|r| r.is_empty()), "expected empty regions");
if op.is_used()
&& let Some(region) = op.parent_region()
{
assert!(region.borrow().may_be_graph_region(), "expected that op has no uses");
}
}
rewrite_listener.notify_operation_erased(operation);
drop(op);
let mut op = operation.borrow_mut();
op.drop_all_uses();
op.erase();
}
fn erase_tree<R: ?Sized + Rewriter>(op: OperationRef, rewriter: &mut R) {
let mut next_region = op.borrow().regions().front().as_pointer();
while let Some(region) = next_region.take() {
next_region = region.next();
let mut erased_blocks = SmallVec::<[BlockRef; 4]>::default();
let mut region_entry = region.borrow().entry_block_ref();
while let Some(entry) = region_entry.take() {
erased_blocks.clear();
for block in PostOrderBlockIter::new(entry) {
let mut next_op = block.borrow().body().front().as_pointer();
while let Some(op) = next_op.take() {
next_op = op.next();
erase_tree(op, rewriter);
}
erased_blocks.push(block);
}
for mut block in erased_blocks.drain(..) {
for arg in block.borrow_mut().arguments_mut() {
arg.borrow_mut().uses_mut().clear();
}
block.borrow_mut().drop_all_uses();
rewriter.erase_block(block);
}
region_entry = region.borrow().entry_block_ref();
}
}
erase_single_op(op, rewriter);
}
erase_tree(op, self);
}
fn erase_block(&mut self, mut block: BlockRef) {
assert!(!block.borrow().is_used(), "expected 'block' to be unused");
let mut next_op = block.borrow().body().back().as_pointer();
while let Some(op) = next_op.take() {
next_op = op.prev();
assert!(!op.borrow().is_used(), "expected 'op' to be unused");
self.erase_op(op);
}
self.notify_block_erased(block);
block.borrow_mut().erase();
}
fn inline_region_before(&mut self, mut region: RegionRef, mut ip: RegionRef) {
assert!(!RegionRef::ptr_eq(®ion, &ip), "cannot inline a region into itself");
log::trace!(target: "rewriter", "inlining blocks of {region} into {ip}");
let region_body = region.borrow_mut().body_mut().take();
if !self.has_listener() {
let mut parent_region = ip.borrow_mut();
let parent_body = parent_region.body_mut();
let mut cursor = parent_body.front_mut();
cursor.splice_before(region_body);
} else {
let ip = ip.borrow().entry_block_ref().unwrap();
for block in region_body {
self.move_block_before(block, ip);
}
}
}
fn inline_block_before(
&mut self,
mut src: BlockRef,
mut dest: BlockRef,
ip: Option<OperationRef>,
args: &[Option<ValueRef>],
) {
assert!(
args.len() == src.borrow().num_arguments(),
"incorrect # of argument replacement values"
);
assert!(!src.borrow().has_predecessors(), "expected 'src' to have no predecessors");
let insert_at_block_end = if let Some(ip) = ip {
let ip_block = ip.parent().expect("expected 'ip' to belong to a block");
assert_eq!(ip_block, dest, "invalid insertion point: must be an op in 'dest'");
ip.next().is_none()
} else {
true
};
if insert_at_block_end {
assert!(!dest.borrow().has_successors(), "expected 'dest' to have no successors");
} else {
assert!(!src.borrow().has_successors(), "expected 'src' to have no successors");
}
for (arg, replacement) in src.borrow().arguments().iter().copied().zip(args.iter().copied())
{
if let Some(replacement) = replacement {
self.replace_all_uses_of_value_with(arg.upcast(), replacement);
}
}
if self.has_listener() {
let mut src_mut = src.borrow_mut();
let mut src_cursor = src_mut.body_mut().front_mut();
while let Some(op) = src_cursor.remove() {
if insert_at_block_end {
self.insert_op_at_end(op, dest);
} else {
self.insert_op_before(op, ip.unwrap());
}
}
} else {
let mut dest_block = dest.borrow_mut();
if let Some(ip) = ip {
dest_block.splice_block_before(&mut src.borrow_mut(), ip);
} else {
dest_block.splice_block(&mut src.borrow_mut());
}
}
assert!(src.borrow().body().is_empty(), "expected 'src' to be empty");
self.erase_block(src);
}
fn merge_blocks(&mut self, src: BlockRef, dest: BlockRef, args: &[Option<ValueRef>]) {
log::trace!(
target: "rewriter",
"merging {src} into {dest} replacing uses of its block arguments with {}",
DisplayValues::new(args.iter().map(|v| DisplayOptional(v.as_ref())))
);
let ip = dest.borrow().body().back().as_pointer();
self.inline_block_before(src, dest, ip, args);
}
fn split_block(&mut self, mut block: BlockRef, ip: OperationRef) -> BlockRef {
if !self.has_listener() {
return block.borrow_mut().split_block(ip);
}
assert_eq!(
block,
ip.parent().expect("expected 'ip' to be attached to a block"),
"expected 'ip' to be in 'block'"
);
let region =
block.parent().expect("cannot split a block which is not attached to a region");
let mut guard = InsertionGuard::new(self);
let new_block = guard.create_block(region, Some(block), &[]);
if ip.next().is_none() {
return new_block;
}
let mut block_mut = block.borrow_mut();
let mut cursor = block_mut.body_mut().back_mut();
let ip = new_block.borrow().body().front().as_pointer().unwrap();
while let Some(op) = cursor.remove() {
let is_last_move = OperationRef::ptr_eq(&op, &ip);
guard.insert_op_before(op, ip);
if is_last_move {
break;
}
}
new_block
}
fn move_block_before(&mut self, mut block: BlockRef, ip: BlockRef) {
let current_region = block.parent();
if current_region.is_none() {
block.borrow_mut().insert_before(ip);
} else {
block.borrow_mut().move_before(ip);
}
self.notify_block_inserted(block, current_region, Some(ip));
}
fn move_op_before(&mut self, mut op: OperationRef, ip: OperationRef) {
let prev = ProgramPoint::before(op);
op.borrow_mut().move_to(ProgramPoint::before(ip));
self.notify_operation_inserted(op, prev);
}
fn move_op_after(&mut self, mut op: OperationRef, ip: OperationRef) {
let prev = ProgramPoint::before(op);
op.borrow_mut().move_to(ProgramPoint::after(ip));
self.notify_operation_inserted(op, prev);
}
fn move_op_to_end(&mut self, mut op: OperationRef, ip: BlockRef) {
let prev = ProgramPoint::before(op);
op.borrow_mut().move_to(ProgramPoint::at_end_of(ip));
self.notify_operation_inserted(op, prev);
}
fn insert_op_before(&mut self, mut op: OperationRef, ip: OperationRef) {
let prev = ProgramPoint::before(op);
op.borrow_mut().as_operation_ref().insert_before(ip);
self.notify_operation_inserted(op, prev);
}
fn insert_op_after(&mut self, mut op: OperationRef, ip: OperationRef) {
let prev = ProgramPoint::before(op);
op.borrow_mut().as_operation_ref().insert_after(ip);
self.notify_operation_inserted(op, prev);
}
fn insert_op_at_end(&mut self, op: OperationRef, ip: BlockRef) {
let prev = ProgramPoint::before(op);
op.insert_at_end(ip);
self.notify_operation_inserted(op, prev);
}
fn replace_all_uses_of_value_with(&mut self, mut from: ValueRef, mut to: ValueRef) {
let mut from_val = from.borrow_mut();
let from_uses = from_val.uses_mut();
let mut cursor = from_uses.front_mut();
while let Some(mut operand) = cursor.remove() {
let op = operand.borrow().owner;
self.notify_operation_modification_started(&op);
operand.borrow_mut().value = Some(to);
to.borrow_mut().insert_use(operand);
self.notify_operation_modified(op);
}
}
fn replace_all_uses_of_block_with(&mut self, mut from: BlockRef, mut to: BlockRef) {
let mut from_block = from.borrow_mut();
let from_uses = from_block.uses_mut();
let mut cursor = from_uses.front_mut();
while let Some(operand) = cursor.remove() {
let op = operand.borrow().owner;
self.notify_operation_modification_started(&op);
to.borrow_mut().insert_use(operand);
self.notify_operation_modified(op);
}
}
fn replace_all_uses_with(&mut self, from: &[ValueRef], to: &[Option<ValueRef>]) {
assert_eq!(from.len(), to.len(), "incorrect number of replacements");
for (from, to) in from.iter().cloned().zip(to.iter().cloned()) {
if let Some(to) = to {
self.replace_all_uses_of_value_with(from, to);
}
}
}
fn replace_all_op_uses_with_values(&mut self, from: OperationRef, to: &[Option<ValueRef>]) {
self.notify_operation_replaced_with_values(from, to);
let results = from
.borrow()
.results()
.all()
.iter()
.copied()
.map(|result| result as ValueRef)
.collect::<SmallVec<[ValueRef; 2]>>();
self.replace_all_uses_with(&results, to);
}
fn replace_all_op_uses_with(&mut self, from: OperationRef, to: OperationRef) {
self.notify_operation_replaced(from, to);
let from_results = from
.borrow()
.results()
.all()
.iter()
.copied()
.map(|result| result as ValueRef)
.collect::<SmallVec<[ValueRef; 2]>>();
let to_results = to
.borrow()
.results()
.all()
.iter()
.copied()
.map(|result| Some(result as ValueRef))
.collect::<SmallVec<[Option<ValueRef>; 2]>>();
self.replace_all_uses_with(&from_results, &to_results);
}
fn replace_op_uses_within_block(
&mut self,
from: OperationRef,
to: &[ValueRef],
block: BlockRef,
) -> bool {
let parent_op = block.grandparent();
self.maybe_replace_op_uses_with(from, to, |operand| {
!parent_op
.as_ref()
.is_some_and(|op| op.borrow().is_proper_ancestor_of(&operand.owner.borrow()))
})
}
fn replace_all_uses_except(
&mut self,
from: ValueRef,
to: ValueRef,
exceptions: &[OperationRef],
) {
self.maybe_replace_uses_of_value_with(from, to, |operand| {
!exceptions.contains(&operand.owner)
});
}
}
pub trait RewriterExt: Rewriter {
fn modify_op_in_place(&mut self, op: OperationRef) -> InPlaceModificationGuard<'_, Self> {
InPlaceModificationGuard::new(self, op)
}
fn maybe_replace_uses_of_value_with<P>(
&mut self,
mut from: ValueRef,
mut to: ValueRef,
should_replace: P,
) -> bool
where
P: Fn(&OpOperandImpl) -> bool,
{
let mut all_replaced = true;
let mut from = from.borrow_mut();
let from_uses = from.uses_mut();
let mut cursor = from_uses.front_mut();
while let Some(user) = cursor.as_pointer() {
if should_replace(&user.borrow()) {
let owner = user.borrow().owner;
self.notify_operation_modification_started(&owner);
let operand = cursor.remove().unwrap();
to.borrow_mut().insert_use(operand);
self.notify_operation_modified(owner);
} else {
all_replaced = false;
cursor.move_next();
}
}
all_replaced
}
fn maybe_replace_uses_with<P>(
&mut self,
from: &[ValueRef],
to: &[ValueRef],
should_replace: P,
) -> bool
where
P: Fn(&OpOperandImpl) -> bool,
{
assert_eq!(from.len(), to.len(), "incorrect number of replacements");
let mut all_replaced = true;
for (from, to) in from.iter().cloned().zip(to.iter().cloned()) {
all_replaced &= self.maybe_replace_uses_of_value_with(from, to, &should_replace);
}
all_replaced
}
fn maybe_replace_op_uses_with<P>(
&mut self,
from: OperationRef,
to: &[ValueRef],
should_replace: P,
) -> bool
where
P: Fn(&OpOperandImpl) -> bool,
{
let results = SmallVec::<[ValueRef; 2]>::from_iter(
from.borrow().results.all().iter().cloned().map(|result| result as ValueRef),
);
self.maybe_replace_uses_with(&results, to, should_replace)
}
}
impl<R: ?Sized + Rewriter> RewriterExt for R {}
#[allow(unused_variables)]
pub trait RewriterListener: Listener {
fn notify_block_erased(&self, block: BlockRef) {}
fn notify_operation_modification_started(&self, op: &OperationRef) {}
fn notify_operation_modification_canceled(&self, op: &OperationRef) {}
fn notify_operation_modified(&self, op: OperationRef) {}
fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
let replacement = replacement.borrow();
let values = replacement
.results()
.all()
.iter()
.cloned()
.map(|result| Some(result as ValueRef))
.collect::<SmallVec<[Option<ValueRef>; 2]>>();
self.notify_operation_replaced_with_values(op, &values);
}
fn notify_operation_replaced_with_values(
&self,
op: OperationRef,
replacement: &[Option<ValueRef>],
) {
}
fn notify_operation_erased(&self, op: OperationRef) {}
fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {}
fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {}
fn notify_match_failure(&self, span: SourceSpan, reason: Report) {}
}
impl<L: RewriterListener> RewriterListener for Option<L> {
fn notify_block_erased(&self, block: BlockRef) {
if let Some(listener) = self.as_ref() {
listener.notify_block_erased(block);
}
}
fn notify_operation_modification_started(&self, op: &OperationRef) {
if let Some(listener) = self.as_ref() {
listener.notify_operation_modification_started(op);
}
}
fn notify_operation_modification_canceled(&self, op: &OperationRef) {
if let Some(listener) = self.as_ref() {
listener.notify_operation_modification_canceled(op);
}
}
fn notify_operation_modified(&self, op: OperationRef) {
if let Some(listener) = self.as_ref() {
listener.notify_operation_modified(op);
}
}
fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
if let Some(listener) = self.as_ref() {
listener.notify_operation_replaced(op, replacement);
}
}
fn notify_operation_replaced_with_values(
&self,
op: OperationRef,
replacement: &[Option<ValueRef>],
) {
if let Some(listener) = self.as_ref() {
listener.notify_operation_replaced_with_values(op, replacement);
}
}
fn notify_operation_erased(&self, op: OperationRef) {
if let Some(listener) = self.as_ref() {
listener.notify_operation_erased(op);
}
}
fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
if let Some(listener) = self.as_ref() {
listener.notify_pattern_begin(pattern, op);
}
}
fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
if let Some(listener) = self.as_ref() {
listener.notify_pattern_end(pattern, success);
}
}
fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
if let Some(listener) = self.as_ref() {
listener.notify_match_failure(span, reason);
}
}
}
impl<L: ?Sized + RewriterListener> RewriterListener for Box<L> {
fn notify_block_erased(&self, block: BlockRef) {
(**self).notify_block_erased(block);
}
fn notify_operation_modification_started(&self, op: &OperationRef) {
(**self).notify_operation_modification_started(op);
}
fn notify_operation_modification_canceled(&self, op: &OperationRef) {
(**self).notify_operation_modification_canceled(op);
}
fn notify_operation_modified(&self, op: OperationRef) {
(**self).notify_operation_modified(op);
}
fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
(**self).notify_operation_replaced(op, replacement);
}
fn notify_operation_replaced_with_values(
&self,
op: OperationRef,
replacement: &[Option<ValueRef>],
) {
(**self).notify_operation_replaced_with_values(op, replacement);
}
fn notify_operation_erased(&self, op: OperationRef) {
(**self).notify_operation_erased(op)
}
fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
(**self).notify_pattern_begin(pattern, op);
}
fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
(**self).notify_pattern_end(pattern, success);
}
fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
(**self).notify_match_failure(span, reason);
}
}
impl<L: ?Sized + RewriterListener> RewriterListener for Rc<L> {
fn notify_block_erased(&self, block: BlockRef) {
(**self).notify_block_erased(block);
}
fn notify_operation_modification_started(&self, op: &OperationRef) {
(**self).notify_operation_modification_started(op);
}
fn notify_operation_modification_canceled(&self, op: &OperationRef) {
(**self).notify_operation_modification_canceled(op);
}
fn notify_operation_modified(&self, op: OperationRef) {
(**self).notify_operation_modified(op);
}
fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
(**self).notify_operation_replaced(op, replacement);
}
fn notify_operation_replaced_with_values(
&self,
op: OperationRef,
replacement: &[Option<ValueRef>],
) {
(**self).notify_operation_replaced_with_values(op, replacement);
}
fn notify_operation_erased(&self, op: OperationRef) {
(**self).notify_operation_erased(op)
}
fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
(**self).notify_pattern_begin(pattern, op);
}
fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
(**self).notify_pattern_end(pattern, success);
}
fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
(**self).notify_match_failure(span, reason);
}
}
pub struct NoopRewriterListener;
impl Listener for NoopRewriterListener {
#[inline]
fn kind(&self) -> ListenerType {
ListenerType::Rewriter
}
#[inline(always)]
fn notify_operation_inserted(&self, _op: OperationRef, _prev: ProgramPoint) {}
#[inline(always)]
fn notify_block_inserted(
&self,
_block: BlockRef,
_prev: Option<RegionRef>,
_ip: Option<BlockRef>,
) {
}
}
impl RewriterListener for NoopRewriterListener {
#[inline(always)]
fn notify_operation_replaced(&self, _op: OperationRef, _replacement: OperationRef) {}
}
pub struct TracingRewriterListener;
impl Listener for TracingRewriterListener {
#[inline]
fn kind(&self) -> ListenerType {
ListenerType::Rewriter
}
#[inline]
fn notify_operation_inserted(&self, _op: OperationRef, _prev: ProgramPoint) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
let name = _op.name();
let (event, direction) = if _prev.is_valid() {
("moved", "from")
} else {
("inserted", "at")
};
if let Some(symbol) = _op.borrow().as_symbol() {
log::trace!(
target: "rewriter",
symbol = symbol.name().as_str(),
dialect = name.dialect().as_str(),
op = name.name().as_str(),
rewrite_event = event;
"{event} '{name}' {direction} {_prev}"
);
} else {
log::trace!(
target: "rewriter",
dialect = name.dialect().as_str(),
op = name.name().as_str(),
rewrite_event = event;
"{event} '{name}' {direction} {_prev}",
);
}
}
}
#[inline]
fn notify_block_inserted(
&self,
_block: BlockRef,
_prev: Option<RegionRef>,
_ip: Option<BlockRef>,
) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
match (_prev, _ip) {
(None, None) => {
log::trace!(
target: "rewriter",
rewrite_event = "created";
"created {_block}"
);
}
(None, Some(ip)) => {
log::trace!(
target: "rewriter",
rewrite_event = "inserted";
"inserted {_block} at {ip}"
);
}
(Some(prev), Some(ip)) => {
log::trace!(
target: "rewriter",
rewrite_event = "moved";
"moved {_block} from {prev} to {ip}"
);
}
_ => unreachable!(),
}
}
}
}
impl RewriterListener for TracingRewriterListener {
fn notify_operation_replaced(&self, _op: OperationRef, _replacement: OperationRef) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
log::trace!(
target: "rewriter",
rewrite_event = "replaced";
"replaced {_op} with {_replacement}"
);
}
}
fn notify_block_erased(&self, _block: BlockRef) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
log::trace!(
target: "rewriter",
rewrite_event = "erased";
"erased {_block}"
);
}
}
fn notify_operation_modification_started(&self, _op: &OperationRef) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
let name = _op.name();
log::trace!(
target: "rewriter",
dialect = name.dialect().as_str(),
op = name.name().as_str(),
rewrite_event = "modification-started";
"starting modification of {_op}"
);
}
}
fn notify_operation_modification_canceled(&self, _op: &OperationRef) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
let name = _op.name();
log::trace!(
target: "rewriter",
dialect = name.dialect().as_str(),
op = name.name().as_str(),
rewrite_event = "modification-canceled";
"canceled modification"
);
}
}
fn notify_operation_modified(&self, _op: OperationRef) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
let name = _op.name();
log::trace!(
target: "rewriter",
dialect = name.dialect().as_str(),
op = name.name().as_str(),
rewrite_event = "modified";
"completed modification of {_op}"
);
}
}
fn notify_operation_replaced_with_values(
&self,
_op: OperationRef,
_replacement: &[Option<ValueRef>],
) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
let name = _op.name();
log::trace!(
target: "rewriter",
dialect = name.dialect().as_str(),
op = name.name().as_str(),
rewrite_event = "replaced";
"replaced op with {}: {_op}",
DisplayValues::new(_replacement.iter().map(|v| {
DisplayOptional(v.as_ref())
}))
);
}
}
fn notify_operation_erased(&self, _op: OperationRef) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
let name = _op.name();
log::trace!(
target: "rewriter",
dialect = name.dialect().as_str(),
op = name.name().as_str(),
rewrite_event = "erased";
"erased op {_op}"
);
}
}
fn notify_pattern_begin(&self, _pattern: &dyn Pattern, _op: OperationRef) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
let name = _op.name();
log::trace!(
target: "rewriter",
dialect = name.dialect().as_str(),
op = name.name().as_str(),
pattern = _pattern.name(),
rewrite_event = "pattern-begin";
"attempting pattern against {_op}"
);
}
}
fn notify_pattern_end(&self, _pattern: &dyn Pattern, _success: bool) {
if log::log_enabled!(target: "rewriter", log::Level::Trace) {
let outcome = if _success {
"matched successfully"
} else {
"failed"
};
log::trace!(
target: "rewriter",
pattern = _pattern.name(),
rewrite_event = "pattern-end";
"pattern {outcome}"
);
}
}
fn notify_match_failure(&self, _span: SourceSpan, _reason: Report) {
if log::log_enabled!(target: "rewriter", log::Level::Error) {
let diag = PrintDiagnostic::new(_reason);
log::error!(
target: "rewriter",
rewrite_event = "match-failure";
"match failed with: {diag}",
);
}
}
}
pub struct ForwardingListener<Base, Derived> {
base: Base,
derived: Derived,
}
impl<Base, Derived> ForwardingListener<Base, Derived> {
pub fn new(base: Base, derived: Derived) -> Self {
Self { base, derived }
}
}
impl<Base: Listener, Derived: Listener> Listener for ForwardingListener<Base, Derived> {
fn kind(&self) -> ListenerType {
self.derived.kind()
}
fn notify_block_inserted(
&self,
block: BlockRef,
prev: Option<RegionRef>,
ip: Option<BlockRef>,
) {
self.base.notify_block_inserted(block, prev, ip);
self.derived.notify_block_inserted(block, prev, ip);
}
fn notify_operation_inserted(&self, op: OperationRef, prev: ProgramPoint) {
self.base.notify_operation_inserted(op, prev);
self.derived.notify_operation_inserted(op, prev);
}
}
impl<Base: RewriterListener, Derived: RewriterListener> RewriterListener
for ForwardingListener<Base, Derived>
{
fn notify_block_erased(&self, block: BlockRef) {
self.base.notify_block_erased(block);
self.derived.notify_block_erased(block);
}
fn notify_operation_modification_started(&self, op: &OperationRef) {
self.base.notify_operation_modification_started(op);
self.derived.notify_operation_modification_started(op);
}
fn notify_operation_modification_canceled(&self, op: &OperationRef) {
self.base.notify_operation_modification_canceled(op);
self.derived.notify_operation_modification_canceled(op);
}
fn notify_operation_modified(&self, op: OperationRef) {
self.base.notify_operation_modified(op);
self.derived.notify_operation_modified(op);
}
fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
self.base.notify_operation_replaced(op, replacement);
self.derived.notify_operation_replaced(op, replacement);
}
fn notify_operation_replaced_with_values(
&self,
op: OperationRef,
replacement: &[Option<ValueRef>],
) {
self.base.notify_operation_replaced_with_values(op, replacement);
self.derived.notify_operation_replaced_with_values(op, replacement);
}
fn notify_operation_erased(&self, op: OperationRef) {
self.base.notify_operation_erased(op);
self.derived.notify_operation_erased(op);
}
fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
self.base.notify_pattern_begin(pattern, op);
self.derived.notify_pattern_begin(pattern, op);
}
fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
self.base.notify_pattern_end(pattern, success);
self.derived.notify_pattern_end(pattern, success);
}
fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
let err = Report::msg(format!("{reason}"));
self.base.notify_match_failure(span, reason);
self.derived.notify_match_failure(span, err);
}
}
pub struct InPlaceModificationGuard<'a, R: ?Sized + Rewriter> {
rewriter: &'a mut R,
op: OperationRef,
canceled: bool,
}
impl<'a, R> InPlaceModificationGuard<'a, R>
where
R: ?Sized + Rewriter,
{
pub fn new(rewriter: &'a mut R, op: OperationRef) -> Self {
rewriter.notify_operation_modification_started(&op);
Self {
rewriter,
op,
canceled: false,
}
}
#[inline]
pub fn rewriter(&mut self) -> &mut R {
self.rewriter
}
#[inline]
pub fn op(&self) -> &OperationRef {
&self.op
}
pub fn cancel(mut self) {
self.canceled = true;
}
pub fn finalize(self) {}
}
impl<R: ?Sized + Rewriter> core::ops::Deref for InPlaceModificationGuard<'_, R> {
type Target = R;
#[inline(always)]
fn deref(&self) -> &Self::Target {
self.rewriter
}
}
impl<R: ?Sized + Rewriter> core::ops::DerefMut for InPlaceModificationGuard<'_, R> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
self.rewriter
}
}
impl<R: ?Sized + Rewriter> Drop for InPlaceModificationGuard<'_, R> {
fn drop(&mut self) {
if self.canceled {
self.rewriter.notify_operation_modification_canceled(&self.op);
} else {
self.rewriter.notify_operation_modified(self.op);
}
}
}
pub struct PatternRewriter<L = NoopRewriterListener> {
rewriter: RewriterImpl<L>,
recoverable: bool,
}
impl PatternRewriter {
pub fn new(context: Rc<Context>) -> Self {
let rewriter = RewriterImpl::new(context);
Self {
rewriter,
recoverable: false,
}
}
pub fn from_builder(builder: OpBuilder) -> Self {
let (context, _, ip) = builder.into_parts();
let mut rewriter = RewriterImpl::new(context);
rewriter.restore_insertion_point(ip);
Self {
rewriter,
recoverable: false,
}
}
}
impl<L: RewriterListener> PatternRewriter<L> {
pub fn new_with_listener(context: Rc<Context>, listener: L) -> Self {
let rewriter = RewriterImpl::<NoopRewriterListener>::new(context).with_listener(listener);
Self {
rewriter,
recoverable: false,
}
}
#[inline]
pub const fn can_recover_from_rewrite_failure(&self) -> bool {
self.recoverable
}
}
impl<L> Deref for PatternRewriter<L> {
type Target = RewriterImpl<L>;
#[inline(always)]
fn deref(&self) -> &Self::Target {
&self.rewriter
}
}
impl<L> DerefMut for PatternRewriter<L> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.rewriter
}
}
pub struct RewriterImpl<L = NoopRewriterListener> {
context: Rc<Context>,
listener: Option<L>,
ip: ProgramPoint,
}
impl<L> RewriterImpl<L> {
pub fn new(context: Rc<Context>) -> Self {
Self {
context,
listener: None,
ip: ProgramPoint::default(),
}
}
pub fn with_listener<L2>(self, listener: L2) -> RewriterImpl<L2>
where
L2: Listener,
{
RewriterImpl {
context: self.context,
listener: Some(listener),
ip: self.ip,
}
}
}
impl<L: RewriterListener> From<OpBuilder<L>> for RewriterImpl<L> {
#[inline]
fn from(builder: OpBuilder<L>) -> Self {
let (context, listener, ip) = builder.into_parts();
Self {
context,
listener,
ip,
}
}
}
impl<L: Listener> Builder for RewriterImpl<L> {
#[inline(always)]
fn context(&self) -> &Context {
&self.context
}
#[inline(always)]
fn context_rc(&self) -> Rc<Context> {
self.context.clone()
}
#[inline(always)]
fn insertion_point(&self) -> &ProgramPoint {
&self.ip
}
#[inline(always)]
fn clear_insertion_point(&mut self) -> ProgramPoint {
let ip = self.ip;
self.ip = ProgramPoint::Invalid;
ip
}
#[inline(always)]
fn restore_insertion_point(&mut self, ip: ProgramPoint) {
self.ip = ip;
}
#[inline(always)]
fn set_insertion_point(&mut self, ip: ProgramPoint) {
self.ip = ip;
}
}
impl<L: RewriterListener> Rewriter for RewriterImpl<L> {
#[inline(always)]
fn has_listener(&self) -> bool {
self.listener.is_some()
}
}
impl<L: Listener> Listener for RewriterImpl<L> {
fn kind(&self) -> ListenerType {
ListenerType::Rewriter
}
fn notify_operation_inserted(&self, op: OperationRef, prev: ProgramPoint) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_operation_inserted(op, prev);
}
}
fn notify_block_inserted(
&self,
block: BlockRef,
prev: Option<RegionRef>,
ip: Option<BlockRef>,
) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_block_inserted(block, prev, ip);
}
}
}
impl<L: RewriterListener> RewriterListener for RewriterImpl<L> {
fn notify_block_erased(&self, block: BlockRef) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_block_erased(block);
}
}
fn notify_operation_modification_started(&self, op: &OperationRef) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_operation_modification_started(op);
}
}
fn notify_operation_modification_canceled(&self, op: &OperationRef) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_operation_modification_canceled(op);
}
}
fn notify_operation_modified(&self, op: OperationRef) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_operation_modified(op);
}
}
fn notify_operation_replaced(&self, op: OperationRef, replacement: OperationRef) {
if self.listener.is_some() {
let replacement = replacement.borrow();
let values = replacement
.results()
.all()
.iter()
.cloned()
.map(|result| Some(result.upcast()))
.collect::<SmallVec<[Option<ValueRef>; 2]>>();
self.notify_operation_replaced_with_values(op, &values);
}
}
fn notify_operation_replaced_with_values(
&self,
op: OperationRef,
replacement: &[Option<ValueRef>],
) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_operation_replaced_with_values(op, replacement);
}
}
fn notify_operation_erased(&self, op: OperationRef) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_operation_erased(op);
}
}
fn notify_pattern_begin(&self, pattern: &dyn Pattern, op: OperationRef) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_pattern_begin(pattern, op);
}
}
fn notify_pattern_end(&self, pattern: &dyn Pattern, success: bool) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_pattern_end(pattern, success);
}
}
fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
if let Some(listener) = self.listener.as_ref() {
listener.notify_match_failure(span, reason);
}
}
}