hugr_core/ops/
custom.rs

1//! Extensible operations.
2
3use itertools::Itertools;
4use std::borrow::Cow;
5use std::sync::Arc;
6use thiserror::Error;
7#[cfg(test)]
8use {
9    crate::extension::test::SimpleOpDef, crate::proptest::any_nonempty_smolstr,
10    ::proptest::prelude::*, ::proptest_derive::Arbitrary,
11};
12
13use crate::core::HugrNode;
14use crate::extension::simple_op::MakeExtensionOp;
15use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
16use crate::types::{Signature, type_param::TypeArg};
17use crate::{IncomingPort, ops};
18
19use super::dataflow::DataflowOpTrait;
20use super::tag::OpTag;
21use super::{NamedOp, OpName, OpNameRef};
22
23/// An operation defined by an [`OpDef`] from a loaded [Extension].
24///
25/// Extension ops are not serializable. They must be downgraded into an [`OpaqueOp`] instead.
26/// See [`ExtensionOp::make_opaque`].
27///
28/// [Extension]: crate::Extension
29#[derive(Clone, Debug, serde::Serialize)]
30#[serde(into = "OpaqueOp")]
31#[cfg_attr(test, derive(Arbitrary))]
32pub struct ExtensionOp {
33    #[cfg_attr(
34        test,
35        proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
36    )]
37    def: Arc<OpDef>,
38    args: Vec<TypeArg>,
39    signature: Signature, // Cache
40}
41
42impl ExtensionOp {
43    /// Create a new `ExtensionOp` given the type arguments and specified input extensions
44    pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
45        let args: Vec<TypeArg> = args.into();
46        let signature = def.compute_signature(&args)?;
47        Ok(Self {
48            def,
49            args,
50            signature,
51        })
52    }
53
54    /// If `OpDef` is missing binary computation, trust the cached signature.
55    pub(crate) fn new_with_cached(
56        def: Arc<OpDef>,
57        args: impl IntoIterator<Item = TypeArg>,
58        opaque: &OpaqueOp,
59    ) -> Result<Self, SignatureError> {
60        let args: Vec<TypeArg> = args.into_iter().collect();
61        // TODO skip computation depending on config
62        // see https://github.com/CQCL/hugr/issues/1363
63        let signature = match def.compute_signature(&args) {
64            Ok(sig) => sig,
65            Err(SignatureError::MissingComputeFunc) => {
66                // TODO raise warning: https://github.com/CQCL/hugr/issues/1432
67                opaque.signature().into_owned()
68            }
69            Err(e) => return Err(e),
70        };
71        Ok(Self {
72            def,
73            args,
74            signature,
75        })
76    }
77
78    /// Return the argument values for this operation.
79    #[must_use]
80    pub fn args(&self) -> &[TypeArg] {
81        &self.args
82    }
83
84    /// Returns a reference to the [`OpDef`] of this [`ExtensionOp`].
85    #[must_use]
86    pub fn def(&self) -> &OpDef {
87        self.def.as_ref()
88    }
89
90    /// Gets an Arc to the [`OpDef`] of this instance, i.e. usable to create
91    /// new instances.
92    #[must_use]
93    pub fn def_arc(&self) -> &Arc<OpDef> {
94        &self.def
95    }
96
97    /// Attempt to evaluate this operation. See [`OpDef::constant_fold`].
98    #[must_use]
99    pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
100        self.def().constant_fold(self.args(), consts)
101    }
102
103    /// Creates a new [`OpaqueOp`] as a downgraded version of this
104    /// [`ExtensionOp`].
105    ///
106    /// Regenerating the [`ExtensionOp`] back from the [`OpaqueOp`] requires a
107    /// registry with the appropriate extension.
108    ///
109    /// For a non-cloning version of this operation, use [`OpaqueOp::from`].
110    #[must_use]
111    pub fn make_opaque(&self) -> OpaqueOp {
112        OpaqueOp {
113            extension: self.def.extension_id().clone(),
114            name: self.def.name().clone(),
115            args: self.args.clone(),
116            signature: self.signature.clone(),
117        }
118    }
119
120    /// Returns a mutable reference to the cached signature of the operation.
121    pub fn signature_mut(&mut self) -> &mut Signature {
122        &mut self.signature
123    }
124
125    /// Returns a mutable reference to the type arguments of the operation.
126    pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
127        self.args.as_mut_slice()
128    }
129
130    /// Cast the operation to an specific extension op.
131    ///
132    /// Returns `None` if the operation is not of the requested type.
133    #[must_use]
134    pub fn cast<T: MakeExtensionOp>(&self) -> Option<T> {
135        T::from_extension_op(self).ok()
136    }
137
138    /// Returns the extension id of the operation.
139    #[must_use]
140    pub fn extension_id(&self) -> &ExtensionId {
141        self.def.extension_id()
142    }
143
144    /// Returns the unqualified id of the operation. e.g. 'iadd'
145    ///
146    #[must_use]
147    pub fn unqualified_id(&self) -> &OpNameRef {
148        self.def.name()
149    }
150
151    /// Returns the qualified id of the operation. e.g. 'arithmetic.iadd'
152    #[must_use]
153    pub fn qualified_id(&self) -> OpName {
154        qualify_name(self.extension_id(), self.unqualified_id())
155    }
156}
157
158impl From<ExtensionOp> for OpaqueOp {
159    fn from(op: ExtensionOp) -> Self {
160        let ExtensionOp {
161            def,
162            args,
163            signature,
164        } = op;
165        OpaqueOp {
166            extension: def.extension_id().clone(),
167            name: def.name().clone(),
168            args,
169            signature,
170        }
171    }
172}
173
174impl PartialEq for ExtensionOp {
175    fn eq(&self, other: &Self) -> bool {
176        if Arc::<OpDef>::ptr_eq(&self.def, &other.def) {
177            // If the OpDef is exactly the same, we can skip some checks.
178            self.args() == other.args()
179        } else {
180            self.args() == other.args()
181                && self.signature() == other.signature()
182                && self.def.name() == other.def.name()
183                && self.def.extension_id() == other.def.extension_id()
184        }
185    }
186}
187
188impl Eq for ExtensionOp {}
189
190impl NamedOp for ExtensionOp {
191    /// The name of the operation.
192    fn name(&self) -> OpName {
193        self.qualified_id()
194    }
195}
196
197impl DataflowOpTrait for ExtensionOp {
198    const TAG: OpTag = OpTag::Leaf;
199
200    fn description(&self) -> &str {
201        self.def().description()
202    }
203
204    fn signature(&self) -> Cow<'_, Signature> {
205        Cow::Borrowed(&self.signature)
206    }
207
208    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
209        let args = self
210            .args
211            .iter()
212            .map(|ta| ta.substitute(subst))
213            .collect::<Vec<_>>();
214        let signature = self.signature.substitute(subst);
215        Self {
216            def: self.def.clone(),
217            args,
218            signature,
219        }
220    }
221}
222
223/// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`].
224///
225/// [`ExtensionOp`]s are serialised as `OpaqueOp`s.
226///
227/// The signature of a [`ExtensionOp`] always includes that op's extension. We do not
228/// require that the `signature` field of [`OpaqueOp`] contains `extension`,
229/// instead we are careful to add it whenever we look at the `signature` of an
230/// `OpaqueOp`. This is a small efficiency in serialisation and allows us to
231/// be more liberal in deserialisation.
232#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
233#[cfg_attr(test, derive(Arbitrary))]
234pub struct OpaqueOp {
235    extension: ExtensionId,
236    #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
237    name: OpName,
238    args: Vec<TypeArg>,
239    // note that the `signature` field might not include `extension`. Thus this must
240    // remain private, and should be accessed through
241    // `DataflowOpTrait::signature`.
242    signature: Signature,
243}
244
245fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
246    format!("{res_id}.{name}").into()
247}
248
249impl OpaqueOp {
250    /// Creates a new `OpaqueOp` from all the fields we'd expect to serialize.
251    pub fn new(
252        extension: ExtensionId,
253        name: impl Into<OpName>,
254        args: impl Into<Vec<TypeArg>>,
255        signature: Signature,
256    ) -> Self {
257        Self {
258            extension,
259            name: name.into(),
260            args: args.into(),
261            signature,
262        }
263    }
264
265    /// Returns a mutable reference to the signature of the operation.
266    pub fn signature_mut(&mut self) -> &mut Signature {
267        &mut self.signature
268    }
269}
270
271impl NamedOp for OpaqueOp {
272    fn name(&self) -> OpName {
273        format!("OpaqueOp:{}", self.qualified_id()).into()
274    }
275}
276
277impl OpaqueOp {
278    /// Unique name of the operation.
279    #[must_use]
280    pub fn unqualified_id(&self) -> &OpName {
281        &self.name
282    }
283
284    /// Unique name of the operation.
285    #[must_use]
286    pub fn qualified_id(&self) -> OpName {
287        qualify_name(self.extension(), self.unqualified_id())
288    }
289
290    /// Type arguments.
291    #[must_use]
292    pub fn args(&self) -> &[TypeArg] {
293        &self.args
294    }
295
296    /// Parent extension.
297    #[must_use]
298    pub fn extension(&self) -> &ExtensionId {
299        &self.extension
300    }
301
302    /// Returns a mutable reference to the type arguments of the operation.
303    pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
304        self.args.as_mut_slice()
305    }
306}
307
308impl DataflowOpTrait for OpaqueOp {
309    const TAG: OpTag = OpTag::Leaf;
310
311    fn description(&self) -> &str {
312        "Opaque operation"
313    }
314
315    fn signature(&self) -> Cow<'_, Signature> {
316        Cow::Borrowed(&self.signature)
317    }
318
319    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
320        Self {
321            args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
322            signature: self.signature.substitute(subst),
323            ..self.clone()
324        }
325    }
326}
327
328/// Errors that arise after loading a Hugr containing opaque ops (serialized just as their names)
329/// when trying to resolve the serialized names against a registry of known Extensions.
330#[derive(Clone, Debug, Error, PartialEq)]
331#[non_exhaustive]
332pub enum OpaqueOpError<N: HugrNode> {
333    /// The Extension was found but did not contain the expected `OpDef`
334    #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
335            available_ops.iter().join(", ")
336    )]
337    OpNotFoundInExtension {
338        /// The node where the error occurred.
339        node: N,
340        /// The missing operation.
341        op: OpName,
342        /// The extension where the operation was expected.
343        extension: ExtensionId,
344        /// The available operations in the extension.
345        available_ops: Vec<OpName>,
346    },
347    /// Extension and `OpDef` found, but computed signature did not match stored
348    #[error(
349        "Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}"
350    )]
351    #[allow(missing_docs)]
352    SignatureMismatch {
353        node: N,
354        extension: ExtensionId,
355        op: OpName,
356        stored: Signature,
357        computed: Signature,
358    },
359    /// An error in computing the signature of the `ExtensionOp`
360    #[error("Error in signature of operation '{name}' in {node}: {cause}")]
361    #[allow(missing_docs)]
362    SignatureError {
363        node: N,
364        name: OpName,
365        #[source]
366        cause: SignatureError,
367    },
368    /// Unresolved operation encountered during validation.
369    #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
370    UnresolvedOp(N, OpName, ExtensionId),
371    /// Error updating the extension registry in the Hugr while resolving opaque ops.
372    #[error("Error updating extension registry: {0}")]
373    ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
374}
375
376#[cfg(test)]
377mod test {
378
379    use ops::OpType;
380
381    use crate::Node;
382    use crate::extension::ExtensionRegistry;
383    use crate::extension::resolution::resolve_op_extensions;
384    use crate::std_extensions::STD_REG;
385    use crate::std_extensions::arithmetic::conversions::{self};
386    use crate::{
387        Extension,
388        extension::{
389            SignatureFunc,
390            prelude::{bool_t, qb_t, usize_t},
391        },
392        std_extensions::arithmetic::int_types::INT_TYPES,
393        types::FuncValueType,
394    };
395
396    use super::*;
397
398    /// Unwrap the replacement type's `OpDef` from the return type of `resolve_op_definition`.
399    fn resolve_res_definition(res: &OpType) -> &OpDef {
400        res.as_extension_op().unwrap().def()
401    }
402
403    #[test]
404    fn new_opaque_op() {
405        let sig = Signature::new_endo(vec![qb_t()]);
406        let op = OpaqueOp::new(
407            "res".try_into().unwrap(),
408            "op",
409            vec![TypeArg::Type { ty: usize_t() }],
410            sig.clone(),
411        );
412        assert_eq!(op.name(), "OpaqueOp:res.op");
413        assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]);
414        assert_eq!(op.signature().as_ref(), &sig);
415    }
416
417    #[test]
418    fn resolve_opaque_op() {
419        let registry = &STD_REG;
420        let i0 = &INT_TYPES[0];
421        let opaque = OpaqueOp::new(
422            conversions::EXTENSION_ID,
423            "itobool",
424            vec![],
425            Signature::new(i0.clone(), bool_t()),
426        );
427        let mut resolved = opaque.into();
428        resolve_op_extensions(
429            Node::from(portgraph::NodeIndex::new(1)),
430            &mut resolved,
431            registry,
432        )
433        .unwrap();
434        assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
435    }
436
437    #[test]
438    fn resolve_missing() {
439        let val_name = "missing_val";
440        let comp_name = "missing_comp";
441        let endo_sig = Signature::new_endo(bool_t());
442
443        let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
444            ext.add_op(
445                val_name.into(),
446                String::new(),
447                SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
448                extension_ref,
449            )
450            .unwrap();
451
452            ext.add_op(
453                comp_name.into(),
454                String::new(),
455                SignatureFunc::MissingComputeFunc,
456                extension_ref,
457            )
458            .unwrap();
459        });
460        let ext_id = ext.name().clone();
461
462        let registry = ExtensionRegistry::new([ext]);
463        registry.validate().unwrap();
464        let opaque_val = OpaqueOp::new(ext_id.clone(), val_name, vec![], endo_sig.clone());
465        let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, vec![], endo_sig);
466        let mut resolved_val = opaque_val.into();
467        resolve_op_extensions(
468            Node::from(portgraph::NodeIndex::new(1)),
469            &mut resolved_val,
470            &registry,
471        )
472        .unwrap();
473        assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
474
475        let mut resolved_comp = opaque_comp.into();
476        resolve_op_extensions(
477            Node::from(portgraph::NodeIndex::new(2)),
478            &mut resolved_comp,
479            &registry,
480        )
481        .unwrap();
482        assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
483    }
484}