use crate::convert::apply_rewrites;
use crate::convert::ChangedOp;
use crate::convert::Pass;
use crate::convert::Rewrite;
use crate::convert::RewriteResult;
use crate::dialect;
use crate::ir::Block;
use crate::ir::BlockArgument;
use crate::ir::BlockArgumentName;
use crate::ir::BlockName;
use crate::ir::Op;
use crate::ir::OpOperand;
use crate::ir::Operation;
use crate::ir::Region;
use crate::ir::Users;
use crate::ir::Value;
use crate::ir::Values;
use crate::shared::Shared;
use crate::shared::SharedExt;
use anyhow::Result;
struct IfLowering;
fn lower_yield_op(op: &dialect::scf::YieldOp, after: Shared<Block>) -> Result<Shared<dyn Op>> {
let operation = op.operation();
let var = operation.rd().operand(0).unwrap();
let operand = Shared::new(OpOperand::from_block(after).into());
operation.wr().set_operand(0, operand);
operation.wr().set_operand(1, var);
let new_op = dialect::cf::BranchOp::from_operation_arc(operation.clone());
let new_op = Shared::new(new_op.into());
Ok(new_op)
}
fn branch_op(after: Shared<Block>) -> Shared<dyn Op> {
let operation = Operation::default();
let mut new_op = dialect::cf::BranchOp::from_operation(operation);
let operand = Shared::new(OpOperand::from_block(after).into());
new_op.set_dest(operand);
Shared::new(new_op.into())
}
fn add_branch_to_after(block: Shared<Block>, after: Shared<Block>) {
let ops = &mut block.wr().ops;
let ops_clone = ops.clone();
let last_op = ops_clone.last().unwrap();
let last_op = last_op.rd();
let yield_op = last_op.as_any().downcast_ref::<dialect::scf::YieldOp>();
if let Some(yield_op) = yield_op {
let new_op = lower_yield_op(yield_op, after.clone()).unwrap();
ops.pop();
ops.push(new_op.clone());
} else {
let new_op = branch_op(after.clone());
new_op.rd().set_parent(block.clone());
ops.push(new_op.clone());
};
}
fn move_successors_to_exit_block(op: &dialect::scf::IfOp, exit_block: Shared<Block>) -> Result<()> {
let if_op_parent = op.operation().rd().parent().expect("Expected parent");
let if_op_index = if_op_parent
.rd()
.index_of(&op.operation().rd())
.expect("Expected index");
let ops = &mut if_op_parent.wr().ops;
let return_ops = ops[if_op_index + 1..].to_vec();
for op in return_ops.iter() {
op.rd().set_parent(exit_block.clone());
}
exit_block.wr().ops = return_ops;
ops.drain(if_op_index + 1..);
Ok(())
}
fn add_merge_block(
parent_region: Shared<Region>,
results: Values,
exit: Shared<Block>,
) -> Result<(Shared<Block>, Values)> {
let unset_block = parent_region.rd().add_empty_block_before(exit.clone());
let merge = unset_block.set_parent(Some(parent_region.clone()));
let merge_block_arguments = as_block_arguments(results, merge.clone())?;
merge.wr().set_arguments(merge_block_arguments.clone());
merge.wr().label = BlockName::Unset;
let mut operation = Operation::default();
operation.set_parent(Some(merge.clone()));
let mut merge_op = dialect::cf::BranchOp::from_operation(operation);
let operand = Shared::new(OpOperand::from_block(exit).into());
merge_op.set_dest(operand);
let merge_op: Shared<dyn Op> = Shared::new(merge_op.into());
merge.wr().ops = vec![merge_op.clone()];
Ok((merge, merge_block_arguments))
}
fn add_exit_block(op: &dialect::scf::IfOp, parent_region: Shared<Region>) -> Result<Shared<Block>> {
let unset_block = parent_region.rd().add_empty_block();
let exit = unset_block.set_parent(Some(parent_region.clone()));
exit.wr().label = BlockName::Unset;
move_successors_to_exit_block(op, exit.clone())?;
Ok(exit)
}
fn as_block_arguments(results: Values, parent: Shared<Block>) -> Result<Values> {
let mut out = vec![];
for result in results.into_iter() {
let result_rd = result.rd();
let name = result_rd.name();
let typ = result_rd.typ().unwrap();
let name = BlockArgumentName::Name(name.unwrap());
let mut arg = BlockArgument::new(name, typ);
arg.set_parent(Some(parent.clone()));
let arg = Value::BlockArgument(arg);
let arg = Shared::new(arg.into());
out.push(arg);
}
Ok(Values::from_vec(out))
}
fn results_users(results: Values) -> Vec<Users> {
let mut out = vec![];
for result in results.into_iter() {
out.push(result.rd().users());
}
out
}
fn add_blocks(
op: &dialect::scf::IfOp,
parent_region: Shared<Region>,
) -> Result<(Shared<Block>, Shared<Block>)> {
let results = op.operation().rd().results();
let results_users = results_users(results.clone());
let exit = add_exit_block(op, parent_region.clone())?;
let has_results = !results.is_empty();
let then_region = op.then().expect("Expected `then` region");
let then = then_region.rd().blocks().into_iter().next().unwrap();
then.wr().label = BlockName::Unset;
exit.rd().inline_region_before(then_region.clone());
let else_region = op.els().expect("Expected `else` region");
let els = else_region.rd().blocks().into_iter().next().unwrap();
els.wr().label = BlockName::Unset;
exit.rd().inline_region_before(else_region.clone());
let after = if has_results {
let (merge, merge_block_arguments) =
add_merge_block(parent_region.clone(), results.clone(), exit.clone())?;
let merge_block_arguments = merge_block_arguments.vec();
let merge_block_arguments = merge_block_arguments.rd();
assert!(results_users.len() == merge_block_arguments.len());
for i in 0..results_users.len() {
let users = &results_users[i];
let users = match users {
Users::OpOperands(users) => users,
Users::HasNoOpResults => &vec![],
};
println!("users.len: {}", users.len());
let arg = merge_block_arguments[i].clone();
for user in users.iter() {
let mut user = user.wr();
user.set_value(arg.clone());
}
}
merge
} else {
exit.clone()
};
add_branch_to_after(then.clone(), after.clone());
add_branch_to_after(els.clone(), after.clone());
Ok((then, els))
}
impl Rewrite for IfLowering {
fn name(&self) -> &'static str {
"scf_to_cf::IfLowering"
}
fn parallelizable(&self) -> bool {
false
}
fn rewrite(&self, op: Shared<dyn Op>) -> Result<RewriteResult> {
let op = op.rd();
let op = match op.as_any().downcast_ref::<dialect::scf::IfOp>() {
Some(op) => op,
None => return Ok(RewriteResult::Unchanged),
};
let block = op.operation().rd().parent();
let block = block.expect("no parent block");
let region = block.rd().parent.clone();
let region = region.expect("no parent region");
let (then, els) = add_blocks(op, region.clone())?;
let mut operation = Operation::default();
operation.set_parent(Some(block.clone()));
operation.set_operand(0, op.operation().rd().operand(0).clone().unwrap());
let then_operand = Shared::new(OpOperand::from_block(then).into());
operation.set_operand(1, then_operand);
let els_operand = Shared::new(OpOperand::from_block(els).into());
operation.set_operand(2, els_operand);
let new = dialect::cf::CondBranchOp::from_operation(operation);
let new = Shared::new(new.into());
op.replace(new.clone());
new.rd().operation().wr().set_results(Values::default());
Ok(RewriteResult::Changed(ChangedOp::new(new)))
}
}
pub struct ConvertSCFToCF;
impl Pass for ConvertSCFToCF {
const NAME: &'static str = "convert-scf-to-cf";
fn convert(op: Shared<dyn Op>) -> Result<RewriteResult> {
let rewrites: Vec<&dyn Rewrite> = vec![&IfLowering];
apply_rewrites(op, &rewrites)
}
}