use strum::IntoEnumIterator;
use crate::ops::{ExtensionOp, OpName, OpNameRef};
use crate::{
ops::{NamedOp, OpType},
types::TypeArg,
Extension,
};
use super::{
op_def::SignatureFunc, ExtensionBuildError, ExtensionId, ExtensionRegistry, OpDef,
SignatureError,
};
use delegate::delegate;
use thiserror::Error;
#[derive(Debug, Error, PartialEq, Clone)]
#[error("{0}")]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum OpLoadError {
#[error("Op with name {0} is not a member of this set.")]
NotMember(String),
#[error("Type args invalid: {0}.")]
InvalidArgs(#[from] SignatureError),
#[error("OpDef belongs to extension {0}, expected {1}.")]
WrongExtension(ExtensionId, ExtensionId),
}
impl<T> NamedOp for T
where
for<'a> &'a T: Into<&'static str>,
{
fn name(&self) -> OpName {
let s = self.into();
s.into()
}
}
pub trait MakeOpDef: NamedOp {
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
where
Self: Sized;
fn signature(&self) -> SignatureFunc;
fn extension(&self) -> ExtensionId;
fn description(&self) -> String {
self.name().to_string()
}
fn post_opdef(&self, _def: &mut OpDef) {}
fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> {
let def = extension.add_op(self.name(), self.description(), self.signature())?;
self.post_opdef(def);
Ok(())
}
fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError>
where
Self: IntoEnumIterator,
{
for op in Self::iter() {
op.add_to_extension(extension)?;
}
Ok(())
}
fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized + std::str::FromStr,
{
Self::from_extension_op(ext_op)
}
}
pub trait HasConcrete: MakeOpDef {
type Concrete: MakeExtensionOp;
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError>;
}
pub trait HasDef: MakeExtensionOp {
type Def: HasConcrete<Concrete = Self> + std::str::FromStr;
fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
Self::from_extension_op(ext_op)
}
}
pub trait MakeExtensionOp: NamedOp {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized;
fn from_optype(op: &OpType) -> Option<Self>
where
Self: Sized,
{
let ext: &ExtensionOp = op.as_extension_op()?;
Self::from_extension_op(ext).ok()
}
fn type_args(&self) -> Vec<TypeArg>;
fn to_registered(
self,
extension_id: ExtensionId,
registry: &ExtensionRegistry,
) -> RegisteredOp<'_, Self>
where
Self: Sized,
{
RegisteredOp {
extension_id,
registry,
op: self,
}
}
}
impl<T: MakeOpDef> MakeExtensionOp for T {
#[inline]
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
Self::from_def(ext_op.def())
}
#[inline]
fn type_args(&self) -> Vec<TypeArg> {
vec![]
}
}
pub fn try_from_name<T>(name: &OpNameRef, def_extension: &ExtensionId) -> Result<T, OpLoadError>
where
T: std::str::FromStr + MakeOpDef,
{
let op = T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))?;
let expected_extension = op.extension();
if def_extension != &expected_extension {
return Err(OpLoadError::WrongExtension(
def_extension.clone(),
expected_extension,
));
}
Ok(op)
}
#[derive(Clone, Debug)]
pub struct RegisteredOp<'r, T> {
extension_id: ExtensionId,
registry: &'r ExtensionRegistry,
op: T,
}
impl<T> RegisteredOp<'_, T> {
pub fn to_inner(self) -> T {
self.op
}
}
impl<T: MakeExtensionOp> RegisteredOp<'_, T> {
pub fn to_extension_op(&self) -> Option<ExtensionOp> {
ExtensionOp::new(
self.registry
.get(&self.extension_id)?
.get_op(&self.name())?
.clone(),
self.type_args(),
self.registry,
)
.ok()
}
delegate! {
to self.op {
pub fn name(&self) -> OpName;
pub fn type_args(&self) -> Vec<TypeArg>;
}
}
}
pub trait MakeRegisteredOp: MakeExtensionOp {
fn extension_id(&self) -> ExtensionId;
fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry;
fn to_extension_op(self) -> Option<ExtensionOp>
where
Self: Sized,
{
let registered: RegisteredOp<_> = self.into();
registered.to_extension_op()
}
}
impl<T: MakeRegisteredOp> From<T> for RegisteredOp<'_, T> {
fn from(ext_op: T) -> Self {
let extension_id = ext_op.extension_id();
let registry = ext_op.registry();
ext_op.to_registered(extension_id, registry)
}
}
impl<T: MakeRegisteredOp> From<T> for OpType {
fn from(ext_op: T) -> Self {
ext_op.to_extension_op().unwrap().into()
}
}
#[cfg(test)]
mod test {
use crate::{const_extension_ids, type_row, types::Signature};
use super::*;
use lazy_static::lazy_static;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};
#[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
enum DummyEnum {
Dumb,
}
impl MakeOpDef for DummyEnum {
fn signature(&self) -> SignatureFunc {
Signature::new_endo(type_row![]).into()
}
fn from_def(_op_def: &OpDef) -> Result<Self, OpLoadError> {
Ok(Self::Dumb)
}
fn extension(&self) -> ExtensionId {
EXT_ID.to_owned()
}
}
impl HasConcrete for DummyEnum {
type Concrete = Self;
fn instantiate(&self, _type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
if _type_args.is_empty() {
Ok(self.clone())
} else {
Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
}
}
}
const_extension_ids! {
const EXT_ID: ExtensionId = "DummyExt";
}
lazy_static! {
static ref EXT: Extension = {
let mut e = Extension::new_test(EXT_ID.clone());
DummyEnum::Dumb.add_to_extension(&mut e).unwrap();
e
};
static ref DUMMY_REG: ExtensionRegistry =
ExtensionRegistry::try_new([EXT.to_owned()]).unwrap();
}
impl MakeRegisteredOp for DummyEnum {
fn extension_id(&self) -> ExtensionId {
EXT_ID.to_owned()
}
fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
&DUMMY_REG
}
}
#[test]
fn test_dummy_enum() {
let o = DummyEnum::Dumb;
assert_eq!(
DummyEnum::from_def(EXT.get_op(&o.name()).unwrap()).unwrap(),
o
);
assert_eq!(
DummyEnum::from_optype(&o.clone().to_extension_op().unwrap().into()).unwrap(),
o
);
let registered: RegisteredOp<_> = o.clone().into();
assert_eq!(registered.to_inner(), o);
assert_eq!(o.instantiate(&[]), Ok(o.clone()));
assert_eq!(
o.instantiate(&[TypeArg::BoundedNat { n: 1 }]),
Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
);
}
}