use alloc::rc::Rc;
use midenc_hir::{
patterns::{Pattern, PatternBenefit, PatternInfo, PatternKind, RewritePattern},
*,
};
use crate::*;
pub struct RemoveLoopInvariantArgsFromBeforeBlock {
info: PatternInfo,
}
impl RemoveLoopInvariantArgsFromBeforeBlock {
pub fn new(context: Rc<Context>) -> Self {
let scf_dialect = context.get_or_register_dialect::<ScfDialect>();
let while_op = scf_dialect.registered_name::<While>().expect("scf.while is not registered");
Self {
info: PatternInfo::new(
context,
"remove-loop-invariant-args-from-before-block",
PatternKind::Operation(while_op),
PatternBenefit::MAX,
),
}
}
}
impl Pattern for RemoveLoopInvariantArgsFromBeforeBlock {
fn info(&self) -> &PatternInfo {
&self.info
}
}
impl RewritePattern for RemoveLoopInvariantArgsFromBeforeBlock {
fn match_and_rewrite(
&self,
operation: OperationRef,
rewriter: &mut dyn Rewriter,
) -> Result<bool, Report> {
let op = operation.borrow();
let Some(while_op) = op.downcast_ref::<While>() else {
return Ok(false);
};
let before_block = while_op.before().entry_block_ref().unwrap();
let before_args = before_block
.borrow()
.arguments()
.iter()
.map(|arg| arg.borrow().as_value_ref())
.collect::<SmallVec<[_; 4]>>();
let cond_op = while_op.condition_op();
let cond_op_args = cond_op
.borrow()
.forwarded()
.into_iter()
.map(|o| o.borrow().as_value_ref())
.collect::<SmallVec<[_; 4]>>();
let yield_op = while_op.yield_op();
let yield_op_args = yield_op
.borrow()
.yielded()
.into_iter()
.map(|o| o.borrow().as_value_ref())
.collect::<SmallVec<[_; 4]>>();
let mut can_simplify = false;
for (index, (init_value, yield_arg)) in while_op
.inits()
.into_iter()
.map(|o| o.borrow().as_value_ref())
.zip(yield_op_args.iter().copied())
.enumerate()
{
if yield_arg == init_value {
can_simplify = true;
break;
}
if let Ok(yield_op_block_arg) = yield_arg.try_downcast_value::<BlockArgument>() {
let cond_op_arg = cond_op_args[yield_op_block_arg.borrow().index()];
if cond_op_arg == before_args[index] || cond_op_arg == init_value {
can_simplify = true;
break;
}
}
}
if !can_simplify {
return Ok(false);
}
let mut new_init_args = SmallVec::<[ValueRef; 4]>::default();
let mut new_yield_args = SmallVec::<[ValueRef; 4]>::default();
let mut before_block_init_val_map = SmallVec::<[Option<ValueRef>; 8]>::default();
before_block_init_val_map.resize(yield_op_args.len(), None);
for (index, (init_value, yield_arg)) in while_op
.inits()
.into_iter()
.map(|o| o.borrow().as_value_ref())
.zip(yield_op_args.iter().copied())
.enumerate()
{
if yield_arg == init_value {
before_block_init_val_map[index] = Some(init_value);
continue;
}
if let Ok(yield_op_block_arg) = yield_arg.try_downcast_value::<BlockArgument>() {
let cond_op_arg = cond_op_args[yield_op_block_arg.borrow().index()];
if cond_op_arg == before_args[index] || cond_op_arg == init_value {
before_block_init_val_map[index] = Some(init_value);
continue;
}
}
new_init_args.push(init_value);
new_yield_args.push(yield_arg);
}
{
let mut guard = InsertionGuard::new(rewriter);
let yield_op = yield_op.as_operation_ref();
guard.set_insertion_point_before(yield_op);
let new_yield = guard.r#yield(new_yield_args.iter().copied(), yield_op.span())?;
guard.replace_op(yield_op, new_yield.as_operation_ref());
}
let mut result_types = while_op
.results()
.iter()
.map(|r| r.borrow().ty().clone())
.collect::<SmallVec<[_; 4]>>();
let new_while =
rewriter.r#while(new_init_args.iter().copied(), &result_types, while_op.span())?;
let new_before_region = new_while.borrow().before().as_region_ref();
result_types.clear();
result_types.extend(new_yield_args.iter().map(|arg| arg.borrow().ty().clone()));
let new_before_block = rewriter.create_block(new_before_region, None, &result_types);
let num_before_block_args = before_block.borrow().num_arguments();
let mut new_before_block_args = SmallVec::<[_; 4]>::with_capacity(num_before_block_args);
new_before_block_args.resize(num_before_block_args, None);
{
let mut next_new_before_block_argument = 0;
let new_before_block = new_before_block.borrow();
for i in 0..num_before_block_args {
if let Some(val) = before_block_init_val_map[i] {
new_before_block_args[i] = Some(val);
} else {
new_before_block_args[i] = Some(
new_before_block.arguments()[next_new_before_block_argument] as ValueRef,
);
next_new_before_block_argument += 1;
}
}
}
let after_region = while_op.after().as_region_ref();
drop(op);
rewriter.merge_blocks(before_block, new_before_block, &new_before_block_args);
rewriter.inline_region_before(after_region, new_while.borrow().after().as_region_ref());
let replacements = new_while
.borrow()
.results()
.all()
.into_iter()
.map(|r| Some(*r as ValueRef))
.collect::<SmallVec<[_; 4]>>();
rewriter.replace_op_with_values(operation, &replacements);
Ok(true)
}
}