use combine::{
Parser,
parser::{self, char::spaces},
token,
};
use downcast_rs::{Downcast, impl_downcast};
use dyn_clone::DynClone;
use rustc_hash::FxHashMap;
use std::{
fmt::{self, Display},
hash::Hash,
ops::Deref,
sync::LazyLock,
};
use thiserror::Error;
use crate::{
attribute::AttributeDict,
builtin::{type_interfaces::FunctionTypeInterface, types::FunctionType},
common_traits::{Named, Verify},
context::{Context, Ptr, collect_deduped_interface_verifiers},
dialect::{Dialect, DialectName},
identifier::Identifier,
impl_printable_for_display, input_err,
irfmt::{
parsers::{
block_opd_parser, delimited_list_parser, location, process_parsed_ssa_defs, spaced,
ssa_opd_parser, zero_or_more_parser,
},
printers::{functional_type, iter_with_sep},
},
location::{Located, Location},
operation::{Operation, verify_operation},
parsable::{IntoParseResult, Parsable, ParseResult, StateStream},
printable::{self, Printable},
region::Region,
result::Result,
r#type::Typed,
};
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct OpName(String);
impl OpName {
pub fn new(name: &str) -> OpName {
OpName(name.to_string())
}
}
impl Deref for OpName {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl_printable_for_display!(OpName);
impl Parsable for OpName {
type Arg = ();
type Parsed = OpName;
fn parse<'a>(
state_stream: &mut crate::parsable::StateStream<'a>,
_arg: Self::Arg,
) -> ParseResult<'a, Self::Parsed>
where
Self: Sized,
{
Identifier::parser(())
.map(|name| OpName::new(&name))
.parse_stream(state_stream)
.into()
}
}
impl Display for OpName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct OpId {
pub dialect: DialectName,
pub name: OpName,
}
impl_printable_for_display!(OpId);
impl Parsable for OpId {
type Arg = ();
type Parsed = OpId;
fn parse<'a>(
state_stream: &mut StateStream<'a>,
_arg: Self::Arg,
) -> ParseResult<'a, Self::Parsed>
where
Self: Sized,
{
let mut parser = DialectName::parser(())
.skip(parser::char::char('.'))
.and(OpName::parser(()))
.map(|(dialect, name)| OpId { dialect, name });
parser.parse_stream(state_stream).into()
}
}
impl Display for OpId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}", self.dialect, self.name)
}
}
pub(crate) type ConcreteOpInfo = (fn(Ptr<Operation>) -> OpObj, std::any::TypeId);
pub trait Op: Downcast + Verify + Printable + DynClone {
fn get_operation(&self) -> Ptr<Operation>;
#[doc(hidden)]
fn wrap_operation(op: Ptr<Operation>) -> OpObj
where
Self: Sized;
#[doc(hidden)]
fn from_operation(op: Ptr<Operation>) -> Self
where
Self: Sized;
fn get_concrete_op_info() -> ConcreteOpInfo
where
Self: Sized,
{
(Self::wrap_operation, std::any::TypeId::of::<Self>())
}
fn get_opid(&self) -> OpId;
fn get_opid_static() -> OpId
where
Self: Sized;
#[doc(hidden)]
fn verify_interfaces(&self, ctx: &Context) -> Result<()>;
fn register(ctx: &mut Context)
where
Self: Sized + Parsable<Arg = Vec<(Identifier, Location)>, Parsed = OpBox>,
{
let op_parser: OpParserFn = Box::new(|&(), args| Self::parser(args));
let opid = Self::get_opid_static();
Dialect::register(ctx, &opid.dialect).add_op(opid.clone(), op_parser);
}
fn loc(&self, ctx: &Context) -> Location {
self.get_operation().deref(ctx).loc()
}
}
impl_downcast!(Op);
dyn_clone::clone_trait_object!(Op);
pub(crate) type OpParserFn = Box<
for<'a> fn(
&'a (),
Vec<(Identifier, Location)>,
) -> Box<dyn Parser<StateStream<'a>, Output = OpObj, PartialState = ()> + 'a>,
>;
pub type OpObj = OpBox;
impl PartialEq for OpObj {
fn eq(&self, other: &Self) -> bool {
self.as_ref()
.get_operation()
.eq(&other.as_ref().get_operation())
}
}
pub fn verify_op(op: &dyn Op, ctx: &Context) -> Result<()> {
verify_operation(op.get_operation(), ctx)
}
impl Eq for OpObj {}
impl Hash for OpObj {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.as_ref().get_operation().hash(state)
}
}
#[diagnostic::on_unimplemented(
message = "`{Self}` not an op interface.",
label = "If `{Self}` is a trait, annotate it with #[op_interface] to be able to cast to it from a `&dyn Op`",
note = "If you want to cast to a concrete `Op`, use `downcast_ref` instead."
)]
pub trait OpInterfaceMarker {}
pub fn op_cast<T: ?Sized + OpInterfaceMarker + 'static>(op: &dyn Op) -> Option<&T> {
crate::utils::trait_cast::any_to_trait::<T>(op.as_any())
}
pub fn op_impls<T: ?Sized + OpInterfaceMarker + 'static>(op: &dyn Op) -> bool {
op_cast::<T>(op).is_some()
}
pub type OpInterfaceVerifier = fn(&dyn Op, &Context) -> Result<()>;
pub type OpInterfaceAllVerifiers = fn() -> Vec<OpInterfaceVerifier>;
#[doc(hidden)]
type OpInterfaceVerifierInfo = (std::any::TypeId, OpInterfaceAllVerifiers);
#[cfg(not(target_family = "wasm"))]
pub mod statics {
use super::*;
#[::pliron::linkme::distributed_slice]
pub static OP_INTERFACE_VERIFIERS: [LazyLock<OpInterfaceVerifierInfo>] = [..];
pub fn get_op_interface_verifiers()
-> impl Iterator<Item = &'static LazyLock<OpInterfaceVerifierInfo>> {
OP_INTERFACE_VERIFIERS.iter()
}
}
#[cfg(target_family = "wasm")]
pub mod statics {
use super::*;
use crate::utils::inventory::LazyLockWrapper;
::pliron::inventory::collect!(LazyLockWrapper<OpInterfaceVerifierInfo>);
pub fn get_op_interface_verifiers()
-> impl Iterator<Item = &'static LazyLock<OpInterfaceVerifierInfo>> {
::pliron::inventory::iter::<LazyLockWrapper<OpInterfaceVerifierInfo>>().map(|llw| llw.0)
}
}
pub use statics::*;
#[doc(hidden)]
pub static OP_INTERFACE_VERIFIERS_MAP: LazyLock<
FxHashMap<std::any::TypeId, Vec<OpInterfaceVerifier>>,
> = LazyLock::new(|| collect_deduped_interface_verifiers(get_op_interface_verifiers()));
pub fn canonical_syntax_print(
op: OpObj,
ctx: &Context,
state: &printable::State,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
let sep = printable::ListSeparator::CharSpace(',');
let opid = op.as_ref().get_opid();
let op = op.as_ref().get_operation().deref(ctx);
let operands = iter_with_sep(op.operands(), sep);
let successors = iter_with_sep(
op.successors()
.map(|succ| "^".to_string() + &succ.unique_name(ctx)),
sep,
);
let op_type = functional_type(
iter_with_sep(op.operands().map(|opd| opd.get_type(ctx)), sep),
iter_with_sep(op.results().map(|res| res.get_type(ctx)), sep),
);
let regions = iter_with_sep(op.regions(), printable::ListSeparator::Newline);
if op.get_num_results() != 0 {
let results = iter_with_sep(op.results(), sep);
write!(f, "{} = ", results.disp(ctx))?;
}
write!(
f,
"{} ({}) [{}] {}: {}",
opid.disp(ctx),
operands.disp(ctx),
successors.disp(ctx),
op.attributes.clone_skip_outlined().disp(ctx),
op_type.disp(ctx),
)?;
if op.num_regions() > 0 {
regions.fmt(ctx, state, f)?;
}
Ok(())
}
#[derive(Error, Debug)]
pub enum CanonicalSyntaxParseError {
#[error("Type specifies {num_res_ty} results, but operation has {num_res} results")]
ResultsMismatch { num_res_ty: usize, num_res: usize },
#[error("Type specifies {num_opd_ty} operands, but operation has {num_opd} operands")]
OperandsMismatch { num_opd_ty: usize, num_opd: usize },
}
pub fn canonical_syntax_parse<'a, T: Op>(
state_stream: &mut StateStream<'a>,
results: Vec<(Identifier, Location)>,
) -> ParseResult<'a, OpObj> {
let mut without_regions = delimited_list_parser('(', ')', ',', ssa_opd_parser())
.and(spaces().with(delimited_list_parser('[', ']', ',', block_opd_parser())))
.and(spaces().with(AttributeDict::parser(())))
.skip(spaced(token(':')))
.and((location(), FunctionType::parser(())))
.then(
move |(((operands, successors), attr_dict), (fty_loc, fty))| {
let results = results.clone();
let fty_loc = fty_loc.clone();
combine::parser(move |parsable_state: &mut StateStream<'a>| {
let results = results.clone();
let ctx = &mut parsable_state.state.ctx;
let results_types = fty.deref(ctx).res_types().to_vec();
let operands_types = fty.deref(ctx).arg_types().to_vec();
if results_types.len() != results.len() {
input_err!(
fty_loc.clone(),
CanonicalSyntaxParseError::ResultsMismatch {
num_res_ty: results_types.len(),
num_res: results.len()
}
)?
}
if operands.len() != operands_types.len() {
input_err!(
fty_loc.clone(),
CanonicalSyntaxParseError::OperandsMismatch {
num_opd_ty: operands_types.len(),
num_opd: operands.len()
}
)?
}
let opr = Operation::new(
ctx,
T::get_concrete_op_info(),
results_types,
operands.clone(),
successors.clone(),
0,
);
opr.deref_mut(ctx).attributes = attr_dict.clone();
process_parsed_ssa_defs(parsable_state, &results, opr)?;
Ok(opr).into_parse_result()
})
},
);
let op = without_regions.parse_stream(state_stream).into_result()?.0;
zero_or_more_parser(Region::parser(op))
.parse_stream(state_stream)
.into_result()?;
let op = T::wrap_operation(op);
Ok(op).into_parse_result()
}
pub fn canonical_syntax_parser<'a, T: Op>(
results: Vec<(Identifier, Location)>,
) -> Box<dyn Parser<StateStream<'a>, Output = OpObj, PartialState = ()> + 'a> {
combine::parser(move |parsable_state: &mut StateStream<'a>| {
canonical_syntax_parse::<T>(parsable_state, results.clone())
})
.boxed()
}
#[derive(Clone)]
struct OpData {
#[allow(unused)]
op: Ptr<Operation>,
}
#[derive(Clone)]
pub struct OpBox {
data: OpData,
vtable_ptr: *const (),
}
impl OpBox {
pub fn new<T: Op>(op: T) -> Self {
struct StaticAsserter<S>(S);
impl<S> StaticAsserter<S> {
const ASSERTTION: () = {
assert!(
std::mem::size_of::<OpData>() == std::mem::size_of::<S>(),
"OpBox can only box Op objects"
);
};
}
let _: () = StaticAsserter::<T>::ASSERTTION;
let dyn_ref: &dyn Op = &op;
let (_, vtable_ptr) =
unsafe { std::mem::transmute::<&dyn Op, (*const T, *const ())>(dyn_ref) };
OpBox {
data: OpData {
op: op.get_operation(),
},
vtable_ptr,
}
}
pub fn op_ref(&self) -> &dyn Op {
unsafe {
let dyn_ref: &dyn Op =
std::mem::transmute::<(&OpData, *const ()), &dyn Op>((&self.data, self.vtable_ptr));
dyn_ref
}
}
pub fn downcast<T: Op>(self) -> Option<T> {
self.as_ref()
.downcast_ref::<T>()
.map(|op| T::from_operation(op.get_operation()))
}
}
impl AsRef<dyn Op> for OpBox {
fn as_ref(&self) -> &dyn Op {
self.op_ref()
}
}
impl Deref for OpBox {
type Target = dyn Op;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}