use alloc::rc::Rc;
use midenc_hir::{adt::SmallSet, *};
use crate::{
builders::{DefaultInstBuilder, InstBuilder},
ops::While,
Condition, Constant, HirDialect, If,
};
pub struct ConvertDoWhileToWhileTrue {
info: PatternInfo,
}
impl ConvertDoWhileToWhileTrue {
pub fn new(context: Rc<Context>) -> Self {
let hir_dialect = context.get_or_register_dialect::<HirDialect>();
let while_op = hir_dialect.registered_name::<While>().expect("hir.while is not registered");
Self {
info: PatternInfo::new(
context,
"convert-do-while-to-while-true",
PatternKind::Operation(while_op),
PatternBenefit::MAX,
),
}
}
}
impl Pattern for ConvertDoWhileToWhileTrue {
fn info(&self) -> &PatternInfo {
&self.info
}
}
impl RewritePattern for ConvertDoWhileToWhileTrue {
fn matches(&self, _op: OperationRef) -> Result<bool, Report> {
panic!("call match_and_rewrite")
}
fn rewrite(&self, _op: OperationRef, _rewriter: &mut dyn Rewriter) {
panic!("call match_and_rewrite")
}
fn match_and_rewrite(
&self,
operation: OperationRef,
rewriter: &mut dyn Rewriter,
) -> Result<bool, Report> {
let op = operation.borrow();
let Some(while_op) = op.downcast_ref::<While>() else {
return Ok(false);
};
let before_block = while_op.before().entry_block_ref().unwrap();
let after_block = while_op.after().entry_block_ref().unwrap();
let after = after_block.borrow();
let after_term = after.terminator().unwrap();
let after_only_yields = after_term.prev().is_none();
let condition = while_op.condition_op();
let condition_op = condition.borrow();
let condition_value = condition_op.condition().as_value_ref();
let Some(condition_owner) = condition_value.borrow().get_defining_op() else {
return Ok(false);
};
let condition_owner_op = condition_owner.borrow();
let Some(if_op) = condition_owner_op.downcast_ref::<If>() else {
return Ok(false);
};
let Some(condition_constant) = eval_condition(condition_value, if_op) else {
return Ok(false);
};
if !transformation_is_safe(&condition_op) {
return Ok(false);
}
let span = while_op.span();
let result_types = while_op
.results()
.iter()
.map(|r| r.borrow().ty().clone())
.collect::<SmallVec<[_; 4]>>();
rewriter.set_insertion_point_before(operation);
let loop_inits = while_op.inits().into_iter().map(|o| o.borrow().as_value_ref());
let new_while =
DefaultInstBuilder::new(rewriter).r#while(loop_inits, &result_types, span)?;
let before_args = while_op
.before()
.entry()
.arguments()
.iter()
.map(|arg| arg.borrow().as_value_ref())
.collect::<SmallVec<[_; 4]>>();
let after_args = while_op
.before()
.entry()
.arguments()
.iter()
.map(|arg| arg.borrow().as_value_ref())
.collect::<SmallVec<[_; 4]>>();
todo!();
Ok(true)
}
}
fn transformation_is_safe(condition: &Condition) -> bool {
let parent_block = condition.parent().unwrap();
let mut allowed = SmallSet::<OperationRef, 8>::default();
allowed.insert(condition.as_operation_ref());
let mut worklist = SmallVec::<[_; 4]>::from_iter(condition.operands().iter().copied());
while let Some(operand) = worklist.pop() {
if let Some(defining_op) = operand.borrow().value().get_defining_op() {
if defining_op.parent().unwrap() == parent_block {
allowed.insert(defining_op);
worklist.extend(defining_op.borrow().operands().iter().copied());
}
}
}
let mut next_op = parent_block.borrow().body().back().as_pointer();
while let Some(op) = next_op.take() {
next_op = op.prev();
if !allowed.contains(&op) {
return false;
}
}
true
}
fn eval_condition(value: ValueRef, if_op: &If) -> Option<bool> {
let value = value.borrow();
let result = value.downcast_ref::<OpResult>().unwrap();
let result_index = result.index();
let then_yield = if_op.then_yield();
let then_yielded = then_yield.borrow().yielded()[result_index].borrow().as_value_ref();
let definition = then_yielded.borrow().get_defining_op()?;
let definition = definition.borrow();
let definition = definition.downcast_ref::<Constant>()?;
let then_value = definition.value().as_bool()?;
let else_yield = if_op.else_yield();
let else_yielded = else_yield.borrow().yielded()[result_index].borrow().as_value_ref();
let definition = else_yielded.borrow().get_defining_op()?;
let definition = definition.borrow();
let definition = definition.downcast_ref::<Constant>()?;
let else_value = definition.value().as_bool()?;
if then_value == else_value {
None
} else {
Some(then_value)
}
}