use itertools::Itertools;
use std::borrow::Cow;
use std::sync::Arc;
use thiserror::Error;
#[cfg(test)]
use {
crate::extension::test::SimpleOpDef,
crate::proptest::{any_nonempty_smolstr, any_nonempty_string},
::proptest::prelude::*,
::proptest_derive::Arbitrary,
};
use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
use crate::types::{type_param::TypeArg, Signature};
use crate::{ops, IncomingPort, Node};
use super::dataflow::DataflowOpTrait;
use super::tag::OpTag;
use super::{NamedOp, OpName, OpNameRef};
#[derive(Clone, Debug, serde::Serialize)]
#[serde(into = "OpaqueOp")]
#[cfg_attr(test, derive(Arbitrary))]
pub struct ExtensionOp {
#[cfg_attr(
test,
proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
)]
def: Arc<OpDef>,
args: Vec<TypeArg>,
signature: Signature, }
impl ExtensionOp {
pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
let args: Vec<TypeArg> = args.into();
let signature = def.compute_signature(&args)?;
Ok(Self {
def,
args,
signature,
})
}
pub(crate) fn new_with_cached(
def: Arc<OpDef>,
args: impl IntoIterator<Item = TypeArg>,
opaque: &OpaqueOp,
) -> Result<Self, SignatureError> {
let args: Vec<TypeArg> = args.into_iter().collect();
let signature = match def.compute_signature(&args) {
Ok(sig) => sig,
Err(SignatureError::MissingComputeFunc) => {
opaque.signature().into_owned()
}
Err(e) => return Err(e),
};
Ok(Self {
def,
args,
signature,
})
}
pub fn args(&self) -> &[TypeArg] {
&self.args
}
pub fn def(&self) -> &OpDef {
self.def.as_ref()
}
pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
self.def().constant_fold(self.args(), consts)
}
pub fn make_opaque(&self) -> OpaqueOp {
OpaqueOp {
extension: self.def.extension_id().clone(),
name: self.def.name().clone(),
description: self.def.description().into(),
args: self.args.clone(),
signature: self.signature.clone(),
}
}
pub fn signature_mut(&mut self) -> &mut Signature {
&mut self.signature
}
pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
self.args.as_mut_slice()
}
}
impl From<ExtensionOp> for OpaqueOp {
fn from(op: ExtensionOp) -> Self {
let ExtensionOp {
def,
args,
signature,
} = op;
OpaqueOp {
extension: def.extension_id().clone(),
name: def.name().clone(),
description: def.description().into(),
args,
signature,
}
}
}
impl PartialEq for ExtensionOp {
fn eq(&self, other: &Self) -> bool {
Arc::<OpDef>::ptr_eq(&self.def, &other.def) && self.args == other.args
}
}
impl Eq for ExtensionOp {}
impl NamedOp for ExtensionOp {
fn name(&self) -> OpName {
qualify_name(self.def.extension_id(), self.def.name())
}
}
impl DataflowOpTrait for ExtensionOp {
const TAG: OpTag = OpTag::Leaf;
fn description(&self) -> &str {
self.def().description()
}
fn signature(&self) -> Cow<'_, Signature> {
Cow::Borrowed(&self.signature)
}
fn substitute(&self, subst: &crate::types::Substitution) -> Self {
let args = self
.args
.iter()
.map(|ta| ta.substitute(subst))
.collect::<Vec<_>>();
let signature = self.signature.substitute(subst);
Self {
def: self.def.clone(),
args,
signature,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct OpaqueOp {
extension: ExtensionId,
#[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
name: OpName,
#[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))]
description: String, args: Vec<TypeArg>,
signature: Signature,
}
fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
format!("{}.{}", res_id, name).into()
}
impl OpaqueOp {
pub fn new(
extension: ExtensionId,
name: impl Into<OpName>,
description: String,
args: impl Into<Vec<TypeArg>>,
signature: Signature,
) -> Self {
let signature = signature.with_extension_delta(extension.clone());
Self {
extension,
name: name.into(),
description,
args: args.into(),
signature,
}
}
pub fn signature_mut(&mut self) -> &mut Signature {
&mut self.signature
}
}
impl NamedOp for OpaqueOp {
fn name(&self) -> OpName {
qualify_name(&self.extension, &self.name)
}
}
impl OpaqueOp {
pub fn op_name(&self) -> &OpName {
&self.name
}
pub fn args(&self) -> &[TypeArg] {
&self.args
}
pub fn extension(&self) -> &ExtensionId {
&self.extension
}
pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
self.args.as_mut_slice()
}
}
impl DataflowOpTrait for OpaqueOp {
const TAG: OpTag = OpTag::Leaf;
fn description(&self) -> &str {
&self.description
}
fn signature(&self) -> Cow<'_, Signature> {
Cow::Borrowed(&self.signature)
}
fn substitute(&self, subst: &crate::types::Substitution) -> Self {
Self {
args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
signature: self.signature.substitute(subst),
..self.clone()
}
}
}
#[derive(Clone, Debug, Error, PartialEq)]
#[non_exhaustive]
pub enum OpaqueOpError {
#[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
available_ops.iter().join(", ")
)]
OpNotFoundInExtension {
node: Node,
op: OpName,
extension: ExtensionId,
available_ops: Vec<OpName>,
},
#[error("Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}")]
#[allow(missing_docs)]
SignatureMismatch {
node: Node,
extension: ExtensionId,
op: OpName,
stored: Signature,
computed: Signature,
},
#[error("Error in signature of operation '{name}' in {node}: {cause}")]
#[allow(missing_docs)]
SignatureError {
node: Node,
name: OpName,
#[source]
cause: SignatureError,
},
#[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
UnresolvedOp(Node, OpName, ExtensionId),
#[error("Error updating extension registry: {0}")]
ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
}
#[cfg(test)]
mod test {
use ops::OpType;
use crate::extension::resolution::resolve_op_extensions;
use crate::extension::ExtensionRegistry;
use crate::std_extensions::arithmetic::conversions::{self};
use crate::std_extensions::STD_REG;
use crate::{
extension::{
prelude::{bool_t, qb_t, usize_t},
SignatureFunc,
},
std_extensions::arithmetic::int_types::INT_TYPES,
types::FuncValueType,
Extension,
};
use super::*;
fn resolve_res_definition(res: &OpType) -> &OpDef {
res.as_extension_op().unwrap().def()
}
#[test]
fn new_opaque_op() {
let sig = Signature::new_endo(vec![qb_t()]);
let op = OpaqueOp::new(
"res".try_into().unwrap(),
"op",
"desc".into(),
vec![TypeArg::Type { ty: usize_t() }],
sig.clone(),
);
assert_eq!(op.name(), "res.op");
assert_eq!(DataflowOpTrait::description(&op), "desc");
assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]);
assert_eq!(
op.signature().as_ref(),
&sig.with_extension_delta(op.extension().clone())
);
}
#[test]
fn resolve_opaque_op() {
let registry = &STD_REG;
let i0 = &INT_TYPES[0];
let opaque = OpaqueOp::new(
conversions::EXTENSION_ID,
"itobool",
"description".into(),
vec![],
Signature::new(i0.clone(), bool_t()),
);
let mut resolved = opaque.into();
resolve_op_extensions(
Node::from(portgraph::NodeIndex::new(1)),
&mut resolved,
registry,
)
.unwrap();
assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
}
#[test]
fn resolve_missing() {
let val_name = "missing_val";
let comp_name = "missing_comp";
let endo_sig = Signature::new_endo(bool_t());
let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
ext.add_op(
val_name.into(),
"".to_string(),
SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
extension_ref,
)
.unwrap();
ext.add_op(
comp_name.into(),
"".to_string(),
SignatureFunc::MissingComputeFunc,
extension_ref,
)
.unwrap();
});
let ext_id = ext.name().clone();
let registry = ExtensionRegistry::new([ext]);
registry.validate().unwrap();
let opaque_val = OpaqueOp::new(
ext_id.clone(),
val_name,
"".into(),
vec![],
endo_sig.clone(),
);
let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig);
let mut resolved_val = opaque_val.into();
resolve_op_extensions(
Node::from(portgraph::NodeIndex::new(1)),
&mut resolved_val,
®istry,
)
.unwrap();
assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
let mut resolved_comp = opaque_comp.into();
resolve_op_extensions(
Node::from(portgraph::NodeIndex::new(2)),
&mut resolved_comp,
®istry,
)
.unwrap();
assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
}
}