1use std::cmp::min;
2use std::collections::HashMap;
3use std::collections::btree_map::Entry;
4use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, Weak};
6
7use serde_with::serde_as;
8
9use super::{
10 ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
11 SignatureError,
12};
13
14use crate::Hugr;
15use crate::envelope::serde_with::AsBinaryEnvelope;
16use crate::ops::{OpName, OpNameRef};
17use crate::package::Package;
18use crate::types::type_param::{TypeArg, TypeParam, check_term_types};
19use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
20mod serialize_signature_func;
21
22pub trait CustomSignatureFunc: Send + Sync {
24 fn compute_signature<'o, 'a: 'o>(
28 &'a self,
29 arg_values: &[TypeArg],
30 def: &'o OpDef,
31 ) -> Result<PolyFuncTypeRV, SignatureError>;
32 fn static_params(&self) -> &[TypeParam];
35}
36
37pub trait SignatureFromArgs: Send + Sync {
39 fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError>;
42 fn static_params(&self) -> &[TypeParam];
45}
46
47impl<T: SignatureFromArgs> CustomSignatureFunc for T {
48 #[inline]
49 fn compute_signature<'o, 'a: 'o>(
50 &'a self,
51 arg_values: &[TypeArg],
52 _def: &'o OpDef,
53 ) -> Result<PolyFuncTypeRV, SignatureError> {
54 SignatureFromArgs::compute_signature(self, arg_values)
55 }
56
57 #[inline]
58 fn static_params(&self) -> &[TypeParam] {
59 SignatureFromArgs::static_params(self)
60 }
61}
62
63pub trait ValidateTypeArgs: Send + Sync {
66 fn validate<'o, 'a: 'o>(
70 &self,
71 arg_values: &[TypeArg],
72 def: &'o OpDef,
73 ) -> Result<(), SignatureError>;
74}
75
76pub trait ValidateJustArgs: Send + Sync {
79 fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError>;
82}
83
84impl<T: ValidateJustArgs> ValidateTypeArgs for T {
85 #[inline]
86 fn validate<'o, 'a: 'o>(
87 &self,
88 arg_values: &[TypeArg],
89 _def: &'o OpDef,
90 ) -> Result<(), SignatureError> {
91 ValidateJustArgs::validate(self, arg_values)
92 }
93}
94
95pub trait CustomLowerFunc: Send + Sync {
106 fn try_lower(
110 &self,
111 name: &OpNameRef,
112 arg_values: &[TypeArg],
113 misc: &HashMap<String, serde_json::Value>,
114 available_extensions: &ExtensionSet,
115 ) -> Option<Hugr>;
116}
117
118pub struct CustomValidator {
122 poly_func: PolyFuncTypeRV,
123 pub(crate) validate: Box<dyn ValidateTypeArgs>,
125}
126
127impl CustomValidator {
128 pub fn new(
131 poly_func: impl Into<PolyFuncTypeRV>,
132 validate: impl ValidateTypeArgs + 'static,
133 ) -> Self {
134 Self {
135 poly_func: poly_func.into(),
136 validate: Box::new(validate),
137 }
138 }
139
140 pub(crate) fn poly_func(&self) -> &PolyFuncTypeRV {
142 &self.poly_func
143 }
144
145 pub(super) fn poly_func_mut(&mut self) -> &mut PolyFuncTypeRV {
147 &mut self.poly_func
148 }
149}
150
151pub enum SignatureFunc {
153 PolyFuncType(PolyFuncTypeRV),
155 CustomValidator(CustomValidator),
157 MissingValidateFunc(PolyFuncTypeRV),
159 CustomFunc(Box<dyn CustomSignatureFunc>),
162 MissingComputeFunc,
164}
165
166impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {
167 fn from(v: T) -> Self {
168 Self::CustomFunc(Box::new(v))
169 }
170}
171
172impl From<PolyFuncType> for SignatureFunc {
173 fn from(value: PolyFuncType) -> Self {
174 Self::PolyFuncType(value.into())
175 }
176}
177
178impl From<PolyFuncTypeRV> for SignatureFunc {
179 fn from(v: PolyFuncTypeRV) -> Self {
180 Self::PolyFuncType(v)
181 }
182}
183
184impl From<FuncValueType> for SignatureFunc {
185 fn from(v: FuncValueType) -> Self {
186 Self::PolyFuncType(v.into())
187 }
188}
189
190impl From<Signature> for SignatureFunc {
191 fn from(v: Signature) -> Self {
192 Self::PolyFuncType(FuncValueType::from(v).into())
193 }
194}
195
196impl From<CustomValidator> for SignatureFunc {
197 fn from(v: CustomValidator) -> Self {
198 Self::CustomValidator(v)
199 }
200}
201
202impl SignatureFunc {
203 fn static_params(&self) -> Result<&[TypeParam], SignatureError> {
204 Ok(match self {
205 SignatureFunc::PolyFuncType(ts)
206 | SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. })
207 | SignatureFunc::MissingValidateFunc(ts) => ts.params(),
208 SignatureFunc::CustomFunc(func) => func.static_params(),
209 SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
210 })
211 }
212
213 pub fn ignore_missing_validation(&mut self) {
216 if let SignatureFunc::MissingValidateFunc(ts) = self {
217 *self = SignatureFunc::PolyFuncType(ts.clone());
218 }
219 }
220
221 pub(crate) fn poly_func_type(&self) -> Option<&PolyFuncTypeRV> {
223 match self {
224 SignatureFunc::PolyFuncType(ts) | SignatureFunc::MissingValidateFunc(ts) => Some(ts),
225 SignatureFunc::CustomValidator(custom) => Some(custom.poly_func()),
226 SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
227 }
228 }
229
230 pub fn compute_signature(
242 &self,
243 def: &OpDef,
244 args: &[TypeArg],
245 ) -> Result<Signature, SignatureError> {
246 let temp: PolyFuncTypeRV; let (pf, args) = match &self {
248 SignatureFunc::CustomValidator(custom) => {
249 custom.validate.validate(args, def)?;
250 (&custom.poly_func, args)
251 }
252 SignatureFunc::PolyFuncType(ts) => (ts, args),
253 SignatureFunc::CustomFunc(func) => {
254 let static_params = func.static_params();
255 let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
256
257 check_term_types(static_args, static_params)?;
258 temp = func.compute_signature(static_args, def)?;
259 (&temp, other_args)
260 }
261 SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
262 SignatureFunc::MissingValidateFunc(ts) => (ts, args),
264 };
265 let res = pf.instantiate(args)?;
266
267 res.try_into()
269 }
270}
271
272impl Debug for SignatureFunc {
273 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
274 match self {
275 Self::CustomValidator(ts) => ts.poly_func.fmt(f),
276 Self::PolyFuncType(ts) => ts.fmt(f),
277 Self::CustomFunc { .. } => f.write_str("<custom sig>"),
278 Self::MissingComputeFunc => f.write_str("<missing custom sig>"),
279 Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"),
280 }
281 }
282}
283
284#[serde_as]
291#[derive(serde::Serialize)]
292#[serde(untagged)]
293pub enum LowerFunc {
294 FixedHugr {
297 extensions: ExtensionSet,
299 #[serde_as(as = "Box<AsBinaryEnvelope>")]
311 #[serde(rename = "hugr")]
312 pkg: Box<Package>,
313 },
314 #[serde(skip)]
317 CustomFunc(Box<dyn CustomLowerFunc>),
318}
319
320pub fn deserialize_lower_funcs<'de, D>(deserializer: D) -> Result<Vec<LowerFunc>, D::Error>
327where
328 D: serde::Deserializer<'de>,
329{
330 #[serde_as]
331 #[derive(serde::Deserialize)]
332 struct FixedHugrDeserializer {
333 pub extensions: ExtensionSet,
334 #[serde_as(as = "Box<AsBinaryEnvelope>")]
335 pub hugr: Box<Package>,
336 }
337
338 let funcs: Vec<FixedHugrDeserializer> = serde::Deserialize::deserialize(deserializer)?;
339 Ok(funcs
340 .into_iter()
341 .map(|f| LowerFunc::FixedHugr {
342 extensions: f.extensions,
343 pkg: f.hugr,
344 })
345 .collect())
346}
347
348impl Debug for LowerFunc {
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 match self {
351 Self::FixedHugr { .. } => write!(f, "FixedHugr"),
352 Self::CustomFunc(_) => write!(f, "<custom lower>"),
353 }
354 }
355}
356
357#[derive(Debug, serde::Serialize, serde::Deserialize)]
361pub struct OpDef {
362 extension: ExtensionId,
364 #[serde(skip)]
366 extension_ref: Weak<Extension>,
367 name: OpName,
370 description: String,
372 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
374 misc: HashMap<String, serde_json::Value>,
375
376 #[serde(with = "serialize_signature_func", flatten)]
377 signature_func: SignatureFunc,
378 #[serde(
381 default,
382 skip_serializing_if = "Vec::is_empty",
383 deserialize_with = "deserialize_lower_funcs"
384 )]
385 pub(crate) lower_funcs: Vec<LowerFunc>,
386
387 #[serde(skip)]
389 constant_folder: Option<Box<dyn ConstFold>>,
390}
391
392impl OpDef {
393 pub fn validate_args(
397 &self,
398 args: &[TypeArg],
399 var_decls: &[TypeParam],
400 ) -> Result<(), SignatureError> {
401 let temp: PolyFuncTypeRV; let (pf, args) = match &self.signature_func {
403 SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args),
404 SignatureFunc::PolyFuncType(ts) => (ts, args),
405 SignatureFunc::CustomFunc(custom) => {
406 let (static_args, other_args) =
407 args.split_at(min(custom.static_params().len(), args.len()));
408 static_args.iter().try_for_each(|ta| ta.validate(&[]))?;
409 check_term_types(static_args, custom.static_params())?;
410 temp = custom.compute_signature(static_args, self)?;
411 (&temp, other_args)
412 }
413 SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
414 SignatureFunc::MissingValidateFunc(_) => {
415 return Err(SignatureError::MissingValidateFunc);
416 }
417 };
418 args.iter().try_for_each(|ta| ta.validate(var_decls))?;
419 check_term_types(args, pf.params())?;
420 Ok(())
421 }
422
423 pub fn compute_signature(&self, args: &[TypeArg]) -> Result<Signature, SignatureError> {
426 self.signature_func.compute_signature(self, args)
427 }
428
429 #[must_use]
432 pub fn try_lower(&self, args: &[TypeArg], available_extensions: &ExtensionSet) -> Option<Hugr> {
433 self.lower_funcs
435 .iter()
436 .filter_map(|f| match f {
437 LowerFunc::FixedHugr { extensions, pkg } => {
438 if available_extensions.is_superset(extensions) {
439 pkg.modules.first().cloned()
440 } else {
441 None
442 }
443 }
444 LowerFunc::CustomFunc(f) => {
445 f.try_lower(&self.name, args, &self.misc, available_extensions)
446 }
447 })
448 .next()
449 }
450
451 #[must_use]
453 pub fn name(&self) -> &OpName {
454 &self.name
455 }
456
457 #[must_use]
459 pub fn extension_id(&self) -> &ExtensionId {
460 &self.extension
461 }
462
463 #[must_use]
465 pub fn extension(&self) -> Weak<Extension> {
466 self.extension_ref.clone()
467 }
468
469 pub(super) fn extension_mut(&mut self) -> &mut Weak<Extension> {
471 &mut self.extension_ref
472 }
473
474 #[must_use]
476 pub fn description(&self) -> &str {
477 self.description.as_ref()
478 }
479
480 pub fn params(&self) -> Result<&[TypeParam], SignatureError> {
482 self.signature_func.static_params()
483 }
484
485 pub(super) fn validate(&self) -> Result<(), SignatureError> {
486 if let SignatureFunc::CustomValidator(ts) = &self.signature_func {
489 ts.poly_func.validate()?;
493 }
494 Ok(())
495 }
496
497 pub fn add_lower_func(&mut self, lower: LowerFunc) {
499 self.lower_funcs.push(lower);
500 }
501
502 pub fn add_misc(
504 &mut self,
505 k: impl ToString,
506 v: serde_json::Value,
507 ) -> Option<serde_json::Value> {
508 self.misc.insert(k.to_string(), v)
509 }
510
511 #[allow(unused)] pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator<Item = (&str, &serde_json::Value)> {
514 self.misc.iter().map(|(k, v)| (k.as_str(), v))
515 }
516
517 pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
520 self.constant_folder = Some(Box::new(fold));
521 }
522
523 #[must_use]
526 pub fn constant_fold(
527 &self,
528 type_args: &[TypeArg],
529 consts: &[(crate::IncomingPort, crate::ops::Value)],
530 ) -> ConstFoldResult {
531 (self.constant_folder.as_ref())?.fold(type_args, consts)
532 }
533
534 #[must_use]
536 pub fn signature_func(&self) -> &SignatureFunc {
537 &self.signature_func
538 }
539
540 pub(super) fn signature_func_mut(&mut self) -> &mut SignatureFunc {
542 &mut self.signature_func
543 }
544}
545
546impl Extension {
547 pub fn add_op(
576 &mut self,
577 name: OpName,
578 description: String,
579 signature_func: impl Into<SignatureFunc>,
580 extension_ref: &Weak<Extension>,
581 ) -> Result<&mut OpDef, ExtensionBuildError> {
582 let op = OpDef {
583 extension: self.name.clone(),
584 extension_ref: extension_ref.clone(),
585 name,
586 description,
587 signature_func: signature_func.into(),
588 misc: Default::default(),
589 lower_funcs: Default::default(),
590 constant_folder: Default::default(),
591 };
592
593 match self.operations.entry(op.name.clone()) {
594 Entry::Occupied(_) => Err(ExtensionBuildError::OpDefExists(op.name)),
595 Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()),
597 }
598 }
599}
600
601#[cfg(test)]
602pub(super) mod test {
603 use std::num::NonZeroU64;
604
605 use itertools::Itertools;
606
607 use super::SignatureFromArgs;
608 use crate::builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig};
609 use crate::extension::SignatureError;
610 use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc};
611 use crate::extension::prelude::usize_t;
612 use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
613 use crate::ops::OpName;
614 use crate::package::Package;
615 use crate::std_extensions::collections::list;
616 use crate::types::type_param::{TermTypeError, TypeParam};
617 use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
618 use crate::{Extension, const_extension_ids};
619
620 const_extension_ids! {
621 const EXT_ID: ExtensionId = "MyExt";
622 }
623
624 #[derive(serde::Serialize, serde::Deserialize, Debug)]
626 pub struct SimpleOpDef(OpDef);
627
628 impl SimpleOpDef {
629 #[must_use]
631 pub fn new(op_def: OpDef) -> Self {
632 assert!(op_def.constant_folder.is_none());
633 assert!(matches!(
634 op_def.signature_func,
635 SignatureFunc::PolyFuncType(_)
636 ));
637 assert!(
638 op_def
639 .lower_funcs
640 .iter()
641 .all(|lf| matches!(lf, LowerFunc::FixedHugr { .. }))
642 );
643 Self(op_def)
644 }
645 }
646
647 impl From<SimpleOpDef> for OpDef {
648 fn from(value: SimpleOpDef) -> Self {
649 value.0
650 }
651 }
652
653 impl PartialEq for SimpleOpDef {
654 fn eq(&self, other: &Self) -> bool {
655 let OpDef {
656 extension,
657 extension_ref: _,
658 name,
659 description,
660 misc,
661 signature_func,
662 lower_funcs,
663 constant_folder: _,
664 } = &self.0;
665 let OpDef {
666 extension: other_extension,
667 extension_ref: _,
668 name: other_name,
669 description: other_description,
670 misc: other_misc,
671 signature_func: other_signature_func,
672 lower_funcs: other_lower_funcs,
673 constant_folder: _,
674 } = &other.0;
675
676 let get_sig = |sf: &_| match sf {
677 SignatureFunc::CustomValidator(CustomValidator {
682 poly_func,
683 validate: _,
684 })
685 | SignatureFunc::PolyFuncType(poly_func)
686 | SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()),
687 SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
688 };
689
690 let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
691 lfs.iter()
692 .map(|lf| match lf {
693 LowerFunc::FixedHugr { extensions, pkg } => {
696 Some((extensions.clone(), pkg.clone()))
697 }
698 LowerFunc::CustomFunc(_) => None,
700 })
701 .collect_vec()
702 };
703
704 extension == other_extension
705 && name == other_name
706 && description == other_description
707 && misc == other_misc
708 && get_sig(signature_func) == get_sig(other_signature_func)
709 && get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs)
710 }
711 }
712
713 #[test]
714 fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
715 let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap();
716 const OP_NAME: OpName = OpName::new_inline("Reverse");
717
718 let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
719 const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
720 let list_of_var =
721 Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
722 let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo([list_of_var]));
723
724 let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?;
725 def.add_lower_func(LowerFunc::FixedHugr {
726 extensions: ExtensionSet::new(),
727 pkg: Box::new(Package::from_hugr(crate::builder::test::simple_dfg_hugr())), });
729 def.add_misc("key", Default::default());
730 assert_eq!(def.description(), "desc");
731 assert_eq!(def.lower_funcs.len(), 1);
732 assert_eq!(def.misc.len(), 1);
733
734 Ok(())
735 })?;
736
737 let reg = ExtensionRegistry::new([PRELUDE.clone(), list::EXTENSION.clone(), ext]);
738 reg.validate()?;
739 let e = reg.get(&EXT_ID).unwrap();
740
741 let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?);
742 let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
743 let rev = dfg.add_dataflow_op(
744 e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()])
745 .unwrap(),
746 dfg.input_wires(),
747 )?;
748 dfg.finish_hugr_with_outputs(rev.outputs())?;
749
750 Ok(())
751 }
752
753 #[test]
754 fn binary_polyfunc() -> Result<(), Box<dyn std::error::Error>> {
755 struct SigFun();
760 impl SignatureFromArgs for SigFun {
761 fn compute_signature(
762 &self,
763 arg_values: &[TypeArg],
764 ) -> Result<PolyFuncTypeRV, SignatureError> {
765 const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
766 let [TypeArg::BoundedNat(n)] = arg_values else {
767 return Err(SignatureError::InvalidTypeArgs);
768 };
769 let n = *n as usize;
770 let tvs: Vec<Type> = (0..n)
771 .map(|_| Type::new_var_use(0, TypeBound::Linear))
772 .collect();
773 Ok(PolyFuncTypeRV::new(
774 vec![TP.clone()],
775 Signature::new(tvs.clone(), vec![Type::new_tuple(tvs)]),
776 ))
777 }
778
779 fn static_params(&self) -> &[TypeParam] {
780 const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()];
781 MAX_NAT
782 }
783 }
784 let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
785 let def: &mut crate::extension::OpDef =
786 ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?;
787
788 let args = [TypeArg::BoundedNat(3), usize_t().into()];
790 assert_eq!(
791 def.compute_signature(&args),
792 Ok(Signature::new(
793 vec![usize_t(); 3],
794 vec![Type::new_tuple(vec![usize_t(); 3])]
795 ))
796 );
797 assert_eq!(def.validate_args(&args, &[]), Ok(()));
798
799 let tyvar = Type::new_var_use(0, TypeBound::Copyable);
801 let tyvars: Vec<Type> = vec![tyvar.clone(); 3];
802 let args = [TypeArg::BoundedNat(3), tyvar.clone().into()];
803 assert_eq!(
804 def.compute_signature(&args),
805 Ok(Signature::new(
806 tyvars.clone(),
807 vec![Type::new_tuple(tyvars)]
808 ))
809 );
810 def.validate_args(&args, &[TypeBound::Copyable.into()])
811 .unwrap();
812
813 assert_eq!(
815 def.validate_args(&args, &[TypeBound::Linear.into()]),
816 Err(SignatureError::TypeVarDoesNotMatchDeclaration {
817 actual: Box::new(TypeBound::Linear.into()),
818 cached: Box::new(TypeBound::Copyable.into())
819 })
820 );
821
822 let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap());
824 let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()];
825 assert_eq!(
827 def.compute_signature(&args),
828 Err(SignatureError::InvalidTypeArgs)
829 );
830 assert_eq!(
832 def.validate_args(&args, &[kind]),
833 Err(SignatureError::FreeTypeVar {
834 idx: 0,
835 num_decls: 0
836 })
837 );
838
839 Ok(())
840 })?;
841
842 Ok(())
843 }
844
845 #[test]
846 fn type_scheme_instantiate_var() -> Result<(), Box<dyn std::error::Error>> {
847 let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
850 let def = ext.add_op(
851 "SimpleOp".into(),
852 String::new(),
853 PolyFuncTypeRV::new(
854 vec![TypeBound::Linear.into()],
855 Signature::new_endo([Type::new_var_use(0, TypeBound::Linear)]),
856 ),
857 extension_ref,
858 )?;
859 let tv = Type::new_var_use(0, TypeBound::Copyable);
860 let args = [tv.clone().into()];
861 let decls = [TypeBound::Copyable.into()];
862 def.validate_args(&args, &decls).unwrap();
863 assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo([tv])));
864 let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into();
866 assert_eq!(
867 def.compute_signature(std::slice::from_ref(&arg)),
868 Err(SignatureError::TypeArgMismatch(
869 TermTypeError::TypeMismatch {
870 type_: Box::new(TypeBound::Linear.into()),
871 term: Box::new(arg),
872 }
873 ))
874 );
875 Ok(())
876 })?;
877 Ok(())
878 }
879
880 mod proptest {
881 use std::sync::Weak;
882
883 use super::SimpleOpDef;
884 use ::proptest::prelude::*;
885
886 use crate::package::Package;
887 use crate::{
888 builder::test::simple_dfg_hugr,
889 extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc, op_def::LowerFunc},
890 types::PolyFuncTypeRV,
891 };
892
893 impl Arbitrary for SignatureFunc {
894 type Parameters = ();
895 type Strategy = BoxedStrategy<Self>;
896 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
897 any::<PolyFuncTypeRV>()
901 .prop_map(SignatureFunc::PolyFuncType)
902 .boxed()
903 }
904 }
905
906 impl Arbitrary for LowerFunc {
907 type Parameters = ();
908 type Strategy = BoxedStrategy<Self>;
909 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
910 any::<ExtensionSet>()
913 .prop_map(|extensions| LowerFunc::FixedHugr {
914 extensions,
915 pkg: Box::new(Package::from_hugr(simple_dfg_hugr())),
916 })
917 .boxed()
918 }
919 }
920
921 impl Arbitrary for SimpleOpDef {
922 type Parameters = ();
923 type Strategy = BoxedStrategy<Self>;
924 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
925 use crate::proptest::{any_serde_json_value, any_smolstr, any_string};
926 use proptest::collection::{hash_map, vec};
927 let misc = hash_map(any_string(), any_serde_json_value(), 0..3);
928 (
929 any::<ExtensionId>(),
930 any_smolstr(),
931 any_string(),
932 misc,
933 any::<SignatureFunc>(),
934 vec(any::<LowerFunc>(), 0..2),
935 )
936 .prop_map(
937 |(extension, name, description, misc, signature_func, lower_funcs)| {
938 Self::new(OpDef {
939 extension,
940 extension_ref: Weak::default(),
942 name,
943 description,
944 misc,
945 signature_func,
946 lower_funcs,
947 constant_folder: None,
950 })
951 },
952 )
953 .boxed()
954 }
955 }
956 }
957}