hugr_core/extension/
op_def.rs

1use std::cmp::min;
2use std::collections::HashMap;
3use std::collections::btree_map::Entry;
4use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, Weak};
6
7use serde_with::serde_as;
8
9use super::{
10    ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
11    SignatureError,
12};
13
14use crate::Hugr;
15use crate::envelope::serde_with::AsBinaryEnvelope;
16use crate::ops::{OpName, OpNameRef};
17use crate::types::type_param::{TypeArg, TypeParam, check_term_types};
18use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
19mod serialize_signature_func;
20
21/// Trait necessary for binary computations of `OpDef` signature
22pub trait CustomSignatureFunc: Send + Sync {
23    /// Compute signature of node given
24    /// values for the type parameters,
25    /// the operation definition and the extension registry.
26    fn compute_signature<'o, 'a: 'o>(
27        &'a self,
28        arg_values: &[TypeArg],
29        def: &'o OpDef,
30    ) -> Result<PolyFuncTypeRV, SignatureError>;
31    /// The declared type parameters which require values in order for signature to
32    /// be computed.
33    fn static_params(&self) -> &[TypeParam];
34}
35
36/// Compute signature of `OpDef` given type arguments.
37pub trait SignatureFromArgs: Send + Sync {
38    /// Compute signature of node given
39    /// values for the type parameters.
40    fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError>;
41    /// The declared type parameters which require values in order for signature to
42    /// be computed.
43    fn static_params(&self) -> &[TypeParam];
44}
45
46impl<T: SignatureFromArgs> CustomSignatureFunc for T {
47    #[inline]
48    fn compute_signature<'o, 'a: 'o>(
49        &'a self,
50        arg_values: &[TypeArg],
51        _def: &'o OpDef,
52    ) -> Result<PolyFuncTypeRV, SignatureError> {
53        SignatureFromArgs::compute_signature(self, arg_values)
54    }
55
56    #[inline]
57    fn static_params(&self) -> &[TypeParam] {
58        SignatureFromArgs::static_params(self)
59    }
60}
61
62/// Trait for validating type arguments to a `PolyFuncTypeRV` beyond conformation to
63/// declared type parameter (which should have been checked beforehand).
64pub trait ValidateTypeArgs: Send + Sync {
65    /// Validate the type arguments of node given
66    /// values for the type parameters,
67    /// the operation definition and the extension registry.
68    fn validate<'o, 'a: 'o>(
69        &self,
70        arg_values: &[TypeArg],
71        def: &'o OpDef,
72    ) -> Result<(), SignatureError>;
73}
74
75/// Trait for validating type arguments to a `PolyFuncTypeRV` beyond conformation to
76/// declared type parameter (which should have been checked beforehand), given just the arguments.
77pub trait ValidateJustArgs: Send + Sync {
78    /// Validate the type arguments of node given
79    /// values for the type parameters.
80    fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError>;
81}
82
83impl<T: ValidateJustArgs> ValidateTypeArgs for T {
84    #[inline]
85    fn validate<'o, 'a: 'o>(
86        &self,
87        arg_values: &[TypeArg],
88        _def: &'o OpDef,
89    ) -> Result<(), SignatureError> {
90        ValidateJustArgs::validate(self, arg_values)
91    }
92}
93
94/// Trait for Extensions to provide custom binary code that can lower an operation to
95/// a Hugr using only a limited set of other extensions. That is, trait
96/// implementations can return a Hugr that implements the operation using only
97/// those extensions and that can be used to replace the operation node. This may be
98/// useful for third-party Extensions or as a fallback for tools that do not support
99/// the operation natively.
100///
101/// This trait allows the Hugr to be varied according to the operation's [`TypeArg`]s;
102/// if this is not necessary then a single Hugr can be provided instead via
103/// [`LowerFunc::FixedHugr`].
104pub trait CustomLowerFunc: Send + Sync {
105    /// Return a Hugr that implements the node using only the specified available extensions;
106    /// may fail.
107    /// TODO: some error type to indicate Extensions required?
108    fn try_lower(
109        &self,
110        name: &OpNameRef,
111        arg_values: &[TypeArg],
112        misc: &HashMap<String, serde_json::Value>,
113        available_extensions: &ExtensionSet,
114    ) -> Option<Hugr>;
115}
116
117/// Encode a signature as [`PolyFuncTypeRV`] but with additional validation of type
118/// arguments via a custom binary. The binary cannot be serialized so will be
119/// lost over a serialization round-trip.
120pub struct CustomValidator {
121    poly_func: PolyFuncTypeRV,
122    /// Custom function for validating type arguments before returning the signature.
123    pub(crate) validate: Box<dyn ValidateTypeArgs>,
124}
125
126impl CustomValidator {
127    /// Encode a signature using a `PolyFuncTypeRV`, with a custom function for
128    /// validating type arguments before returning the signature.
129    pub fn new(
130        poly_func: impl Into<PolyFuncTypeRV>,
131        validate: impl ValidateTypeArgs + 'static,
132    ) -> Self {
133        Self {
134            poly_func: poly_func.into(),
135            validate: Box::new(validate),
136        }
137    }
138
139    /// Return a mutable reference to the `PolyFuncType`.
140    pub(super) fn poly_func_mut(&mut self) -> &mut PolyFuncTypeRV {
141        &mut self.poly_func
142    }
143}
144
145/// The ways in which an `OpDef` may compute the Signature of each operation node.
146pub enum SignatureFunc {
147    /// An explicit polymorphic function type.
148    PolyFuncType(PolyFuncTypeRV),
149    /// A polymorphic function type (like [`Self::PolyFuncType`] but also with a custom binary for validating type arguments.
150    CustomValidator(CustomValidator),
151    /// Serialized declaration specified a custom validate binary but it was not provided.
152    MissingValidateFunc(PolyFuncTypeRV),
153    /// A custom binary which computes a polymorphic function type given values
154    /// for its static type parameters.
155    CustomFunc(Box<dyn CustomSignatureFunc>),
156    /// Serialized declaration specified a custom compute binary but it was not provided.
157    MissingComputeFunc,
158}
159
160impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {
161    fn from(v: T) -> Self {
162        Self::CustomFunc(Box::new(v))
163    }
164}
165
166impl From<PolyFuncType> for SignatureFunc {
167    fn from(value: PolyFuncType) -> Self {
168        Self::PolyFuncType(value.into())
169    }
170}
171
172impl From<PolyFuncTypeRV> for SignatureFunc {
173    fn from(v: PolyFuncTypeRV) -> Self {
174        Self::PolyFuncType(v)
175    }
176}
177
178impl From<FuncValueType> for SignatureFunc {
179    fn from(v: FuncValueType) -> Self {
180        Self::PolyFuncType(v.into())
181    }
182}
183
184impl From<Signature> for SignatureFunc {
185    fn from(v: Signature) -> Self {
186        Self::PolyFuncType(FuncValueType::from(v).into())
187    }
188}
189
190impl From<CustomValidator> for SignatureFunc {
191    fn from(v: CustomValidator) -> Self {
192        Self::CustomValidator(v)
193    }
194}
195
196impl SignatureFunc {
197    fn static_params(&self) -> Result<&[TypeParam], SignatureError> {
198        Ok(match self {
199            SignatureFunc::PolyFuncType(ts)
200            | SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. })
201            | SignatureFunc::MissingValidateFunc(ts) => ts.params(),
202            SignatureFunc::CustomFunc(func) => func.static_params(),
203            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
204        })
205    }
206
207    /// If the signature is missing a custom validation function, ignore and treat as
208    /// self-contained type scheme (with no custom validation).
209    pub fn ignore_missing_validation(&mut self) {
210        if let SignatureFunc::MissingValidateFunc(ts) = self {
211            *self = SignatureFunc::PolyFuncType(ts.clone());
212        }
213    }
214
215    /// Compute the concrete signature ([`FuncValueType`]).
216    ///
217    /// # Panics
218    ///
219    /// Panics if `self` is a [`SignatureFunc::CustomFunc`] and there are not enough type
220    /// arguments provided to match the number of static parameters.
221    ///
222    /// # Errors
223    ///
224    /// This function will return an error if the type arguments are invalid or
225    /// there is some error in type computation.
226    pub fn compute_signature(
227        &self,
228        def: &OpDef,
229        args: &[TypeArg],
230    ) -> Result<Signature, SignatureError> {
231        let temp: PolyFuncTypeRV; // to keep alive
232        let (pf, args) = match &self {
233            SignatureFunc::CustomValidator(custom) => {
234                custom.validate.validate(args, def)?;
235                (&custom.poly_func, args)
236            }
237            SignatureFunc::PolyFuncType(ts) => (ts, args),
238            SignatureFunc::CustomFunc(func) => {
239                let static_params = func.static_params();
240                let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
241
242                check_term_types(static_args, static_params)?;
243                temp = func.compute_signature(static_args, def)?;
244                (&temp, other_args)
245            }
246            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
247            // TODO raise warning: https://github.com/CQCL/hugr/issues/1432
248            SignatureFunc::MissingValidateFunc(ts) => (ts, args),
249        };
250        let res = pf.instantiate(args)?;
251
252        // If there are any row variables left, this will fail with an error:
253        res.try_into()
254    }
255}
256
257impl Debug for SignatureFunc {
258    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
259        match self {
260            Self::CustomValidator(ts) => ts.poly_func.fmt(f),
261            Self::PolyFuncType(ts) => ts.fmt(f),
262            Self::CustomFunc { .. } => f.write_str("<custom sig>"),
263            Self::MissingComputeFunc => f.write_str("<missing custom sig>"),
264            Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"),
265        }
266    }
267}
268
269/// Different ways that an [OpDef] can lower operation nodes i.e. provide a Hugr
270/// that implements the operation using a set of other extensions.
271///
272/// Does not implement [`serde::Deserialize`] directly since the serde error for
273/// untagged enums is unhelpful. Use [`deserialize_lower_funcs`] with
274/// [`serde(deserialize_with = "deserialize_lower_funcs")] instead.
275#[serde_as]
276#[derive(serde::Serialize)]
277#[serde(untagged)]
278pub enum LowerFunc {
279    /// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s,
280    /// this will generally only be applicable if the [OpDef] has no [TypeParam]s.
281    FixedHugr {
282        /// The extensions required by the [`Hugr`]
283        extensions: ExtensionSet,
284        /// The [`Hugr`] to be used to replace [ExtensionOp]s matching the parent
285        /// [OpDef]
286        ///
287        /// [ExtensionOp]: crate::ops::ExtensionOp
288        #[serde_as(as = "Box<AsBinaryEnvelope>")]
289        hugr: Box<Hugr>,
290    },
291    /// Custom binary function that can (fallibly) compute a Hugr
292    /// for the particular instance and set of available extensions.
293    #[serde(skip)]
294    CustomFunc(Box<dyn CustomLowerFunc>),
295}
296
297/// A function for deserializing sequences of [`LowerFunc::FixedHugr`].
298///
299/// We could let serde deserialize [`LowerFunc`] as-is, but if the LowerFunc
300/// deserialization fails it just returns an opaque "data did not match any
301/// variant of untagged enum LowerFunc" error. This function will return the
302/// internal errors instead.
303pub fn deserialize_lower_funcs<'de, D>(deserializer: D) -> Result<Vec<LowerFunc>, D::Error>
304where
305    D: serde::Deserializer<'de>,
306{
307    #[serde_as]
308    #[derive(serde::Deserialize)]
309    struct FixedHugrDeserializer {
310        pub extensions: ExtensionSet,
311        #[serde_as(as = "Box<AsBinaryEnvelope>")]
312        pub hugr: Box<Hugr>,
313    }
314
315    let funcs: Vec<FixedHugrDeserializer> = serde::Deserialize::deserialize(deserializer)?;
316    Ok(funcs
317        .into_iter()
318        .map(|f| LowerFunc::FixedHugr {
319            extensions: f.extensions,
320            hugr: f.hugr,
321        })
322        .collect())
323}
324
325impl Debug for LowerFunc {
326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        match self {
328            Self::FixedHugr { .. } => write!(f, "FixedHugr"),
329            Self::CustomFunc(_) => write!(f, "<custom lower>"),
330        }
331    }
332}
333
334/// Serializable definition for dynamically loaded operations.
335///
336/// TODO: Define a way to construct new `OpDef`'s from a serialized definition.
337#[derive(Debug, serde::Serialize, serde::Deserialize)]
338pub struct OpDef {
339    /// The unique Extension owning this `OpDef` (of which this `OpDef` is a member)
340    extension: ExtensionId,
341    /// A weak reference to the extension defining this operation.
342    #[serde(skip)]
343    extension_ref: Weak<Extension>,
344    /// Unique identifier of the operation. Used to look up `OpDefs` in the registry
345    /// when deserializing nodes (which store only the name).
346    name: OpName,
347    /// Human readable description of the operation.
348    description: String,
349    /// Miscellaneous data associated with the operation.
350    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
351    misc: HashMap<String, serde_json::Value>,
352
353    #[serde(with = "serialize_signature_func", flatten)]
354    signature_func: SignatureFunc,
355    // Some operations cannot lower themselves and tools that do not understand them
356    // can only treat them as opaque/black-box ops.
357    #[serde(
358        default,
359        skip_serializing_if = "Vec::is_empty",
360        deserialize_with = "deserialize_lower_funcs"
361    )]
362    pub(crate) lower_funcs: Vec<LowerFunc>,
363
364    /// Operations can optionally implement [`ConstFold`] to implement constant folding.
365    #[serde(skip)]
366    constant_folder: Option<Box<dyn ConstFold>>,
367}
368
369impl OpDef {
370    /// Check provided type arguments are valid against their extensions,
371    /// against parameters, and that no type variables are used as static arguments
372    /// (to [`compute_signature`][CustomSignatureFunc::compute_signature])
373    pub fn validate_args(
374        &self,
375        args: &[TypeArg],
376        var_decls: &[TypeParam],
377    ) -> Result<(), SignatureError> {
378        let temp: PolyFuncTypeRV; // to keep alive
379        let (pf, args) = match &self.signature_func {
380            SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args),
381            SignatureFunc::PolyFuncType(ts) => (ts, args),
382            SignatureFunc::CustomFunc(custom) => {
383                let (static_args, other_args) =
384                    args.split_at(min(custom.static_params().len(), args.len()));
385                static_args.iter().try_for_each(|ta| ta.validate(&[]))?;
386                check_term_types(static_args, custom.static_params())?;
387                temp = custom.compute_signature(static_args, self)?;
388                (&temp, other_args)
389            }
390            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
391            SignatureFunc::MissingValidateFunc(_) => {
392                return Err(SignatureError::MissingValidateFunc);
393            }
394        };
395        args.iter().try_for_each(|ta| ta.validate(var_decls))?;
396        check_term_types(args, pf.params())?;
397        Ok(())
398    }
399
400    /// Computes the signature of a node, i.e. an instantiation of this
401    /// `OpDef` with statically-provided [`TypeArg`]s.
402    pub fn compute_signature(&self, args: &[TypeArg]) -> Result<Signature, SignatureError> {
403        self.signature_func.compute_signature(self, args)
404    }
405
406    /// Fallibly returns a Hugr that may replace an instance of this `OpDef`
407    /// given a set of available extensions that may be used in the Hugr.
408    #[must_use]
409    pub fn try_lower(&self, args: &[TypeArg], available_extensions: &ExtensionSet) -> Option<Hugr> {
410        // TODO test this
411        self.lower_funcs
412            .iter()
413            .filter_map(|f| match f {
414                LowerFunc::FixedHugr { extensions, hugr } => {
415                    if available_extensions.is_superset(extensions) {
416                        Some(hugr.as_ref().clone())
417                    } else {
418                        None
419                    }
420                }
421                LowerFunc::CustomFunc(f) => {
422                    f.try_lower(&self.name, args, &self.misc, available_extensions)
423                }
424            })
425            .next()
426    }
427
428    /// Returns a reference to the name of this [`OpDef`].
429    #[must_use]
430    pub fn name(&self) -> &OpName {
431        &self.name
432    }
433
434    /// Returns a reference to the extension id of this [`OpDef`].
435    #[must_use]
436    pub fn extension_id(&self) -> &ExtensionId {
437        &self.extension
438    }
439
440    /// Returns a weak reference to the extension defining this operation.
441    #[must_use]
442    pub fn extension(&self) -> Weak<Extension> {
443        self.extension_ref.clone()
444    }
445
446    /// Returns a mutable reference to the weak extension pointer in the operation definition.
447    pub(super) fn extension_mut(&mut self) -> &mut Weak<Extension> {
448        &mut self.extension_ref
449    }
450
451    /// Returns a reference to the description of this [`OpDef`].
452    #[must_use]
453    pub fn description(&self) -> &str {
454        self.description.as_ref()
455    }
456
457    /// Returns a reference to the params of this [`OpDef`].
458    pub fn params(&self) -> Result<&[TypeParam], SignatureError> {
459        self.signature_func.static_params()
460    }
461
462    pub(super) fn validate(&self) -> Result<(), SignatureError> {
463        // TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
464        // for both type scheme and custom binary
465        if let SignatureFunc::CustomValidator(ts) = &self.signature_func {
466            // The type scheme may contain row variables so be of variable length;
467            // these will have to be substituted to fixed-length concrete types when
468            // the OpDef is instantiated into an actual OpType.
469            ts.poly_func.validate()?;
470        }
471        Ok(())
472    }
473
474    /// Add a lowering function to the [`OpDef`]
475    pub fn add_lower_func(&mut self, lower: LowerFunc) {
476        self.lower_funcs.push(lower);
477    }
478
479    /// Insert miscellaneous data `v` to the [`OpDef`], keyed by `k`.
480    pub fn add_misc(
481        &mut self,
482        k: impl ToString,
483        v: serde_json::Value,
484    ) -> Option<serde_json::Value> {
485        self.misc.insert(k.to_string(), v)
486    }
487
488    /// Iterate over all miscellaneous data in the [`OpDef`].
489    #[allow(unused)] // Unused when no features are enabled
490    pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator<Item = (&str, &serde_json::Value)> {
491        self.misc.iter().map(|(k, v)| (k.as_str(), v))
492    }
493
494    /// Set the constant folding function for this Op, which can evaluate it
495    /// given constant inputs.
496    pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
497        self.constant_folder = Some(Box::new(fold));
498    }
499
500    /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
501    /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s.
502    #[must_use]
503    pub fn constant_fold(
504        &self,
505        type_args: &[TypeArg],
506        consts: &[(crate::IncomingPort, crate::ops::Value)],
507    ) -> ConstFoldResult {
508        (self.constant_folder.as_ref())?.fold(type_args, consts)
509    }
510
511    /// Returns a reference to the signature function of this [`OpDef`].
512    #[must_use]
513    pub fn signature_func(&self) -> &SignatureFunc {
514        &self.signature_func
515    }
516
517    /// Returns a mutable reference to the signature function of this [`OpDef`].
518    pub(super) fn signature_func_mut(&mut self) -> &mut SignatureFunc {
519        &mut self.signature_func
520    }
521}
522
523impl Extension {
524    /// Add an operation definition to the extension. Must be a type scheme
525    /// (defined by a [`PolyFuncTypeRV`]), a type scheme along with binary
526    /// validation for type arguments ([`CustomValidator`]), or a custom binary
527    /// function for computing the signature given type arguments (implementing
528    /// `[CustomSignatureFunc]`).
529    ///
530    /// This method requires a [`Weak`] reference to the [`Arc`] containing the
531    /// extension being defined. The intended way to call this method is inside
532    /// the closure passed to [`Extension::new_arc`] when defining the extension.
533    ///
534    /// # Example
535    ///
536    /// ```
537    /// # use hugr_core::types::Signature;
538    /// # use hugr_core::extension::{Extension, ExtensionId, Version};
539    /// Extension::new_arc(
540    ///     ExtensionId::new_unchecked("my.extension"),
541    ///     Version::new(0, 1, 0),
542    ///     |ext, extension_ref| {
543    ///         ext.add_op(
544    ///             "MyOp".into(),
545    ///             "Some operation".into(),
546    ///             Signature::new_endo(vec![]),
547    ///             extension_ref,
548    ///         );
549    ///     },
550    /// );
551    /// ```
552    pub fn add_op(
553        &mut self,
554        name: OpName,
555        description: String,
556        signature_func: impl Into<SignatureFunc>,
557        extension_ref: &Weak<Extension>,
558    ) -> Result<&mut OpDef, ExtensionBuildError> {
559        let op = OpDef {
560            extension: self.name.clone(),
561            extension_ref: extension_ref.clone(),
562            name,
563            description,
564            signature_func: signature_func.into(),
565            misc: Default::default(),
566            lower_funcs: Default::default(),
567            constant_folder: Default::default(),
568        };
569
570        match self.operations.entry(op.name.clone()) {
571            Entry::Occupied(_) => Err(ExtensionBuildError::OpDefExists(op.name)),
572            // Just made the arc so should only be one reference to it, can get_mut,
573            Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()),
574        }
575    }
576}
577
578#[cfg(test)]
579pub(super) mod test {
580    use std::num::NonZeroU64;
581
582    use itertools::Itertools;
583
584    use super::SignatureFromArgs;
585    use crate::builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig};
586    use crate::extension::SignatureError;
587    use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc};
588    use crate::extension::prelude::usize_t;
589    use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
590    use crate::ops::OpName;
591    use crate::std_extensions::collections::list;
592    use crate::types::type_param::{TermTypeError, TypeParam};
593    use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
594    use crate::{Extension, const_extension_ids};
595
596    const_extension_ids! {
597        const EXT_ID: ExtensionId = "MyExt";
598    }
599
600    /// A dummy wrapper over an operation definition.
601    #[derive(serde::Serialize, serde::Deserialize, Debug)]
602    pub struct SimpleOpDef(OpDef);
603
604    impl SimpleOpDef {
605        /// Create a new dummy opdef.
606        #[must_use]
607        pub fn new(op_def: OpDef) -> Self {
608            assert!(op_def.constant_folder.is_none());
609            assert!(matches!(
610                op_def.signature_func,
611                SignatureFunc::PolyFuncType(_)
612            ));
613            assert!(
614                op_def
615                    .lower_funcs
616                    .iter()
617                    .all(|lf| matches!(lf, LowerFunc::FixedHugr { .. }))
618            );
619            Self(op_def)
620        }
621    }
622
623    impl From<SimpleOpDef> for OpDef {
624        fn from(value: SimpleOpDef) -> Self {
625            value.0
626        }
627    }
628
629    impl PartialEq for SimpleOpDef {
630        fn eq(&self, other: &Self) -> bool {
631            let OpDef {
632                extension,
633                extension_ref: _,
634                name,
635                description,
636                misc,
637                signature_func,
638                lower_funcs,
639                constant_folder: _,
640            } = &self.0;
641            let OpDef {
642                extension: other_extension,
643                extension_ref: _,
644                name: other_name,
645                description: other_description,
646                misc: other_misc,
647                signature_func: other_signature_func,
648                lower_funcs: other_lower_funcs,
649                constant_folder: _,
650            } = &other.0;
651
652            let get_sig = |sf: &_| match sf {
653                // if SignatureFunc or CustomValidator are changed we should get
654                // a compile error here. To fix: modify the fields matched on here,
655                // maintaining the lack of `..` and, for each part that is
656                // serializable, ensure we are checking it for equality below.
657                SignatureFunc::CustomValidator(CustomValidator {
658                    poly_func,
659                    validate: _,
660                })
661                | SignatureFunc::PolyFuncType(poly_func)
662                | SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()),
663                SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
664            };
665
666            let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
667                lfs.iter()
668                    .map(|lf| match lf {
669                        // as with get_sig above, this should break if the hierarchy
670                        // is changed, update similarly.
671                        LowerFunc::FixedHugr { extensions, hugr } => {
672                            Some((extensions.clone(), hugr.clone()))
673                        }
674                        // This is ruled out by `new()` but leave it here for later.
675                        LowerFunc::CustomFunc(_) => None,
676                    })
677                    .collect_vec()
678            };
679
680            extension == other_extension
681                && name == other_name
682                && description == other_description
683                && misc == other_misc
684                && get_sig(signature_func) == get_sig(other_signature_func)
685                && get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs)
686        }
687    }
688
689    #[test]
690    fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
691        let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap();
692        const OP_NAME: OpName = OpName::new_inline("Reverse");
693
694        let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
695            const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
696            let list_of_var =
697                Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
698            let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var]));
699
700            let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?;
701            def.add_lower_func(LowerFunc::FixedHugr {
702                extensions: ExtensionSet::new(),
703                hugr: Box::new(crate::builder::test::simple_dfg_hugr()), // this is nonsense, but we are not testing the actual lowering here
704            });
705            def.add_misc("key", Default::default());
706            assert_eq!(def.description(), "desc");
707            assert_eq!(def.lower_funcs.len(), 1);
708            assert_eq!(def.misc.len(), 1);
709
710            Ok(())
711        })?;
712
713        let reg = ExtensionRegistry::new([PRELUDE.clone(), list::EXTENSION.clone(), ext]);
714        reg.validate()?;
715        let e = reg.get(&EXT_ID).unwrap();
716
717        let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?);
718        let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
719        let rev = dfg.add_dataflow_op(
720            e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()])
721                .unwrap(),
722            dfg.input_wires(),
723        )?;
724        dfg.finish_hugr_with_outputs(rev.outputs())?;
725
726        Ok(())
727    }
728
729    #[test]
730    fn binary_polyfunc() -> Result<(), Box<dyn std::error::Error>> {
731        // Test a custom binary `compute_signature` that returns a PolyFuncTypeRV
732        // where the latter declares more type params itself. In particular,
733        // we should be able to substitute (external) type variables into the latter,
734        // but not pass them into the former (custom binary function).
735        struct SigFun();
736        impl SignatureFromArgs for SigFun {
737            fn compute_signature(
738                &self,
739                arg_values: &[TypeArg],
740            ) -> Result<PolyFuncTypeRV, SignatureError> {
741                const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
742                let [TypeArg::BoundedNat(n)] = arg_values else {
743                    return Err(SignatureError::InvalidTypeArgs);
744                };
745                let n = *n as usize;
746                let tvs: Vec<Type> = (0..n)
747                    .map(|_| Type::new_var_use(0, TypeBound::Linear))
748                    .collect();
749                Ok(PolyFuncTypeRV::new(
750                    vec![TP.clone()],
751                    Signature::new(tvs.clone(), vec![Type::new_tuple(tvs)]),
752                ))
753            }
754
755            fn static_params(&self) -> &[TypeParam] {
756                const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()];
757                MAX_NAT
758            }
759        }
760        let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
761            let def: &mut crate::extension::OpDef =
762                ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?;
763
764            // Base case, no type variables:
765            let args = [TypeArg::BoundedNat(3), usize_t().into()];
766            assert_eq!(
767                def.compute_signature(&args),
768                Ok(Signature::new(
769                    vec![usize_t(); 3],
770                    vec![Type::new_tuple(vec![usize_t(); 3])]
771                ))
772            );
773            assert_eq!(def.validate_args(&args, &[]), Ok(()));
774
775            // Second arg may be a variable (substitutable)
776            let tyvar = Type::new_var_use(0, TypeBound::Copyable);
777            let tyvars: Vec<Type> = vec![tyvar.clone(); 3];
778            let args = [TypeArg::BoundedNat(3), tyvar.clone().into()];
779            assert_eq!(
780                def.compute_signature(&args),
781                Ok(Signature::new(
782                    tyvars.clone(),
783                    vec![Type::new_tuple(tyvars)]
784                ))
785            );
786            def.validate_args(&args, &[TypeBound::Copyable.into()])
787                .unwrap();
788
789            // quick sanity check that we are validating the args - note changed bound:
790            assert_eq!(
791                def.validate_args(&args, &[TypeBound::Linear.into()]),
792                Err(SignatureError::TypeVarDoesNotMatchDeclaration {
793                    actual: Box::new(TypeBound::Linear.into()),
794                    cached: Box::new(TypeBound::Copyable.into())
795                })
796            );
797
798            // First arg must be concrete, not a variable
799            let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap());
800            let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()];
801            // We can't prevent this from getting into our compute_signature implementation:
802            assert_eq!(
803                def.compute_signature(&args),
804                Err(SignatureError::InvalidTypeArgs)
805            );
806            // But validation rules it out, even when the variable is declared:
807            assert_eq!(
808                def.validate_args(&args, &[kind]),
809                Err(SignatureError::FreeTypeVar {
810                    idx: 0,
811                    num_decls: 0
812                })
813            );
814
815            Ok(())
816        })?;
817
818        Ok(())
819    }
820
821    #[test]
822    fn type_scheme_instantiate_var() -> Result<(), Box<dyn std::error::Error>> {
823        // Check that we can instantiate a PolyFuncTypeRV-scheme with an (external)
824        // type variable
825        let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
826            let def = ext.add_op(
827                "SimpleOp".into(),
828                String::new(),
829                PolyFuncTypeRV::new(
830                    vec![TypeBound::Linear.into()],
831                    Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Linear)]),
832                ),
833                extension_ref,
834            )?;
835            let tv = Type::new_var_use(0, TypeBound::Copyable);
836            let args = [tv.clone().into()];
837            let decls = [TypeBound::Copyable.into()];
838            def.validate_args(&args, &decls).unwrap();
839            assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv)));
840            // But not with an external row variable
841            let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into();
842            assert_eq!(
843                def.compute_signature(std::slice::from_ref(&arg)),
844                Err(SignatureError::TypeArgMismatch(
845                    TermTypeError::TypeMismatch {
846                        type_: Box::new(TypeBound::Linear.into()),
847                        term: Box::new(arg),
848                    }
849                ))
850            );
851            Ok(())
852        })?;
853        Ok(())
854    }
855
856    mod proptest {
857        use std::sync::Weak;
858
859        use super::SimpleOpDef;
860        use ::proptest::prelude::*;
861
862        use crate::{
863            builder::test::simple_dfg_hugr,
864            extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc, op_def::LowerFunc},
865            types::PolyFuncTypeRV,
866        };
867
868        impl Arbitrary for SignatureFunc {
869            type Parameters = ();
870            type Strategy = BoxedStrategy<Self>;
871            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
872                // TODO there is also  SignatureFunc::CustomFunc, but for now
873                // this is not serialized. When it is, we should generate
874                // examples here .
875                any::<PolyFuncTypeRV>()
876                    .prop_map(SignatureFunc::PolyFuncType)
877                    .boxed()
878            }
879        }
880
881        impl Arbitrary for LowerFunc {
882            type Parameters = ();
883            type Strategy = BoxedStrategy<Self>;
884            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
885                // TODO There is also LowerFunc::CustomFunc, but for now this is
886                // not serialized. When it is, we should generate examples here.
887                any::<ExtensionSet>()
888                    .prop_map(|extensions| LowerFunc::FixedHugr {
889                        extensions,
890                        hugr: Box::new(simple_dfg_hugr()),
891                    })
892                    .boxed()
893            }
894        }
895
896        impl Arbitrary for SimpleOpDef {
897            type Parameters = ();
898            type Strategy = BoxedStrategy<Self>;
899            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
900                use crate::proptest::{any_serde_json_value, any_smolstr, any_string};
901                use proptest::collection::{hash_map, vec};
902                let misc = hash_map(any_string(), any_serde_json_value(), 0..3);
903                (
904                    any::<ExtensionId>(),
905                    any_smolstr(),
906                    any_string(),
907                    misc,
908                    any::<SignatureFunc>(),
909                    vec(any::<LowerFunc>(), 0..2),
910                )
911                    .prop_map(
912                        |(extension, name, description, misc, signature_func, lower_funcs)| {
913                            Self::new(OpDef {
914                                extension,
915                                // Use a dead weak reference. Trying to access the extension will always return None.
916                                extension_ref: Weak::default(),
917                                name,
918                                description,
919                                misc,
920                                signature_func,
921                                lower_funcs,
922                                // TODO ``constant_folder` is not serialized, we should
923                                // generate examples once it is.
924                                constant_folder: None,
925                            })
926                        },
927                    )
928                    .boxed()
929            }
930        }
931    }
932}