use std::sync::{Arc, Weak};
use strum::IntoEnumIterator;
use crate::ops::{ExtensionOp, OpName, OpNameRef};
use crate::{Extension, ops::OpType, types::TypeArg};
use super::{ExtensionBuildError, ExtensionId, OpDef, SignatureError, op_def::SignatureFunc};
use crate::ops::custom::qualify_name;
use delegate::delegate;
use thiserror::Error;
#[derive(Debug, Error, PartialEq, Clone)]
#[error("{0}")]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum OpLoadError {
#[error("Op with name {0} is not a member of this set.")]
NotMember(String),
#[error("Type args invalid: {0}.")]
InvalidArgs(#[from] SignatureError),
#[error("OpDef belongs to extension {0}, expected {1}.")]
WrongExtension(ExtensionId, ExtensionId),
}
pub trait MakeOpDef {
fn opdef_id(&self) -> OpName;
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
where
Self: Sized;
fn extension(&self) -> ExtensionId;
fn qualified_opdef_id(&self) -> OpName {
qualify_name(&self.extension(), &self.opdef_id())
}
fn extension_ref(&self) -> Weak<Extension>;
fn init_signature(&self, extension_ref: &Weak<Extension>) -> SignatureFunc;
fn signature(&self) -> SignatureFunc {
self.init_signature(&self.extension_ref())
}
fn description(&self) -> String {
self.opdef_id().to_string()
}
fn post_opdef(&self, _def: &mut OpDef) {}
fn add_to_extension(
&self,
extension: &mut Extension,
extension_ref: &Weak<Extension>,
) -> Result<(), ExtensionBuildError> {
let def = extension.add_op(
self.opdef_id(),
self.description(),
self.init_signature(extension_ref),
extension_ref,
)?;
self.post_opdef(def);
Ok(())
}
fn load_all_ops(
extension: &mut Extension,
extension_ref: &Weak<Extension>,
) -> Result<(), ExtensionBuildError>
where
Self: IntoEnumIterator,
{
for op in Self::iter() {
op.add_to_extension(extension, extension_ref)?;
}
Ok(())
}
fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized + std::str::FromStr,
{
Self::from_extension_op(ext_op)
}
}
pub trait HasConcrete: MakeOpDef {
type Concrete: MakeExtensionOp;
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError>;
}
pub trait HasDef: MakeExtensionOp {
type Def: HasConcrete<Concrete = Self> + std::str::FromStr;
fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
Self::from_extension_op(ext_op)
}
}
pub trait MakeExtensionOp {
fn op_id(&self) -> OpName;
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized;
#[must_use]
fn from_optype(op: &OpType) -> Option<Self>
where
Self: Sized,
{
let ext: &ExtensionOp = op.as_extension_op()?;
Self::from_extension_op(ext).ok()
}
fn type_args(&self) -> Vec<TypeArg>;
fn to_registered(
self,
extension_id: ExtensionId,
extension: Arc<Extension>,
) -> RegisteredOp<Self>
where
Self: Sized,
{
RegisteredOp {
extension_id,
extension,
op: self,
}
}
}
impl<T: MakeOpDef> MakeExtensionOp for T {
fn op_id(&self) -> OpName {
self.opdef_id()
}
#[inline]
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
Self::from_def(ext_op.def())
}
#[inline]
fn type_args(&self) -> Vec<TypeArg> {
vec![]
}
}
pub fn try_from_name<T>(name: &OpNameRef, def_extension: &ExtensionId) -> Result<T, OpLoadError>
where
T: std::str::FromStr + MakeOpDef,
{
let op = T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))?;
let expected_extension = op.extension();
if def_extension != &expected_extension {
return Err(OpLoadError::WrongExtension(
def_extension.clone(),
expected_extension,
));
}
Ok(op)
}
#[derive(Clone, Debug)]
pub struct RegisteredOp<T> {
pub extension_id: ExtensionId,
extension: Arc<Extension>,
op: T,
}
impl<T> RegisteredOp<T> {
pub fn to_inner(self) -> T {
self.op
}
}
impl<T: MakeExtensionOp> RegisteredOp<T> {
pub fn to_extension_op(&self) -> Result<ExtensionOp, SignatureError> {
let op_def = self.extension.get_op(&self.op_id()).unwrap_or_else(|| {
panic!(
"Extension::get_op() called with an invalid name ({}).",
self.op_id()
)
});
ExtensionOp::new(op_def.clone(), self.type_args())
}
delegate! {
to self.op {
pub fn op_id(&self) -> OpName;
pub fn type_args(&self) -> Vec<TypeArg>;
}
}
}
pub trait MakeRegisteredOp: MakeExtensionOp {
fn extension_id(&self) -> ExtensionId;
fn extension_ref(&self) -> Arc<Extension>;
fn to_extension_op(self) -> Result<ExtensionOp, SignatureError>
where
Self: Sized,
{
let registered: RegisteredOp<_> = self.into();
registered.to_extension_op()
}
}
impl<T: MakeRegisteredOp> From<T> for RegisteredOp<T> {
fn from(ext_op: T) -> Self {
let extension_id = ext_op.extension_id();
let extension = ext_op.extension_ref();
ext_op.to_registered(extension_id, extension)
}
}
impl<T: MakeRegisteredOp> From<T> for OpType {
fn from(ext_op: T) -> Self {
ext_op.to_extension_op().unwrap().into()
}
}
#[cfg(test)]
mod test {
use std::sync::{Arc, LazyLock};
use crate::{
const_extension_ids, type_row,
types::{Signature, Term},
};
use super::*;
use strum::{EnumIter, EnumString, IntoStaticStr};
#[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
enum DummyEnum {
Dumb,
}
impl MakeOpDef for DummyEnum {
fn opdef_id(&self) -> OpName {
<&'static str>::from(self).into()
}
fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
Signature::new_endo(type_row![]).into()
}
fn extension_ref(&self) -> Weak<Extension> {
Arc::downgrade(&EXT)
}
fn from_def(_op_def: &OpDef) -> Result<Self, OpLoadError> {
Ok(Self::Dumb)
}
fn extension(&self) -> ExtensionId {
EXT_ID.clone()
}
}
impl HasConcrete for DummyEnum {
type Concrete = Self;
fn instantiate(&self, _type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
if _type_args.is_empty() {
Ok(self.clone())
} else {
Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
}
}
}
const_extension_ids! {
const EXT_ID: ExtensionId = "DummyExt";
}
static EXT: LazyLock<Arc<Extension>> = LazyLock::new(|| {
Extension::new_test_arc(EXT_ID.clone(), |ext, extension_ref| {
DummyEnum::Dumb
.add_to_extension(ext, extension_ref)
.unwrap();
})
});
impl MakeRegisteredOp for DummyEnum {
fn extension_id(&self) -> ExtensionId {
EXT_ID.clone()
}
fn extension_ref(&self) -> Arc<Extension> {
EXT.clone()
}
}
#[test]
fn test_dummy_enum() {
let o = DummyEnum::Dumb;
assert_eq!(
DummyEnum::from_def(EXT.get_op(&o.opdef_id()).unwrap()).unwrap(),
o
);
assert_eq!(
DummyEnum::from_optype(&o.clone().to_extension_op().unwrap().into()).unwrap(),
o
);
assert_eq!(format!("{EXT_ID}.Dumb"), o.qualified_opdef_id());
let registered: RegisteredOp<_> = o.clone().into();
assert_eq!(registered.to_inner(), o);
assert_eq!(o.instantiate(&[]), Ok(o.clone()));
assert_eq!(
o.instantiate(&[Term::from(1u64)]),
Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs))
);
}
}