use alloc::rc::Rc;
use midenc_hir::{
derive::{EffectOpInterface, OpParser, OpPrinter, operation},
dialects::builtin::attributes::U32ArrayAttr,
effects::*,
parse::ParserExt,
patterns::RewritePatternSet,
print::AsmPrinter,
traits::*,
*,
};
use crate::ScfDialect;
#[derive(OpPrinter, OpParser)]
#[operation(
dialect = ScfDialect,
traits(SingleBlock, NoRegionArguments, HasRecursiveMemoryEffects),
implements(RegionBranchOpInterface, OpPrinter)
)]
pub struct If {
#[operand]
condition: Bool,
#[region(name = "then")]
then_body: Region,
#[region(name = "else")]
else_body: Region,
#[results]
returns: AnyType,
}
impl If {
pub fn then_yield(&self) -> UnsafeIntrusiveEntityRef<Yield> {
let terminator = self.then_body().entry().terminator().unwrap();
terminator
.try_downcast_op::<Yield>()
.expect("invalid hir.if then terminator: expected yield")
}
pub fn else_yield(&self) -> UnsafeIntrusiveEntityRef<Yield> {
let terminator = self.else_body().entry().terminator().unwrap();
terminator
.try_downcast_op::<Yield>()
.expect("invalid hir.if else terminator: expected yield")
}
}
impl Canonicalizable for If {
fn get_canonicalization_patterns(rewrites: &mut RewritePatternSet, context: Rc<Context>) {
rewrites.push(crate::canonicalization::ConvertTrivialIfToSelect::new(context.clone()));
rewrites.push(crate::canonicalization::IfRemoveUnusedResults::new(context.clone()));
rewrites.push(crate::canonicalization::FoldRedundantYields::new(context));
}
}
impl RegionBranchOpInterface for If {
fn get_entry_successor_regions(
&self,
operands: &[Option<AttributeRef>],
) -> RegionSuccessorIter<'_> {
let condition = operands[0].as_ref().and_then(|v| v.borrow().as_bool());
let has_then = condition.is_none_or(|v| v);
let else_possible = condition.is_none_or(|v| !v);
let has_else = else_possible && !self.else_body().is_empty();
let mut infos = SmallVec::<[RegionSuccessorInfo; 2]>::default();
if has_then {
infos.push(RegionSuccessorInfo::Entering(self.then_body().as_region_ref()));
}
if else_possible {
if has_else {
infos.push(RegionSuccessorInfo::Entering(self.else_body().as_region_ref()));
} else {
infos.push(RegionSuccessorInfo::Returning(
self.results().all().iter().map(|v| v.borrow().as_value_ref()).collect(),
));
}
}
RegionSuccessorIter::new(self.as_operation(), infos)
}
fn get_successor_regions(&self, point: RegionBranchPoint) -> RegionSuccessorIter<'_> {
match point {
RegionBranchPoint::Parent => {
let mut infos: SmallVec<[_; 2]> =
smallvec![RegionSuccessorInfo::Entering(self.then_body().as_region_ref())];
if !self.else_body().is_empty() {
infos.push(RegionSuccessorInfo::Entering(self.else_body().as_region_ref()));
}
RegionSuccessorIter::new(self.as_operation(), infos)
}
RegionBranchPoint::Child(_) => {
RegionSuccessorIter::new(
self.as_operation(),
[RegionSuccessorInfo::Returning(
self.results().all().iter().map(|v| v.borrow().as_value_ref()).collect(),
)],
)
}
}
}
fn get_region_invocation_bounds(
&self,
operands: &[Option<AttributeRef>],
) -> SmallVec<[InvocationBounds; 1]> {
let condition = operands[0].as_ref().and_then(|v| v.borrow().as_bool());
if let Some(condition) = condition {
if condition {
smallvec![InvocationBounds::Exact(1), InvocationBounds::Never]
} else {
smallvec![InvocationBounds::Never, InvocationBounds::Exact(1)]
}
} else {
smallvec![InvocationBounds::NoMoreThan(1); 2]
}
}
#[inline(always)]
fn is_repetitive_region(&self, _index: usize) -> bool {
false
}
#[inline(always)]
fn has_loop(&self) -> bool {
false
}
}
#[derive(OpPrinter, OpParser)]
#[operation(
dialect = ScfDialect,
traits(SingleBlock, HasRecursiveMemoryEffects),
implements(RegionBranchOpInterface, LoopLikeOpInterface, OpPrinter)
)]
pub struct While {
#[operands]
inits: AnyType,
#[region]
before: Region,
#[region]
after: Region,
}
impl While {
pub fn condition_op(&self) -> UnsafeIntrusiveEntityRef<Condition> {
let term = self
.before()
.entry()
.terminator()
.expect("expected before region to have a terminator");
term.try_downcast_op::<Condition>()
.expect("expected before region to terminate with hir.condition")
}
pub fn yield_op(&self) -> UnsafeIntrusiveEntityRef<Yield> {
let term = self
.after()
.entry()
.terminator()
.expect("expected after region to have a terminator");
term.try_downcast_op::<Yield>()
.expect("expected after region to terminate with hir.yield")
}
}
impl Canonicalizable for While {
fn get_canonicalization_patterns(rewrites: &mut RewritePatternSet, context: Rc<Context>) {
rewrites.push(crate::canonicalization::RemoveLoopInvariantArgsFromBeforeBlock::new(
context.clone(),
));
rewrites.push(crate::canonicalization::WhileConditionTruth::new(context.clone()));
rewrites.push(crate::canonicalization::WhileUnusedResult::new(context.clone()));
rewrites.push(crate::canonicalization::WhileRemoveDuplicatedResults::new(context.clone()));
rewrites.push(crate::canonicalization::WhileRemoveUnusedArgs::new(context.clone()));
}
}
impl LoopLikeOpInterface for While {
fn get_region_iter_args(&self) -> Option<EntityRef<'_, [BlockArgumentRef]>> {
let entry = self.before().entry_block_ref()?;
Some(EntityRef::map(entry.borrow(), |block| block.arguments()))
}
fn get_loop_header_region(&self) -> RegionRef {
self.before().as_region_ref()
}
fn get_loop_regions(&self) -> SmallVec<[RegionRef; 2]> {
smallvec![self.before().as_region_ref(), self.after().as_region_ref()]
}
fn get_inits_mut(&mut self) -> OpOperandRangeMut<'_> {
self.inits_mut()
}
fn get_yielded_values_mut(&mut self) -> Option<EntityProjectionMut<'_, OpOperandRangeMut<'_>>> {
let mut yield_op = self
.after()
.entry()
.terminator()
.expect("invalid `while`: expected loop body to be terminated");
Some(EntityMut::project(yield_op.borrow_mut(), |op| op.operands_mut().group_mut(0)))
}
}
impl RegionBranchOpInterface for While {
#[inline]
fn get_entry_successor_operands(&self, _point: RegionBranchPoint) -> SuccessorOperandRange<'_> {
SuccessorOperandRange::forward(self.operands().all())
}
fn get_successor_regions(&self, point: RegionBranchPoint) -> RegionSuccessorIter<'_> {
match point {
RegionBranchPoint::Parent => {
RegionSuccessorIter::new(
self.as_operation(),
[RegionSuccessorInfo::Entering(self.before().as_region_ref())],
)
}
RegionBranchPoint::Child(region) => {
let before_region = self.before().as_region_ref();
let after_region = self.after().as_region_ref();
assert!(region == before_region || region == after_region);
if region == after_region {
RegionSuccessorIter::new(
self.as_operation(),
[RegionSuccessorInfo::Entering(before_region)],
)
} else {
RegionSuccessorIter::new(
self.as_operation(),
[
RegionSuccessorInfo::Returning(
self.results()
.all()
.iter()
.map(|r| r.borrow().as_value_ref())
.collect(),
),
RegionSuccessorInfo::Entering(after_region),
],
)
}
}
}
}
#[inline]
fn get_region_invocation_bounds(
&self,
_operands: &[Option<AttributeRef>],
) -> SmallVec<[InvocationBounds; 1]> {
smallvec![InvocationBounds::Unknown; self.num_regions()]
}
#[inline(always)]
fn is_repetitive_region(&self, _index: usize) -> bool {
true
}
#[inline(always)]
fn has_loop(&self) -> bool {
true
}
}
#[operation(
dialect = ScfDialect,
traits(SingleBlock, HasRecursiveMemoryEffects),
implements(RegionBranchOpInterface, OpPrinter)
)]
pub struct IndexSwitch {
#[operand]
selector: UInt32,
#[attr]
cases: U32ArrayAttr,
#[region]
default_region: Region,
}
impl OpPrinter for IndexSwitch {
fn print(&self, printer: &mut AsmPrinter<'_>) {
use alloc::borrow::Cow;
use formatter::*;
printer.print_space();
printer.print_value_uses(ValueRange::<1>::Operands(&[self.selector().as_operand_ref()]));
printer.print_space();
for case in self.cases().iter() {
let index = self.get_case_index_for_selector(*case).unwrap();
let region = self.get_case_region(index);
*printer += nl() + const_text("case ") + display(*case) + const_text(" ");
printer.print_region(®ion.borrow());
}
*printer += nl() + const_text("default ");
printer.print_region(&self.default_region());
if self.op.has_attributes() {
printer.print_space();
printer.print_attribute_dictionary(
self.op.attributes().iter().map(|attr| *attr.as_named_attribute()),
);
}
printer.print_space();
printer.print_colon_type_list(
self.results().iter().map(|r| Cow::Owned(r.borrow().ty().clone())),
);
}
}
impl OpParser for IndexSwitch {
fn parse(state: &mut OperationState, parser: &mut dyn OpAsmParser<'_>) -> ParseResult {
use alloc::{format, vec};
use midenc_hir::{
diagnostics::{LabeledSpan, RelatedError, Report, Severity, miette::diagnostic},
dialects::builtin::attributes::Array,
parse::ParserError,
};
let selector = parser.parse_operand( true)?;
let selector = parser.resolve_operand(selector, Type::U32)?;
state.add_operand(selector);
let mut cases = Array::<u32>::default();
let mut regions = SmallVec::<[RegionRef; 2]>::default();
while parser.parse_optional_custom_keyword("case")?.is_some() {
let case_value = parser.parse_decimal_integer::<u32>()?;
if cases.contains(&case_value) {
return Err(ParserError::Report(RelatedError::new(Report::from(diagnostic!(
severity = Severity::Error,
labels = vec![LabeledSpan::at(
case_value.span(),
"this case selector has already been used"
)],
"invalid scf.index_switch operation"
)))));
}
let region = parser.context().create_region();
parser.parse_region(region, &[], false)?;
cases.push(case_value.into_inner());
regions.push(region);
}
parser.parse_custom_keyword("default")?;
let fallback_region = parser.context().create_region();
parser.parse_region(fallback_region, &[], false)?;
state
.add_attribute("cases", parser.context_rc().create_attribute::<U32ArrayAttr, _>(cases));
for region in regions {
state.add_region(region);
}
state.add_region(fallback_region);
parser.parse_optional_attribute_dict(&mut state.attrs)?;
parser.parse_colon_type_list(&mut state.results)?;
Ok(())
}
}
impl IndexSwitch {
pub fn num_cases(&self) -> usize {
self.cases().len()
}
pub fn get_default_block(&self) -> BlockRef {
self.default_region().entry_block_ref().expect("default region has no blocks")
}
pub fn get_case_index_for_selector(&self, selector: u32) -> Option<usize> {
self.cases().iter().position(|case| *case == selector)
}
#[track_caller]
pub fn get_case_block(&self, index: usize) -> BlockRef {
let block_ref = self.get_case_region(index).borrow().entry_block_ref();
match block_ref {
None => panic!("region for case {index} has no blocks"),
Some(block) => block,
}
}
#[track_caller]
pub fn get_case_region(&self, mut index: usize) -> RegionRef {
let mut next_region = self.regions().front().as_pointer();
let mut current_index = 0;
index += 1;
while let Some(region) = next_region.take() {
if index == current_index {
return region;
}
next_region = region.next();
current_index += 1;
}
panic!("invalid region index `{}`: out of bounds", index - 1)
}
}
impl RegionBranchOpInterface for IndexSwitch {
fn get_entry_successor_regions(
&self,
operands: &[Option<AttributeRef>],
) -> RegionSuccessorIter<'_> {
let selector = operands[0].as_ref().and_then(|v| v.borrow().as_u32());
let selected = selector.map(|s| self.get_case_index_for_selector(s));
match selected {
None => {
let infos =
self.regions().iter().map(|r| RegionSuccessorInfo::Entering(r.as_region_ref()));
RegionSuccessorIter::new(self.as_operation(), infos)
}
Some(Some(selected)) => {
RegionSuccessorIter::new(
self.as_operation(),
[RegionSuccessorInfo::Entering(self.get_case_region(selected))],
)
}
Some(None) => {
RegionSuccessorIter::new(
self.as_operation(),
[RegionSuccessorInfo::Entering(self.default_region().as_region_ref())],
)
}
}
}
fn get_successor_regions(&self, point: RegionBranchPoint) -> RegionSuccessorIter<'_> {
match point {
RegionBranchPoint::Parent => {
let infos =
self.regions().iter().map(|r| RegionSuccessorInfo::Entering(r.as_region_ref()));
RegionSuccessorIter::new(self.as_operation(), infos)
}
RegionBranchPoint::Child(_) => {
RegionSuccessorIter::new(
self.as_operation(),
[RegionSuccessorInfo::Returning(
self.results().all().iter().map(|v| v.borrow().as_value_ref()).collect(),
)],
)
}
}
}
fn get_region_invocation_bounds(
&self,
operands: &[Option<AttributeRef>],
) -> SmallVec<[InvocationBounds; 1]> {
let selector = operands[0].as_ref().and_then(|v| v.borrow().as_u32());
if let Some(selector) = selector {
let mut bounds = smallvec![InvocationBounds::Never; self.num_cases()];
let selected =
self.get_case_index_for_selector(selector).map(|idx| idx + 1).unwrap_or(0);
bounds[selected] = InvocationBounds::Exact(1);
bounds
} else {
smallvec![InvocationBounds::NoMoreThan(1); self.num_cases()]
}
}
#[inline(always)]
fn is_repetitive_region(&self, _index: usize) -> bool {
false
}
#[inline(always)]
fn has_loop(&self) -> bool {
false
}
}
impl Canonicalizable for IndexSwitch {
fn get_canonicalization_patterns(rewrites: &mut RewritePatternSet, context: Rc<Context>) {
rewrites.push(crate::canonicalization::FoldConstantIndexSwitch::new(context.clone()));
rewrites.push(crate::canonicalization::FoldRedundantYields::new(context));
}
}
#[derive(EffectOpInterface, OpPrinter, OpParser)]
#[operation(
dialect = ScfDialect,
traits(Terminator, ReturnLike),
implements(RegionBranchTerminatorOpInterface, MemoryEffectOpInterface, OpPrinter)
)]
pub struct Condition {
#[operand]
condition: Bool,
#[operands]
forwarded: AnyType,
}
impl RegionBranchTerminatorOpInterface for Condition {
#[inline]
fn get_successor_operands(&self, _point: RegionBranchPoint) -> SuccessorOperandRange<'_> {
SuccessorOperandRange::forward(self.forwarded())
}
#[inline]
fn get_mutable_successor_operands(
&mut self,
_point: RegionBranchPoint,
) -> SuccessorOperandRangeMut<'_> {
SuccessorOperandRangeMut::forward(self.forwarded_mut())
}
fn get_successor_regions(
&self,
operands: &[Option<AttributeRef>],
) -> SmallVec<[RegionSuccessorInfo; 2]> {
let cond = operands[0].as_ref().and_then(|v| v.borrow().as_bool());
let mut regions = SmallVec::<[RegionSuccessorInfo; 2]>::default();
let parent_op = self.parent_op().unwrap();
let parent_op = parent_op.borrow();
let while_op = parent_op
.downcast_ref::<While>()
.expect("expected `Condition` op to be a child of a `While` op");
let after_region = while_op.after().as_region_ref();
match cond {
None => {
regions.push(RegionSuccessorInfo::Entering(after_region));
regions.push(RegionSuccessorInfo::Returning(
while_op.results().all().iter().map(|r| r.borrow().as_value_ref()).collect(),
));
}
Some(true) => {
regions.push(RegionSuccessorInfo::Entering(after_region));
}
Some(false) => {
regions.push(RegionSuccessorInfo::Returning(
while_op.results().all().iter().map(|r| r.borrow().as_value_ref()).collect(),
));
}
}
regions
}
}
#[derive(EffectOpInterface, OpPrinter, OpParser)]
#[operation(
dialect = ScfDialect,
traits(Terminator, ReturnLike, Pure, AlwaysSpeculatable),
implements(
RegionBranchTerminatorOpInterface,
MemoryEffectOpInterface,
ConditionallySpeculatable,
OpPrinter,
)
)]
pub struct Yield {
#[operands]
yielded: AnyType,
}
impl RegionBranchTerminatorOpInterface for Yield {
#[inline]
fn get_successor_operands(&self, _point: RegionBranchPoint) -> SuccessorOperandRange<'_> {
SuccessorOperandRange::forward(self.yielded())
}
fn get_mutable_successor_operands(
&mut self,
_point: RegionBranchPoint,
) -> SuccessorOperandRangeMut<'_> {
SuccessorOperandRangeMut::forward(self.yielded_mut())
}
fn get_successor_regions(
&self,
_operands: &[Option<AttributeRef>],
) -> SmallVec<[RegionSuccessorInfo; 2]> {
let parent_op = self.parent_op().unwrap();
let parent_op = parent_op.borrow();
if parent_op.is::<If>() || parent_op.is::<IndexSwitch>() {
smallvec![RegionSuccessorInfo::Returning(
parent_op.results().all().iter().map(|v| v.borrow().as_value_ref()).collect()
)]
} else if let Some(while_op) = parent_op.downcast_ref::<While>() {
let before_region = while_op.before().as_region_ref();
smallvec![RegionSuccessorInfo::Entering(before_region)]
} else {
panic!("unsupported parent operation for '{}': '{}'", self.name(), parent_op.name())
}
}
}
impl ConditionallySpeculatable for Yield {
fn speculatability(&self) -> Speculatability {
Speculatability::Speculatable
}
}