hugr_core/ops/
custom.rs

1//! Extensible operations.
2
3use itertools::Itertools;
4use std::borrow::Cow;
5use std::sync::Arc;
6use thiserror::Error;
7#[cfg(test)]
8use {
9    crate::extension::test::SimpleOpDef,
10    crate::proptest::{any_nonempty_smolstr, any_nonempty_string},
11    ::proptest::prelude::*,
12    ::proptest_derive::Arbitrary,
13};
14
15use crate::extension::simple_op::MakeExtensionOp;
16use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError};
17use crate::types::{type_param::TypeArg, Signature};
18use crate::{ops, IncomingPort, Node};
19
20use super::dataflow::DataflowOpTrait;
21use super::tag::OpTag;
22use super::{NamedOp, OpName, OpNameRef};
23
24/// An operation defined by an [OpDef] from a loaded [Extension].
25///
26/// Extension ops are not serializable. They must be downgraded into an [OpaqueOp] instead.
27/// See [ExtensionOp::make_opaque].
28///
29/// [Extension]: crate::Extension
30#[derive(Clone, Debug, serde::Serialize)]
31#[serde(into = "OpaqueOp")]
32#[cfg_attr(test, derive(Arbitrary))]
33pub struct ExtensionOp {
34    #[cfg_attr(
35        test,
36        proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
37    )]
38    def: Arc<OpDef>,
39    args: Vec<TypeArg>,
40    signature: Signature, // Cache
41}
42
43impl ExtensionOp {
44    /// Create a new ExtensionOp given the type arguments and specified input extensions
45    pub fn new(def: Arc<OpDef>, args: impl Into<Vec<TypeArg>>) -> Result<Self, SignatureError> {
46        let args: Vec<TypeArg> = args.into();
47        let signature = def.compute_signature(&args)?;
48        Ok(Self {
49            def,
50            args,
51            signature,
52        })
53    }
54
55    /// If OpDef is missing binary computation, trust the cached signature.
56    pub(crate) fn new_with_cached(
57        def: Arc<OpDef>,
58        args: impl IntoIterator<Item = TypeArg>,
59        opaque: &OpaqueOp,
60    ) -> Result<Self, SignatureError> {
61        let args: Vec<TypeArg> = args.into_iter().collect();
62        // TODO skip computation depending on config
63        // see https://github.com/CQCL/hugr/issues/1363
64        let signature = match def.compute_signature(&args) {
65            Ok(sig) => sig,
66            Err(SignatureError::MissingComputeFunc) => {
67                // TODO raise warning: https://github.com/CQCL/hugr/issues/1432
68                opaque.signature().into_owned()
69            }
70            Err(e) => return Err(e),
71        };
72        Ok(Self {
73            def,
74            args,
75            signature,
76        })
77    }
78
79    /// Return the argument values for this operation.
80    pub fn args(&self) -> &[TypeArg] {
81        &self.args
82    }
83
84    /// Returns a reference to the [`OpDef`] of this [`ExtensionOp`].
85    pub fn def(&self) -> &OpDef {
86        self.def.as_ref()
87    }
88
89    /// Gets an Arc to the [`OpDef`] of this instance, i.e. usable to create
90    /// new instances.
91    pub fn def_arc(&self) -> &Arc<OpDef> {
92        &self.def
93    }
94
95    /// Attempt to evaluate this operation. See [`OpDef::constant_fold`].
96    pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Value)]) -> ConstFoldResult {
97        self.def().constant_fold(self.args(), consts)
98    }
99
100    /// Creates a new [`OpaqueOp`] as a downgraded version of this
101    /// [`ExtensionOp`].
102    ///
103    /// Regenerating the [`ExtensionOp`] back from the [`OpaqueOp`] requires a
104    /// registry with the appropriate extension. See
105    /// [`crate::Hugr::resolve_extension_defs`].
106    ///
107    /// For a non-cloning version of this operation, use [`OpaqueOp::from`].
108    pub fn make_opaque(&self) -> OpaqueOp {
109        OpaqueOp {
110            extension: self.def.extension_id().clone(),
111            name: self.def.name().clone(),
112            description: self.def.description().into(),
113            args: self.args.clone(),
114            signature: self.signature.clone(),
115        }
116    }
117
118    /// Returns a mutable reference to the cached signature of the operation.
119    pub fn signature_mut(&mut self) -> &mut Signature {
120        &mut self.signature
121    }
122
123    /// Returns a mutable reference to the type arguments of the operation.
124    pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
125        self.args.as_mut_slice()
126    }
127
128    /// Cast the operation to an specific extension op.
129    ///
130    /// Returns `None` if the operation is not of the requested type.
131    pub fn cast<T: MakeExtensionOp>(&self) -> Option<T> {
132        T::from_extension_op(self).ok()
133    }
134
135    /// Returns the extension id of the operation.
136    pub fn extension_id(&self) -> &ExtensionId {
137        self.def.extension_id()
138    }
139}
140
141impl From<ExtensionOp> for OpaqueOp {
142    fn from(op: ExtensionOp) -> Self {
143        let ExtensionOp {
144            def,
145            args,
146            signature,
147        } = op;
148        OpaqueOp {
149            extension: def.extension_id().clone(),
150            name: def.name().clone(),
151            description: def.description().into(),
152            args,
153            signature,
154        }
155    }
156}
157
158impl PartialEq for ExtensionOp {
159    fn eq(&self, other: &Self) -> bool {
160        Arc::<OpDef>::ptr_eq(&self.def, &other.def) && self.args == other.args
161    }
162}
163
164impl Eq for ExtensionOp {}
165
166impl NamedOp for ExtensionOp {
167    /// The name of the operation.
168    fn name(&self) -> OpName {
169        qualify_name(self.def.extension_id(), self.def.name())
170    }
171}
172
173impl DataflowOpTrait for ExtensionOp {
174    const TAG: OpTag = OpTag::Leaf;
175
176    fn description(&self) -> &str {
177        self.def().description()
178    }
179
180    fn signature(&self) -> Cow<'_, Signature> {
181        Cow::Borrowed(&self.signature)
182    }
183
184    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
185        let args = self
186            .args
187            .iter()
188            .map(|ta| ta.substitute(subst))
189            .collect::<Vec<_>>();
190        let signature = self.signature.substitute(subst);
191        Self {
192            def: self.def.clone(),
193            args,
194            signature,
195        }
196    }
197}
198
199/// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`].
200///
201/// [ExtensionOp]s are serialised as `OpaqueOp`s.
202///
203/// The signature of a [ExtensionOp] always includes that op's extension. We do not
204/// require that the `signature` field of [OpaqueOp] contains `extension`,
205/// instead we are careful to add it whenever we look at the `signature` of an
206/// `OpaqueOp`. This is a small efficiency in serialisation and allows us to
207/// be more liberal in deserialisation.
208#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
209#[cfg_attr(test, derive(Arbitrary))]
210pub struct OpaqueOp {
211    extension: ExtensionId,
212    #[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
213    name: OpName,
214    #[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))]
215    description: String, // cache in advance so description() can return &str
216    args: Vec<TypeArg>,
217    // note that the `signature` field might not include `extension`. Thus this must
218    // remain private, and should be accessed through
219    // `DataflowOpTrait::signature`.
220    signature: Signature,
221}
222
223fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName {
224    format!("{}.{}", res_id, name).into()
225}
226
227impl OpaqueOp {
228    /// Creates a new OpaqueOp from all the fields we'd expect to serialize.
229    pub fn new(
230        extension: ExtensionId,
231        name: impl Into<OpName>,
232        description: String,
233        args: impl Into<Vec<TypeArg>>,
234        signature: Signature,
235    ) -> Self {
236        let signature = signature.with_extension_delta(extension.clone());
237        Self {
238            extension,
239            name: name.into(),
240            description,
241            args: args.into(),
242            signature,
243        }
244    }
245
246    /// Returns a mutable reference to the signature of the operation.
247    pub fn signature_mut(&mut self) -> &mut Signature {
248        &mut self.signature
249    }
250}
251
252impl NamedOp for OpaqueOp {
253    /// The name of the operation.
254    fn name(&self) -> OpName {
255        qualify_name(&self.extension, &self.name)
256    }
257}
258impl OpaqueOp {
259    /// Unique name of the operation.
260    pub fn op_name(&self) -> &OpName {
261        &self.name
262    }
263
264    /// Type arguments.
265    pub fn args(&self) -> &[TypeArg] {
266        &self.args
267    }
268
269    /// Parent extension.
270    pub fn extension(&self) -> &ExtensionId {
271        &self.extension
272    }
273
274    /// Returns a mutable reference to the type arguments of the operation.
275    pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] {
276        self.args.as_mut_slice()
277    }
278}
279
280impl DataflowOpTrait for OpaqueOp {
281    const TAG: OpTag = OpTag::Leaf;
282
283    fn description(&self) -> &str {
284        &self.description
285    }
286
287    fn signature(&self) -> Cow<'_, Signature> {
288        Cow::Borrowed(&self.signature)
289    }
290
291    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
292        Self {
293            args: self.args.iter().map(|ta| ta.substitute(subst)).collect(),
294            signature: self.signature.substitute(subst),
295            ..self.clone()
296        }
297    }
298}
299
300/// Errors that arise after loading a Hugr containing opaque ops (serialized just as their names)
301/// when trying to resolve the serialized names against a registry of known Extensions.
302#[derive(Clone, Debug, Error, PartialEq)]
303#[non_exhaustive]
304pub enum OpaqueOpError {
305    /// The Extension was found but did not contain the expected OpDef
306    #[error("Operation '{op}' in {node} not found in Extension {extension}. Available operations: {}",
307            available_ops.iter().join(", ")
308    )]
309    OpNotFoundInExtension {
310        /// The node where the error occurred.
311        node: Node,
312        /// The missing operation.
313        op: OpName,
314        /// The extension where the operation was expected.
315        extension: ExtensionId,
316        /// The available operations in the extension.
317        available_ops: Vec<OpName>,
318    },
319    /// Extension and OpDef found, but computed signature did not match stored
320    #[error("Conflicting signature: resolved {op} in extension {extension} to a concrete implementation which computed {computed} but stored signature was {stored}")]
321    #[allow(missing_docs)]
322    SignatureMismatch {
323        node: Node,
324        extension: ExtensionId,
325        op: OpName,
326        stored: Signature,
327        computed: Signature,
328    },
329    /// An error in computing the signature of the ExtensionOp
330    #[error("Error in signature of operation '{name}' in {node}: {cause}")]
331    #[allow(missing_docs)]
332    SignatureError {
333        node: Node,
334        name: OpName,
335        #[source]
336        cause: SignatureError,
337    },
338    /// Unresolved operation encountered during validation.
339    #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")]
340    UnresolvedOp(Node, OpName, ExtensionId),
341    /// Error updating the extension registry in the Hugr while resolving opaque ops.
342    #[error("Error updating extension registry: {0}")]
343    ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError),
344}
345
346#[cfg(test)]
347mod test {
348
349    use ops::OpType;
350
351    use crate::extension::resolution::resolve_op_extensions;
352    use crate::extension::ExtensionRegistry;
353    use crate::std_extensions::arithmetic::conversions::{self};
354    use crate::std_extensions::STD_REG;
355    use crate::{
356        extension::{
357            prelude::{bool_t, qb_t, usize_t},
358            SignatureFunc,
359        },
360        std_extensions::arithmetic::int_types::INT_TYPES,
361        types::FuncValueType,
362        Extension,
363    };
364
365    use super::*;
366
367    /// Unwrap the replacement type's `OpDef` from the return type of `resolve_op_definition`.
368    fn resolve_res_definition(res: &OpType) -> &OpDef {
369        res.as_extension_op().unwrap().def()
370    }
371
372    #[test]
373    fn new_opaque_op() {
374        let sig = Signature::new_endo(vec![qb_t()]);
375        let op = OpaqueOp::new(
376            "res".try_into().unwrap(),
377            "op",
378            "desc".into(),
379            vec![TypeArg::Type { ty: usize_t() }],
380            sig.clone(),
381        );
382        assert_eq!(op.name(), "res.op");
383        assert_eq!(DataflowOpTrait::description(&op), "desc");
384        assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]);
385        assert_eq!(
386            op.signature().as_ref(),
387            &sig.with_extension_delta(op.extension().clone())
388        );
389    }
390
391    #[test]
392    fn resolve_opaque_op() {
393        let registry = &STD_REG;
394        let i0 = &INT_TYPES[0];
395        let opaque = OpaqueOp::new(
396            conversions::EXTENSION_ID,
397            "itobool",
398            "description".into(),
399            vec![],
400            Signature::new(i0.clone(), bool_t()),
401        );
402        let mut resolved = opaque.into();
403        resolve_op_extensions(
404            Node::from(portgraph::NodeIndex::new(1)),
405            &mut resolved,
406            registry,
407        )
408        .unwrap();
409        assert_eq!(resolve_res_definition(&resolved).name(), "itobool");
410    }
411
412    #[test]
413    fn resolve_missing() {
414        let val_name = "missing_val";
415        let comp_name = "missing_comp";
416        let endo_sig = Signature::new_endo(bool_t());
417
418        let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| {
419            ext.add_op(
420                val_name.into(),
421                "".to_string(),
422                SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()),
423                extension_ref,
424            )
425            .unwrap();
426
427            ext.add_op(
428                comp_name.into(),
429                "".to_string(),
430                SignatureFunc::MissingComputeFunc,
431                extension_ref,
432            )
433            .unwrap();
434        });
435        let ext_id = ext.name().clone();
436
437        let registry = ExtensionRegistry::new([ext]);
438        registry.validate().unwrap();
439        let opaque_val = OpaqueOp::new(
440            ext_id.clone(),
441            val_name,
442            "".into(),
443            vec![],
444            endo_sig.clone(),
445        );
446        let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig);
447        let mut resolved_val = opaque_val.into();
448        resolve_op_extensions(
449            Node::from(portgraph::NodeIndex::new(1)),
450            &mut resolved_val,
451            &registry,
452        )
453        .unwrap();
454        assert_eq!(resolve_res_definition(&resolved_val).name(), val_name);
455
456        let mut resolved_comp = opaque_comp.into();
457        resolve_op_extensions(
458            Node::from(portgraph::NodeIndex::new(2)),
459            &mut resolved_comp,
460            &registry,
461        )
462        .unwrap();
463        assert_eq!(resolve_res_definition(&resolved_comp).name(), comp_name);
464    }
465}