hugr_core/extension/
op_def.rs

1use std::cmp::min;
2use std::collections::btree_map::Entry;
3use std::collections::HashMap;
4use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, Weak};
6
7use super::{
8    ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
9    SignatureError,
10};
11
12use crate::ops::{OpName, OpNameRef};
13use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
14use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
15use crate::Hugr;
16mod serialize_signature_func;
17
18/// Trait necessary for binary computations of OpDef signature
19pub trait CustomSignatureFunc: Send + Sync {
20    /// Compute signature of node given
21    /// values for the type parameters,
22    /// the operation definition and the extension registry.
23    fn compute_signature<'o, 'a: 'o>(
24        &'a self,
25        arg_values: &[TypeArg],
26        def: &'o OpDef,
27    ) -> Result<PolyFuncTypeRV, SignatureError>;
28    /// The declared type parameters which require values in order for signature to
29    /// be computed.
30    fn static_params(&self) -> &[TypeParam];
31}
32
33/// Compute signature of `OpDef` given type arguments.
34pub trait SignatureFromArgs: Send + Sync {
35    /// Compute signature of node given
36    /// values for the type parameters.
37    fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError>;
38    /// The declared type parameters which require values in order for signature to
39    /// be computed.
40    fn static_params(&self) -> &[TypeParam];
41}
42
43impl<T: SignatureFromArgs> CustomSignatureFunc for T {
44    #[inline]
45    fn compute_signature<'o, 'a: 'o>(
46        &'a self,
47        arg_values: &[TypeArg],
48        _def: &'o OpDef,
49    ) -> Result<PolyFuncTypeRV, SignatureError> {
50        SignatureFromArgs::compute_signature(self, arg_values)
51    }
52
53    #[inline]
54    fn static_params(&self) -> &[TypeParam] {
55        SignatureFromArgs::static_params(self)
56    }
57}
58
59/// Trait for validating type arguments to a PolyFuncTypeRV beyond conformation to
60/// declared type parameter (which should have been checked beforehand).
61pub trait ValidateTypeArgs: Send + Sync {
62    /// Validate the type arguments of node given
63    /// values for the type parameters,
64    /// the operation definition and the extension registry.
65    fn validate<'o, 'a: 'o>(
66        &self,
67        arg_values: &[TypeArg],
68        def: &'o OpDef,
69    ) -> Result<(), SignatureError>;
70}
71
72/// Trait for validating type arguments to a PolyFuncTypeRV beyond conformation to
73/// declared type parameter (which should have been checked beforehand), given just the arguments.
74pub trait ValidateJustArgs: Send + Sync {
75    /// Validate the type arguments of node given
76    /// values for the type parameters.
77    fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError>;
78}
79
80impl<T: ValidateJustArgs> ValidateTypeArgs for T {
81    #[inline]
82    fn validate<'o, 'a: 'o>(
83        &self,
84        arg_values: &[TypeArg],
85        _def: &'o OpDef,
86    ) -> Result<(), SignatureError> {
87        ValidateJustArgs::validate(self, arg_values)
88    }
89}
90
91/// Trait for Extensions to provide custom binary code that can lower an operation to
92/// a Hugr using only a limited set of other extensions. That is, trait
93/// implementations can return a Hugr that implements the operation using only
94/// those extensions and that can be used to replace the operation node. This may be
95/// useful for third-party Extensions or as a fallback for tools that do not support
96/// the operation natively.
97///
98/// This trait allows the Hugr to be varied according to the operation's [TypeArg]s;
99/// if this is not necessary then a single Hugr can be provided instead via
100/// [LowerFunc::FixedHugr].
101pub trait CustomLowerFunc: Send + Sync {
102    /// Return a Hugr that implements the node using only the specified available extensions;
103    /// may fail.
104    /// TODO: some error type to indicate Extensions required?
105    fn try_lower(
106        &self,
107        name: &OpNameRef,
108        arg_values: &[TypeArg],
109        misc: &HashMap<String, serde_json::Value>,
110        available_extensions: &ExtensionSet,
111    ) -> Option<Hugr>;
112}
113
114/// Encode a signature as [PolyFuncTypeRV] but with additional validation of type
115/// arguments via a custom binary. The binary cannot be serialized so will be
116/// lost over a serialization round-trip.
117pub struct CustomValidator {
118    poly_func: PolyFuncTypeRV,
119    /// Custom function for validating type arguments before returning the signature.
120    pub(crate) validate: Box<dyn ValidateTypeArgs>,
121}
122
123impl CustomValidator {
124    /// Encode a signature using a `PolyFuncTypeRV`, with a custom function for
125    /// validating type arguments before returning the signature.
126    pub fn new(
127        poly_func: impl Into<PolyFuncTypeRV>,
128        validate: impl ValidateTypeArgs + 'static,
129    ) -> Self {
130        Self {
131            poly_func: poly_func.into(),
132            validate: Box::new(validate),
133        }
134    }
135
136    /// Return a mutable reference to the PolyFuncType.
137    pub(super) fn poly_func_mut(&mut self) -> &mut PolyFuncTypeRV {
138        &mut self.poly_func
139    }
140}
141
142/// The ways in which an OpDef may compute the Signature of each operation node.
143pub enum SignatureFunc {
144    /// An explicit polymorphic function type.
145    PolyFuncType(PolyFuncTypeRV),
146    /// A polymorphic function type (like [Self::PolyFuncType] but also with a custom binary for validating type arguments.
147    CustomValidator(CustomValidator),
148    /// Serialized declaration specified a custom validate binary but it was not provided.
149    MissingValidateFunc(PolyFuncTypeRV),
150    /// A custom binary which computes a polymorphic function type given values
151    /// for its static type parameters.
152    CustomFunc(Box<dyn CustomSignatureFunc>),
153    /// Serialized declaration specified a custom compute binary but it was not provided.
154    MissingComputeFunc,
155}
156
157impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {
158    fn from(v: T) -> Self {
159        Self::CustomFunc(Box::new(v))
160    }
161}
162
163impl From<PolyFuncType> for SignatureFunc {
164    fn from(value: PolyFuncType) -> Self {
165        Self::PolyFuncType(value.into())
166    }
167}
168
169impl From<PolyFuncTypeRV> for SignatureFunc {
170    fn from(v: PolyFuncTypeRV) -> Self {
171        Self::PolyFuncType(v)
172    }
173}
174
175impl From<FuncValueType> for SignatureFunc {
176    fn from(v: FuncValueType) -> Self {
177        Self::PolyFuncType(v.into())
178    }
179}
180
181impl From<Signature> for SignatureFunc {
182    fn from(v: Signature) -> Self {
183        Self::PolyFuncType(FuncValueType::from(v).into())
184    }
185}
186
187impl From<CustomValidator> for SignatureFunc {
188    fn from(v: CustomValidator) -> Self {
189        Self::CustomValidator(v)
190    }
191}
192
193impl SignatureFunc {
194    fn static_params(&self) -> Result<&[TypeParam], SignatureError> {
195        Ok(match self {
196            SignatureFunc::PolyFuncType(ts)
197            | SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. })
198            | SignatureFunc::MissingValidateFunc(ts) => ts.params(),
199            SignatureFunc::CustomFunc(func) => func.static_params(),
200            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
201        })
202    }
203
204    /// If the signature is missing a custom validation function, ignore and treat as
205    /// self-contained type scheme (with no custom validation).
206    pub fn ignore_missing_validation(&mut self) {
207        if let SignatureFunc::MissingValidateFunc(ts) = self {
208            *self = SignatureFunc::PolyFuncType(ts.clone());
209        }
210    }
211
212    /// Compute the concrete signature ([FuncValueType]).
213    ///
214    /// # Panics
215    ///
216    /// Panics if `self` is a [SignatureFunc::CustomFunc] and there are not enough type
217    /// arguments provided to match the number of static parameters.
218    ///
219    /// # Errors
220    ///
221    /// This function will return an error if the type arguments are invalid or
222    /// there is some error in type computation.
223    pub fn compute_signature(
224        &self,
225        def: &OpDef,
226        args: &[TypeArg],
227    ) -> Result<Signature, SignatureError> {
228        let temp: PolyFuncTypeRV; // to keep alive
229        let (pf, args) = match &self {
230            SignatureFunc::CustomValidator(custom) => {
231                custom.validate.validate(args, def)?;
232                (&custom.poly_func, args)
233            }
234            SignatureFunc::PolyFuncType(ts) => (ts, args),
235            SignatureFunc::CustomFunc(func) => {
236                let static_params = func.static_params();
237                let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
238
239                check_type_args(static_args, static_params)?;
240                temp = func.compute_signature(static_args, def)?;
241                (&temp, other_args)
242            }
243            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
244            // TODO raise warning: https://github.com/CQCL/hugr/issues/1432
245            SignatureFunc::MissingValidateFunc(ts) => (ts, args),
246        };
247        let mut res = pf.instantiate(args)?;
248
249        // Automatically add the extensions where the operation is defined to
250        // the runtime requirements of the op.
251        res.runtime_reqs.insert(def.extension.clone());
252
253        // If there are any row variables left, this will fail with an error:
254        res.try_into()
255    }
256}
257
258impl Debug for SignatureFunc {
259    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
260        match self {
261            Self::CustomValidator(ts) => ts.poly_func.fmt(f),
262            Self::PolyFuncType(ts) => ts.fmt(f),
263            Self::CustomFunc { .. } => f.write_str("<custom sig>"),
264            Self::MissingComputeFunc => f.write_str("<missing custom sig>"),
265            Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"),
266        }
267    }
268}
269
270/// Different ways that an [OpDef] can lower operation nodes i.e. provide a Hugr
271/// that implements the operation using a set of other extensions.
272#[derive(serde::Deserialize, serde::Serialize)]
273#[serde(untagged)]
274pub enum LowerFunc {
275    /// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s,
276    /// this will generally only be applicable if the [OpDef] has no [TypeParam]s.
277    FixedHugr {
278        /// The extensions required by the [`Hugr`]
279        extensions: ExtensionSet,
280        /// The [`Hugr`] to be used to replace [ExtensionOp]s matching the parent
281        /// [OpDef]
282        ///
283        /// [ExtensionOp]: crate::ops::ExtensionOp
284        hugr: Hugr,
285    },
286    /// Custom binary function that can (fallibly) compute a Hugr
287    /// for the particular instance and set of available extensions.
288    #[serde(skip)]
289    CustomFunc(Box<dyn CustomLowerFunc>),
290}
291
292impl Debug for LowerFunc {
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        match self {
295            Self::FixedHugr { .. } => write!(f, "FixedHugr"),
296            Self::CustomFunc(_) => write!(f, "<custom lower>"),
297        }
298    }
299}
300
301/// Serializable definition for dynamically loaded operations.
302///
303/// TODO: Define a way to construct new OpDef's from a serialized definition.
304#[derive(Debug, serde::Serialize, serde::Deserialize)]
305pub struct OpDef {
306    /// The unique Extension owning this OpDef (of which this OpDef is a member)
307    extension: ExtensionId,
308    /// A weak reference to the extension defining this operation.
309    #[serde(skip)]
310    extension_ref: Weak<Extension>,
311    /// Unique identifier of the operation. Used to look up OpDefs in the registry
312    /// when deserializing nodes (which store only the name).
313    name: OpName,
314    /// Human readable description of the operation.
315    description: String,
316    /// Miscellaneous data associated with the operation.
317    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
318    misc: HashMap<String, serde_json::Value>,
319
320    #[serde(with = "serialize_signature_func", flatten)]
321    signature_func: SignatureFunc,
322    // Some operations cannot lower themselves and tools that do not understand them
323    // can only treat them as opaque/black-box ops.
324    #[serde(default, skip_serializing_if = "Vec::is_empty")]
325    pub(crate) lower_funcs: Vec<LowerFunc>,
326
327    /// Operations can optionally implement [`ConstFold`] to implement constant folding.
328    #[serde(skip)]
329    constant_folder: Option<Box<dyn ConstFold>>,
330}
331
332impl OpDef {
333    /// Check provided type arguments are valid against their extensions,
334    /// against parameters, and that no type variables are used as static arguments
335    /// (to [compute_signature][CustomSignatureFunc::compute_signature])
336    pub fn validate_args(
337        &self,
338        args: &[TypeArg],
339        var_decls: &[TypeParam],
340    ) -> Result<(), SignatureError> {
341        let temp: PolyFuncTypeRV; // to keep alive
342        let (pf, args) = match &self.signature_func {
343            SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args),
344            SignatureFunc::PolyFuncType(ts) => (ts, args),
345            SignatureFunc::CustomFunc(custom) => {
346                let (static_args, other_args) =
347                    args.split_at(min(custom.static_params().len(), args.len()));
348                static_args.iter().try_for_each(|ta| ta.validate(&[]))?;
349                check_type_args(static_args, custom.static_params())?;
350                temp = custom.compute_signature(static_args, self)?;
351                (&temp, other_args)
352            }
353            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
354            SignatureFunc::MissingValidateFunc(_) => {
355                return Err(SignatureError::MissingValidateFunc)
356            }
357        };
358        args.iter().try_for_each(|ta| ta.validate(var_decls))?;
359        check_type_args(args, pf.params())?;
360        Ok(())
361    }
362
363    /// Computes the signature of a node, i.e. an instantiation of this
364    /// OpDef with statically-provided [TypeArg]s.
365    pub fn compute_signature(&self, args: &[TypeArg]) -> Result<Signature, SignatureError> {
366        self.signature_func.compute_signature(self, args)
367    }
368
369    /// Fallibly returns a Hugr that may replace an instance of this OpDef
370    /// given a set of available extensions that may be used in the Hugr.
371    pub fn try_lower(&self, args: &[TypeArg], available_extensions: &ExtensionSet) -> Option<Hugr> {
372        // TODO test this
373        self.lower_funcs
374            .iter()
375            .flat_map(|f| match f {
376                LowerFunc::FixedHugr { extensions, hugr } => {
377                    if available_extensions.is_superset(extensions) {
378                        Some(hugr.clone())
379                    } else {
380                        None
381                    }
382                }
383                LowerFunc::CustomFunc(f) => {
384                    f.try_lower(&self.name, args, &self.misc, available_extensions)
385                }
386            })
387            .next()
388    }
389
390    /// Returns a reference to the name of this [`OpDef`].
391    pub fn name(&self) -> &OpName {
392        &self.name
393    }
394
395    /// Returns a reference to the extension id of this [`OpDef`].
396    pub fn extension_id(&self) -> &ExtensionId {
397        &self.extension
398    }
399
400    /// Returns a weak reference to the extension defining this operation.
401    pub fn extension(&self) -> Weak<Extension> {
402        self.extension_ref.clone()
403    }
404
405    /// Returns a mutable reference to the weak extension pointer in the operation definition.
406    pub(super) fn extension_mut(&mut self) -> &mut Weak<Extension> {
407        &mut self.extension_ref
408    }
409
410    /// Returns a reference to the description of this [`OpDef`].
411    pub fn description(&self) -> &str {
412        self.description.as_ref()
413    }
414
415    /// Returns a reference to the params of this [`OpDef`].
416    pub fn params(&self) -> Result<&[TypeParam], SignatureError> {
417        self.signature_func.static_params()
418    }
419
420    pub(super) fn validate(&self) -> Result<(), SignatureError> {
421        // TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
422        // for both type scheme and custom binary
423        if let SignatureFunc::CustomValidator(ts) = &self.signature_func {
424            // The type scheme may contain row variables so be of variable length;
425            // these will have to be substituted to fixed-length concrete types when
426            // the OpDef is instantiated into an actual OpType.
427            ts.poly_func.validate()?;
428        }
429        Ok(())
430    }
431
432    /// Add a lowering function to the [OpDef]
433    pub fn add_lower_func(&mut self, lower: LowerFunc) {
434        self.lower_funcs.push(lower);
435    }
436
437    /// Insert miscellaneous data `v` to the [OpDef], keyed by `k`.
438    pub fn add_misc(
439        &mut self,
440        k: impl ToString,
441        v: serde_json::Value,
442    ) -> Option<serde_json::Value> {
443        self.misc.insert(k.to_string(), v)
444    }
445
446    /// Iterate over all miscellaneous data in the [OpDef].
447    #[allow(unused)] // Unused when no features are enabled
448    pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator<Item = (&str, &serde_json::Value)> {
449        self.misc.iter().map(|(k, v)| (k.as_str(), v))
450    }
451
452    /// Set the constant folding function for this Op, which can evaluate it
453    /// given constant inputs.
454    pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
455        self.constant_folder = Some(Box::new(fold))
456    }
457
458    /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
459    /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s.
460    pub fn constant_fold(
461        &self,
462        type_args: &[TypeArg],
463        consts: &[(crate::IncomingPort, crate::ops::Value)],
464    ) -> ConstFoldResult {
465        (self.constant_folder.as_ref())?.fold(type_args, consts)
466    }
467
468    /// Returns a reference to the signature function of this [`OpDef`].
469    pub fn signature_func(&self) -> &SignatureFunc {
470        &self.signature_func
471    }
472
473    /// Returns a mutable reference to the signature function of this [`OpDef`].
474    pub(super) fn signature_func_mut(&mut self) -> &mut SignatureFunc {
475        &mut self.signature_func
476    }
477}
478
479impl Extension {
480    /// Add an operation definition to the extension. Must be a type scheme
481    /// (defined by a [`PolyFuncTypeRV`]), a type scheme along with binary
482    /// validation for type arguments ([`CustomValidator`]), or a custom binary
483    /// function for computing the signature given type arguments (implementing
484    /// `[CustomSignatureFunc]`).
485    ///
486    /// This method requires a [`Weak`] reference to the [`Arc`] containing the
487    /// extension being defined. The intended way to call this method is inside
488    /// the closure passed to [`Extension::new_arc`] when defining the extension.
489    ///
490    /// # Example
491    ///
492    /// ```
493    /// # use hugr_core::types::Signature;
494    /// # use hugr_core::extension::{Extension, ExtensionId, Version};
495    /// Extension::new_arc(
496    ///     ExtensionId::new_unchecked("my.extension"),
497    ///     Version::new(0, 1, 0),
498    ///     |ext, extension_ref| {
499    ///         ext.add_op(
500    ///             "MyOp".into(),
501    ///             "Some operation".into(),
502    ///             Signature::new_endo(vec![]),
503    ///             extension_ref,
504    ///         );
505    ///     },
506    /// );
507    /// ```
508    pub fn add_op(
509        &mut self,
510        name: OpName,
511        description: String,
512        signature_func: impl Into<SignatureFunc>,
513        extension_ref: &Weak<Extension>,
514    ) -> Result<&mut OpDef, ExtensionBuildError> {
515        let op = OpDef {
516            extension: self.name.clone(),
517            extension_ref: extension_ref.clone(),
518            name,
519            description,
520            signature_func: signature_func.into(),
521            misc: Default::default(),
522            lower_funcs: Default::default(),
523            constant_folder: Default::default(),
524        };
525
526        match self.operations.entry(op.name.clone()) {
527            Entry::Occupied(_) => Err(ExtensionBuildError::OpDefExists(op.name)),
528            // Just made the arc so should only be one reference to it, can get_mut,
529            Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()),
530        }
531    }
532}
533
534#[cfg(test)]
535pub(super) mod test {
536    use std::num::NonZeroU64;
537
538    use itertools::Itertools;
539
540    use super::SignatureFromArgs;
541    use crate::builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr};
542    use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc};
543    use crate::extension::prelude::usize_t;
544    use crate::extension::SignatureError;
545    use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
546    use crate::ops::OpName;
547    use crate::std_extensions::collections::list;
548    use crate::types::type_param::{TypeArgError, TypeParam};
549    use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
550    use crate::{const_extension_ids, Extension};
551
552    const_extension_ids! {
553        const EXT_ID: ExtensionId = "MyExt";
554    }
555
556    /// A dummy wrapper over an operation definition.
557    #[derive(serde::Serialize, serde::Deserialize, Debug)]
558    pub struct SimpleOpDef(OpDef);
559
560    impl SimpleOpDef {
561        /// Create a new dummy opdef.
562        pub fn new(op_def: OpDef) -> Self {
563            assert!(op_def.constant_folder.is_none());
564            assert!(matches!(
565                op_def.signature_func,
566                SignatureFunc::PolyFuncType(_)
567            ));
568            assert!(op_def
569                .lower_funcs
570                .iter()
571                .all(|lf| matches!(lf, LowerFunc::FixedHugr { .. })));
572            Self(op_def)
573        }
574    }
575
576    impl From<SimpleOpDef> for OpDef {
577        fn from(value: SimpleOpDef) -> Self {
578            value.0
579        }
580    }
581
582    impl PartialEq for SimpleOpDef {
583        fn eq(&self, other: &Self) -> bool {
584            let OpDef {
585                extension,
586                extension_ref: _,
587                name,
588                description,
589                misc,
590                signature_func,
591                lower_funcs,
592                constant_folder: _,
593            } = &self.0;
594            let OpDef {
595                extension: other_extension,
596                extension_ref: _,
597                name: other_name,
598                description: other_description,
599                misc: other_misc,
600                signature_func: other_signature_func,
601                lower_funcs: other_lower_funcs,
602                constant_folder: _,
603            } = &other.0;
604
605            let get_sig = |sf: &_| match sf {
606                // if SignatureFunc or CustomValidator are changed we should get
607                // a compile error here. To fix: modify the fields matched on here,
608                // maintaining the lack of `..` and, for each part that is
609                // serializable, ensure we are checking it for equality below.
610                SignatureFunc::CustomValidator(CustomValidator {
611                    poly_func,
612                    validate: _,
613                })
614                | SignatureFunc::PolyFuncType(poly_func)
615                | SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()),
616                SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
617            };
618
619            let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
620                lfs.iter()
621                    .map(|lf| match lf {
622                        // as with get_sig above, this should break if the hierarchy
623                        // is changed, update similarly.
624                        LowerFunc::FixedHugr { extensions, hugr } => {
625                            Some((extensions.clone(), hugr.clone()))
626                        }
627                        // This is ruled out by `new()` but leave it here for later.
628                        LowerFunc::CustomFunc(_) => None,
629                    })
630                    .collect_vec()
631            };
632
633            extension == other_extension
634                && name == other_name
635                && description == other_description
636                && misc == other_misc
637                && get_sig(signature_func) == get_sig(other_signature_func)
638                && get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs)
639        }
640    }
641
642    #[test]
643    fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
644        let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap();
645        const OP_NAME: OpName = OpName::new_inline("Reverse");
646
647        let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
648            const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
649            let list_of_var =
650                Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
651            let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var]));
652
653            let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?;
654            def.add_lower_func(LowerFunc::FixedHugr {
655                extensions: ExtensionSet::new(),
656                hugr: crate::builder::test::simple_dfg_hugr(), // this is nonsense, but we are not testing the actual lowering here
657            });
658            def.add_misc("key", Default::default());
659            assert_eq!(def.description(), "desc");
660            assert_eq!(def.lower_funcs.len(), 1);
661            assert_eq!(def.misc.len(), 1);
662
663            Ok(())
664        })?;
665
666        let reg = ExtensionRegistry::new([PRELUDE.clone(), list::EXTENSION.clone(), ext]);
667        reg.validate()?;
668        let e = reg.get(&EXT_ID).unwrap();
669
670        let list_usize =
671            Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: usize_t() }])?);
672        let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
673        let rev = dfg.add_dataflow_op(
674            e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }])
675                .unwrap(),
676            dfg.input_wires(),
677        )?;
678        dfg.finish_hugr_with_outputs(rev.outputs())?;
679
680        Ok(())
681    }
682
683    #[test]
684    fn binary_polyfunc() -> Result<(), Box<dyn std::error::Error>> {
685        // Test a custom binary `compute_signature` that returns a PolyFuncTypeRV
686        // where the latter declares more type params itself. In particular,
687        // we should be able to substitute (external) type variables into the latter,
688        // but not pass them into the former (custom binary function).
689        struct SigFun();
690        impl SignatureFromArgs for SigFun {
691            fn compute_signature(
692                &self,
693                arg_values: &[TypeArg],
694            ) -> Result<PolyFuncTypeRV, SignatureError> {
695                const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
696                let [TypeArg::BoundedNat { n }] = arg_values else {
697                    return Err(SignatureError::InvalidTypeArgs);
698                };
699                let n = *n as usize;
700                let tvs: Vec<Type> = (0..n)
701                    .map(|_| Type::new_var_use(0, TypeBound::Any))
702                    .collect();
703                Ok(PolyFuncTypeRV::new(
704                    vec![TP.to_owned()],
705                    Signature::new(tvs.clone(), vec![Type::new_tuple(tvs)]),
706                ))
707            }
708
709            fn static_params(&self) -> &[TypeParam] {
710                const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat()];
711                MAX_NAT
712            }
713        }
714        let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
715            let def: &mut crate::extension::OpDef =
716                ext.add_op("MyOp".into(), "".to_string(), SigFun(), extension_ref)?;
717
718            // Base case, no type variables:
719            let args = [TypeArg::BoundedNat { n: 3 }, usize_t().into()];
720            assert_eq!(
721                def.compute_signature(&args),
722                Ok(Signature::new(
723                    vec![usize_t(); 3],
724                    vec![Type::new_tuple(vec![usize_t(); 3])]
725                )
726                .with_extension_delta(EXT_ID))
727            );
728            assert_eq!(def.validate_args(&args, &[]), Ok(()));
729
730            // Second arg may be a variable (substitutable)
731            let tyvar = Type::new_var_use(0, TypeBound::Copyable);
732            let tyvars: Vec<Type> = vec![tyvar.clone(); 3];
733            let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()];
734            assert_eq!(
735                def.compute_signature(&args),
736                Ok(
737                    Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)])
738                        .with_extension_delta(EXT_ID)
739                )
740            );
741            def.validate_args(&args, &[TypeBound::Copyable.into()])
742                .unwrap();
743
744            // quick sanity check that we are validating the args - note changed bound:
745            assert_eq!(
746                def.validate_args(&args, &[TypeBound::Any.into()]),
747                Err(SignatureError::TypeVarDoesNotMatchDeclaration {
748                    actual: TypeBound::Any.into(),
749                    cached: TypeBound::Copyable.into()
750                })
751            );
752
753            // First arg must be concrete, not a variable
754            let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap());
755            let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()];
756            // We can't prevent this from getting into our compute_signature implementation:
757            assert_eq!(
758                def.compute_signature(&args),
759                Err(SignatureError::InvalidTypeArgs)
760            );
761            // But validation rules it out, even when the variable is declared:
762            assert_eq!(
763                def.validate_args(&args, &[kind]),
764                Err(SignatureError::FreeTypeVar {
765                    idx: 0,
766                    num_decls: 0
767                })
768            );
769
770            Ok(())
771        })?;
772
773        Ok(())
774    }
775
776    #[test]
777    fn type_scheme_instantiate_var() -> Result<(), Box<dyn std::error::Error>> {
778        // Check that we can instantiate a PolyFuncTypeRV-scheme with an (external)
779        // type variable
780        let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
781            let def = ext.add_op(
782                "SimpleOp".into(),
783                "".into(),
784                PolyFuncTypeRV::new(
785                    vec![TypeBound::Any.into()],
786                    Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]),
787                ),
788                extension_ref,
789            )?;
790            let tv = Type::new_var_use(1, TypeBound::Copyable);
791            let args = [TypeArg::Type { ty: tv.clone() }];
792            let decls = [TypeParam::Extensions, TypeBound::Copyable.into()];
793            def.validate_args(&args, &decls).unwrap();
794            assert_eq!(
795                def.compute_signature(&args),
796                Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID))
797            );
798            // But not with an external row variable
799            let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into();
800            assert_eq!(
801                def.compute_signature(&[arg.clone()]),
802                Err(SignatureError::TypeArgMismatch(
803                    TypeArgError::TypeMismatch {
804                        param: TypeBound::Any.into(),
805                        arg
806                    }
807                ))
808            );
809            Ok(())
810        })?;
811        Ok(())
812    }
813
814    #[test]
815    fn instantiate_extension_delta() -> Result<(), Box<dyn std::error::Error>> {
816        use crate::extension::prelude::bool_t;
817
818        let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
819            let params: Vec<TypeParam> = vec![TypeParam::Extensions];
820            let db_set = ExtensionSet::type_var(0);
821            let fun_ty = Signature::new_endo(bool_t()).with_extension_delta(db_set);
822
823            let def = ext.add_op(
824                "SimpleOp".into(),
825                "".into(),
826                PolyFuncTypeRV::new(params.clone(), fun_ty),
827                extension_ref,
828            )?;
829
830            // Concrete extension set
831            let es = ExtensionSet::singleton(EXT_ID);
832            let exp_fun_ty = Signature::new_endo(bool_t()).with_extension_delta(es.clone());
833            let args = [TypeArg::Extensions { es }];
834
835            def.validate_args(&args, &params).unwrap();
836            assert_eq!(def.compute_signature(&args), Ok(exp_fun_ty));
837
838            Ok(())
839        })?;
840
841        Ok(())
842    }
843
844    mod proptest {
845        use std::sync::Weak;
846
847        use super::SimpleOpDef;
848        use ::proptest::prelude::*;
849
850        use crate::{
851            builder::test::simple_dfg_hugr,
852            extension::{op_def::LowerFunc, ExtensionId, ExtensionSet, OpDef, SignatureFunc},
853            types::PolyFuncTypeRV,
854        };
855
856        impl Arbitrary for SignatureFunc {
857            type Parameters = ();
858            type Strategy = BoxedStrategy<Self>;
859            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
860                // TODO there is also  SignatureFunc::CustomFunc, but for now
861                // this is not serialized. When it is, we should generate
862                // examples here .
863                any::<PolyFuncTypeRV>()
864                    .prop_map(SignatureFunc::PolyFuncType)
865                    .boxed()
866            }
867        }
868
869        impl Arbitrary for LowerFunc {
870            type Parameters = ();
871            type Strategy = BoxedStrategy<Self>;
872            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
873                // TODO There is also LowerFunc::CustomFunc, but for now this is
874                // not serialized. When it is, we should generate examples here.
875                any::<ExtensionSet>()
876                    .prop_map(|extensions| LowerFunc::FixedHugr {
877                        extensions,
878                        hugr: simple_dfg_hugr(),
879                    })
880                    .boxed()
881            }
882        }
883
884        impl Arbitrary for SimpleOpDef {
885            type Parameters = ();
886            type Strategy = BoxedStrategy<Self>;
887            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
888                use crate::proptest::{any_serde_json_value, any_smolstr, any_string};
889                use proptest::collection::{hash_map, vec};
890                let misc = hash_map(any_string(), any_serde_json_value(), 0..3);
891                (
892                    any::<ExtensionId>(),
893                    any_smolstr(),
894                    any_string(),
895                    misc,
896                    any::<SignatureFunc>(),
897                    vec(any::<LowerFunc>(), 0..2),
898                )
899                    .prop_map(
900                        |(extension, name, description, misc, signature_func, lower_funcs)| {
901                            Self::new(OpDef {
902                                extension,
903                                // Use a dead weak reference. Trying to access the extension will always return None.
904                                extension_ref: Weak::default(),
905                                name,
906                                description,
907                                misc,
908                                signature_func,
909                                lower_funcs,
910                                // TODO ``constant_folder` is not serialized, we should
911                                // generate examples once it is.
912                                constant_folder: None,
913                            })
914                        },
915                    )
916                    .boxed()
917            }
918        }
919    }
920}