1mod custom;
4
5use std::borrow::Cow;
6use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher};
8
9use super::{NamedOp, OpName, OpTrait, StaticTag};
10use super::{OpTag, OpType};
11use crate::extension::ExtensionSet;
12use crate::types::{CustomType, EdgeKind, Signature, SumType, SumTypeError, Type, TypeRow};
13use crate::{Hugr, HugrView};
14
15use delegate::delegate;
16use itertools::Itertools;
17use serde::{Deserialize, Serialize};
18use smol_str::SmolStr;
19use thiserror::Error;
20
21pub use custom::{
22 downcast_equal_consts, get_pair_of_input_values, get_single_input_value, CustomConst,
23 CustomSerialized, TryHash,
24};
25
26#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
27#[non_exhaustive]
31#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
32pub struct Const {
33 #[serde(rename = "v")]
35 pub value: Value,
36}
37
38impl Const {
39 pub fn new(value: Value) -> Self {
41 Self { value }
42 }
43
44 pub fn value(&self) -> &Value {
46 &self.value
47 }
48
49 delegate! {
50 to self.value {
51 pub fn get_type(&self) -> Type;
53 pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T>;
56
57 pub fn validate(&self) -> Result<(), ConstTypeError>;
59 }
60 }
61}
62
63impl From<Value> for Const {
64 fn from(value: Value) -> Self {
65 Self::new(value)
66 }
67}
68
69impl NamedOp for Const {
70 fn name(&self) -> OpName {
71 self.value().name()
72 }
73}
74
75impl StaticTag for Const {
76 const TAG: OpTag = OpTag::Const;
77}
78
79impl OpTrait for Const {
80 fn description(&self) -> &str {
81 "Constant value"
82 }
83
84 fn extension_delta(&self) -> ExtensionSet {
85 self.value().extension_reqs()
86 }
87
88 fn tag(&self) -> OpTag {
89 <Self as StaticTag>::TAG
90 }
91
92 fn static_output(&self) -> Option<EdgeKind> {
93 Some(EdgeKind::Const(self.get_type()))
94 }
95
96 }
98
99impl From<Const> for Value {
100 fn from(konst: Const) -> Self {
101 konst.value
102 }
103}
104
105impl AsRef<Value> for Const {
106 fn as_ref(&self) -> &Value {
107 self.value()
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
112struct SerialSum {
113 #[serde(default)]
114 tag: usize,
115 #[serde(rename = "vs")]
116 values: Vec<Value>,
117 #[serde(default, rename = "typ")]
118 sum_type: Option<SumType>,
119}
120
121#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
122#[serde(try_from = "SerialSum")]
123#[serde(into = "SerialSum")]
124pub struct Sum {
127 pub tag: usize,
129 pub values: Vec<Value>,
133 pub sum_type: SumType,
135}
136
137impl Sum {
138 pub fn as_tuple(&self) -> Option<&[Value]> {
140 self.sum_type.as_tuple().map(|_| self.values.as_ref())
142 }
143
144 fn try_hash<H: Hasher>(&self, st: &mut H) -> bool {
145 maybe_hash_values(&self.values, st) && {
146 st.write_usize(self.tag);
147 self.sum_type.hash(st);
148 true
149 }
150 }
151}
152
153pub(crate) fn maybe_hash_values<H: Hasher>(vals: &[Value], st: &mut H) -> bool {
154 let mut hasher = DefaultHasher::new();
157 vals.iter().all(|e| e.try_hash(&mut hasher)) && {
158 st.write_u64(hasher.finish());
159 true
160 }
161}
162
163impl TryFrom<SerialSum> for Sum {
164 type Error = &'static str;
165
166 fn try_from(value: SerialSum) -> Result<Self, Self::Error> {
167 let SerialSum {
168 tag,
169 values,
170 sum_type,
171 } = value;
172
173 let sum_type = if let Some(sum_type) = sum_type {
174 sum_type
175 } else {
176 if tag != 0 {
177 return Err("Sum type must be provided if tag is not 0");
178 }
179 SumType::new_tuple(values.iter().map(Value::get_type).collect_vec())
180 };
181
182 Ok(Self {
183 tag,
184 values,
185 sum_type,
186 })
187 }
188}
189
190impl From<Sum> for SerialSum {
191 fn from(value: Sum) -> Self {
192 Self {
193 tag: value.tag,
194 values: value.values,
195 sum_type: Some(value.sum_type),
196 }
197 }
198}
199
200#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
201#[serde(tag = "v")]
202pub enum Value {
205 Extension {
207 #[serde(flatten)]
208 e: OpaqueValue,
210 },
211 Function {
214 hugr: Box<Hugr>,
216 },
217 #[serde(alias = "Tuple")]
220 Sum(Sum),
221}
222
223#[cfg_attr(not(miri), doc = "```")] #[cfg_attr(miri, doc = "```ignore")]
246#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct OpaqueValue {
274 #[serde(flatten, with = "self::custom::serde_extension_value")]
275 v: Box<dyn CustomConst>,
276}
277
278impl OpaqueValue {
279 pub fn new(cc: impl CustomConst) -> Self {
281 Self { v: Box::new(cc) }
282 }
283
284 pub fn value(&self) -> &dyn CustomConst {
286 self.v.as_ref()
287 }
288
289 pub(crate) fn value_mut(&mut self) -> &mut dyn CustomConst {
291 self.v.as_mut()
292 }
293
294 delegate! {
295 to self.value() {
296 pub fn get_type(&self) -> Type;
298 pub fn name(&self) -> ValueName;
300 pub fn extension_reqs(&self) -> ExtensionSet;
302 }
303 }
304}
305
306impl<CC: CustomConst> From<CC> for OpaqueValue {
307 fn from(x: CC) -> Self {
308 Self::new(x)
309 }
310}
311
312impl From<Box<dyn CustomConst>> for OpaqueValue {
313 fn from(value: Box<dyn CustomConst>) -> Self {
314 Self { v: value }
315 }
316}
317
318impl PartialEq for OpaqueValue {
319 fn eq(&self, other: &Self) -> bool {
320 self.value().equal_consts(other.value())
321 }
322}
323
324#[derive(Clone, Debug, PartialEq, Eq, Error)]
326#[non_exhaustive]
327pub enum CustomCheckFailure {
328 #[error("Expected type: {expected} but value was of type: {found}")]
330 TypeMismatch {
331 expected: CustomType,
333 found: Type,
335 },
336 #[error("{0}")]
338 Message(String),
339}
340
341#[derive(Clone, Debug, PartialEq, Error)]
343#[non_exhaustive]
344pub enum ConstTypeError {
345 #[error("{0}")]
347 SumType(#[from] SumTypeError),
348 #[error(
350 "A function constant cannot be defined using a Hugr with root of type {hugr_root_type}. Must be a monomorphic function.",
351 )]
352 NotMonomorphicFunction {
353 hugr_root_type: OpType,
355 },
356 #[error("Value {1:?} does not match expected type {0}")]
358 ConstCheckFail(Type, Value),
359 #[error("Error when checking custom type: {0}")]
361 CustomCheckFail(#[from] CustomCheckFailure),
362}
363
364fn mono_fn_type(h: &Hugr) -> Result<Cow<'_, Signature>, ConstTypeError> {
366 let err = || ConstTypeError::NotMonomorphicFunction {
367 hugr_root_type: h.root_type().clone(),
368 };
369 if let Some(pf) = h.poly_func_type() {
370 match pf.try_into() {
371 Ok(sig) => return Ok(Cow::Owned(sig)),
372 Err(_) => return Err(err()),
373 };
374 }
375
376 h.inner_function_type().ok_or_else(err)
377}
378
379impl Value {
380 pub fn get_type(&self) -> Type {
382 match self {
383 Self::Extension { e } => e.get_type(),
384 Self::Sum(Sum { sum_type, .. }) => sum_type.clone().into(),
385 Self::Function { hugr } => {
386 let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e));
387 Type::new_function(func_type.into_owned())
388 }
389 }
390 }
391
392 pub fn sum(
396 tag: usize,
397 items: impl IntoIterator<Item = Value>,
398 typ: SumType,
399 ) -> Result<Self, ConstTypeError> {
400 let values: Vec<Value> = items.into_iter().collect();
401 typ.check_type(tag, &values)?;
402 Ok(Self::Sum(Sum {
403 tag,
404 values,
405 sum_type: typ,
406 }))
407 }
408
409 pub fn tuple(items: impl IntoIterator<Item = Value>) -> Self {
411 let vs = items.into_iter().collect_vec();
412 let tys = vs.iter().map(Self::get_type).collect_vec();
413
414 Self::sum(0, vs, SumType::new_tuple(tys)).expect("Tuple type is valid")
415 }
416
417 pub fn function(hugr: impl Into<Hugr>) -> Result<Self, ConstTypeError> {
423 let hugr = hugr.into();
424 mono_fn_type(&hugr)?;
425 Ok(Self::Function {
426 hugr: Box::new(hugr),
427 })
428 }
429
430 pub const fn unit() -> Self {
432 Self::Sum(Sum {
433 tag: 0,
434 values: vec![],
435 sum_type: SumType::Unit { size: 1 },
436 })
437 }
438
439 pub fn unit_sum(tag: usize, size: u8) -> Result<Self, ConstTypeError> {
441 Self::sum(tag, [], SumType::Unit { size })
442 }
443
444 pub fn unary_unit_sum() -> Self {
446 Self::unit_sum(0, 1).expect("0 < 1")
447 }
448
449 pub fn true_val() -> Self {
451 Self::unit_sum(1, 2).expect("1 < 2")
452 }
453
454 pub fn false_val() -> Self {
456 Self::unit_sum(0, 2).expect("0 < 2")
457 }
458
459 pub fn some<V: Into<Value>>(values: impl IntoIterator<Item = V>) -> Self {
462 let values: Vec<Value> = values.into_iter().map(Into::into).collect_vec();
463 let value_types: Vec<Type> = values.iter().map(|v| v.get_type()).collect_vec();
464 let sum_type = SumType::new_option(value_types);
465 Self::sum(1, values, sum_type).unwrap()
466 }
467
468 pub fn none(value_types: impl Into<TypeRow>) -> Self {
471 Self::sum(0, [], SumType::new_option(value_types)).unwrap()
472 }
473
474 pub fn from_bool(b: bool) -> Self {
478 if b {
479 Self::true_val()
480 } else {
481 Self::false_val()
482 }
483 }
484
485 pub fn extension(custom_const: impl CustomConst) -> Self {
487 Self::Extension {
488 e: OpaqueValue::new(custom_const),
489 }
490 }
491
492 pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
494 if let Self::Extension { e } = self {
495 e.v.downcast_ref()
496 } else {
497 None
498 }
499 }
500
501 fn name(&self) -> OpName {
502 match self {
503 Self::Extension { e } => format!("const:custom:{}", e.name()),
504 Self::Function { hugr: h } => {
505 let Ok(t) = mono_fn_type(h) else {
506 panic!("HUGR root node isn't a valid function parent.");
507 };
508 format!("const:function:[{}]", t)
509 }
510 Self::Sum(Sum {
511 tag,
512 values,
513 sum_type,
514 }) => {
515 if sum_type.as_tuple().is_some() {
516 let names: Vec<_> = values.iter().map(Value::name).collect();
517 format!("const:seq:{{{}}}", names.iter().join(", "))
518 } else {
519 format!("const:sum:{{tag:{tag}, vals:{values:?}}}")
520 }
521 }
522 }
523 .into()
524 }
525
526 pub fn extension_reqs(&self) -> ExtensionSet {
528 match self {
529 Self::Extension { e } => e.extension_reqs().clone(),
530 Self::Function { .. } => ExtensionSet::new(), Self::Sum(Sum { values, .. }) => {
532 ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs()))
533 }
534 }
535 }
536
537 pub fn validate(&self) -> Result<(), ConstTypeError> {
539 match self {
540 Self::Extension { e } => Ok(e.value().validate()?),
541 Self::Function { hugr } => {
542 mono_fn_type(hugr)?;
543 Ok(())
544 }
545 Self::Sum(Sum {
546 tag,
547 values,
548 sum_type,
549 }) => {
550 sum_type.check_type(*tag, values)?;
551 Ok(())
552 }
553 }
554 }
555
556 pub fn as_tuple(&self) -> Option<&[Value]> {
558 if let Self::Sum(sum) = self {
559 sum.as_tuple()
560 } else {
561 None
562 }
563 }
564
565 pub fn try_hash<H: Hasher>(&self, st: &mut H) -> bool {
569 match self {
570 Value::Extension { e } => e.value().try_hash(&mut *st),
571 Value::Function { .. } => false,
572 Value::Sum(s) => s.try_hash(st),
573 }
574 }
575}
576
577impl<T> From<T> for Value
578where
579 T: CustomConst,
580{
581 fn from(value: T) -> Self {
582 Self::extension(value)
583 }
584}
585
586pub type ValueName = SmolStr;
588
589pub type ValueNameRef = str;
591
592#[cfg(test)]
593pub(crate) mod test {
594 use std::collections::HashSet;
595 use std::sync::{Arc, Weak};
596
597 use super::Value;
598 use crate::builder::inout_sig;
599 use crate::builder::test::simple_dfg_hugr;
600 use crate::extension::prelude::{bool_t, usize_custom_t};
601 use crate::extension::resolution::{
602 resolve_custom_type_extensions, resolve_typearg_extensions, ExtensionResolutionError,
603 WeakExtensionRegistry,
604 };
605 use crate::extension::PRELUDE;
606 use crate::std_extensions::arithmetic::int_types::ConstInt;
607 use crate::std_extensions::collections::array::{array_type, ArrayValue};
608 use crate::{
609 builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
610 extension::{
611 prelude::{usize_t, ConstUsize},
612 ExtensionId,
613 },
614 std_extensions::arithmetic::float_types::{float64_type, ConstF64},
615 type_row,
616 types::type_param::TypeArg,
617 types::{Type, TypeBound, TypeRow},
618 };
619 use cool_asserts::assert_matches;
620 use rstest::{fixture, rstest};
621
622 use super::*;
623
624 #[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
625 pub(crate) struct CustomTestValue(pub CustomType);
627
628 #[typetag::serde]
629 impl CustomConst for CustomTestValue {
630 fn name(&self) -> ValueName {
631 format!("CustomTestValue({:?})", self.0).into()
632 }
633
634 fn extension_reqs(&self) -> ExtensionSet {
635 ExtensionSet::singleton(self.0.extension().clone())
636 }
637
638 fn update_extensions(
639 &mut self,
640 extensions: &WeakExtensionRegistry,
641 ) -> Result<(), ExtensionResolutionError> {
642 resolve_custom_type_extensions(&mut self.0, extensions)?;
643 for arg in self.0.args_mut() {
646 resolve_typearg_extensions(arg, extensions)?;
647 }
648 Ok(())
649 }
650
651 fn get_type(&self) -> Type {
652 self.0.clone().into()
653 }
654
655 fn equal_consts(&self, other: &dyn CustomConst) -> bool {
656 crate::ops::constant::downcast_equal_consts(self, other)
657 }
658 }
659
660 pub(crate) fn serialized_float(f: f64) -> Value {
662 CustomSerialized::try_from_custom_const(ConstF64::new(f))
663 .unwrap()
664 .into()
665 }
666
667 #[test]
669 fn test_sum() -> Result<(), BuildError> {
670 use crate::builder::Container;
671 let pred_rows = vec![vec![usize_t(), float64_type()].into(), Type::EMPTY_TYPEROW];
672 let pred_ty = SumType::new(pred_rows.clone());
673
674 let mut b = DFGBuilder::new(inout_sig(
675 type_row![],
676 TypeRow::from(vec![pred_ty.clone().into()]),
677 ))?;
678 let usize_custom_t = usize_custom_t(&Arc::downgrade(&PRELUDE));
679 let c = b.add_constant(Value::sum(
680 0,
681 [
682 CustomTestValue(usize_custom_t.clone()).into(),
683 ConstF64::new(5.1).into(),
684 ],
685 pred_ty.clone(),
686 )?);
687 let w = b.load_const(&c);
688 b.finish_hugr_with_outputs([w]).unwrap();
689
690 let mut b = DFGBuilder::new(Signature::new(
691 type_row![],
692 TypeRow::from(vec![pred_ty.clone().into()]),
693 ))?;
694 let c = b.add_constant(Value::sum(1, [], pred_ty.clone())?);
695 let w = b.load_const(&c);
696 b.finish_hugr_with_outputs([w]).unwrap();
697
698 Ok(())
699 }
700
701 #[test]
702 fn test_bad_sum() {
703 let pred_ty = SumType::new([vec![usize_t(), float64_type()].into(), type_row![]]);
704
705 let good_sum = const_usize();
706 println!("{}", serde_json::to_string_pretty(&good_sum).unwrap());
707
708 let good_sum =
709 Value::sum(0, [const_usize(), serialized_float(5.1)], pred_ty.clone()).unwrap();
710 println!("{}", serde_json::to_string_pretty(&good_sum).unwrap());
711
712 let res = Value::sum(0, [], pred_ty.clone());
713 assert_matches!(
714 res,
715 Err(ConstTypeError::SumType(SumTypeError::WrongVariantLength {
716 tag: 0,
717 expected: 2,
718 found: 0
719 }))
720 );
721
722 let res = Value::sum(4, [], pred_ty.clone());
723 assert_matches!(
724 res,
725 Err(ConstTypeError::SumType(SumTypeError::InvalidTag {
726 tag: 4,
727 num_variants: 2
728 }))
729 );
730
731 let res = Value::sum(0, [const_usize(), const_usize()], pred_ty);
732 assert_matches!(
733 res,
734 Err(ConstTypeError::SumType(SumTypeError::InvalidValueType {
735 tag: 0,
736 index: 1,
737 expected,
738 found,
739 })) if expected == float64_type() && found == const_usize()
740 );
741 }
742
743 #[rstest]
744 fn function_value(simple_dfg_hugr: Hugr) {
745 let v = Value::function(simple_dfg_hugr).unwrap();
746
747 let correct_type = Type::new_function(Signature::new_endo(vec![bool_t()]));
748
749 assert_eq!(v.get_type(), correct_type);
750 assert!(v.name().starts_with("const:function:"))
751 }
752
753 #[fixture]
754 fn const_usize() -> Value {
755 ConstUsize::new(257).into()
756 }
757
758 #[fixture]
759 fn const_serialized_usize() -> Value {
760 CustomSerialized::try_from_custom_const(ConstUsize::new(257))
761 .unwrap()
762 .into()
763 }
764
765 #[fixture]
766 fn const_tuple() -> Value {
767 Value::tuple([const_usize(), Value::true_val()])
768 }
769
770 #[fixture]
772 fn const_tuple_serialized() -> Value {
773 Value::tuple([const_serialized_usize(), Value::true_val()])
774 }
775
776 #[fixture]
777 fn const_array_bool() -> Value {
778 ArrayValue::new(bool_t(), [Value::true_val(), Value::false_val()]).into()
779 }
780
781 #[fixture]
782 fn const_array_options() -> Value {
783 let some_true = Value::some([Value::true_val()]);
784 let none = Value::none(vec![bool_t()]);
785 let elem_ty = SumType::new_option(vec![bool_t()]);
786 ArrayValue::new(elem_ty.into(), [some_true, none]).into()
787 }
788
789 #[rstest]
790 #[case(Value::unit(), Type::UNIT, "const:seq:{}")]
791 #[case(const_usize(), usize_t(), "const:custom:ConstUsize(")]
792 #[case(serialized_float(17.4), float64_type(), "const:custom:json:Object")]
793 #[case(const_tuple(), Type::new_tuple(vec![usize_t(), bool_t()]), "const:seq:{")]
794 #[case(const_array_bool(), array_type(2, bool_t()), "const:custom:array")]
795 #[case(
796 const_array_options(),
797 array_type(2, SumType::new_option(vec![bool_t()]).into()),
798 "const:custom:array"
799 )]
800 fn const_type(
801 #[case] const_value: Value,
802 #[case] expected_type: Type,
803 #[case] name_prefix: &str,
804 ) {
805 assert_eq!(const_value.get_type(), expected_type);
806 let name = const_value.name();
807 assert!(
808 name.starts_with(name_prefix),
809 "{name} does not start with {name_prefix}"
810 );
811 }
812
813 #[rstest]
814 #[case(Value::unit(), Value::unit())]
815 #[case(const_usize(), const_usize())]
816 #[case(const_serialized_usize(), const_usize())]
817 #[case(const_tuple_serialized(), const_tuple())]
818 #[case(const_array_bool(), const_array_bool())]
819 #[case(const_array_options(), const_array_options())]
820 #[cfg_attr(miri, ignore)]
823 fn const_serde_roundtrip(#[case] const_value: Value, #[case] expected_value: Value) {
824 let serialized = serde_json::to_string(&const_value).unwrap();
825 let deserialized: Value = serde_json::from_str(&serialized).unwrap();
826
827 assert_eq!(deserialized, expected_value);
828 }
829
830 #[rstest]
831 fn const_custom_value(const_usize: Value, const_tuple: Value) {
832 assert_eq!(
833 const_usize.get_custom_value::<ConstUsize>(),
834 Some(&ConstUsize::new(257))
835 );
836 assert_eq!(const_usize.get_custom_value::<ConstInt>(), None);
837 assert_eq!(const_tuple.get_custom_value::<ConstUsize>(), None);
838 assert_eq!(const_tuple.get_custom_value::<ConstInt>(), None);
839 }
840
841 #[test]
842 fn test_json_const() {
843 let ex_id: ExtensionId = "my_extension".try_into().unwrap();
844 let typ_int = CustomType::new(
845 "my_type",
846 vec![TypeArg::BoundedNat { n: 8 }],
847 ex_id.clone(),
848 TypeBound::Copyable,
849 &Weak::default(),
851 );
852 let json_const: Value =
853 CustomSerialized::new(typ_int.clone(), 6.into(), ex_id.clone()).into();
854 let classic_t = Type::new_extension(typ_int.clone());
855 assert_matches!(classic_t.least_upper_bound(), TypeBound::Copyable);
856 assert_eq!(json_const.get_type(), classic_t);
857
858 let typ_qb = CustomType::new(
859 "my_type",
860 vec![],
861 ex_id,
862 TypeBound::Copyable,
863 &Weak::default(),
864 );
865 let t = Type::new_extension(typ_qb.clone());
866 assert_ne!(json_const.get_type(), t);
867 }
868
869 #[rstest]
870 fn hash_tuple(const_tuple: Value) {
871 let vals = [
872 Value::unit(),
873 Value::true_val(),
874 Value::false_val(),
875 ConstUsize::new(13).into(),
876 Value::tuple([ConstUsize::new(13).into()]),
877 Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(14).into()]),
878 Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(15).into()]),
879 const_tuple,
880 ];
881
882 let num_vals = vals.len();
883 let hashes = vals.map(|v| {
884 let mut h = DefaultHasher::new();
885 v.try_hash(&mut h).then_some(()).unwrap();
886 h.finish()
887 });
888 assert_eq!(HashSet::from(hashes).len(), num_vals); }
890
891 #[test]
892 fn unhashable_tuple() {
893 let tup = Value::tuple([ConstUsize::new(5).into(), ConstF64::new(4.97).into()]);
894 let mut h1 = DefaultHasher::new();
895 let r = tup.try_hash(&mut h1);
896 assert!(!r);
897
898 h1.write_usize(5);
901 let mut h2 = DefaultHasher::new();
902 h2.write_usize(5);
903 assert_eq!(h1.finish(), h2.finish());
904 }
905
906 mod proptest {
907 use super::super::{OpaqueValue, Sum};
908 use crate::{
909 ops::{constant::CustomSerialized, Value},
910 std_extensions::arithmetic::int_types::ConstInt,
911 std_extensions::collections::list::ListValue,
912 types::{SumType, Type},
913 };
914 use ::proptest::{collection::vec, prelude::*};
915 impl Arbitrary for OpaqueValue {
916 type Parameters = ();
917 type Strategy = BoxedStrategy<Self>;
918 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
919 prop_oneof![
922 any::<ConstInt>().prop_map_into(),
923 any::<CustomSerialized>().prop_map_into()
924 ]
925 .prop_recursive(
926 3, 32, 3, |child_strat| {
930 (any::<Type>(), vec(child_strat, 0..3)).prop_map(|(typ, children)| {
931 Self::new(ListValue::new(
932 typ,
933 children.into_iter().map(|e| Value::Extension { e }),
934 ))
935 })
936 },
937 )
938 .boxed()
939 }
940 }
941
942 impl Arbitrary for Value {
943 type Parameters = ();
944 type Strategy = BoxedStrategy<Self>;
945 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
946 use ::proptest::collection::vec;
947 let leaf_strat = prop_oneof![
948 any::<OpaqueValue>().prop_map(|e| Self::Extension { e }),
949 crate::proptest::any_hugr().prop_map(|x| Value::function(x).unwrap())
950 ];
951 leaf_strat
952 .prop_recursive(
953 3, 32, 3, |element| {
957 prop_oneof![
958 vec(element.clone(), 0..3).prop_map(Self::tuple),
959 (
960 any::<usize>(),
961 vec(element.clone(), 0..3),
962 any_with::<SumType>(1.into()) )
964 .prop_map(
965 |(tag, values, sum_type)| {
966 Self::Sum(Sum {
967 tag,
968 values,
969 sum_type,
970 })
971 }
972 ),
973 ]
974 },
975 )
976 .boxed()
977 }
978 }
979 }
980
981 #[test]
982 fn test_tuple_deserialize() {
983 let json = r#"
984 {
985 "v": "Tuple",
986 "vs": [
987 {
988 "v": "Sum",
989 "tag": 0,
990 "typ": {
991 "t": "Sum",
992 "s": "Unit",
993 "size": 1
994 },
995 "vs": []
996 },
997 {
998 "v": "Sum",
999 "tag": 1,
1000 "typ": {
1001 "t": "Sum",
1002 "s": "General",
1003 "rows": [
1004 [
1005 {
1006 "t": "Sum",
1007 "s": "Unit",
1008 "size": 1
1009 }
1010 ],
1011 [
1012 {
1013 "t": "Sum",
1014 "s": "Unit",
1015 "size": 2
1016 }
1017 ]
1018 ]
1019 },
1020 "vs": [
1021 {
1022 "v": "Sum",
1023 "tag": 1,
1024 "typ": {
1025 "t": "Sum",
1026 "s": "Unit",
1027 "size": 2
1028 },
1029 "vs": []
1030 }
1031 ]
1032 }
1033 ]
1034}
1035 "#;
1036
1037 let v: Value = serde_json::from_str(json).unwrap();
1038 assert_eq!(
1039 v,
1040 Value::tuple([
1041 Value::unit(),
1042 Value::sum(
1043 1,
1044 [Value::true_val()],
1045 SumType::new([
1046 type_row![Type::UNIT],
1047 vec![Value::true_val().get_type()].into()
1048 ]),
1049 )
1050 .unwrap()
1051 ])
1052 );
1053 }
1054}