use alloc::rc::Rc;
use midenc_hir::{
patterns::{Pattern, PatternBenefit, PatternInfo, PatternKind, RewritePattern},
*,
};
use crate::*;
pub struct WhileUnusedResult {
info: PatternInfo,
}
impl WhileUnusedResult {
pub fn new(context: Rc<Context>) -> Self {
let scf_dialect = context.get_or_register_dialect::<ScfDialect>();
let while_op = scf_dialect.registered_name::<While>().expect("scf.while is not registered");
Self {
info: PatternInfo::new(
context,
"while-unused-result",
PatternKind::Operation(while_op),
PatternBenefit::MAX,
),
}
}
}
impl Pattern for WhileUnusedResult {
fn info(&self) -> &PatternInfo {
&self.info
}
}
impl RewritePattern for WhileUnusedResult {
fn match_and_rewrite(
&self,
op: OperationRef,
rewriter: &mut dyn Rewriter,
) -> Result<bool, Report> {
let operation = op.borrow();
let Some(while_op) = operation.downcast_ref::<While>() else {
return Ok(false);
};
let condition_operation = while_op.condition_op();
let span = while_op.span();
let after_args = {
while_op
.after()
.entry()
.arguments()
.iter()
.map(|arg| arg.borrow().as_value_ref())
.collect::<SmallVec<[_; 4]>>()
};
let forwarded = {
condition_operation
.borrow()
.forwarded()
.iter()
.map(|o| o.borrow().as_value_ref())
.collect::<SmallVec<[_; 4]>>()
};
let mut new_results_indices = SmallVec::<[usize; 4]>::default();
let mut new_result_types = SmallVec::<[Type; 4]>::default();
let mut new_term_args = SmallVec::<[ValueRef; 4]>::default();
let mut new_arg_spans = SmallVec::<[SourceSpan; 4]>::default();
let mut need_update = false;
for (i, result) in while_op.results().iter().enumerate() {
let result = result.borrow();
let after_arg = after_args[i];
let term_arg = forwarded[i];
if !result.is_used() && !after_arg.borrow().is_used() {
need_update = true;
} else {
new_results_indices.push(i);
new_term_args.push(term_arg);
new_result_types.push(result.ty().clone());
new_arg_spans.push(result.span());
}
}
if !need_update {
return Ok(false);
}
{
let mut guard = InsertionGuard::new(rewriter);
let (span, condition, condition_op) = {
let cond_op = condition_operation.borrow();
let condition = cond_op.condition().as_value_ref();
(cond_op.span(), condition, cond_op.as_operation_ref())
};
guard.set_insertion_point_before(condition_op);
let new_condition = guard.condition(condition, new_term_args, span)?;
let new_condition_op = new_condition.as_operation_ref();
guard.replace_op(condition_op, new_condition_op);
}
let new_while = {
let inits = while_op.inits().into_iter().map(|o| o.borrow().as_value_ref());
rewriter.r#while(inits, &new_result_types, span)?
};
let new_after_region = { new_while.borrow().after().as_region_ref() };
let new_after_block = new_after_region.borrow().entry_block_ref().unwrap();
let num_results = while_op.num_results();
let mut new_results: SmallVec<[_; 4]> = smallvec![None; num_results];
let mut new_after_block_args: SmallVec<[_; 4]> = smallvec![None; num_results];
{
let new_while_op = new_while.borrow();
let new_after_block = new_after_block.borrow();
for (i, new_result_index) in new_results_indices.iter().copied().enumerate() {
new_results[new_result_index] =
Some(new_while_op.results()[i].borrow().as_value_ref());
new_after_block_args[new_result_index] =
Some(new_after_block.arguments()[i].borrow().as_value_ref());
}
}
let before_region = while_op.before().as_region_ref();
let (new_before_region, new_before_block) = {
let new_while = new_while.borrow();
(
new_while.before().as_region_ref(),
new_while.before().entry_block_ref().unwrap(),
)
};
let after_block = while_op.after().entry_block_ref().unwrap();
drop(operation);
rewriter.inline_region_before(before_region, new_before_region);
rewriter.erase_block(new_before_block);
rewriter.merge_blocks(after_block, new_after_block, &new_after_block_args);
rewriter.replace_op_with_values(op, &new_results);
Ok(true)
}
}