use std::sync::Arc;
use thiserror::Error;
#[cfg(test)]
use {
crate::proptest::{any_nonempty_smolstr, any_nonempty_string},
::proptest_derive::Arbitrary,
};
use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError};
use crate::hugr::internal::HugrMutInternals;
use crate::hugr::{HugrView, NodeType};
use crate::types::EdgeKind;
use crate::types::{type_param::TypeArg, FunctionType};
use crate::{ops, Hugr, IncomingPort, Node};
use super::dataflow::DataflowOpTrait;
use super::tag::OpTag;
use super::{NamedOp, OpName, OpNameRef, OpTrait, OpType};
#[derive(Clone, Debug, Eq, serde::Serialize, serde::Deserialize)]
#[serde(into = "OpaqueOp", from = "OpaqueOp")]
pub enum CustomOp {
Extension(Box<ExtensionOp>),
Opaque(Box<OpaqueOp>),
}
impl PartialEq for CustomOp {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Extension(l0), Self::Extension(r0)) => l0 == r0,
(Self::Opaque(l0), Self::Opaque(r0)) => l0 == r0,
(Self::Extension(l0), Self::Opaque(r0)) => &l0.make_opaque() == r0.as_ref(),
(Self::Opaque(l0), Self::Extension(r0)) => l0.as_ref() == &r0.make_opaque(),
}
}
}
impl CustomOp {
pub fn new_extension(op: ExtensionOp) -> Self {
Self::Extension(Box::new(op))
}
pub fn new_opaque(op: OpaqueOp) -> Self {
Self::Opaque(Box::new(op))
}
pub fn args(&self) -> &[TypeArg] {
match self {
Self::Opaque(op) => op.args(),
Self::Extension(op) => op.args(),
}
}
pub fn extension(&self) -> &ExtensionId {
match self {
Self::Opaque(op) => op.extension(),
Self::Extension(op) => op.def.extension(),
}
}
pub fn as_extension_op(&self) -> Option<&ExtensionOp> {
match self {
Self::Extension(e) => Some(e),
Self::Opaque(_) => None,
}
}
pub fn into_opaque(self) -> OpaqueOp {
match self {
Self::Opaque(op) => *op,
Self::Extension(op) => (*op).into(),
}
}
pub fn is_extension_op(&self) -> bool {
matches!(self, Self::Extension(_))
}
pub fn is_opaque(&self) -> bool {
matches!(self, Self::Opaque(_))
}
}
impl NamedOp for CustomOp {
fn name(&self) -> OpName {
let (res_id, op_name) = match self {
Self::Opaque(op) => (&op.extension, &op.op_name),
Self::Extension(ext) => (ext.def.extension(), ext.def.name()),
};
qualify_name(res_id, op_name)
}
}
impl DataflowOpTrait for CustomOp {
const TAG: OpTag = OpTag::Leaf;
fn description(&self) -> &str {
match self {
Self::Opaque(op) => DataflowOpTrait::description(op.as_ref()),
Self::Extension(ext_op) => DataflowOpTrait::description(ext_op.as_ref()),
}
}
fn signature(&self) -> FunctionType {
match self {
Self::Opaque(op) => op.signature.clone(),
Self::Extension(ext_op) => ext_op.signature(),
}
}
fn other_input(&self) -> Option<EdgeKind> {
Some(EdgeKind::StateOrder)
}
fn other_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::StateOrder)
}
}
impl From<OpaqueOp> for CustomOp {
fn from(op: OpaqueOp) -> Self {
Self::new_opaque(op)
}
}
impl From<CustomOp> for OpaqueOp {
fn from(value: CustomOp) -> Self {
value.into_opaque()
}
}
impl From<ExtensionOp> for CustomOp {
fn from(op: ExtensionOp) -> Self {
Self::new_extension(op)
}
}
#[derive(Clone, Debug)]
pub struct ExtensionOp {
def: Arc<OpDef>,
args: Vec<TypeArg>,
signature: FunctionType, }
impl ExtensionOp {
pub fn new(
def: Arc<OpDef>,
args: impl Into<Vec<TypeArg>>,
exts: &ExtensionRegistry,
) -> Result<Self, SignatureError> {
let args = args.into();
let signature = def.compute_signature(&args, exts)?;
Ok(Self {
def,
args,
signature,
})
}
pub fn args(&self) -> &[TypeArg] {
&self.args
}
pub fn def(&self) -> &OpDef {
self.def.as_ref()
}
pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
self.def().constant_fold(self.args(), consts)
}
pub fn make_opaque(&self) -> OpaqueOp {
OpaqueOp {
extension: self.def.extension().clone(),
op_name: self.def.name().clone(),
description: self.def.description().into(),
args: self.args.clone(),
signature: self.signature.clone(),
}
}
}
impl From<ExtensionOp> for OpaqueOp {
fn from(op: ExtensionOp) -> Self {
let ExtensionOp {
def,
args,
signature,
} = op;
OpaqueOp {
extension: def.extension().clone(),
op_name: def.name().clone(),
description: def.description().into(),
args,
signature,
}
}
}
impl From<ExtensionOp> for OpType {
fn from(value: ExtensionOp) -> Self {
OpType::CustomOp(value.into())
}
}
impl PartialEq for ExtensionOp {
fn eq(&self, other: &Self) -> bool {
Arc::<OpDef>::ptr_eq(&self.def, &other.def) && self.args == other.args
}
}
impl Eq for ExtensionOp {}
impl DataflowOpTrait for ExtensionOp {
const TAG: OpTag = OpTag::Leaf;
fn description(&self) -> &str {
self.def().description()
}
fn signature(&self) -> FunctionType {
self.signature.clone()
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct OpaqueOp {
extension: ExtensionId,
#[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
op_name: OpName,
#[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))]
description: String, args: Vec<TypeArg>,
signature: FunctionType,
}
fn qualify_name(res_id: &ExtensionId, op_name: &OpNameRef) -> OpName {
format!("{}.{}", res_id, op_name).into()
}
impl OpaqueOp {
pub fn new(
extension: ExtensionId,
op_name: impl Into<OpName>,
description: String,
args: impl Into<Vec<TypeArg>>,
signature: FunctionType,
) -> Self {
Self {
extension,
op_name: op_name.into(),
description,
args: args.into(),
signature,
}
}
}
impl OpaqueOp {
pub fn name(&self) -> &OpName {
&self.op_name
}
pub fn args(&self) -> &[TypeArg] {
&self.args
}
pub fn extension(&self) -> &ExtensionId {
&self.extension
}
}
impl From<OpaqueOp> for OpType {
fn from(value: OpaqueOp) -> Self {
OpType::CustomOp(value.into())
}
}
impl DataflowOpTrait for OpaqueOp {
const TAG: OpTag = OpTag::Leaf;
fn description(&self) -> &str {
&self.description
}
fn signature(&self) -> FunctionType {
self.signature.clone()
}
}
pub fn resolve_extension_ops(
h: &mut Hugr,
extension_registry: &ExtensionRegistry,
) -> Result<(), CustomOpError> {
let mut replacements = Vec::new();
for n in h.nodes() {
if let OpType::CustomOp(CustomOp::Opaque(opaque)) = h.get_optype(n) {
if let Some(resolved) = resolve_opaque_op(n, opaque, extension_registry)? {
replacements.push((n, resolved))
}
}
}
for (n, op) in replacements {
let node_type = NodeType::new(op, h.get_nodetype(n).input_extensions().cloned());
debug_assert_eq!(h.get_optype(n).tag(), OpTag::Leaf);
debug_assert_eq!(node_type.tag(), OpTag::Leaf);
h.replace_op(n, node_type).unwrap();
}
Ok(())
}
pub fn resolve_opaque_op(
_n: Node,
opaque: &OpaqueOp,
extension_registry: &ExtensionRegistry,
) -> Result<Option<ExtensionOp>, CustomOpError> {
if let Some(r) = extension_registry.get(&opaque.extension) {
let Some(def) = r.get_op(&opaque.op_name) else {
return Err(CustomOpError::OpNotFoundInExtension(
opaque.op_name.clone(),
r.name().clone(),
));
};
let ext_op =
ExtensionOp::new(def.clone(), opaque.args.clone(), extension_registry).unwrap();
if opaque.signature != ext_op.signature {
return Err(CustomOpError::SignatureMismatch {
extension: opaque.extension.clone(),
op: def.name().clone(),
computed: ext_op.signature.clone(),
stored: opaque.signature.clone(),
});
};
Ok(Some(ext_op))
} else {
Ok(None)
}
}
#[derive(Clone, Debug, Error, PartialEq)]
#[non_exhaustive]
pub enum CustomOpError {
#[error("Operation {0} not found in Extension {1}")]
OpNotFoundInExtension(OpName, ExtensionId),
#[error("Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}")]
#[allow(missing_docs)]
SignatureMismatch {
extension: ExtensionId,
op: OpName,
stored: FunctionType,
computed: FunctionType,
},
}
#[cfg(test)]
mod test {
use crate::extension::prelude::{QB_T, USIZE_T};
use super::*;
#[test]
fn new_opaque_op() {
let sig = FunctionType::new_endo(vec![QB_T]);
let op: CustomOp = OpaqueOp::new(
"res".try_into().unwrap(),
"op",
"desc".into(),
vec![TypeArg::Type { ty: USIZE_T }],
sig.clone(),
)
.into();
assert_eq!(op.name(), "res.op");
assert_eq!(DataflowOpTrait::description(&op), "desc");
assert_eq!(op.args(), &[TypeArg::Type { ty: USIZE_T }]);
assert_eq!(op.signature(), sig);
assert!(op.is_opaque());
assert!(!op.is_extension_op());
}
mod proptest {
use ::proptest::prelude::*;
impl Arbitrary for super::super::CustomOp {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
any::<super::super::OpaqueOp>().prop_map_into().boxed()
}
}
}
}