use std::vec;
use crate::{
basic_block::BasicBlock,
common_traits::Named,
context::{Context, Ptr},
graph::traversals::region::post_order,
identifier::{Identifier, underscore},
irbuild::{
inserter::{BlockInsertionPoint, IRInserter, Inserter, OpInsertionPoint},
listener::RewriteListener,
},
linked_list::{ContainsLinkedList, LinkedList},
location::Located,
op::Op,
operation::Operation,
region::Region,
r#type::{TypeObj, Typed},
value::Value,
};
pub trait Rewriter<L: RewriteListener>: Inserter<L> {
type RewriterConfig: Clone;
fn replace_operation(&mut self, ctx: &mut Context, op: Ptr<Operation>, new_op: Ptr<Operation>);
fn replace_operation_with_values(
&mut self,
ctx: &mut Context,
op: Ptr<Operation>,
new_values: Vec<Value>,
);
fn replace_value_uses_with(&mut self, ctx: &Context, old_value: Value, new_value: Value);
fn erase_operation(&mut self, ctx: &mut Context, op: Ptr<Operation>);
fn erase_block(&mut self, ctx: &mut Context, block: Ptr<BasicBlock>);
fn erase_region(&mut self, ctx: &mut Context, region: Ptr<Region>);
fn unlink_operation(&mut self, ctx: &Context, op: Ptr<Operation>);
fn unlink_block(&mut self, ctx: &Context, block: Ptr<BasicBlock>);
fn move_operation(&mut self, ctx: &Context, op: Ptr<Operation>, new_point: OpInsertionPoint);
fn move_block(&mut self, ctx: &Context, block: Ptr<BasicBlock>, new_point: BlockInsertionPoint);
fn split_block(
&mut self,
ctx: &mut Context,
block: Ptr<BasicBlock>,
position: OpInsertionPoint,
) -> Ptr<BasicBlock>;
fn inline_region(
&mut self,
ctx: &Context,
src_region: Ptr<Region>,
dest_insertion_point: BlockInsertionPoint,
);
fn set_value_type(&mut self, ctx: &Context, value: Value, new_type: Ptr<TypeObj>);
fn is_modified(&self) -> bool;
fn set_modified(&mut self);
fn clear_modified(&mut self);
fn get_config(&self) -> &Self::RewriterConfig;
fn get_config_mut(&mut self) -> &mut Self::RewriterConfig;
}
pub struct IRRewriter<L: RewriteListener, I: Inserter<L> = IRInserter<L>> {
inserter: I,
modified: bool,
config: IRRewriterConfig,
_phantom: std::marker::PhantomData<L>,
}
impl<L: RewriteListener, I: Inserter<L>> Default for IRRewriter<L, I>
where
I: Default,
{
fn default() -> Self {
Self {
inserter: I::default(),
modified: false,
config: IRRewriterConfig::default(),
_phantom: std::marker::PhantomData,
}
}
}
impl<L: RewriteListener, I: Inserter<L>> Inserter<L> for IRRewriter<L, I> {
fn append_operation(&mut self, ctx: &Context, operation: Ptr<Operation>) {
self.inserter.append_operation(ctx, operation)
}
fn append_op(&mut self, ctx: &Context, op: impl Op) {
self.inserter.append_op(ctx, op)
}
fn insert_operation(&mut self, ctx: &Context, operation: Ptr<Operation>) {
self.inserter.insert_operation(ctx, operation)
}
fn insert_op(&mut self, ctx: &Context, op: impl Op) {
self.inserter.insert_op(ctx, op)
}
fn insert_block(
&mut self,
ctx: &Context,
insertion_point: BlockInsertionPoint,
block: Ptr<BasicBlock>,
) {
self.inserter.insert_block(ctx, insertion_point, block)
}
fn create_block(
&mut self,
ctx: &mut Context,
insertion_point: BlockInsertionPoint,
label: Option<Identifier>,
arg_types: Vec<Ptr<TypeObj>>,
) -> Ptr<BasicBlock> {
self.inserter
.create_block(ctx, insertion_point, label, arg_types)
}
fn get_insertion_point(&self) -> OpInsertionPoint {
self.inserter.get_insertion_point()
}
fn get_insertion_block(&self, ctx: &Context) -> Option<Ptr<BasicBlock>> {
self.inserter.get_insertion_block(ctx)
}
fn set_insertion_point(&mut self, point: OpInsertionPoint) {
self.inserter.set_insertion_point(point)
}
fn set_listener(&mut self, listener: L) {
self.inserter.set_listener(listener);
}
fn get_listener(&self) -> &L {
self.inserter.get_listener()
}
fn get_listener_mut(&mut self) -> &mut L {
self.inserter.get_listener_mut()
}
}
#[derive(Clone)]
pub struct IRRewriterConfig {
pub set_loc_on_operation_replacement: bool,
}
impl Default for IRRewriterConfig {
fn default() -> Self {
Self {
set_loc_on_operation_replacement: true,
}
}
}
impl<L: RewriteListener, I: Inserter<L>> Rewriter<L> for IRRewriter<L, I> {
type RewriterConfig = IRRewriterConfig;
fn replace_operation(&mut self, ctx: &mut Context, op: Ptr<Operation>, new_op: Ptr<Operation>) {
if op != new_op && self.config.set_loc_on_operation_replacement {
new_op.deref_mut(ctx).set_loc(op.deref(ctx).loc());
}
let new_values = new_op.deref(ctx).results().collect();
self.replace_operation_with_values(ctx, op, new_values);
}
fn replace_operation_with_values(
&mut self,
ctx: &mut Context,
op: Ptr<Operation>,
new_values: Vec<Value>,
) {
assert!(
op.deref(ctx).get_num_results() == new_values.len(),
"Replacement values must match the number of results of the original operation."
);
let results: Vec<_> = op.deref(ctx).results().zip(new_values).collect();
for (res, new_res) in results {
self.get_listener_mut()
.notify_value_use_replacement(ctx, res, new_res);
res.replace_all_uses_with(ctx, &new_res);
}
self.erase_operation(ctx, op);
self.set_modified();
}
fn replace_value_uses_with(&mut self, ctx: &Context, old_value: Value, new_value: Value) {
if old_value == new_value {
return;
}
self.get_listener_mut()
.notify_value_use_replacement(ctx, old_value, new_value);
old_value.replace_all_uses_with(ctx, &new_value);
self.set_modified();
}
fn erase_operation(&mut self, ctx: &mut Context, op: Ptr<Operation>) {
let regions = op.deref(ctx).regions().collect::<Vec<_>>();
for region in regions.into_iter().rev() {
self.erase_region(ctx, region);
}
self.get_listener_mut().notify_operation_erasure(ctx, op);
Operation::erase(op, ctx);
self.set_modified();
}
fn erase_block(&mut self, ctx: &mut Context, block: Ptr<BasicBlock>) {
let operations = block.deref(ctx).iter(ctx).collect::<Vec<_>>();
for op in operations.into_iter().rev() {
self.erase_operation(ctx, op);
}
self.get_listener_mut().notify_block_erasure(ctx, block);
BasicBlock::erase(block, ctx);
self.set_modified();
}
fn erase_region(&mut self, ctx: &mut Context, region: Ptr<Region>) {
let blocks = post_order(ctx, ®ion);
for block in blocks.iter().rev() {
let operations = block.deref(ctx).iter(ctx).collect::<Vec<_>>();
for op in operations.into_iter().rev() {
self.erase_operation(ctx, op);
}
}
for block in blocks {
self.erase_block(ctx, block);
}
self.get_listener_mut().notify_region_erasure(ctx, region);
let index_in_parent = region.deref(ctx).find_index_in_parent(ctx);
let parent_op = region.deref(ctx).get_parent_op();
Operation::erase_region(parent_op, ctx, index_in_parent);
self.set_modified();
}
fn unlink_operation(&mut self, ctx: &Context, op: Ptr<Operation>) {
self.get_listener_mut().notify_operation_unlinking(ctx, op);
op.unlink(ctx);
self.set_modified();
}
fn unlink_block(&mut self, ctx: &Context, block: Ptr<BasicBlock>) {
self.get_listener_mut().notify_block_unlinking(ctx, block);
block.unlink(ctx);
self.set_modified();
}
fn move_operation(&mut self, ctx: &Context, op: Ptr<Operation>, new_point: OpInsertionPoint) {
self.unlink_operation(ctx, op);
ScopedRewriter::new(self, new_point).insert_operation(ctx, op);
}
fn move_block(
&mut self,
ctx: &Context,
block: Ptr<BasicBlock>,
new_point: BlockInsertionPoint,
) {
self.unlink_block(ctx, block);
self.insert_block(ctx, new_point, block);
}
fn split_block(
&mut self,
ctx: &mut Context,
block: Ptr<BasicBlock>,
position: OpInsertionPoint,
) -> Ptr<BasicBlock> {
let mut rewriter = ScopedRewriter::new(self, OpInsertionPoint::Unset);
let label = block
.deref(ctx)
.given_name(ctx)
.map(|label| label + underscore() + "split".try_into().unwrap());
let new_block =
rewriter.create_block(ctx, BlockInsertionPoint::AfterBlock(block), label, vec![]);
let first_op_opt = match position {
OpInsertionPoint::AtBlockStart(target_block) => {
target_block.deref(ctx).iter(ctx).next()
}
OpInsertionPoint::AtBlockEnd(_target_block) => None,
OpInsertionPoint::BeforeOperation(op) => Some(op),
OpInsertionPoint::AfterOperation(op) => op.deref(ctx).get_next(),
OpInsertionPoint::Unset => panic!("Cannot split block at unset insertion point."),
};
let mut current_op_opt = first_op_opt;
while let Some(current_op) = current_op_opt {
let next_op = current_op.deref(ctx).get_next();
rewriter.move_operation(ctx, current_op, OpInsertionPoint::AtBlockEnd(new_block));
current_op_opt = next_op;
}
new_block
}
fn inline_region(
&mut self,
ctx: &Context,
src_region: Ptr<Region>,
dest_insertion_point: BlockInsertionPoint,
) {
assert!(
src_region
!= dest_insertion_point
.get_insertion_region(ctx)
.expect("Insertion point itself is not in a Region"),
"Cannot inline a region into itself."
);
let blocks: Vec<_> = src_region.deref(ctx).iter(ctx).collect();
let mut insertion_pt = dest_insertion_point;
for block in blocks {
self.move_block(ctx, block, insertion_pt);
insertion_pt = BlockInsertionPoint::AfterBlock(block);
}
}
fn set_value_type(&mut self, ctx: &Context, value: Value, new_type: Ptr<TypeObj>) {
let old_type = value.get_type(ctx);
if old_type == new_type {
return;
}
self.get_listener_mut()
.notify_value_type_change(ctx, value, old_type, new_type);
value.set_type(ctx, new_type);
self.set_modified();
}
fn is_modified(&self) -> bool {
self.modified
}
fn set_modified(&mut self) {
self.modified = true;
}
fn clear_modified(&mut self) {
self.modified = false;
}
fn get_config(&self) -> &Self::RewriterConfig {
&self.config
}
fn get_config_mut(&mut self) -> &mut Self::RewriterConfig {
&mut self.config
}
}
pub struct ScopedRewriter<'a, L: RewriteListener, R: Rewriter<L>> {
rewriter: &'a mut R,
prev_insertion_point: OpInsertionPoint,
prev_config: R::RewriterConfig,
_phantom: std::marker::PhantomData<L>,
}
impl<'a, L: RewriteListener, R: Rewriter<L>> ScopedRewriter<'a, L, R> {
pub fn new(rewriter: &'a mut R, insertion_point: OpInsertionPoint) -> Self {
let prev_insertion_point = rewriter.get_insertion_point();
let prev_config = rewriter.get_config().clone();
rewriter.set_insertion_point(insertion_point);
Self {
rewriter,
prev_insertion_point,
prev_config,
_phantom: std::marker::PhantomData,
}
}
}
impl<'a, L: RewriteListener, R: Rewriter<L>> Drop for ScopedRewriter<'a, L, R> {
fn drop(&mut self) {
self.rewriter.set_insertion_point(self.prev_insertion_point);
*self.rewriter.get_config_mut() = self.prev_config.clone();
}
}
impl<'a, L: RewriteListener, R: Rewriter<L>> Inserter<L> for ScopedRewriter<'a, L, R> {
fn append_operation(&mut self, ctx: &Context, operation: Ptr<Operation>) {
self.rewriter.append_operation(ctx, operation)
}
fn append_op(&mut self, ctx: &Context, op: impl Op) {
self.rewriter.append_op(ctx, op)
}
fn insert_operation(&mut self, ctx: &Context, operation: Ptr<Operation>) {
self.rewriter.insert_operation(ctx, operation)
}
fn insert_op(&mut self, ctx: &Context, op: impl Op) {
self.rewriter.insert_op(ctx, op)
}
fn insert_block(
&mut self,
ctx: &Context,
insertion_point: BlockInsertionPoint,
block: Ptr<BasicBlock>,
) {
self.rewriter.insert_block(ctx, insertion_point, block)
}
fn create_block(
&mut self,
ctx: &mut Context,
insertion_point: BlockInsertionPoint,
label: Option<Identifier>,
arg_types: Vec<Ptr<TypeObj>>,
) -> Ptr<BasicBlock> {
self.rewriter
.create_block(ctx, insertion_point, label, arg_types)
}
fn get_insertion_point(&self) -> OpInsertionPoint {
self.rewriter.get_insertion_point()
}
fn get_insertion_block(&self, ctx: &Context) -> Option<Ptr<BasicBlock>> {
self.rewriter.get_insertion_block(ctx)
}
fn set_insertion_point(&mut self, point: OpInsertionPoint) {
self.rewriter.set_insertion_point(point)
}
fn set_listener(&mut self, listener: L) {
self.rewriter.set_listener(listener)
}
fn get_listener(&self) -> &L {
self.rewriter.get_listener()
}
fn get_listener_mut(&mut self) -> &mut L {
self.rewriter.get_listener_mut()
}
}
impl<'a, L: RewriteListener, R: Rewriter<L>> Rewriter<L> for ScopedRewriter<'a, L, R> {
type RewriterConfig = R::RewriterConfig;
fn replace_operation(&mut self, ctx: &mut Context, op: Ptr<Operation>, new_op: Ptr<Operation>) {
self.rewriter.replace_operation(ctx, op, new_op)
}
fn replace_operation_with_values(
&mut self,
ctx: &mut Context,
op: Ptr<Operation>,
new_values: Vec<Value>,
) {
self.rewriter
.replace_operation_with_values(ctx, op, new_values)
}
fn replace_value_uses_with(&mut self, ctx: &Context, old_value: Value, new_value: Value) {
self.rewriter
.replace_value_uses_with(ctx, old_value, new_value)
}
fn erase_operation(&mut self, ctx: &mut Context, op: Ptr<Operation>) {
self.rewriter.erase_operation(ctx, op)
}
fn erase_block(&mut self, ctx: &mut Context, block: Ptr<BasicBlock>) {
self.rewriter.erase_block(ctx, block)
}
fn erase_region(&mut self, ctx: &mut Context, region: Ptr<Region>) {
self.rewriter.erase_region(ctx, region)
}
fn unlink_operation(&mut self, ctx: &Context, op: Ptr<Operation>) {
self.rewriter.unlink_operation(ctx, op)
}
fn unlink_block(&mut self, ctx: &Context, block: Ptr<BasicBlock>) {
self.rewriter.unlink_block(ctx, block)
}
fn move_operation(&mut self, ctx: &Context, op: Ptr<Operation>, new_point: OpInsertionPoint) {
self.rewriter.move_operation(ctx, op, new_point)
}
fn move_block(
&mut self,
ctx: &Context,
block: Ptr<BasicBlock>,
new_point: BlockInsertionPoint,
) {
self.rewriter.move_block(ctx, block, new_point)
}
fn split_block(
&mut self,
ctx: &mut Context,
block: Ptr<BasicBlock>,
position: OpInsertionPoint,
) -> Ptr<BasicBlock> {
self.rewriter.split_block(ctx, block, position)
}
fn inline_region(
&mut self,
ctx: &Context,
src_region: Ptr<Region>,
dest_insertion_point: BlockInsertionPoint,
) {
self.rewriter
.inline_region(ctx, src_region, dest_insertion_point)
}
fn set_value_type(&mut self, ctx: &Context, value: Value, new_type: Ptr<TypeObj>) {
self.rewriter.set_value_type(ctx, value, new_type)
}
fn is_modified(&self) -> bool {
self.rewriter.is_modified()
}
fn set_modified(&mut self) {
self.rewriter.set_modified()
}
fn clear_modified(&mut self) {
self.rewriter.clear_modified()
}
fn get_config(&self) -> &Self::RewriterConfig {
self.rewriter.get_config()
}
fn get_config_mut(&mut self) -> &mut Self::RewriterConfig {
self.rewriter.get_config_mut()
}
}