hugr_core/ops/
custom.rs

1//! Extensible operations.
2
3use itertools::Itertools;
4use std::borrow::Cow;
5use std::sync::Arc;
6use thiserror::Error;
7#[cfg(test)]
8use {
9    crate::extension::test::SimpleOpDef, crate::proptest::any_nonempty_smolstr,
10    crate::types::proptest_utils::any_serde_type_arg_vec, ::proptest::prelude::*,
11    ::proptest_derive::Arbitrary,
12};
13
14use crate::core::HugrNode;
15use crate::extension::simple_op::MakeExtensionOp;
16use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
17use crate::types::{Signature, type_param::TypeArg};
18use crate::{IncomingPort, ops};
19
20use super::dataflow::DataflowOpTrait;
21use super::tag::OpTag;
22use super::{NamedOp, OpName, OpNameRef};
23
24/// An operation defined by an [`OpDef`] from a loaded [Extension].
25///
26/// Extension ops are not serializable. They must be downgraded into an [`OpaqueOp`] instead.
27/// See [`ExtensionOp::make_opaque`].
28///
29/// [Extension]: crate::Extension
30#[derive(Clone, Debug, serde::Serialize)]
31#[serde(into = "OpaqueOp")]
32#[cfg_attr(test, derive(Arbitrary))]
33pub struct ExtensionOp {
34    #[cfg_attr(
35        test,
36        proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
37    )]
38    def: Arc<OpDef>,
39    #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
40    args: Vec<TypeArg>,
41    signature: Signature, // Cache
42}
43
44impl ExtensionOp {
45    /// Create a new `ExtensionOp` given the type arguments and specified input extensions
46    pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
47        let args: Vec<TypeArg> = args.into();
48        let signature = def.compute_signature(&args)?;
49        Ok(Self {
50            def,
51            args,
52            signature,
53        })
54    }
55
56    /// If `OpDef` is missing binary computation, trust the cached signature.
57    pub(crate) fn new_with_cached(
58        def: Arc<OpDef>,
59        args: impl IntoIterator<Item = TypeArg>,
60        opaque: &OpaqueOp,
61    ) -> Result<Self, SignatureError> {
62        let args: Vec<TypeArg> = args.into_iter().collect();
63        // TODO skip computation depending on config
64        // see https://github.com/CQCL/hugr/issues/1363
65        let signature = match def.compute_signature(&args) {
66            Ok(sig) => sig,
67            Err(SignatureError::MissingComputeFunc) => {
68                // TODO raise warning: https://github.com/CQCL/hugr/issues/1432
69                opaque.signature().into_owned()
70            }
71            Err(e) => return Err(e),
72        };
73        Ok(Self {
74            def,
75            args,
76            signature,
77        })
78    }
79
80    /// Return the argument values for this operation.
81    #[must_use]
82    pub fn args(&self) -> &[TypeArg] {
83        &self.args
84    }
85
86    /// Returns a reference to the [`OpDef`] of this [`ExtensionOp`].
87    #[must_use]
88    pub fn def(&self) -> &OpDef {
89        self.def.as_ref()
90    }
91
92    /// Gets an Arc to the [`OpDef`] of this instance, i.e. usable to create
93    /// new instances.
94    #[must_use]
95    pub fn def_arc(&self) -> &Arc<OpDef> {
96        &self.def
97    }
98
99    /// Attempt to evaluate this operation. See [`OpDef::constant_fold`].
100    #[must_use]
101    pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
102        self.def().constant_fold(self.args(), consts)
103    }
104
105    /// Creates a new [`OpaqueOp`] as a downgraded version of this
106    /// [`ExtensionOp`].
107    ///
108    /// Regenerating the [`ExtensionOp`] back from the [`OpaqueOp`] requires a
109    /// registry with the appropriate extension.
110    ///
111    /// For a non-cloning version of this operation, use [`OpaqueOp::from`].
112    #[must_use]
113    pub fn make_opaque(&self) -> OpaqueOp {
114        OpaqueOp {
115            extension: self.def.extension_id().clone(),
116            name: self.def.name().clone(),
117            args: self.args.clone(),
118            signature: self.signature.clone(),
119        }
120    }
121
122    /// Returns a mutable reference to the cached signature of the operation.
123    pub fn signature_mut(&mut self) -> &mut Signature {
124        &mut self.signature
125    }
126
127    /// Returns a mutable reference to the type arguments of the operation.
128    pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
129        self.args.as_mut_slice()
130    }
131
132    /// Cast the operation to an specific extension op.
133    ///
134    /// Returns `None` if the operation is not of the requested type.
135    #[must_use]
136    pub fn cast<T: MakeExtensionOp>(&self) -> Option<T> {
137        T::from_extension_op(self).ok()
138    }
139
140    /// Returns the extension id of the operation.
141    #[must_use]
142    pub fn extension_id(&self) -> &ExtensionId {
143        self.def.extension_id()
144    }
145
146    /// Returns the unqualified id of the operation. e.g. 'iadd'
147    ///
148    #[must_use]
149    pub fn unqualified_id(&self) -> &OpNameRef {
150        self.def.name()
151    }
152
153    /// Returns the qualified id of the operation. e.g. 'arithmetic.iadd'
154    #[must_use]
155    pub fn qualified_id(&self) -> OpName {
156        qualify_name(self.extension_id(), self.unqualified_id())
157    }
158}
159
160impl From<ExtensionOp> for OpaqueOp {
161    fn from(op: ExtensionOp) -> Self {
162        let ExtensionOp {
163            def,
164            args,
165            signature,
166        } = op;
167        OpaqueOp {
168            extension: def.extension_id().clone(),
169            name: def.name().clone(),
170            args,
171            signature,
172        }
173    }
174}
175
176impl PartialEq for ExtensionOp {
177    fn eq(&self, other: &Self) -> bool {
178        if Arc::<OpDef>::ptr_eq(&self.def, &other.def) {
179            // If the OpDef is exactly the same, we can skip some checks.
180            self.args() == other.args()
181        } else {
182            self.args() == other.args()
183                && self.signature() == other.signature()
184                && self.def.name() == other.def.name()
185                && self.def.extension_id() == other.def.extension_id()
186        }
187    }
188}
189
190impl Eq for ExtensionOp {}
191
192impl NamedOp for ExtensionOp {
193    /// The name of the operation.
194    fn name(&self) -> OpName {
195        self.qualified_id()
196    }
197}
198
199impl DataflowOpTrait for ExtensionOp {
200    const TAG: OpTag = OpTag::Leaf;
201
202    fn description(&self) -> &str {
203        self.def().description()
204    }
205
206    fn signature(&self) -> Cow<'_, Signature> {
207        Cow::Borrowed(&self.signature)
208    }
209
210    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
211        let args = self
212            .args
213            .iter()
214            .map(|ta| ta.substitute(subst))
215            .collect::<Vec<_>>();
216        let signature = self.signature.substitute(subst);
217        Self {
218            def: self.def.clone(),
219            args,
220            signature,
221        }
222    }
223}
224
225/// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`].
226///
227/// [`ExtensionOp`]s are serialised as `OpaqueOp`s.
228///
229/// The signature of a [`ExtensionOp`] always includes that op's extension. We do not
230/// require that the `signature` field of [`OpaqueOp`] contains `extension`,
231/// instead we are careful to add it whenever we look at the `signature` of an
232/// `OpaqueOp`. This is a small efficiency in serialisation and allows us to
233/// be more liberal in deserialisation.
234#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
235#[cfg_attr(test, derive(Arbitrary))]
236pub struct OpaqueOp {
237    extension: ExtensionId,
238    #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
239    name: OpName,
240    #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
241    args: Vec<TypeArg>,
242    // note that the `signature` field might not include `extension`. Thus this must
243    // remain private, and should be accessed through
244    // `DataflowOpTrait::signature`.
245    signature: Signature,
246}
247
248fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
249    format!("{res_id}.{name}").into()
250}
251
252impl OpaqueOp {
253    /// Creates a new `OpaqueOp` from all the fields we'd expect to serialize.
254    pub fn new(
255        extension: ExtensionId,
256        name: impl Into<OpName>,
257        args: impl Into<Vec<TypeArg>>,
258        signature: Signature,
259    ) -> Self {
260        Self {
261            extension,
262            name: name.into(),
263            args: args.into(),
264            signature,
265        }
266    }
267
268    /// Returns a mutable reference to the signature of the operation.
269    pub fn signature_mut(&mut self) -> &mut Signature {
270        &mut self.signature
271    }
272}
273
274impl NamedOp for OpaqueOp {
275    fn name(&self) -> OpName {
276        format!("OpaqueOp:{}", self.qualified_id()).into()
277    }
278}
279
280impl OpaqueOp {
281    /// Unique name of the operation.
282    #[must_use]
283    pub fn unqualified_id(&self) -> &OpName {
284        &self.name
285    }
286
287    /// Unique name of the operation.
288    #[must_use]
289    pub fn qualified_id(&self) -> OpName {
290        qualify_name(self.extension(), self.unqualified_id())
291    }
292
293    /// Type arguments.
294    #[must_use]
295    pub fn args(&self) -> &[TypeArg] {
296        &self.args
297    }
298
299    /// Parent extension.
300    #[must_use]
301    pub fn extension(&self) -> &ExtensionId {
302        &self.extension
303    }
304
305    /// Returns a mutable reference to the type arguments of the operation.
306    pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
307        self.args.as_mut_slice()
308    }
309}
310
311impl DataflowOpTrait for OpaqueOp {
312    const TAG: OpTag = OpTag::Leaf;
313
314    fn description(&self) -> &str {
315        "Opaque operation"
316    }
317
318    fn signature(&self) -> Cow<'_, Signature> {
319        Cow::Borrowed(&self.signature)
320    }
321
322    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
323        Self {
324            args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
325            signature: self.signature.substitute(subst),
326            ..self.clone()
327        }
328    }
329}
330
331/// Errors that arise after loading a Hugr containing opaque ops (serialized just as their names)
332/// when trying to resolve the serialized names against a registry of known Extensions.
333#[derive(Clone, Debug, Error, PartialEq)]
334#[non_exhaustive]
335pub enum OpaqueOpError<N: HugrNode> {
336    /// The Extension was found but did not contain the expected `OpDef`
337    #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
338            available_ops.iter().join(", ")
339    )]
340    OpNotFoundInExtension {
341        /// The node where the error occurred.
342        node: N,
343        /// The missing operation.
344        op: OpName,
345        /// The extension where the operation was expected.
346        extension: ExtensionId,
347        /// The available operations in the extension.
348        available_ops: Vec<OpName>,
349    },
350    /// Extension and `OpDef` found, but computed signature did not match stored
351    #[error(
352        "Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}"
353    )]
354    #[allow(missing_docs)]
355    SignatureMismatch {
356        node: N,
357        extension: ExtensionId,
358        op: OpName,
359        stored: Box<Signature>,
360        computed: Box<Signature>,
361    },
362    /// An error in computing the signature of the `ExtensionOp`
363    #[error("Error in signature of operation '{name}' in {node}: {cause}")]
364    #[allow(missing_docs)]
365    SignatureError {
366        node: N,
367        name: OpName,
368        #[source]
369        cause: SignatureError,
370    },
371    /// Unresolved operation encountered during validation.
372    #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
373    UnresolvedOp(N, OpName, ExtensionId),
374    /// Error updating the extension registry in the Hugr while resolving opaque ops.
375    #[error("Error updating extension registry: {0}")]
376    ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
377}
378
379#[cfg(test)]
380mod test {
381
382    use ops::OpType;
383
384    use crate::Node;
385    use crate::extension::ExtensionRegistry;
386    use crate::extension::resolution::resolve_op_extensions;
387    use crate::std_extensions::STD_REG;
388    use crate::std_extensions::arithmetic::conversions::{self};
389    use crate::{
390        Extension,
391        extension::{
392            SignatureFunc,
393            prelude::{bool_t, qb_t, usize_t},
394        },
395        std_extensions::arithmetic::int_types::INT_TYPES,
396        types::FuncValueType,
397    };
398
399    use super::*;
400
401    /// Unwrap the replacement type's `OpDef` from the return type of `resolve_op_definition`.
402    fn resolve_res_definition(res: &OpType) -> &OpDef {
403        res.as_extension_op().unwrap().def()
404    }
405
406    #[test]
407    fn new_opaque_op() {
408        let sig = Signature::new_endo(vec![qb_t()]);
409        let op = OpaqueOp::new(
410            "res".try_into().unwrap(),
411            "op",
412            vec![usize_t().into()],
413            sig.clone(),
414        );
415        assert_eq!(op.name(), "OpaqueOp:res.op");
416        assert_eq!(op.args(), &[usize_t().into()]);
417        assert_eq!(op.signature().as_ref(), &sig);
418    }
419
420    #[test]
421    fn resolve_opaque_op() {
422        let registry = &STD_REG;
423        let i0 = &INT_TYPES[0];
424        let opaque = OpaqueOp::new(
425            conversions::EXTENSION_ID,
426            "itobool",
427            vec![],
428            Signature::new(i0.clone(), bool_t()),
429        );
430        let mut resolved = opaque.into();
431        resolve_op_extensions(
432            Node::from(portgraph::NodeIndex::new(1)),
433            &mut resolved,
434            registry,
435        )
436        .unwrap();
437        assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
438    }
439
440    #[test]
441    fn resolve_missing() {
442        let val_name = "missing_val";
443        let comp_name = "missing_comp";
444        let endo_sig = Signature::new_endo(bool_t());
445
446        let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
447            ext.add_op(
448                val_name.into(),
449                String::new(),
450                SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
451                extension_ref,
452            )
453            .unwrap();
454
455            ext.add_op(
456                comp_name.into(),
457                String::new(),
458                SignatureFunc::MissingComputeFunc,
459                extension_ref,
460            )
461            .unwrap();
462        });
463        let ext_id = ext.name().clone();
464
465        let registry = ExtensionRegistry::new([ext]);
466        registry.validate().unwrap();
467        let opaque_val = OpaqueOp::new(ext_id.clone(), val_name, vec![], endo_sig.clone());
468        let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, vec![], endo_sig);
469        let mut resolved_val = opaque_val.into();
470        resolve_op_extensions(
471            Node::from(portgraph::NodeIndex::new(1)),
472            &mut resolved_val,
473            &registry,
474        )
475        .unwrap();
476        assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
477
478        let mut resolved_comp = opaque_comp.into();
479        resolve_op_extensions(
480            Node::from(portgraph::NodeIndex::new(2)),
481            &mut resolved_comp,
482            &registry,
483        )
484        .unwrap();
485        assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
486    }
487}