Skip to main content

hugr_core/extension/
op_def.rs

1use std::cmp::min;
2use std::collections::HashMap;
3use std::collections::btree_map::Entry;
4use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, Weak};
6
7use serde_with::serde_as;
8
9use super::{
10    ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
11    SignatureError,
12};
13
14use crate::Hugr;
15use crate::envelope::serde_with::AsBinaryEnvelope;
16use crate::ops::{OpName, OpNameRef};
17use crate::package::Package;
18use crate::types::type_param::{TypeArg, TypeParam, check_term_types};
19use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
20mod serialize_signature_func;
21
22/// Trait necessary for binary computations of `OpDef` signature
23pub trait CustomSignatureFunc: Send + Sync {
24    /// Compute signature of node given
25    /// values for the type parameters,
26    /// the operation definition and the extension registry.
27    fn compute_signature<'o, 'a: 'o>(
28        &'a self,
29        arg_values: &[TypeArg],
30        def: &'o OpDef,
31    ) -> Result<PolyFuncTypeRV, SignatureError>;
32    /// The declared type parameters which require values in order for signature to
33    /// be computed.
34    fn static_params(&self) -> &[TypeParam];
35}
36
37/// Compute signature of `OpDef` given type arguments.
38pub trait SignatureFromArgs: Send + Sync {
39    /// Compute signature of node given
40    /// values for the type parameters.
41    fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError>;
42    /// The declared type parameters which require values in order for signature to
43    /// be computed.
44    fn static_params(&self) -> &[TypeParam];
45}
46
47impl<T: SignatureFromArgs> CustomSignatureFunc for T {
48    #[inline]
49    fn compute_signature<'o, 'a: 'o>(
50        &'a self,
51        arg_values: &[TypeArg],
52        _def: &'o OpDef,
53    ) -> Result<PolyFuncTypeRV, SignatureError> {
54        SignatureFromArgs::compute_signature(self, arg_values)
55    }
56
57    #[inline]
58    fn static_params(&self) -> &[TypeParam] {
59        SignatureFromArgs::static_params(self)
60    }
61}
62
63/// Trait for validating type arguments to a `PolyFuncTypeRV` beyond conformation to
64/// declared type parameter (which should have been checked beforehand).
65pub trait ValidateTypeArgs: Send + Sync {
66    /// Validate the type arguments of node given
67    /// values for the type parameters,
68    /// the operation definition and the extension registry.
69    fn validate<'o, 'a: 'o>(
70        &self,
71        arg_values: &[TypeArg],
72        def: &'o OpDef,
73    ) -> Result<(), SignatureError>;
74}
75
76/// Trait for validating type arguments to a `PolyFuncTypeRV` beyond conformation to
77/// declared type parameter (which should have been checked beforehand), given just the arguments.
78pub trait ValidateJustArgs: Send + Sync {
79    /// Validate the type arguments of node given
80    /// values for the type parameters.
81    fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError>;
82}
83
84impl<T: ValidateJustArgs> ValidateTypeArgs for T {
85    #[inline]
86    fn validate<'o, 'a: 'o>(
87        &self,
88        arg_values: &[TypeArg],
89        _def: &'o OpDef,
90    ) -> Result<(), SignatureError> {
91        ValidateJustArgs::validate(self, arg_values)
92    }
93}
94
95/// Trait for Extensions to provide custom binary code that can lower an operation to
96/// a Hugr using only a limited set of other extensions. That is, trait
97/// implementations can return a Hugr that implements the operation using only
98/// those extensions and that can be used to replace the operation node. This may be
99/// useful for third-party Extensions or as a fallback for tools that do not support
100/// the operation natively.
101///
102/// This trait allows the Hugr to be varied according to the operation's [`TypeArg`]s;
103/// if this is not necessary then a single Hugr can be provided instead via
104/// [`LowerFunc::FixedHugr`].
105pub trait CustomLowerFunc: Send + Sync {
106    /// Return a Hugr that implements the node using only the specified available extensions;
107    /// may fail.
108    /// TODO: some error type to indicate Extensions required?
109    fn try_lower(
110        &self,
111        name: &OpNameRef,
112        arg_values: &[TypeArg],
113        misc: &HashMap<String, serde_json::Value>,
114        available_extensions: &ExtensionSet,
115    ) -> Option<Hugr>;
116}
117
118/// Encode a signature as [`PolyFuncTypeRV`] but with additional validation of type
119/// arguments via a custom binary. The binary cannot be serialized so will be
120/// lost over a serialization round-trip.
121pub struct CustomValidator {
122    poly_func: PolyFuncTypeRV,
123    /// Custom function for validating type arguments before returning the signature.
124    pub(crate) validate: Box<dyn ValidateTypeArgs>,
125}
126
127impl CustomValidator {
128    /// Encode a signature using a `PolyFuncTypeRV`, with a custom function for
129    /// validating type arguments before returning the signature.
130    pub fn new(
131        poly_func: impl Into<PolyFuncTypeRV>,
132        validate: impl ValidateTypeArgs + 'static,
133    ) -> Self {
134        Self {
135            poly_func: poly_func.into(),
136            validate: Box::new(validate),
137        }
138    }
139
140    /// Return a reference to the `PolyFuncTypeRV` used by this validator.
141    pub(crate) fn poly_func(&self) -> &PolyFuncTypeRV {
142        &self.poly_func
143    }
144
145    /// Return a mutable reference to the `PolyFuncType`.
146    pub(super) fn poly_func_mut(&mut self) -> &mut PolyFuncTypeRV {
147        &mut self.poly_func
148    }
149}
150
151/// The ways in which an `OpDef` may compute the Signature of each operation node.
152pub enum SignatureFunc {
153    /// An explicit polymorphic function type.
154    PolyFuncType(PolyFuncTypeRV),
155    /// A polymorphic function type (like [`Self::PolyFuncType`] but also with a custom binary for validating type arguments.
156    CustomValidator(CustomValidator),
157    /// Serialized declaration specified a custom validate binary but it was not provided.
158    MissingValidateFunc(PolyFuncTypeRV),
159    /// A custom binary which computes a polymorphic function type given values
160    /// for its static type parameters.
161    CustomFunc(Box<dyn CustomSignatureFunc>),
162    /// Serialized declaration specified a custom compute binary but it was not provided.
163    MissingComputeFunc,
164}
165
166impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {
167    fn from(v: T) -> Self {
168        Self::CustomFunc(Box::new(v))
169    }
170}
171
172impl From<PolyFuncType> for SignatureFunc {
173    fn from(value: PolyFuncType) -> Self {
174        Self::PolyFuncType(value.into())
175    }
176}
177
178impl From<PolyFuncTypeRV> for SignatureFunc {
179    fn from(v: PolyFuncTypeRV) -> Self {
180        Self::PolyFuncType(v)
181    }
182}
183
184impl From<FuncValueType> for SignatureFunc {
185    fn from(v: FuncValueType) -> Self {
186        Self::PolyFuncType(v.into())
187    }
188}
189
190impl From<Signature> for SignatureFunc {
191    fn from(v: Signature) -> Self {
192        Self::PolyFuncType(FuncValueType::from(v).into())
193    }
194}
195
196impl From<CustomValidator> for SignatureFunc {
197    fn from(v: CustomValidator) -> Self {
198        Self::CustomValidator(v)
199    }
200}
201
202impl SignatureFunc {
203    fn static_params(&self) -> Result<&[TypeParam], SignatureError> {
204        Ok(match self {
205            SignatureFunc::PolyFuncType(ts)
206            | SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. })
207            | SignatureFunc::MissingValidateFunc(ts) => ts.params(),
208            SignatureFunc::CustomFunc(func) => func.static_params(),
209            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
210        })
211    }
212
213    /// If the signature is missing a custom validation function, ignore and treat as
214    /// self-contained type scheme (with no custom validation).
215    pub fn ignore_missing_validation(&mut self) {
216        if let SignatureFunc::MissingValidateFunc(ts) = self {
217            *self = SignatureFunc::PolyFuncType(ts.clone());
218        }
219    }
220
221    /// Return the underlying poly function type when available.
222    pub(crate) fn poly_func_type(&self) -> Option<&PolyFuncTypeRV> {
223        match self {
224            SignatureFunc::PolyFuncType(ts) | SignatureFunc::MissingValidateFunc(ts) => Some(ts),
225            SignatureFunc::CustomValidator(custom) => Some(custom.poly_func()),
226            SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
227        }
228    }
229
230    /// Compute the concrete signature ([`FuncValueType`]).
231    ///
232    /// # Panics
233    ///
234    /// Panics if `self` is a [`SignatureFunc::CustomFunc`] and there are not enough type
235    /// arguments provided to match the number of static parameters.
236    ///
237    /// # Errors
238    ///
239    /// This function will return an error if the type arguments are invalid or
240    /// there is some error in type computation.
241    pub fn compute_signature(
242        &self,
243        def: &OpDef,
244        args: &[TypeArg],
245    ) -> Result<Signature, SignatureError> {
246        let temp: PolyFuncTypeRV; // to keep alive
247        let (pf, args) = match &self {
248            SignatureFunc::CustomValidator(custom) => {
249                custom.validate.validate(args, def)?;
250                (&custom.poly_func, args)
251            }
252            SignatureFunc::PolyFuncType(ts) => (ts, args),
253            SignatureFunc::CustomFunc(func) => {
254                let static_params = func.static_params();
255                let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
256
257                check_term_types(static_args, static_params)?;
258                temp = func.compute_signature(static_args, def)?;
259                (&temp, other_args)
260            }
261            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
262            // TODO raise warning: https://github.com/CQCL/hugr/issues/1432
263            SignatureFunc::MissingValidateFunc(ts) => (ts, args),
264        };
265        let res = pf.instantiate(args)?;
266
267        // If there are any row variables left, this will fail with an error:
268        res.try_into()
269    }
270}
271
272impl Debug for SignatureFunc {
273    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
274        match self {
275            Self::CustomValidator(ts) => ts.poly_func.fmt(f),
276            Self::PolyFuncType(ts) => ts.fmt(f),
277            Self::CustomFunc { .. } => f.write_str("<custom sig>"),
278            Self::MissingComputeFunc => f.write_str("<missing custom sig>"),
279            Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"),
280        }
281    }
282}
283
284/// Different ways that an [OpDef] can lower operation nodes i.e. provide a Hugr
285/// that implements the operation using a set of other extensions.
286///
287/// Does not implement [`serde::Deserialize`] directly since the serde error for
288/// untagged enums is unhelpful. Use [`deserialize_lower_funcs`] with
289/// [`serde(deserialize_with = "deserialize_lower_funcs")] instead.
290#[serde_as]
291#[derive(serde::Serialize)]
292#[serde(untagged)]
293pub enum LowerFunc {
294    /// Lowering to a fixed Hugr. Since this cannot depend upon the [TypeArg]s,
295    /// this will generally only be applicable if the [OpDef] has no [TypeParam]s.
296    FixedHugr {
297        /// The extensions required by the [`Hugr`]
298        extensions: ExtensionSet,
299        /// The [`Hugr`] to be used to replace [ExtensionOp]s matching the
300        /// parent [OpDef]
301        ///
302        /// We store it as a single-module package here to keep any encoded
303        /// extensions required to define the Hugr alive.
304        ///
305        /// The Package should contain any non-std extension required to define
306        /// the Hugr. Otherwise, we will not be able to resolve the extensions
307        /// when loading the Hugr.
308        ///
309        /// [ExtensionOp]: crate::ops::ExtensionOp
310        #[serde_as(as = "Box<AsBinaryEnvelope>")]
311        #[serde(rename = "hugr")]
312        pkg: Box<Package>,
313    },
314    /// Custom binary function that can (fallibly) compute a Hugr
315    /// for the particular instance and set of available extensions.
316    #[serde(skip)]
317    CustomFunc(Box<dyn CustomLowerFunc>),
318}
319
320/// A function for deserializing sequences of [`LowerFunc::FixedHugr`].
321///
322/// We could let serde deserialize [`LowerFunc`] as-is, but if the LowerFunc
323/// deserialization fails it just returns an opaque "data did not match any
324/// variant of untagged enum LowerFunc" error. This function will return the
325/// internal errors instead.
326pub fn deserialize_lower_funcs<'de, D>(deserializer: D) -> Result<Vec<LowerFunc>, D::Error>
327where
328    D: serde::Deserializer<'de>,
329{
330    #[serde_as]
331    #[derive(serde::Deserialize)]
332    struct FixedHugrDeserializer {
333        pub extensions: ExtensionSet,
334        #[serde_as(as = "Box<AsBinaryEnvelope>")]
335        pub hugr: Box<Package>,
336    }
337
338    let funcs: Vec<FixedHugrDeserializer> = serde::Deserialize::deserialize(deserializer)?;
339    Ok(funcs
340        .into_iter()
341        .map(|f| LowerFunc::FixedHugr {
342            extensions: f.extensions,
343            pkg: f.hugr,
344        })
345        .collect())
346}
347
348impl Debug for LowerFunc {
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        match self {
351            Self::FixedHugr { .. } => write!(f, "FixedHugr"),
352            Self::CustomFunc(_) => write!(f, "<custom lower>"),
353        }
354    }
355}
356
357/// Serializable definition for dynamically loaded operations.
358///
359/// TODO: Define a way to construct new `OpDef`'s from a serialized definition.
360#[derive(Debug, serde::Serialize, serde::Deserialize)]
361pub struct OpDef {
362    /// The unique Extension owning this `OpDef` (of which this `OpDef` is a member)
363    extension: ExtensionId,
364    /// A weak reference to the extension defining this operation.
365    #[serde(skip)]
366    extension_ref: Weak<Extension>,
367    /// Unique identifier of the operation. Used to look up `OpDefs` in the registry
368    /// when deserializing nodes (which store only the name).
369    name: OpName,
370    /// Human readable description of the operation.
371    description: String,
372    /// Miscellaneous data associated with the operation.
373    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
374    misc: HashMap<String, serde_json::Value>,
375
376    #[serde(with = "serialize_signature_func", flatten)]
377    signature_func: SignatureFunc,
378    // Some operations cannot lower themselves and tools that do not understand them
379    // can only treat them as opaque/black-box ops.
380    #[serde(
381        default,
382        skip_serializing_if = "Vec::is_empty",
383        deserialize_with = "deserialize_lower_funcs"
384    )]
385    pub(crate) lower_funcs: Vec<LowerFunc>,
386
387    /// Operations can optionally implement [`ConstFold`] to implement constant folding.
388    #[serde(skip)]
389    constant_folder: Option<Box<dyn ConstFold>>,
390}
391
392impl OpDef {
393    /// Check provided type arguments are valid against their extensions,
394    /// against parameters, and that no type variables are used as static arguments
395    /// (to [`compute_signature`][CustomSignatureFunc::compute_signature])
396    pub fn validate_args(
397        &self,
398        args: &[TypeArg],
399        var_decls: &[TypeParam],
400    ) -> Result<(), SignatureError> {
401        let temp: PolyFuncTypeRV; // to keep alive
402        let (pf, args) = match &self.signature_func {
403            SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args),
404            SignatureFunc::PolyFuncType(ts) => (ts, args),
405            SignatureFunc::CustomFunc(custom) => {
406                let (static_args, other_args) =
407                    args.split_at(min(custom.static_params().len(), args.len()));
408                static_args.iter().try_for_each(|ta| ta.validate(&[]))?;
409                check_term_types(static_args, custom.static_params())?;
410                temp = custom.compute_signature(static_args, self)?;
411                (&temp, other_args)
412            }
413            SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
414            SignatureFunc::MissingValidateFunc(_) => {
415                return Err(SignatureError::MissingValidateFunc);
416            }
417        };
418        args.iter().try_for_each(|ta| ta.validate(var_decls))?;
419        check_term_types(args, pf.params())?;
420        Ok(())
421    }
422
423    /// Computes the signature of a node, i.e. an instantiation of this
424    /// `OpDef` with statically-provided [`TypeArg`]s.
425    pub fn compute_signature(&self, args: &[TypeArg]) -> Result<Signature, SignatureError> {
426        self.signature_func.compute_signature(self, args)
427    }
428
429    /// Fallibly returns a Hugr that may replace an instance of this `OpDef`
430    /// given a set of available extensions that may be used in the Hugr.
431    #[must_use]
432    pub fn try_lower(&self, args: &[TypeArg], available_extensions: &ExtensionSet) -> Option<Hugr> {
433        // TODO test this
434        self.lower_funcs
435            .iter()
436            .filter_map(|f| match f {
437                LowerFunc::FixedHugr { extensions, pkg } => {
438                    if available_extensions.is_superset(extensions) {
439                        pkg.modules.first().cloned()
440                    } else {
441                        None
442                    }
443                }
444                LowerFunc::CustomFunc(f) => {
445                    f.try_lower(&self.name, args, &self.misc, available_extensions)
446                }
447            })
448            .next()
449    }
450
451    /// Returns a reference to the name of this [`OpDef`].
452    #[must_use]
453    pub fn name(&self) -> &OpName {
454        &self.name
455    }
456
457    /// Returns a reference to the extension id of this [`OpDef`].
458    #[must_use]
459    pub fn extension_id(&self) -> &ExtensionId {
460        &self.extension
461    }
462
463    /// Returns a weak reference to the extension defining this operation.
464    #[must_use]
465    pub fn extension(&self) -> Weak<Extension> {
466        self.extension_ref.clone()
467    }
468
469    /// Returns a mutable reference to the weak extension pointer in the operation definition.
470    pub(super) fn extension_mut(&mut self) -> &mut Weak<Extension> {
471        &mut self.extension_ref
472    }
473
474    /// Returns a reference to the description of this [`OpDef`].
475    #[must_use]
476    pub fn description(&self) -> &str {
477        self.description.as_ref()
478    }
479
480    /// Returns a reference to the params of this [`OpDef`].
481    pub fn params(&self) -> Result<&[TypeParam], SignatureError> {
482        self.signature_func.static_params()
483    }
484
485    pub(super) fn validate(&self) -> Result<(), SignatureError> {
486        // TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
487        // for both type scheme and custom binary
488        if let SignatureFunc::CustomValidator(ts) = &self.signature_func {
489            // The type scheme may contain row variables so be of variable length;
490            // these will have to be substituted to fixed-length concrete types when
491            // the OpDef is instantiated into an actual OpType.
492            ts.poly_func.validate()?;
493        }
494        Ok(())
495    }
496
497    /// Add a lowering function to the [`OpDef`]
498    pub fn add_lower_func(&mut self, lower: LowerFunc) {
499        self.lower_funcs.push(lower);
500    }
501
502    /// Insert miscellaneous data `v` to the [`OpDef`], keyed by `k`.
503    pub fn add_misc(
504        &mut self,
505        k: impl ToString,
506        v: serde_json::Value,
507    ) -> Option<serde_json::Value> {
508        self.misc.insert(k.to_string(), v)
509    }
510
511    /// Iterate over all miscellaneous data in the [`OpDef`].
512    #[allow(unused)] // Unused when no features are enabled
513    pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator<Item = (&str, &serde_json::Value)> {
514        self.misc.iter().map(|(k, v)| (k.as_str(), v))
515    }
516
517    /// Set the constant folding function for this Op, which can evaluate it
518    /// given constant inputs.
519    pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
520        self.constant_folder = Some(Box::new(fold));
521    }
522
523    /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given
524    /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s.
525    #[must_use]
526    pub fn constant_fold(
527        &self,
528        type_args: &[TypeArg],
529        consts: &[(crate::IncomingPort, crate::ops::Value)],
530    ) -> ConstFoldResult {
531        (self.constant_folder.as_ref())?.fold(type_args, consts)
532    }
533
534    /// Returns a reference to the signature function of this [`OpDef`].
535    #[must_use]
536    pub fn signature_func(&self) -> &SignatureFunc {
537        &self.signature_func
538    }
539
540    /// Returns a mutable reference to the signature function of this [`OpDef`].
541    pub(super) fn signature_func_mut(&mut self) -> &mut SignatureFunc {
542        &mut self.signature_func
543    }
544}
545
546impl Extension {
547    /// Add an operation definition to the extension. Must be a type scheme
548    /// (defined by a [`PolyFuncTypeRV`]), a type scheme along with binary
549    /// validation for type arguments ([`CustomValidator`]), or a custom binary
550    /// function for computing the signature given type arguments (implementing
551    /// `[CustomSignatureFunc]`).
552    ///
553    /// This method requires a [`Weak`] reference to the [`Arc`] containing the
554    /// extension being defined. The intended way to call this method is inside
555    /// the closure passed to [`Extension::new_arc`] when defining the extension.
556    ///
557    /// # Example
558    ///
559    /// ```
560    /// # use hugr_core::types::Signature;
561    /// # use hugr_core::extension::{Extension, ExtensionId, Version};
562    /// Extension::new_arc(
563    ///     ExtensionId::new_unchecked("my.extension"),
564    ///     Version::new(0, 1, 0),
565    ///     |ext, extension_ref| {
566    ///         ext.add_op(
567    ///             "MyOp".into(),
568    ///             "Some operation".into(),
569    ///             Signature::new_endo([]),
570    ///             extension_ref,
571    ///         );
572    ///     },
573    /// );
574    /// ```
575    pub fn add_op(
576        &mut self,
577        name: OpName,
578        description: String,
579        signature_func: impl Into<SignatureFunc>,
580        extension_ref: &Weak<Extension>,
581    ) -> Result<&mut OpDef, ExtensionBuildError> {
582        let op = OpDef {
583            extension: self.name.clone(),
584            extension_ref: extension_ref.clone(),
585            name,
586            description,
587            signature_func: signature_func.into(),
588            misc: Default::default(),
589            lower_funcs: Default::default(),
590            constant_folder: Default::default(),
591        };
592
593        match self.operations.entry(op.name.clone()) {
594            Entry::Occupied(_) => Err(ExtensionBuildError::OpDefExists(op.name)),
595            // Just made the arc so should only be one reference to it, can get_mut,
596            Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()),
597        }
598    }
599}
600
601#[cfg(test)]
602pub(super) mod test {
603    use std::num::NonZeroU64;
604
605    use itertools::Itertools;
606
607    use super::SignatureFromArgs;
608    use crate::builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig};
609    use crate::extension::SignatureError;
610    use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc};
611    use crate::extension::prelude::usize_t;
612    use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
613    use crate::ops::OpName;
614    use crate::package::Package;
615    use crate::std_extensions::collections::list;
616    use crate::types::type_param::{TermTypeError, TypeParam};
617    use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
618    use crate::{Extension, const_extension_ids};
619
620    const_extension_ids! {
621        const EXT_ID: ExtensionId = "MyExt";
622    }
623
624    /// A dummy wrapper over an operation definition.
625    #[derive(serde::Serialize, serde::Deserialize, Debug)]
626    pub struct SimpleOpDef(OpDef);
627
628    impl SimpleOpDef {
629        /// Create a new dummy opdef.
630        #[must_use]
631        pub fn new(op_def: OpDef) -> Self {
632            assert!(op_def.constant_folder.is_none());
633            assert!(matches!(
634                op_def.signature_func,
635                SignatureFunc::PolyFuncType(_)
636            ));
637            assert!(
638                op_def
639                    .lower_funcs
640                    .iter()
641                    .all(|lf| matches!(lf, LowerFunc::FixedHugr { .. }))
642            );
643            Self(op_def)
644        }
645    }
646
647    impl From<SimpleOpDef> for OpDef {
648        fn from(value: SimpleOpDef) -> Self {
649            value.0
650        }
651    }
652
653    impl PartialEq for SimpleOpDef {
654        fn eq(&self, other: &Self) -> bool {
655            let OpDef {
656                extension,
657                extension_ref: _,
658                name,
659                description,
660                misc,
661                signature_func,
662                lower_funcs,
663                constant_folder: _,
664            } = &self.0;
665            let OpDef {
666                extension: other_extension,
667                extension_ref: _,
668                name: other_name,
669                description: other_description,
670                misc: other_misc,
671                signature_func: other_signature_func,
672                lower_funcs: other_lower_funcs,
673                constant_folder: _,
674            } = &other.0;
675
676            let get_sig = |sf: &_| match sf {
677                // if SignatureFunc or CustomValidator are changed we should get
678                // a compile error here. To fix: modify the fields matched on here,
679                // maintaining the lack of `..` and, for each part that is
680                // serializable, ensure we are checking it for equality below.
681                SignatureFunc::CustomValidator(CustomValidator {
682                    poly_func,
683                    validate: _,
684                })
685                | SignatureFunc::PolyFuncType(poly_func)
686                | SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()),
687                SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
688            };
689
690            let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
691                lfs.iter()
692                    .map(|lf| match lf {
693                        // as with get_sig above, this should break if the hierarchy
694                        // is changed, update similarly.
695                        LowerFunc::FixedHugr { extensions, pkg } => {
696                            Some((extensions.clone(), pkg.clone()))
697                        }
698                        // This is ruled out by `new()` but leave it here for later.
699                        LowerFunc::CustomFunc(_) => None,
700                    })
701                    .collect_vec()
702            };
703
704            extension == other_extension
705                && name == other_name
706                && description == other_description
707                && misc == other_misc
708                && get_sig(signature_func) == get_sig(other_signature_func)
709                && get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs)
710        }
711    }
712
713    #[test]
714    fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
715        let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap();
716        const OP_NAME: OpName = OpName::new_inline("Reverse");
717
718        let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
719            const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
720            let list_of_var =
721                Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
722            let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo([list_of_var]));
723
724            let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?;
725            def.add_lower_func(LowerFunc::FixedHugr {
726                extensions: ExtensionSet::new(),
727                pkg: Box::new(Package::from_hugr(crate::builder::test::simple_dfg_hugr())), // this is nonsense, but we are not testing the actual lowering here
728            });
729            def.add_misc("key", Default::default());
730            assert_eq!(def.description(), "desc");
731            assert_eq!(def.lower_funcs.len(), 1);
732            assert_eq!(def.misc.len(), 1);
733
734            Ok(())
735        })?;
736
737        let reg = ExtensionRegistry::new([PRELUDE.clone(), list::EXTENSION.clone(), ext]);
738        reg.validate()?;
739        let e = reg.get(&EXT_ID).unwrap();
740
741        let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?);
742        let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
743        let rev = dfg.add_dataflow_op(
744            e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()])
745                .unwrap(),
746            dfg.input_wires(),
747        )?;
748        dfg.finish_hugr_with_outputs(rev.outputs())?;
749
750        Ok(())
751    }
752
753    #[test]
754    fn binary_polyfunc() -> Result<(), Box<dyn std::error::Error>> {
755        // Test a custom binary `compute_signature` that returns a PolyFuncTypeRV
756        // where the latter declares more type params itself. In particular,
757        // we should be able to substitute (external) type variables into the latter,
758        // but not pass them into the former (custom binary function).
759        struct SigFun();
760        impl SignatureFromArgs for SigFun {
761            fn compute_signature(
762                &self,
763                arg_values: &[TypeArg],
764            ) -> Result<PolyFuncTypeRV, SignatureError> {
765                const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
766                let [TypeArg::BoundedNat(n)] = arg_values else {
767                    return Err(SignatureError::InvalidTypeArgs);
768                };
769                let n = *n as usize;
770                let tvs: Vec<Type> = (0..n)
771                    .map(|_| Type::new_var_use(0, TypeBound::Linear))
772                    .collect();
773                Ok(PolyFuncTypeRV::new(
774                    vec![TP.clone()],
775                    Signature::new(tvs.clone(), vec![Type::new_tuple(tvs)]),
776                ))
777            }
778
779            fn static_params(&self) -> &[TypeParam] {
780                const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()];
781                MAX_NAT
782            }
783        }
784        let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
785            let def: &mut crate::extension::OpDef =
786                ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?;
787
788            // Base case, no type variables:
789            let args = [TypeArg::BoundedNat(3), usize_t().into()];
790            assert_eq!(
791                def.compute_signature(&args),
792                Ok(Signature::new(
793                    vec![usize_t(); 3],
794                    vec![Type::new_tuple(vec![usize_t(); 3])]
795                ))
796            );
797            assert_eq!(def.validate_args(&args, &[]), Ok(()));
798
799            // Second arg may be a variable (substitutable)
800            let tyvar = Type::new_var_use(0, TypeBound::Copyable);
801            let tyvars: Vec<Type> = vec![tyvar.clone(); 3];
802            let args = [TypeArg::BoundedNat(3), tyvar.clone().into()];
803            assert_eq!(
804                def.compute_signature(&args),
805                Ok(Signature::new(
806                    tyvars.clone(),
807                    vec![Type::new_tuple(tyvars)]
808                ))
809            );
810            def.validate_args(&args, &[TypeBound::Copyable.into()])
811                .unwrap();
812
813            // quick sanity check that we are validating the args - note changed bound:
814            assert_eq!(
815                def.validate_args(&args, &[TypeBound::Linear.into()]),
816                Err(SignatureError::TypeVarDoesNotMatchDeclaration {
817                    actual: Box::new(TypeBound::Linear.into()),
818                    cached: Box::new(TypeBound::Copyable.into())
819                })
820            );
821
822            // First arg must be concrete, not a variable
823            let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap());
824            let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()];
825            // We can't prevent this from getting into our compute_signature implementation:
826            assert_eq!(
827                def.compute_signature(&args),
828                Err(SignatureError::InvalidTypeArgs)
829            );
830            // But validation rules it out, even when the variable is declared:
831            assert_eq!(
832                def.validate_args(&args, &[kind]),
833                Err(SignatureError::FreeTypeVar {
834                    idx: 0,
835                    num_decls: 0
836                })
837            );
838
839            Ok(())
840        })?;
841
842        Ok(())
843    }
844
845    #[test]
846    fn type_scheme_instantiate_var() -> Result<(), Box<dyn std::error::Error>> {
847        // Check that we can instantiate a PolyFuncTypeRV-scheme with an (external)
848        // type variable
849        let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
850            let def = ext.add_op(
851                "SimpleOp".into(),
852                String::new(),
853                PolyFuncTypeRV::new(
854                    vec![TypeBound::Linear.into()],
855                    Signature::new_endo([Type::new_var_use(0, TypeBound::Linear)]),
856                ),
857                extension_ref,
858            )?;
859            let tv = Type::new_var_use(0, TypeBound::Copyable);
860            let args = [tv.clone().into()];
861            let decls = [TypeBound::Copyable.into()];
862            def.validate_args(&args, &decls).unwrap();
863            assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo([tv])));
864            // But not with an external row variable
865            let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into();
866            assert_eq!(
867                def.compute_signature(std::slice::from_ref(&arg)),
868                Err(SignatureError::TypeArgMismatch(
869                    TermTypeError::TypeMismatch {
870                        type_: Box::new(TypeBound::Linear.into()),
871                        term: Box::new(arg),
872                    }
873                ))
874            );
875            Ok(())
876        })?;
877        Ok(())
878    }
879
880    mod proptest {
881        use std::sync::Weak;
882
883        use super::SimpleOpDef;
884        use ::proptest::prelude::*;
885
886        use crate::package::Package;
887        use crate::{
888            builder::test::simple_dfg_hugr,
889            extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc, op_def::LowerFunc},
890            types::PolyFuncTypeRV,
891        };
892
893        impl Arbitrary for SignatureFunc {
894            type Parameters = ();
895            type Strategy = BoxedStrategy<Self>;
896            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
897                // TODO there is also  SignatureFunc::CustomFunc, but for now
898                // this is not serialized. When it is, we should generate
899                // examples here .
900                any::<PolyFuncTypeRV>()
901                    .prop_map(SignatureFunc::PolyFuncType)
902                    .boxed()
903            }
904        }
905
906        impl Arbitrary for LowerFunc {
907            type Parameters = ();
908            type Strategy = BoxedStrategy<Self>;
909            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
910                // TODO There is also LowerFunc::CustomFunc, but for now this is
911                // not serialized. When it is, we should generate examples here.
912                any::<ExtensionSet>()
913                    .prop_map(|extensions| LowerFunc::FixedHugr {
914                        extensions,
915                        pkg: Box::new(Package::from_hugr(simple_dfg_hugr())),
916                    })
917                    .boxed()
918            }
919        }
920
921        impl Arbitrary for SimpleOpDef {
922            type Parameters = ();
923            type Strategy = BoxedStrategy<Self>;
924            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
925                use crate::proptest::{any_serde_json_value, any_smolstr, any_string};
926                use proptest::collection::{hash_map, vec};
927                let misc = hash_map(any_string(), any_serde_json_value(), 0..3);
928                (
929                    any::<ExtensionId>(),
930                    any_smolstr(),
931                    any_string(),
932                    misc,
933                    any::<SignatureFunc>(),
934                    vec(any::<LowerFunc>(), 0..2),
935                )
936                    .prop_map(
937                        |(extension, name, description, misc, signature_func, lower_funcs)| {
938                            Self::new(OpDef {
939                                extension,
940                                // Use a dead weak reference. Trying to access the extension will always return None.
941                                extension_ref: Weak::default(),
942                                name,
943                                description,
944                                misc,
945                                signature_func,
946                                lower_funcs,
947                                // TODO ``constant_folder` is not serialized, we should
948                                // generate examples once it is.
949                                constant_folder: None,
950                            })
951                        },
952                    )
953                    .boxed()
954            }
955        }
956    }
957}