hugr_core/extension/
op_def.rs

1use std::cmp::min;
2use std::collections::HashMap;
3use std::collections::btree_map::Entry;
4use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, Weak};
6
7use serde_with::serde_as;
8
9use super::{
10    ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
11    SignatureError,
12};
13
14use crate::Hugr;
15use crate::envelope::serde_with::AsStringEnvelope;
16use crate::ops::{OpName, OpNameRef};
17use crate::types::type_param::{TypeArg, TypeParam, check_term_types};
18use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
19mod serialize_signature_func;
20
21/// Trait necessary for binary computations of `OpDef` signature
22pub trait CustomSignatureFunc: Send + Sync {
23    /// Compute signature of node given
24    /// values for the type parameters,
25    /// the operation definition and the extension registry.
26    fn compute_signature<'o, 'a: 'o>(
27        &'a self,
28        arg_values: &[TypeArg],
29        def: &'o OpDef,
30    ) -> Result<PolyFuncTypeRV, SignatureError>;
31    /// The declared type parameters which require values in order for signature to
32    /// be computed.
33    fn static_params(&self) -> &[TypeParam];
34}
35
36/// Compute signature of `OpDef` given type arguments.
37pub trait SignatureFromArgs: Send + Sync {
38    /// Compute signature of node given
39    /// values for the type parameters.
40    fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError>;
41    /// The declared type parameters which require values in order for signature to
42    /// be computed.
43    fn static_params(&self) -> &[TypeParam];
44}
45
46impl<T: SignatureFromArgs> CustomSignatureFunc for T {
47    #[inline]
48    fn compute_signature<'o, 'a: 'o>(
49        &'a self,
50        arg_values: &[TypeArg],
51        _def: &'o OpDef,
52    ) -> Result<PolyFuncTypeRV, SignatureError> {
53        SignatureFromArgs::compute_signature(self, arg_values)
54    }
55
56    #[inline]
57    fn static_params(&self) -> &[TypeParam] {
58        SignatureFromArgs::static_params(self)
59    }
60}
61
62/// Trait for validating type arguments to a `PolyFuncTypeRV` beyond conformation to
63/// declared type parameter (which should have been checked beforehand).
64pub trait ValidateTypeArgs: Send + Sync {
65    /// Validate the type arguments of node given
66    /// values for the type parameters,
67    /// the operation definition and the extension registry.
68    fn validate<'o, 'a: 'o>(
69        &self,
70        arg_values: &[TypeArg],
71        def: &'o OpDef,
72    ) -> Result<(), SignatureError>;
73}
74
75/// Trait for validating type arguments to a `PolyFuncTypeRV` beyond conformation to
76/// declared type parameter (which should have been checked beforehand), given just the arguments.
77pub trait ValidateJustArgs: Send + Sync {
78    /// Validate the type arguments of node given
79    /// values for the type parameters.
80    fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError>;
81}
82
83impl<T: ValidateJustArgs> ValidateTypeArgs for T {
84    #[inline]
85    fn validate<'o, 'a: 'o>(
86        &self,
87        arg_values: &[TypeArg],
88        _def: &'o OpDef,
89    ) -> Result<(), SignatureError> {
90        ValidateJustArgs::validate(self, arg_values)
91    }
92}
93
94/// Trait for Extensions to provide custom binary code that can lower an operation to
95/// a Hugr using only a limited set of other extensions. That is, trait
96/// implementations can return a Hugr that implements the operation using only
97/// those extensions and that can be used to replace the operation node. This may be
98/// useful for third-party Extensions or as a fallback for tools that do not support
99/// the operation natively.
100///
101/// This trait allows the Hugr to be varied according to the operation's [`TypeArg`]s;
102/// if this is not necessary then a single Hugr can be provided instead via
103/// [`LowerFunc::FixedHugr`].
104pub trait CustomLowerFunc: Send + Sync {
105    /// Return a Hugr that implements the node using only the specified available extensions;
106    /// may fail.
107    /// TODO: some error type to indicate Extensions required?
108    fn try_lower(
109        &self,
110        name: &OpNameRef,
111        arg_values: &[TypeArg],
112        misc: &HashMap<String, serde_json::Value>,
113        available_extensions: &ExtensionSet,
114    ) -> Option<Hugr>;
115}
116
117/// Encode a signature as [`PolyFuncTypeRV`] but with additional validation of type
118/// arguments via a custom binary. The binary cannot be serialized so will be
119/// lost over a serialization round-trip.
120pub struct CustomValidator {
121    poly_func: PolyFuncTypeRV,
122    /// Custom function for validating type arguments before returning the signature.
123    pub(crate) validate: Box<dyn ValidateTypeArgs>,
124}
125
126impl CustomValidator {
127    /// Encode a signature using a `PolyFuncTypeRV`, with a custom function for
128    /// validating type arguments before returning the signature.
129    pub fn new(
130        poly_func: impl Into<PolyFuncTypeRV>,
131        validate: impl ValidateTypeArgs + 'static,
132    ) -> Self {
133        Self {
134            poly_func: poly_func.into(),
135            validate: Box::new(validate),
136        }
137    }
138
139    /// Return a mutable reference to the `PolyFuncType`.
140    pub(super) fn poly_func_mut(&mut self) -> &mut PolyFuncTypeRV {
141        &mut self.poly_func
142    }
143}
144
145/// The ways in which an `OpDef` may compute the Signature of each operation node.
146pub enum SignatureFunc {
147    /// An explicit polymorphic function type.
148    PolyFuncType(PolyFuncTypeRV),
149    /// A polymorphic function type (like [`Self::PolyFuncType`] but also with a custom binary for validating type arguments.
150    CustomValidator(CustomValidator),
151    /// Serialized declaration specified a custom validate binary but it was not provided.
152    MissingValidateFunc(PolyFuncTypeRV),
153    /// A custom binary which computes a polymorphic function type given values
154    /// for its static type parameters.
155    CustomFunc(Box<dyn CustomSignatureFunc>),
156    /// Serialized declaration specified a custom compute binary but it was not provided.
157    MissingComputeFunc,
158}
159
160impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {
161    fn from(v: T) -> Self {
162        Self::CustomFunc(Box::new(v))
163    }
164}
165
166impl From<PolyFuncType> for SignatureFunc {
167    fn from(value: PolyFuncType) -> Self {
168        Self::PolyFuncType(value.into())
169    }
170}
171
172impl From<PolyFuncTypeRV> for SignatureFunc {
173    fn from(v: PolyFuncTypeRV) -> Self {
174        Self::PolyFuncType(v)
175    }
176}
177
178impl From<FuncValueType> for SignatureFunc {
179    fn from(v: FuncValueType) -> Self {
180        Self::PolyFuncType(v.into())
181    }
182}
183
184impl From<Signature> for SignatureFunc {
185    fn from(v: Signature) -> Self {
186        Self::PolyFuncType(FuncValueType::from(v).into())
187    }
188}
189
190impl From<CustomValidator> for SignatureFunc {
191    fn from(v: CustomValidator) -> Self {
192        Self::CustomValidator(v)
193    }
194}
195
196impl SignatureFunc {
197    fn static_params(&self) -> Result<&[TypeParam], SignatureError> {
198        Ok(match self {
199            SignatureFunc::PolyFuncType(ts)
200            | SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. })
201            | SignatureFunc::MissingValidateFunc(ts) => ts.params(),
202            SignatureFunc::CustomFunc(func) => func.static_params(),
203            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
204        })
205    }
206
207    /// If the signature is missing a custom validation function, ignore and treat as
208    /// self-contained type scheme (with no custom validation).
209    pub fn ignore_missing_validation(&mut self) {
210        if let SignatureFunc::MissingValidateFunc(ts) = self {
211            *self = SignatureFunc::PolyFuncType(ts.clone());
212        }
213    }
214
215    /// Compute the concrete signature ([`FuncValueType`]).
216    ///
217    /// # Panics
218    ///
219    /// Panics if `self` is a [`SignatureFunc::CustomFunc`] and there are not enough type
220    /// arguments provided to match the number of static parameters.
221    ///
222    /// # Errors
223    ///
224    /// This function will return an error if the type arguments are invalid or
225    /// there is some error in type computation.
226    pub fn compute_signature(
227        &self,
228        def: &OpDef,
229        args: &[TypeArg],
230    ) -> Result<Signature, SignatureError> {
231        let temp: PolyFuncTypeRV; // to keep alive
232        let (pf, args) = match &self {
233            SignatureFunc::CustomValidator(custom) => {
234                custom.validate.validate(args, def)?;
235                (&custom.poly_func, args)
236            }
237            SignatureFunc::PolyFuncType(ts) => (ts, args),
238            SignatureFunc::CustomFunc(func) => {
239                let static_params = func.static_params();
240                let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
241
242                check_term_types(static_args, static_params)?;
243                temp = func.compute_signature(static_args, def)?;
244                (&temp, other_args)
245            }
246            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
247            // TODO raise warning: https://github.com/CQCL/hugr/issues/1432
248            SignatureFunc::MissingValidateFunc(ts) => (ts, args),
249        };
250        let res = pf.instantiate(args)?;
251
252        // If there are any row variables left, this will fail with an error:
253        res.try_into()
254    }
255}
256
257impl Debug for SignatureFunc {
258    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
259        match self {
260            Self::CustomValidator(ts) => ts.poly_func.fmt(f),
261            Self::PolyFuncType(ts) => ts.fmt(f),
262            Self::CustomFunc { .. } => f.write_str("<custom sig>"),
263            Self::MissingComputeFunc => f.write_str("<missing custom sig>"),
264            Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"),
265        }
266    }
267}
268
269/// Different ways that an [OpDef] can lower operation nodes i.e. provide a Hugr
270/// that implements the operation using a set of other extensions.
271#[serde_as]
272#[derive(serde::Deserialize, serde::Serialize)]
273#[serde(untagged)]
274pub enum LowerFunc {
275    /// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s,
276    /// this will generally only be applicable if the [OpDef] has no [TypeParam]s.
277    FixedHugr {
278        /// The extensions required by the [`Hugr`]
279        extensions: ExtensionSet,
280        /// The [`Hugr`] to be used to replace [ExtensionOp]s matching the parent
281        /// [OpDef]
282        ///
283        /// [ExtensionOp]: crate::ops::ExtensionOp
284        #[serde_as(as = "Box<AsStringEnvelope>")]
285        hugr: Box<Hugr>,
286    },
287    /// Custom binary function that can (fallibly) compute a Hugr
288    /// for the particular instance and set of available extensions.
289    #[serde(skip)]
290    CustomFunc(Box<dyn CustomLowerFunc>),
291}
292
293impl Debug for LowerFunc {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        match self {
296            Self::FixedHugr { .. } => write!(f, "FixedHugr"),
297            Self::CustomFunc(_) => write!(f, "<custom lower>"),
298        }
299    }
300}
301
302/// Serializable definition for dynamically loaded operations.
303///
304/// TODO: Define a way to construct new `OpDef`'s from a serialized definition.
305#[derive(Debug, serde::Serialize, serde::Deserialize)]
306pub struct OpDef {
307    /// The unique Extension owning this `OpDef` (of which this `OpDef` is a member)
308    extension: ExtensionId,
309    /// A weak reference to the extension defining this operation.
310    #[serde(skip)]
311    extension_ref: Weak<Extension>,
312    /// Unique identifier of the operation. Used to look up `OpDefs` in the registry
313    /// when deserializing nodes (which store only the name).
314    name: OpName,
315    /// Human readable description of the operation.
316    description: String,
317    /// Miscellaneous data associated with the operation.
318    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
319    misc: HashMap<String, serde_json::Value>,
320
321    #[serde(with = "serialize_signature_func", flatten)]
322    signature_func: SignatureFunc,
323    // Some operations cannot lower themselves and tools that do not understand them
324    // can only treat them as opaque/black-box ops.
325    #[serde(default, skip_serializing_if = "Vec::is_empty")]
326    pub(crate) lower_funcs: Vec<LowerFunc>,
327
328    /// Operations can optionally implement [`ConstFold`] to implement constant folding.
329    #[serde(skip)]
330    constant_folder: Option<Box<dyn ConstFold>>,
331}
332
333impl OpDef {
334    /// Check provided type arguments are valid against their extensions,
335    /// against parameters, and that no type variables are used as static arguments
336    /// (to [`compute_signature`][CustomSignatureFunc::compute_signature])
337    pub fn validate_args(
338        &self,
339        args: &[TypeArg],
340        var_decls: &[TypeParam],
341    ) -> Result<(), SignatureError> {
342        let temp: PolyFuncTypeRV; // to keep alive
343        let (pf, args) = match &self.signature_func {
344            SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args),
345            SignatureFunc::PolyFuncType(ts) => (ts, args),
346            SignatureFunc::CustomFunc(custom) => {
347                let (static_args, other_args) =
348                    args.split_at(min(custom.static_params().len(), args.len()));
349                static_args.iter().try_for_each(|ta| ta.validate(&[]))?;
350                check_term_types(static_args, custom.static_params())?;
351                temp = custom.compute_signature(static_args, self)?;
352                (&temp, other_args)
353            }
354            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
355            SignatureFunc::MissingValidateFunc(_) => {
356                return Err(SignatureError::MissingValidateFunc);
357            }
358        };
359        args.iter().try_for_each(|ta| ta.validate(var_decls))?;
360        check_term_types(args, pf.params())?;
361        Ok(())
362    }
363
364    /// Computes the signature of a node, i.e. an instantiation of this
365    /// `OpDef` with statically-provided [`TypeArg`]s.
366    pub fn compute_signature(&self, args: &[TypeArg]) -> Result<Signature, SignatureError> {
367        self.signature_func.compute_signature(self, args)
368    }
369
370    /// Fallibly returns a Hugr that may replace an instance of this `OpDef`
371    /// given a set of available extensions that may be used in the Hugr.
372    #[must_use]
373    pub fn try_lower(&self, args: &[TypeArg], available_extensions: &ExtensionSet) -> Option<Hugr> {
374        // TODO test this
375        self.lower_funcs
376            .iter()
377            .filter_map(|f| match f {
378                LowerFunc::FixedHugr { extensions, hugr } => {
379                    if available_extensions.is_superset(extensions) {
380                        Some(hugr.as_ref().clone())
381                    } else {
382                        None
383                    }
384                }
385                LowerFunc::CustomFunc(f) => {
386                    f.try_lower(&self.name, args, &self.misc, available_extensions)
387                }
388            })
389            .next()
390    }
391
392    /// Returns a reference to the name of this [`OpDef`].
393    #[must_use]
394    pub fn name(&self) -> &OpName {
395        &self.name
396    }
397
398    /// Returns a reference to the extension id of this [`OpDef`].
399    #[must_use]
400    pub fn extension_id(&self) -> &ExtensionId {
401        &self.extension
402    }
403
404    /// Returns a weak reference to the extension defining this operation.
405    #[must_use]
406    pub fn extension(&self) -> Weak<Extension> {
407        self.extension_ref.clone()
408    }
409
410    /// Returns a mutable reference to the weak extension pointer in the operation definition.
411    pub(super) fn extension_mut(&mut self) -> &mut Weak<Extension> {
412        &mut self.extension_ref
413    }
414
415    /// Returns a reference to the description of this [`OpDef`].
416    #[must_use]
417    pub fn description(&self) -> &str {
418        self.description.as_ref()
419    }
420
421    /// Returns a reference to the params of this [`OpDef`].
422    pub fn params(&self) -> Result<&[TypeParam], SignatureError> {
423        self.signature_func.static_params()
424    }
425
426    pub(super) fn validate(&self) -> Result<(), SignatureError> {
427        // TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
428        // for both type scheme and custom binary
429        if let SignatureFunc::CustomValidator(ts) = &self.signature_func {
430            // The type scheme may contain row variables so be of variable length;
431            // these will have to be substituted to fixed-length concrete types when
432            // the OpDef is instantiated into an actual OpType.
433            ts.poly_func.validate()?;
434        }
435        Ok(())
436    }
437
438    /// Add a lowering function to the [`OpDef`]
439    pub fn add_lower_func(&mut self, lower: LowerFunc) {
440        self.lower_funcs.push(lower);
441    }
442
443    /// Insert miscellaneous data `v` to the [`OpDef`], keyed by `k`.
444    pub fn add_misc(
445        &mut self,
446        k: impl ToString,
447        v: serde_json::Value,
448    ) -> Option<serde_json::Value> {
449        self.misc.insert(k.to_string(), v)
450    }
451
452    /// Iterate over all miscellaneous data in the [`OpDef`].
453    #[allow(unused)] // Unused when no features are enabled
454    pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator<Item = (&str, &serde_json::Value)> {
455        self.misc.iter().map(|(k, v)| (k.as_str(), v))
456    }
457
458    /// Set the constant folding function for this Op, which can evaluate it
459    /// given constant inputs.
460    pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
461        self.constant_folder = Some(Box::new(fold));
462    }
463
464    /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
465    /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s.
466    #[must_use]
467    pub fn constant_fold(
468        &self,
469        type_args: &[TypeArg],
470        consts: &[(crate::IncomingPort, crate::ops::Value)],
471    ) -> ConstFoldResult {
472        (self.constant_folder.as_ref())?.fold(type_args, consts)
473    }
474
475    /// Returns a reference to the signature function of this [`OpDef`].
476    #[must_use]
477    pub fn signature_func(&self) -> &SignatureFunc {
478        &self.signature_func
479    }
480
481    /// Returns a mutable reference to the signature function of this [`OpDef`].
482    pub(super) fn signature_func_mut(&mut self) -> &mut SignatureFunc {
483        &mut self.signature_func
484    }
485}
486
487impl Extension {
488    /// Add an operation definition to the extension. Must be a type scheme
489    /// (defined by a [`PolyFuncTypeRV`]), a type scheme along with binary
490    /// validation for type arguments ([`CustomValidator`]), or a custom binary
491    /// function for computing the signature given type arguments (implementing
492    /// `[CustomSignatureFunc]`).
493    ///
494    /// This method requires a [`Weak`] reference to the [`Arc`] containing the
495    /// extension being defined. The intended way to call this method is inside
496    /// the closure passed to [`Extension::new_arc`] when defining the extension.
497    ///
498    /// # Example
499    ///
500    /// ```
501    /// # use hugr_core::types::Signature;
502    /// # use hugr_core::extension::{Extension, ExtensionId, Version};
503    /// Extension::new_arc(
504    ///     ExtensionId::new_unchecked("my.extension"),
505    ///     Version::new(0, 1, 0),
506    ///     |ext, extension_ref| {
507    ///         ext.add_op(
508    ///             "MyOp".into(),
509    ///             "Some operation".into(),
510    ///             Signature::new_endo(vec![]),
511    ///             extension_ref,
512    ///         );
513    ///     },
514    /// );
515    /// ```
516    pub fn add_op(
517        &mut self,
518        name: OpName,
519        description: String,
520        signature_func: impl Into<SignatureFunc>,
521        extension_ref: &Weak<Extension>,
522    ) -> Result<&mut OpDef, ExtensionBuildError> {
523        let op = OpDef {
524            extension: self.name.clone(),
525            extension_ref: extension_ref.clone(),
526            name,
527            description,
528            signature_func: signature_func.into(),
529            misc: Default::default(),
530            lower_funcs: Default::default(),
531            constant_folder: Default::default(),
532        };
533
534        match self.operations.entry(op.name.clone()) {
535            Entry::Occupied(_) => Err(ExtensionBuildError::OpDefExists(op.name)),
536            // Just made the arc so should only be one reference to it, can get_mut,
537            Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()),
538        }
539    }
540}
541
542#[cfg(test)]
543pub(super) mod test {
544    use std::num::NonZeroU64;
545
546    use itertools::Itertools;
547
548    use super::SignatureFromArgs;
549    use crate::builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig};
550    use crate::extension::SignatureError;
551    use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc};
552    use crate::extension::prelude::usize_t;
553    use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
554    use crate::ops::OpName;
555    use crate::std_extensions::collections::list;
556    use crate::types::type_param::{TermTypeError, TypeParam};
557    use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
558    use crate::{Extension, const_extension_ids};
559
560    const_extension_ids! {
561        const EXT_ID: ExtensionId = "MyExt";
562    }
563
564    /// A dummy wrapper over an operation definition.
565    #[derive(serde::Serialize, serde::Deserialize, Debug)]
566    pub struct SimpleOpDef(OpDef);
567
568    impl SimpleOpDef {
569        /// Create a new dummy opdef.
570        #[must_use]
571        pub fn new(op_def: OpDef) -> Self {
572            assert!(op_def.constant_folder.is_none());
573            assert!(matches!(
574                op_def.signature_func,
575                SignatureFunc::PolyFuncType(_)
576            ));
577            assert!(
578                op_def
579                    .lower_funcs
580                    .iter()
581                    .all(|lf| matches!(lf, LowerFunc::FixedHugr { .. }))
582            );
583            Self(op_def)
584        }
585    }
586
587    impl From<SimpleOpDef> for OpDef {
588        fn from(value: SimpleOpDef) -> Self {
589            value.0
590        }
591    }
592
593    impl PartialEq for SimpleOpDef {
594        fn eq(&self, other: &Self) -> bool {
595            let OpDef {
596                extension,
597                extension_ref: _,
598                name,
599                description,
600                misc,
601                signature_func,
602                lower_funcs,
603                constant_folder: _,
604            } = &self.0;
605            let OpDef {
606                extension: other_extension,
607                extension_ref: _,
608                name: other_name,
609                description: other_description,
610                misc: other_misc,
611                signature_func: other_signature_func,
612                lower_funcs: other_lower_funcs,
613                constant_folder: _,
614            } = &other.0;
615
616            let get_sig = |sf: &_| match sf {
617                // if SignatureFunc or CustomValidator are changed we should get
618                // a compile error here. To fix: modify the fields matched on here,
619                // maintaining the lack of `..` and, for each part that is
620                // serializable, ensure we are checking it for equality below.
621                SignatureFunc::CustomValidator(CustomValidator {
622                    poly_func,
623                    validate: _,
624                })
625                | SignatureFunc::PolyFuncType(poly_func)
626                | SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()),
627                SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
628            };
629
630            let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
631                lfs.iter()
632                    .map(|lf| match lf {
633                        // as with get_sig above, this should break if the hierarchy
634                        // is changed, update similarly.
635                        LowerFunc::FixedHugr { extensions, hugr } => {
636                            Some((extensions.clone(), hugr.clone()))
637                        }
638                        // This is ruled out by `new()` but leave it here for later.
639                        LowerFunc::CustomFunc(_) => None,
640                    })
641                    .collect_vec()
642            };
643
644            extension == other_extension
645                && name == other_name
646                && description == other_description
647                && misc == other_misc
648                && get_sig(signature_func) == get_sig(other_signature_func)
649                && get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs)
650        }
651    }
652
653    #[test]
654    fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
655        let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap();
656        const OP_NAME: OpName = OpName::new_inline("Reverse");
657
658        let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
659            const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
660            let list_of_var =
661                Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
662            let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var]));
663
664            let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?;
665            def.add_lower_func(LowerFunc::FixedHugr {
666                extensions: ExtensionSet::new(),
667                hugr: Box::new(crate::builder::test::simple_dfg_hugr()), // this is nonsense, but we are not testing the actual lowering here
668            });
669            def.add_misc("key", Default::default());
670            assert_eq!(def.description(), "desc");
671            assert_eq!(def.lower_funcs.len(), 1);
672            assert_eq!(def.misc.len(), 1);
673
674            Ok(())
675        })?;
676
677        let reg = ExtensionRegistry::new([PRELUDE.clone(), list::EXTENSION.clone(), ext]);
678        reg.validate()?;
679        let e = reg.get(&EXT_ID).unwrap();
680
681        let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?);
682        let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
683        let rev = dfg.add_dataflow_op(
684            e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()])
685                .unwrap(),
686            dfg.input_wires(),
687        )?;
688        dfg.finish_hugr_with_outputs(rev.outputs())?;
689
690        Ok(())
691    }
692
693    #[test]
694    fn binary_polyfunc() -> Result<(), Box<dyn std::error::Error>> {
695        // Test a custom binary `compute_signature` that returns a PolyFuncTypeRV
696        // where the latter declares more type params itself. In particular,
697        // we should be able to substitute (external) type variables into the latter,
698        // but not pass them into the former (custom binary function).
699        struct SigFun();
700        impl SignatureFromArgs for SigFun {
701            fn compute_signature(
702                &self,
703                arg_values: &[TypeArg],
704            ) -> Result<PolyFuncTypeRV, SignatureError> {
705                const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
706                let [TypeArg::BoundedNat(n)] = arg_values else {
707                    return Err(SignatureError::InvalidTypeArgs);
708                };
709                let n = *n as usize;
710                let tvs: Vec<Type> = (0..n)
711                    .map(|_| Type::new_var_use(0, TypeBound::Linear))
712                    .collect();
713                Ok(PolyFuncTypeRV::new(
714                    vec![TP.clone()],
715                    Signature::new(tvs.clone(), vec![Type::new_tuple(tvs)]),
716                ))
717            }
718
719            fn static_params(&self) -> &[TypeParam] {
720                const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()];
721                MAX_NAT
722            }
723        }
724        let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
725            let def: &mut crate::extension::OpDef =
726                ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?;
727
728            // Base case, no type variables:
729            let args = [TypeArg::BoundedNat(3), usize_t().into()];
730            assert_eq!(
731                def.compute_signature(&args),
732                Ok(Signature::new(
733                    vec![usize_t(); 3],
734                    vec![Type::new_tuple(vec![usize_t(); 3])]
735                ))
736            );
737            assert_eq!(def.validate_args(&args, &[]), Ok(()));
738
739            // Second arg may be a variable (substitutable)
740            let tyvar = Type::new_var_use(0, TypeBound::Copyable);
741            let tyvars: Vec<Type> = vec![tyvar.clone(); 3];
742            let args = [TypeArg::BoundedNat(3), tyvar.clone().into()];
743            assert_eq!(
744                def.compute_signature(&args),
745                Ok(Signature::new(
746                    tyvars.clone(),
747                    vec![Type::new_tuple(tyvars)]
748                ))
749            );
750            def.validate_args(&args, &[TypeBound::Copyable.into()])
751                .unwrap();
752
753            // quick sanity check that we are validating the args - note changed bound:
754            assert_eq!(
755                def.validate_args(&args, &[TypeBound::Linear.into()]),
756                Err(SignatureError::TypeVarDoesNotMatchDeclaration {
757                    actual: Box::new(TypeBound::Linear.into()),
758                    cached: Box::new(TypeBound::Copyable.into())
759                })
760            );
761
762            // First arg must be concrete, not a variable
763            let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap());
764            let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()];
765            // We can't prevent this from getting into our compute_signature implementation:
766            assert_eq!(
767                def.compute_signature(&args),
768                Err(SignatureError::InvalidTypeArgs)
769            );
770            // But validation rules it out, even when the variable is declared:
771            assert_eq!(
772                def.validate_args(&args, &[kind]),
773                Err(SignatureError::FreeTypeVar {
774                    idx: 0,
775                    num_decls: 0
776                })
777            );
778
779            Ok(())
780        })?;
781
782        Ok(())
783    }
784
785    #[test]
786    fn type_scheme_instantiate_var() -> Result<(), Box<dyn std::error::Error>> {
787        // Check that we can instantiate a PolyFuncTypeRV-scheme with an (external)
788        // type variable
789        let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
790            let def = ext.add_op(
791                "SimpleOp".into(),
792                String::new(),
793                PolyFuncTypeRV::new(
794                    vec![TypeBound::Linear.into()],
795                    Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]),
796                ),
797                extension_ref,
798            )?;
799            let tv = Type::new_var_use(0, TypeBound::Copyable);
800            let args = [tv.clone().into()];
801            let decls = [TypeBound::Copyable.into()];
802            def.validate_args(&args, &decls).unwrap();
803            assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv)));
804            // But not with an external row variable
805            let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into();
806            assert_eq!(
807                def.compute_signature(&[arg.clone()]),
808                Err(SignatureError::TypeArgMismatch(
809                    TermTypeError::TypeMismatch {
810                        type_: Box::new(TypeBound::Linear.into()),
811                        term: Box::new(arg),
812                    }
813                ))
814            );
815            Ok(())
816        })?;
817        Ok(())
818    }
819
820    mod proptest {
821        use std::sync::Weak;
822
823        use super::SimpleOpDef;
824        use ::proptest::prelude::*;
825
826        use crate::{
827            builder::test::simple_dfg_hugr,
828            extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc, op_def::LowerFunc},
829            types::PolyFuncTypeRV,
830        };
831
832        impl Arbitrary for SignatureFunc {
833            type Parameters = ();
834            type Strategy = BoxedStrategy<Self>;
835            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
836                // TODO there is also  SignatureFunc::CustomFunc, but for now
837                // this is not serialized. When it is, we should generate
838                // examples here .
839                any::<PolyFuncTypeRV>()
840                    .prop_map(SignatureFunc::PolyFuncType)
841                    .boxed()
842            }
843        }
844
845        impl Arbitrary for LowerFunc {
846            type Parameters = ();
847            type Strategy = BoxedStrategy<Self>;
848            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
849                // TODO There is also LowerFunc::CustomFunc, but for now this is
850                // not serialized. When it is, we should generate examples here.
851                any::<ExtensionSet>()
852                    .prop_map(|extensions| LowerFunc::FixedHugr {
853                        extensions,
854                        hugr: Box::new(simple_dfg_hugr()),
855                    })
856                    .boxed()
857            }
858        }
859
860        impl Arbitrary for SimpleOpDef {
861            type Parameters = ();
862            type Strategy = BoxedStrategy<Self>;
863            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
864                use crate::proptest::{any_serde_json_value, any_smolstr, any_string};
865                use proptest::collection::{hash_map, vec};
866                let misc = hash_map(any_string(), any_serde_json_value(), 0..3);
867                (
868                    any::<ExtensionId>(),
869                    any_smolstr(),
870                    any_string(),
871                    misc,
872                    any::<SignatureFunc>(),
873                    vec(any::<LowerFunc>(), 0..2),
874                )
875                    .prop_map(
876                        |(extension, name, description, misc, signature_func, lower_funcs)| {
877                            Self::new(OpDef {
878                                extension,
879                                // Use a dead weak reference. Trying to access the extension will always return None.
880                                extension_ref: Weak::default(),
881                                name,
882                                description,
883                                misc,
884                                signature_func,
885                                lower_funcs,
886                                // TODO ``constant_folder` is not serialized, we should
887                                // generate examples once it is.
888                                constant_folder: None,
889                            })
890                        },
891                    )
892                    .boxed()
893            }
894        }
895    }
896}