use std::marker::PhantomData;
use combine::{Parser, attempt, parser::char::spaces, token};
use thiserror::Error;
use crate::{
attribute::{AttributeDict, verify_attr},
basic_block::{BasicBlock, BasicBlockVerifyErr},
builtin::op_interfaces::{IsTerminatorInterface, SymbolOpInterface},
common_traits::{Named, RcShare, Verify},
context::{Arena, Context, Ptr, private::ArenaObj},
debug_info,
identifier::Identifier,
input_err,
irfmt::{
outlined::{self, parse_outlines, postparse_outline},
parsers::{list_parser, location, spaced},
printers::iter_with_sep,
},
linked_list::{LinkedList, private},
location::{Located, Location},
op::{ConcreteOpInfo, Op, OpId, OpObj, op_cast, op_impls},
parsable::{self, Parsable, ParseResult, StateStream},
printable::{self, Printable},
region::Region,
result::Result,
r#type::{TypeObj, Typed},
utils::vec_exns::VecExtns,
value::{DefNode, DefTrait, DefUseParticipant, Use, UseNode, Value},
verify_err,
};
pub(crate) struct OpResult {
pub(crate) def: DefNode<Value>,
def_op: Ptr<Operation>,
res_idx: usize,
ty: Ptr<TypeObj>,
}
impl OpResult {
pub fn get_type(&self) -> Ptr<TypeObj> {
self.ty
}
pub fn set_type(&mut self, ty: Ptr<TypeObj>) {
self.ty = ty;
}
}
impl Typed for OpResult {
fn get_type(&self, _ctx: &Context) -> Ptr<TypeObj> {
self.get_type()
}
}
impl Printable for OpResult {
fn fmt(
&self,
ctx: &Context,
_state: &printable::State,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
write!(f, "{}", self.unique_name(ctx))
}
}
impl From<&OpResult> for Value {
fn from(value: &OpResult) -> Self {
Value::OpResult {
op: value.def_op,
res_idx: value.res_idx,
}
}
}
impl Verify for OpResult {
fn verify(&self, ctx: &Context) -> Result<()> {
Into::<Value>::into(self).verify(ctx)
}
}
impl Named for OpResult {
fn given_name(&self, ctx: &Context) -> Option<Identifier> {
debug_info::get_operation_result_name(ctx, self.def_op, self.res_idx)
}
fn id(&self, _ctx: &Context) -> Identifier {
format!("{}_res{}", self.def_op.make_name("op"), self.res_idx)
.try_into()
.unwrap()
}
}
#[derive(Default)]
struct BlockLinks {
parent_block: Option<Ptr<BasicBlock>>,
next_op: Option<Ptr<Operation>>,
prev_op: Option<Ptr<Operation>>,
}
impl BlockLinks {
pub fn new() -> BlockLinks {
BlockLinks::default()
}
}
pub struct Operation {
self_ptr: Ptr<Operation>,
concrete_op: ConcreteOpInfo,
results: Vec<OpResult>,
operands: Vec<Operand<Value>>,
successors: Vec<Operand<Ptr<BasicBlock>>>,
block_links: BlockLinks,
pub attributes: AttributeDict,
pub(crate) regions: Vec<Ptr<Region>>,
loc: Location,
}
impl PartialEq for Operation {
fn eq(&self, other: &Self) -> bool {
self.self_ptr == other.self_ptr
}
}
impl private::LinkedList for Operation {
type ContainerType = BasicBlock;
fn set_next(&mut self, next: Option<Ptr<Self>>) {
self.block_links.next_op = next;
}
fn set_prev(&mut self, prev: Option<Ptr<Self>>) {
self.block_links.prev_op = prev;
}
fn set_container(&mut self, container: Option<Ptr<BasicBlock>>) {
self.block_links.parent_block = container;
}
}
impl LinkedList for Operation {
fn get_next(&self) -> Option<Ptr<Self>> {
self.block_links.next_op
}
fn get_prev(&self) -> Option<Ptr<Self>> {
self.block_links.prev_op
}
fn get_container(&self) -> Option<Ptr<BasicBlock>> {
self.block_links.parent_block
}
}
impl Operation {
pub fn new(
ctx: &mut Context,
concrete_op: ConcreteOpInfo,
result_types: Vec<Ptr<TypeObj>>,
operands: Vec<Value>,
successors: Vec<Ptr<BasicBlock>>,
num_regions: usize,
) -> Ptr<Operation> {
let f = |self_ptr: Ptr<Operation>| Operation {
self_ptr,
concrete_op,
results: Vec::with_capacity(result_types.len()),
operands: Vec::with_capacity(operands.len()),
successors: Vec::with_capacity(successors.len()),
block_links: BlockLinks::new(),
attributes: AttributeDict::default(),
regions: Vec::with_capacity(num_regions),
loc: Location::Unknown,
};
let newop = Self::alloc(ctx, f);
let results = result_types
.into_iter()
.enumerate()
.map(|(res_idx, ty)| OpResult {
def: DefNode::new(),
def_op: newop,
ty,
res_idx,
})
.collect();
newop.deref_mut(ctx).results = results;
let operands = operands
.iter()
.enumerate()
.map(|(opd_idx, def)| Operand::new(ctx, *def, newop, opd_idx))
.collect();
newop.deref_mut(ctx).operands = operands;
let successors = successors
.iter()
.enumerate()
.map(|(succ_idx, def)| Operand::new(ctx, *def, newop, succ_idx))
.collect();
newop.deref_mut(ctx).successors = successors;
newop.deref_mut(ctx).regions = Vec::new_init(num_regions, |_| Region::new(ctx, newop));
newop
}
pub fn get_parent_block(&self) -> Option<Ptr<BasicBlock>> {
self.block_links.parent_block
}
pub fn get_parent_region(&self, ctx: &Context) -> Option<Ptr<Region>> {
self.get_parent_block()
.and_then(|block| block.deref(ctx).get_parent_region())
}
pub fn get_parent_op(&self, ctx: &Context) -> Option<Ptr<Operation>> {
self.get_parent_block()
.and_then(|block| block.deref(ctx).get_parent_op(ctx))
}
pub fn get_num_results(&self) -> usize {
self.results.len()
}
pub fn get_result(&self, idx: usize) -> Value {
self.results
.get(idx)
.map(|res| res.into())
.unwrap_or_else(|| panic!("Result index {idx} out of bounds"))
}
pub fn results(&self) -> impl Iterator<Item = Value> + Clone + '_ {
self.results.iter().map(Into::into)
}
pub fn has_use(&self) -> bool {
self.results.iter().any(|res| res.def.is_used())
}
pub fn num_uses(&self) -> usize {
self.results
.iter()
.fold(0, |count, res| count + res.def.num_uses())
}
pub fn uses(&self) -> impl Iterator<Item = Use<Value>> + '_ {
self.results.iter().flat_map(|res| res.def.uses())
}
pub fn get_type(&self, idx: usize) -> Ptr<TypeObj> {
self.results
.get(idx)
.map(|res| res.ty)
.unwrap_or_else(|| panic!("Result index {idx} out of bounds"))
}
pub fn result_types(&self) -> impl Iterator<Item = Ptr<TypeObj>> + Clone + '_ {
self.results.iter().map(|res| res.ty)
}
pub fn get_num_operands(&self) -> usize {
self.operands.len()
}
pub fn get_operand(&self, opd_idx: usize) -> Value {
self.operands
.get(opd_idx)
.map(|opd| opd.get_def())
.unwrap_or_else(|| panic!("Operand index {opd_idx} out of bounds"))
}
pub fn get_operand_as_use(&self, opd_idx: usize) -> Use<Value> {
self.get_operand_ref(opd_idx).into()
}
pub fn operands(&self) -> impl Iterator<Item = Value> + Clone + '_ {
self.operands.iter().map(Operand::get_def)
}
pub fn replace_operand(this: Ptr<Operation>, ctx: &Context, opd_idx: usize, other: Value) {
let (cur_def, cur_use) = {
let this_ref = this.deref(ctx);
(
this_ref.get_operand(opd_idx),
this_ref.get_operand_as_use(opd_idx),
)
};
cur_def.replace_use_with(ctx, cur_use, &other);
}
pub fn get_num_successors(&self) -> usize {
self.successors.len()
}
pub fn get_successor(&self, succ_idx: usize) -> Ptr<BasicBlock> {
self.successors
.get(succ_idx)
.map(|succ| succ.get_def())
.unwrap_or_else(|| panic!("Successor index {succ_idx} out of bounds"))
}
pub fn get_successor_as_use(&self, succ_idx: usize) -> Use<Ptr<BasicBlock>> {
self.get_successor_ref(succ_idx).into()
}
pub fn replace_successor(
this: Ptr<Operation>,
ctx: &Context,
succ_idx: usize,
other: Ptr<BasicBlock>,
) {
let (cur_target, cur_block_use) = {
let this_ref = this.deref(ctx);
(
this_ref.get_successor(succ_idx),
this_ref.get_successor_as_use(succ_idx),
)
};
cur_target.retarget_pred_to(ctx, cur_block_use, other);
}
pub fn successors(&self) -> impl Iterator<Item = Ptr<BasicBlock>> + Clone + '_ {
self.successors.iter().map(|opd| opd.get_def())
}
pub fn get_op_dyn(ptr: Ptr<Self>, ctx: &Context) -> OpObj {
(ptr.deref(ctx).concrete_op.0)(ptr)
}
pub fn get_op<T: Op>(ptr: Ptr<Self>, ctx: &Context) -> Option<T> {
(ptr.deref(ctx).concrete_op.1 == T::get_concrete_op_info().1)
.then_some(T::from_operation(ptr))
}
pub fn get_opid(ptr: Ptr<Self>, ctx: &Context) -> OpId {
Self::get_op_dyn(ptr, ctx).get_opid()
}
pub fn get_region(&self, reg_idx: usize) -> Ptr<Region> {
self.regions
.get(reg_idx)
.cloned()
.unwrap_or_else(|| panic!("Region index {reg_idx} out of bounds"))
}
pub fn num_regions(&self) -> usize {
self.regions.len()
}
pub fn add_region(ptr: Ptr<Self>, ctx: &mut Context) -> Ptr<Region> {
let region = Region::new(ctx, ptr);
ptr.deref_mut(ctx).regions.push(region);
region
}
pub fn erase_region(ptr: Ptr<Self>, ctx: &mut Context, reg_idx: usize) {
let reg = *ptr.deref(ctx).regions.get(reg_idx).unwrap();
Region::drop_all_uses(reg, ctx);
ptr.deref_mut(ctx).regions.remove(reg_idx);
ArenaObj::dealloc(reg, ctx);
}
pub fn regions(&self) -> impl Iterator<Item = Ptr<Region>> + Clone + '_ {
self.regions.iter().cloned()
}
pub fn drop_all_uses(ptr: Ptr<Self>, ctx: &Context) {
let operands = std::mem::take(&mut (ptr.deref_mut(ctx).operands));
for opd in operands {
opd.drop_use(ctx);
}
let successors = std::mem::take(&mut (ptr.deref_mut(ctx).successors));
for succ in successors {
succ.drop_use(ctx);
}
let regions = ptr.deref(ctx).regions.clone();
for region in regions {
Region::drop_all_uses(region, ctx);
}
}
pub fn erase(ptr: Ptr<Self>, ctx: &mut Context) {
Self::drop_all_uses(ptr, ctx);
assert!(
!ptr.deref(ctx).has_use(),
"Operation with use(s) being erased"
);
if ptr.is_linked(ctx) {
ptr.unlink(ctx);
}
ArenaObj::dealloc(ptr, ctx);
}
pub fn top_level_parse<'a>(
state_stream: &mut parsable::StateStream<'a>,
) -> ParseResult<'a, Ptr<Self>> {
Operation::parse(
state_stream,
OperationParserConfig {
look_for_outlined_attrs: true,
},
)
}
pub fn top_level_parser<'a>()
-> impl Parser<StateStream<'a>, Output = Ptr<Self>, PartialState = ()> + 'a {
combine::parser(move |parsable_state: &mut StateStream<'a>| {
Self::top_level_parse(parsable_state)
})
}
pub(crate) fn get_result_ref(&self, idx: usize) -> &OpResult {
self.results
.get(idx)
.unwrap_or_else(|| panic!("Result index {idx} out of bounds"))
}
pub(crate) fn get_result_mut(&mut self, idx: usize) -> &mut OpResult {
self.results
.get_mut(idx)
.unwrap_or_else(|| panic!("Result index {idx} out of bounds"))
}
pub(crate) fn get_operand_ref(&self, opd_idx: usize) -> &Operand<Value> {
self.operands
.get(opd_idx)
.unwrap_or_else(|| panic!("Operand index {opd_idx} out of bounds"))
}
pub(crate) fn get_operand_mut(&mut self, opd_idx: usize) -> &mut Operand<Value> {
self.operands
.get_mut(opd_idx)
.unwrap_or_else(|| panic!("Operand index {opd_idx} out of bounds"))
}
pub(crate) fn get_successor_ref(&self, succ_idx: usize) -> &Operand<Ptr<BasicBlock>> {
self.successors
.get(succ_idx)
.unwrap_or_else(|| panic!("Successor index {succ_idx} out of bounds"))
}
pub(crate) fn get_successor_mut(&mut self, succ_idx: usize) -> &mut Operand<Ptr<BasicBlock>> {
self.successors
.get_mut(succ_idx)
.unwrap_or_else(|| panic!("Successor index {succ_idx} out of bounds"))
}
}
impl ArenaObj for Operation {
fn get_arena(ctx: &Context) -> &Arena<Self> {
&ctx.operations
}
fn get_arena_mut(ctx: &mut Context) -> &mut Arena<Self> {
&mut ctx.operations
}
fn dealloc_sub_objects(ptr: Ptr<Self>, ctx: &mut Context) {
let regions = ptr.deref(ctx).regions.clone();
for region in regions {
ArenaObj::dealloc(region, ctx);
}
}
fn get_self_ptr(&self, _ctx: &Context) -> Ptr<Self> {
self.self_ptr
}
}
pub(crate) struct Operand<T: DefUseParticipant> {
pub(crate) r#use: UseNode<T>,
pub(crate) opd_idx: usize,
pub(crate) user_op: Ptr<Operation>,
}
impl<T: DefUseParticipant + DefTrait> Operand<T> {
fn get_def(&self) -> T {
self.r#use.get_def()
}
fn drop_use(&self, ctx: &Context) {
self.get_def().get_defnode_mut(ctx).remove_use(self.into());
}
fn new(ctx: &Context, def: T, user_op: Ptr<Operation>, opd_idx: usize) -> Operand<T> {
Operand {
r#use: def.get_defnode_mut(ctx).add_use(
def,
Use {
op: user_op,
opd_idx,
_dummy: PhantomData,
},
),
user_op,
opd_idx,
}
}
}
impl<T: DefUseParticipant> From<&Operand<T>> for Use<T> {
fn from(value: &Operand<T>) -> Self {
Use {
op: value.user_op,
opd_idx: value.opd_idx,
_dummy: PhantomData,
}
}
}
impl<T: DefUseParticipant + Named> Printable for Operand<T> {
fn fmt(
&self,
ctx: &Context,
_state: &printable::State,
f: &mut core::fmt::Formatter<'_>,
) -> core::fmt::Result {
write!(f, "{}", self.r#use.get_def().unique_name(ctx))
}
}
impl<T: DefUseParticipant + Typed> Typed for Operand<T> {
fn get_type(&self, ctx: &Context) -> Ptr<TypeObj> {
self.r#use.get_def().get_type(ctx)
}
}
#[derive(Error, Debug)]
#[error("operand is not a use of its def")]
pub struct DefUseVerifyErr;
impl<T: DefUseParticipant + DefTrait> Verify for Operand<T> {
fn verify(&self, ctx: &Context) -> Result<()> {
if !self
.r#use
.get_def()
.get_defnode_ref(ctx)
.has_use_of(&self.into())
{
let loc = self.user_op.deref(ctx).loc();
verify_err!(loc, DefUseVerifyErr)
} else {
Ok(())
}
}
}
impl Verify for Operation {
fn verify(&self, ctx: &Context) -> Result<()> {
fn verify_inner(opr: &Operation, ctx: &Context) -> Result<()> {
opr.attributes
.0
.values()
.try_for_each(|attr| verify_attr(&**attr, ctx))?;
opr.operands.iter().try_for_each(|opd| opd.verify(ctx))?;
opr.successors.iter().try_for_each(|opd| opd.verify(ctx))?;
opr.regions
.iter()
.try_for_each(|region| region.verify(ctx))?;
opr.results.iter().try_for_each(|res| res.verify(ctx))?;
let op = &*Operation::get_op_dyn(opr.self_ptr, ctx);
if op_impls::<dyn IsTerminatorInterface>(op) && opr.get_next().is_some() {
let loc = opr.loc.clone();
let parent_block = opr
.get_parent_block()
.expect("There's a next operation, so there must be a parent block");
verify_err!(
loc,
BasicBlockVerifyErr::TerminatorNotLast(
parent_block.unique_name(ctx).disp(ctx).to_string()
)
)?
}
op.verify_interfaces(ctx)?;
op.verify(ctx)
}
verify_inner(self, ctx).inspect_err(
|err| {
struct Helper(Ptr<Operation>);
impl Printable for Helper {
fn fmt(&self, ctx: &Context, _state: &printable::State, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
print_dbg(ctx, self.0, f)
}
}
let op = self.self_ptr;
log::error!(target: "verify_error","{} in operation:\n{}", err.disp(ctx), Helper(op).disp(ctx))
}
)
}
}
impl Printable for Operation {
fn fmt(
&self,
ctx: &Context,
state: &printable::State,
f: &mut core::fmt::Formatter<'_>,
) -> core::fmt::Result {
Self::get_op_dyn(self.self_ptr, ctx).fmt(ctx, state, f)?;
outlined::preprint_outline(ctx, self.self_ptr, state.share(), f)?;
if self.get_parent_op(ctx).is_none() {
outlined::print_outlines(ctx, state.share(), f)?;
}
Ok(())
}
}
impl Located for Operation {
fn loc(&self) -> Location {
self.loc.clone()
}
fn set_loc(&mut self, loc: Location) {
self.loc = loc;
}
}
#[derive(Clone)]
pub struct OperationParserConfig {
pub look_for_outlined_attrs: bool,
}
impl Parsable for Operation {
type Arg = OperationParserConfig;
type Parsed = Ptr<Operation>;
fn parse<'a>(
state_stream: &mut parsable::StateStream<'a>,
arg: Self::Arg,
) -> ParseResult<'a, Self::Parsed> {
let loc = state_stream.loc();
let _src = loc
.source()
.expect("Location from Parsable must be Location::SrcPos");
let results_opid = spaces()
.with(combine::optional(attempt(
list_parser(',', (location(), Identifier::parser(()))).skip(spaced(token('='))),
)))
.and(spaced(OpId::parser(())));
results_opid
.then(|(results_opt, opid)| {
let loc = loc.clone();
let results: Vec<_> = results_opt
.unwrap_or(vec![])
.into_iter()
.map(|(res_loc, id)| (id, (res_loc)))
.collect();
combine::parser(move |parsable_state: &mut StateStream<'a>| {
let state = &parsable_state.state;
let dialect = state
.ctx
.dialects
.get(&opid.dialect)
.expect("Dialect name parsed but dialect isn't registered");
let Some(opid_parser) = dialect.ops.get(&opid) else {
input_err!(loc.clone(), "Unregistered Op {}", opid.disp(state.ctx))?
};
let op = opid_parser(&(), results.clone())
.parse_stream(parsable_state)
.map(|op| op.get_operation())
.into();
if let Ok((op, _)) = op {
op.deref_mut(parsable_state.state.ctx).set_loc(loc.clone());
postparse_outline(parsable_state, op)?;
}
if arg.look_for_outlined_attrs {
parse_outlines(parsable_state)?;
}
op
})
})
.parse_stream(state_stream)
.into()
}
}
pub fn print_dbg(
ctx: &Context,
opr: Ptr<Operation>,
f: &mut std::fmt::Formatter<'_>,
) -> core::fmt::Result {
let sep = printable::ListSeparator::CharSpace(',');
let op = Operation::get_op_dyn(opr, ctx);
let opid = op.get_opid();
let opr = opr.deref(ctx);
let operands = iter_with_sep(opr.operands(), sep);
let symbol_opt = match op_cast::<dyn SymbolOpInterface>(&*op) {
Some(sym_op) => " @".to_string() + &sym_op.get_symbol_name(ctx).disp(ctx).to_string(),
None => "".to_string(),
};
if opr.get_num_results() == 0 {
write!(f, "{:?} ", opr.get_self_ptr(ctx))?;
} else {
let results = iter_with_sep(opr.results(), sep);
write!(f, "{} = ", results.disp(ctx))?;
}
write!(
f,
"{}{} ({})",
opid.disp(ctx),
symbol_opt,
operands.disp(ctx)
)?;
if opr.get_num_successors() > 0 {
let successors = iter_with_sep(
opr.successors()
.map(|succ| format!("^{}", succ.unique_name(ctx))),
sep,
);
write!(f, " [{}]", successors.disp(ctx))?;
}
Ok(())
}