hugr_core/ops/
constant.rs

1//! Constant value definitions.
2
3mod custom;
4
5use std::borrow::Cow;
6use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76.
7use std::hash::{Hash, Hasher};
8
9use super::{NamedOp, OpName, OpTrait, StaticTag};
10use super::{OpTag, OpType};
11use crate::envelope::serde_with::AsStringEnvelope;
12use crate::types::{CustomType, EdgeKind, Signature, SumType, SumTypeError, Type, TypeRow};
13use crate::{Hugr, HugrView};
14
15use delegate::delegate;
16use itertools::Itertools;
17use serde::{Deserialize, Serialize};
18use serde_with::serde_as;
19use smol_str::SmolStr;
20use thiserror::Error;
21
22pub use custom::{
23    downcast_equal_consts, get_pair_of_input_values, get_single_input_value, CustomConst,
24    CustomSerialized, TryHash,
25};
26
27#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
28/// An operation returning a constant value.
29///
30/// Represents core types and extension types.
31#[non_exhaustive]
32#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
33pub struct Const {
34    /// The [Value] of the constant.
35    #[serde(rename = "v")]
36    pub value: Value,
37}
38
39impl Const {
40    /// Create a new [`Const`] operation.
41    pub fn new(value: Value) -> Self {
42        Self { value }
43    }
44
45    /// The inner value of the [`Const`]
46    pub fn value(&self) -> &Value {
47        &self.value
48    }
49
50    delegate! {
51        to self.value {
52            /// Returns the type of this constant.
53            pub fn get_type(&self) -> Type;
54            /// For a Const holding a CustomConst, extract the CustomConst by
55            /// downcasting.
56            pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T>;
57
58            /// Check the value.
59            pub fn validate(&self) -> Result<(), ConstTypeError>;
60        }
61    }
62}
63
64impl From<Value> for Const {
65    fn from(value: Value) -> Self {
66        Self::new(value)
67    }
68}
69
70impl NamedOp for Const {
71    fn name(&self) -> OpName {
72        self.value().name()
73    }
74}
75
76impl StaticTag for Const {
77    const TAG: OpTag = OpTag::Const;
78}
79
80impl OpTrait for Const {
81    fn description(&self) -> &str {
82        "Constant value"
83    }
84
85    fn tag(&self) -> OpTag {
86        <Self as StaticTag>::TAG
87    }
88
89    fn static_output(&self) -> Option<EdgeKind> {
90        Some(EdgeKind::Const(self.get_type()))
91    }
92
93    // Constants cannot refer to TypeArgs of the enclosing Hugr, so no substitute().
94}
95
96impl From<Const> for Value {
97    fn from(konst: Const) -> Self {
98        konst.value
99    }
100}
101
102impl AsRef<Value> for Const {
103    fn as_ref(&self) -> &Value {
104        self.value()
105    }
106}
107
108#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
109struct SerialSum {
110    #[serde(default)]
111    tag: usize,
112    #[serde(rename = "vs")]
113    values: Vec<Value>,
114    #[serde(default, rename = "typ")]
115    sum_type: Option<SumType>,
116}
117
118#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
119#[serde(try_from = "SerialSum")]
120#[serde(into = "SerialSum")]
121/// A Sum variant, with a tag indicating the index of the variant and its
122/// value.
123pub struct Sum {
124    /// The tag index of the variant.
125    pub tag: usize,
126    /// The value of the variant.
127    ///
128    /// Sum variants are always a row of values, hence the Vec.
129    pub values: Vec<Value>,
130    /// The full type of the Sum, including the other variants.
131    pub sum_type: SumType,
132}
133
134impl Sum {
135    /// If value is a sum with a single row variant, return the row.
136    pub fn as_tuple(&self) -> Option<&[Value]> {
137        // For valid instances, the type row will not have any row variables.
138        self.sum_type.as_tuple().map(|_| self.values.as_ref())
139    }
140
141    fn try_hash<H: Hasher>(&self, st: &mut H) -> bool {
142        maybe_hash_values(&self.values, st) && {
143            st.write_usize(self.tag);
144            self.sum_type.hash(st);
145            true
146        }
147    }
148}
149
150pub(crate) fn maybe_hash_values<H: Hasher>(vals: &[Value], st: &mut H) -> bool {
151    // We can't mutate the Hasher with the first element
152    // if any element, even the last, fails.
153    let mut hasher = DefaultHasher::new();
154    vals.iter().all(|e| e.try_hash(&mut hasher)) && {
155        st.write_u64(hasher.finish());
156        true
157    }
158}
159
160impl TryFrom<SerialSum> for Sum {
161    type Error = &'static str;
162
163    fn try_from(value: SerialSum) -> Result<Self, Self::Error> {
164        let SerialSum {
165            tag,
166            values,
167            sum_type,
168        } = value;
169
170        let sum_type = if let Some(sum_type) = sum_type {
171            sum_type
172        } else {
173            if tag != 0 {
174                return Err("Sum type must be provided if tag is not 0");
175            }
176            SumType::new_tuple(values.iter().map(Value::get_type).collect_vec())
177        };
178
179        Ok(Self {
180            tag,
181            values,
182            sum_type,
183        })
184    }
185}
186
187impl From<Sum> for SerialSum {
188    fn from(value: Sum) -> Self {
189        Self {
190            tag: value.tag,
191            values: value.values,
192            sum_type: Some(value.sum_type),
193        }
194    }
195}
196
197#[serde_as]
198#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
199#[serde(tag = "v")]
200/// A value that can be stored as a static constant. Representing core types and
201/// extension types.
202pub enum Value {
203    /// An extension constant value, that can check it is of a given [CustomType].
204    Extension {
205        #[serde(flatten)]
206        /// The custom constant value.
207        e: OpaqueValue,
208    },
209    /// A higher-order function value.
210    Function {
211        /// A Hugr defining the function.
212        #[serde_as(as = "Box<AsStringEnvelope>")]
213        hugr: Box<Hugr>,
214    },
215    /// A Sum variant, with a tag indicating the index of the variant and its
216    /// value.
217    #[serde(alias = "Tuple")]
218    Sum(Sum),
219}
220
221/// An opaque newtype around a [`Box<dyn CustomConst>`](CustomConst).
222///
223/// This type has special serialization behaviour in order to support
224/// serialization and deserialization of unknown impls of [CustomConst].
225///
226/// During serialization we first serialize the internal [`dyn` CustomConst](CustomConst)
227/// into a [serde_json::Value]. We then create a [CustomSerialized] wrapping
228/// that value.  That [CustomSerialized] is then serialized in place of the
229/// [OpaqueValue].
230///
231/// During deserialization, first we deserialize a [CustomSerialized]. We
232/// attempt to deserialize the internal [serde_json::Value] using the [`Box<dyn
233/// CustomConst>`](CustomConst) impl. This will fail if the appropriate `impl CustomConst`
234/// is not linked into the running program, in which case we coerce the
235/// [CustomSerialized] into a [`Box<dyn CustomConst>`](CustomConst). The [OpaqueValue] is
236/// then produced from the [`Box<dyn [CustomConst]>`](CustomConst).
237///
238/// In the case where the internal serialized value of a `CustomSerialized`
239/// is another `CustomSerialized` we do not attempt to recurse. This behaviour
240/// may change in future.
241///
242#[cfg_attr(not(miri), doc = "```")] // this doctest depends on typetag, so fails with miri
243#[cfg_attr(miri, doc = "```ignore")]
244/// use serde::{Serialize,Deserialize};
245/// use hugr::{
246///   types::Type,ops::constant::{OpaqueValue, ValueName, CustomConst, CustomSerialized},
247///   extension::{ExtensionSet, prelude::{usize_t, ConstUsize}},
248///   std_extensions::arithmetic::int_types};
249/// use serde_json::json;
250///
251/// let expected_json = json!({
252///     "typ": usize_t(),
253///     "value": {'c': "ConstUsize", 'v': 1}
254/// });
255/// let ev = OpaqueValue::new(ConstUsize::new(1));
256/// assert_eq!(&serde_json::to_value(&ev).unwrap(), &expected_json);
257/// assert_eq!(ev, serde_json::from_value(expected_json).unwrap());
258///
259/// let ev = OpaqueValue::new(CustomSerialized::new(usize_t().clone(), serde_json::Value::Null));
260/// let expected_json = json!({
261///     "typ": usize_t(),
262///     "value": null
263/// });
264///
265/// assert_eq!(&serde_json::to_value(ev.clone()).unwrap(), &expected_json);
266/// assert_eq!(ev, serde_json::from_value(expected_json).unwrap());
267/// ```
268#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct OpaqueValue {
270    #[serde(flatten, with = "self::custom::serde_extension_value")]
271    v: Box<dyn CustomConst>,
272}
273
274impl OpaqueValue {
275    /// Create a new [`OpaqueValue`] from any [`CustomConst`].
276    pub fn new(cc: impl CustomConst) -> Self {
277        Self { v: Box::new(cc) }
278    }
279
280    /// Returns a reference to the internal [`CustomConst`].
281    pub fn value(&self) -> &dyn CustomConst {
282        self.v.as_ref()
283    }
284
285    /// Returns a reference to the internal [`CustomConst`].
286    pub(crate) fn value_mut(&mut self) -> &mut dyn CustomConst {
287        self.v.as_mut()
288    }
289
290    delegate! {
291        to self.value() {
292            /// Returns the type of the internal [`CustomConst`].
293            pub fn get_type(&self) -> Type;
294            /// An identifier of the internal [`CustomConst`].
295            pub fn name(&self) -> ValueName;
296        }
297    }
298}
299
300impl<CC: CustomConst> From<CC> for OpaqueValue {
301    fn from(x: CC) -> Self {
302        Self::new(x)
303    }
304}
305
306impl From<Box<dyn CustomConst>> for OpaqueValue {
307    fn from(value: Box<dyn CustomConst>) -> Self {
308        Self { v: value }
309    }
310}
311
312impl PartialEq for OpaqueValue {
313    fn eq(&self, other: &Self) -> bool {
314        self.value().equal_consts(other.value())
315    }
316}
317
318/// Struct for custom type check fails.
319#[derive(Clone, Debug, PartialEq, Eq, Error)]
320#[non_exhaustive]
321pub enum CustomCheckFailure {
322    /// The value had a specific type that was not what was expected
323    #[error("Expected type: {expected} but value was of type: {found}")]
324    TypeMismatch {
325        /// The expected custom type.
326        expected: CustomType,
327        /// The custom type found when checking.
328        found: Type,
329    },
330    /// Any other message
331    #[error("{0}")]
332    Message(String),
333}
334
335/// Errors that arise from typechecking constants
336#[derive(Clone, Debug, PartialEq, Error)]
337#[non_exhaustive]
338pub enum ConstTypeError {
339    /// Invalid sum type definition.
340    #[error("{0}")]
341    SumType(#[from] SumTypeError),
342    /// Function constant missing a function type.
343    #[error(
344        "A function constant cannot be defined using a Hugr with root of type {hugr_root_type}. Must be a monomorphic function.",
345    )]
346    NotMonomorphicFunction {
347        /// The root node type of the Hugr that (claims to) define the function constant.
348        hugr_root_type: OpType,
349    },
350    /// A mismatch between the type expected and the value.
351    #[error("Value {1:?} does not match expected type {0}")]
352    ConstCheckFail(Type, Value),
353    /// Error when checking a custom value.
354    #[error("Error when checking custom type: {0}")]
355    CustomCheckFail(#[from] CustomCheckFailure),
356}
357
358/// Hugrs (even functions) inside Consts must be monomorphic
359fn mono_fn_type(h: &Hugr) -> Result<Cow<'_, Signature>, ConstTypeError> {
360    let err = || ConstTypeError::NotMonomorphicFunction {
361        hugr_root_type: h.entrypoint_optype().clone(),
362    };
363    if let Some(pf) = h.poly_func_type() {
364        match pf.try_into() {
365            Ok(sig) => return Ok(Cow::Owned(sig)),
366            Err(_) => return Err(err()),
367        };
368    }
369
370    h.inner_function_type().ok_or_else(err)
371}
372
373impl Value {
374    /// Returns the type of this [`Value`].
375    pub fn get_type(&self) -> Type {
376        match self {
377            Self::Extension { e } => e.get_type(),
378            Self::Sum(Sum { sum_type, .. }) => sum_type.clone().into(),
379            Self::Function { hugr } => {
380                let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e));
381                Type::new_function(func_type.into_owned())
382            }
383        }
384    }
385
386    /// Returns a Sum constant. The value is determined by `items` and is
387    /// type-checked `typ`. The `tag`th variant of `typ` should match the types
388    /// of `items`.
389    pub fn sum(
390        tag: usize,
391        items: impl IntoIterator<Item = Value>,
392        typ: SumType,
393    ) -> Result<Self, ConstTypeError> {
394        let values: Vec<Value> = items.into_iter().collect();
395        typ.check_type(tag, &values)?;
396        Ok(Self::Sum(Sum {
397            tag,
398            values,
399            sum_type: typ,
400        }))
401    }
402
403    /// Returns a tuple constant of constant values.
404    pub fn tuple(items: impl IntoIterator<Item = Value>) -> Self {
405        let vs = items.into_iter().collect_vec();
406        let tys = vs.iter().map(Self::get_type).collect_vec();
407
408        Self::sum(0, vs, SumType::new_tuple(tys)).expect("Tuple type is valid")
409    }
410
411    /// Returns a constant function defined by a Hugr.
412    ///
413    /// # Errors
414    ///
415    /// Returns an error if the Hugr root node does not define a function.
416    pub fn function(hugr: impl Into<Hugr>) -> Result<Self, ConstTypeError> {
417        let hugr = hugr.into();
418        mono_fn_type(&hugr)?;
419        Ok(Self::Function {
420            hugr: Box::new(hugr),
421        })
422    }
423
424    /// Returns a constant unit type (empty Tuple).
425    pub const fn unit() -> Self {
426        Self::Sum(Sum {
427            tag: 0,
428            values: vec![],
429            sum_type: SumType::Unit { size: 1 },
430        })
431    }
432
433    /// Returns a constant Sum over units. Used as branching values.
434    pub fn unit_sum(tag: usize, size: u8) -> Result<Self, ConstTypeError> {
435        Self::sum(tag, [], SumType::Unit { size })
436    }
437
438    /// Returns a constant Sum over units, with only one variant.
439    pub fn unary_unit_sum() -> Self {
440        Self::unit_sum(0, 1).expect("0 < 1")
441    }
442
443    /// Returns a constant "true" value, i.e. the second variant of Sum((), ()).
444    pub fn true_val() -> Self {
445        Self::unit_sum(1, 2).expect("1 < 2")
446    }
447
448    /// Returns a constant "false" value, i.e. the first variant of Sum((), ()).
449    pub fn false_val() -> Self {
450        Self::unit_sum(0, 2).expect("0 < 2")
451    }
452
453    /// Returns an optional with some values. This is a Sum with two variants, the
454    /// first being empty and the second being the values.
455    pub fn some<V: Into<Value>>(values: impl IntoIterator<Item = V>) -> Self {
456        let values: Vec<Value> = values.into_iter().map(Into::into).collect_vec();
457        let value_types: Vec<Type> = values.iter().map(|v| v.get_type()).collect_vec();
458        let sum_type = SumType::new_option(value_types);
459        Self::sum(1, values, sum_type).unwrap()
460    }
461
462    /// Returns an optional with no value. This is a Sum with two variants, the
463    /// first being empty and the second being the value.
464    pub fn none(value_types: impl Into<TypeRow>) -> Self {
465        Self::sum(0, [], SumType::new_option(value_types)).unwrap()
466    }
467
468    /// Returns a constant `bool` value.
469    ///
470    /// see [`Value::true_val`] and [`Value::false_val`].
471    pub fn from_bool(b: bool) -> Self {
472        if b {
473            Self::true_val()
474        } else {
475            Self::false_val()
476        }
477    }
478
479    /// Returns a [Value::Extension] holding `custom_const`.
480    pub fn extension(custom_const: impl CustomConst) -> Self {
481        Self::Extension {
482            e: OpaqueValue::new(custom_const),
483        }
484    }
485
486    /// For a [Value] holding a [CustomConst], extract the CustomConst by downcasting.
487    pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
488        if let Self::Extension { e } = self {
489            e.v.downcast_ref()
490        } else {
491            None
492        }
493    }
494
495    fn name(&self) -> OpName {
496        match self {
497            Self::Extension { e } => format!("const:custom:{}", e.name()),
498            Self::Function { hugr: h } => {
499                let Ok(t) = mono_fn_type(h) else {
500                    panic!("HUGR root node isn't a valid function parent.");
501                };
502                format!("const:function:[{}]", t)
503            }
504            Self::Sum(Sum {
505                tag,
506                values,
507                sum_type,
508            }) => {
509                if sum_type.as_tuple().is_some() {
510                    let names: Vec<_> = values.iter().map(Value::name).collect();
511                    format!("const:seq:{{{}}}", names.iter().join(", "))
512                } else {
513                    format!("const:sum:{{tag:{tag}, vals:{values:?}}}")
514                }
515            }
516        }
517        .into()
518    }
519
520    /// Check the value.
521    pub fn validate(&self) -> Result<(), ConstTypeError> {
522        match self {
523            Self::Extension { e } => Ok(e.value().validate()?),
524            Self::Function { hugr } => {
525                mono_fn_type(hugr)?;
526                Ok(())
527            }
528            Self::Sum(Sum {
529                tag,
530                values,
531                sum_type,
532            }) => {
533                sum_type.check_type(*tag, values)?;
534                Ok(())
535            }
536        }
537    }
538
539    /// If value is a sum with a single row variant, return the row.
540    pub fn as_tuple(&self) -> Option<&[Value]> {
541        if let Self::Sum(sum) = self {
542            sum.as_tuple()
543        } else {
544            None
545        }
546    }
547
548    /// Hashes this value, if possible. [Value::Extension]s are hashable according
549    /// to their implementation of [TryHash]; [Value::Function]s never are;
550    /// [Value::Sum]s are if their contents are.
551    pub fn try_hash<H: Hasher>(&self, st: &mut H) -> bool {
552        match self {
553            Value::Extension { e } => e.value().try_hash(&mut *st),
554            Value::Function { .. } => false,
555            Value::Sum(s) => s.try_hash(st),
556        }
557    }
558}
559
560impl<T> From<T> for Value
561where
562    T: CustomConst,
563{
564    fn from(value: T) -> Self {
565        Self::extension(value)
566    }
567}
568
569/// A unique identifier for a constant value.
570pub type ValueName = SmolStr;
571
572/// Slice of a [`ValueName`] constant value identifier.
573pub type ValueNameRef = str;
574
575#[cfg(test)]
576pub(crate) mod test {
577    use std::collections::HashSet;
578    use std::sync::{Arc, Weak};
579
580    use super::Value;
581    use crate::builder::inout_sig;
582    use crate::builder::test::simple_dfg_hugr;
583    use crate::extension::prelude::{bool_t, usize_custom_t};
584    use crate::extension::resolution::{
585        resolve_custom_type_extensions, resolve_typearg_extensions, ExtensionResolutionError,
586        WeakExtensionRegistry,
587    };
588    use crate::extension::PRELUDE;
589    use crate::std_extensions::arithmetic::int_types::ConstInt;
590    use crate::std_extensions::collections::array::{array_type, ArrayValue};
591    use crate::std_extensions::collections::value_array::{value_array_type, VArrayValue};
592    use crate::{
593        builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
594        extension::{
595            prelude::{usize_t, ConstUsize},
596            ExtensionId,
597        },
598        std_extensions::arithmetic::float_types::{float64_type, ConstF64},
599        type_row,
600        types::type_param::TypeArg,
601        types::{Type, TypeBound, TypeRow},
602    };
603    use cool_asserts::assert_matches;
604    use rstest::{fixture, rstest};
605
606    use super::*;
607
608    #[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
609    /// A custom constant value used in testing
610    pub(crate) struct CustomTestValue(pub CustomType);
611
612    #[typetag::serde]
613    impl CustomConst for CustomTestValue {
614        fn name(&self) -> ValueName {
615            format!("CustomTestValue({:?})", self.0).into()
616        }
617
618        fn update_extensions(
619            &mut self,
620            extensions: &WeakExtensionRegistry,
621        ) -> Result<(), ExtensionResolutionError> {
622            resolve_custom_type_extensions(&mut self.0, extensions)?;
623            // This loop is redundant, but we use it to test the public
624            // function.
625            for arg in self.0.args_mut() {
626                resolve_typearg_extensions(arg, extensions)?;
627            }
628            Ok(())
629        }
630
631        fn get_type(&self) -> Type {
632            self.0.clone().into()
633        }
634
635        fn equal_consts(&self, other: &dyn CustomConst) -> bool {
636            crate::ops::constant::downcast_equal_consts(self, other)
637        }
638    }
639
640    /// A [`CustomSerialized`] encoding a [`float64_type()`] float constant used in testing.
641    pub(crate) fn serialized_float(f: f64) -> Value {
642        CustomSerialized::try_from_custom_const(ConstF64::new(f))
643            .unwrap()
644            .into()
645    }
646
647    /// Constructs a DFG hugr defining a sum constant, and returning the loaded value.
648    #[test]
649    fn test_sum() -> Result<(), BuildError> {
650        use crate::builder::Container;
651        let pred_rows = vec![vec![usize_t(), float64_type()].into(), Type::EMPTY_TYPEROW];
652        let pred_ty = SumType::new(pred_rows.clone());
653
654        let mut b = DFGBuilder::new(inout_sig(
655            type_row![],
656            TypeRow::from(vec![pred_ty.clone().into()]),
657        ))?;
658        let usize_custom_t = usize_custom_t(&Arc::downgrade(&PRELUDE));
659        let c = b.add_constant(Value::sum(
660            0,
661            [
662                CustomTestValue(usize_custom_t.clone()).into(),
663                ConstF64::new(5.1).into(),
664            ],
665            pred_ty.clone(),
666        )?);
667        let w = b.load_const(&c);
668        b.finish_hugr_with_outputs([w]).unwrap();
669
670        let mut b = DFGBuilder::new(Signature::new(
671            type_row![],
672            TypeRow::from(vec![pred_ty.clone().into()]),
673        ))?;
674        let c = b.add_constant(Value::sum(1, [], pred_ty.clone())?);
675        let w = b.load_const(&c);
676        b.finish_hugr_with_outputs([w]).unwrap();
677
678        Ok(())
679    }
680
681    #[test]
682    fn test_bad_sum() {
683        let pred_ty = SumType::new([vec![usize_t(), float64_type()].into(), type_row![]]);
684
685        let good_sum = const_usize();
686        println!("{}", serde_json::to_string_pretty(&good_sum).unwrap());
687
688        let good_sum =
689            Value::sum(0, [const_usize(), serialized_float(5.1)], pred_ty.clone()).unwrap();
690        println!("{}", serde_json::to_string_pretty(&good_sum).unwrap());
691
692        let res = Value::sum(0, [], pred_ty.clone());
693        assert_matches!(
694            res,
695            Err(ConstTypeError::SumType(SumTypeError::WrongVariantLength {
696                tag: 0,
697                expected: 2,
698                found: 0
699            }))
700        );
701
702        let res = Value::sum(4, [], pred_ty.clone());
703        assert_matches!(
704            res,
705            Err(ConstTypeError::SumType(SumTypeError::InvalidTag {
706                tag: 4,
707                num_variants: 2
708            }))
709        );
710
711        let res = Value::sum(0, [const_usize(), const_usize()], pred_ty);
712        assert_matches!(
713            res,
714            Err(ConstTypeError::SumType(SumTypeError::InvalidValueType {
715                tag: 0,
716                index: 1,
717                expected,
718                found,
719            })) if expected == float64_type() && found == const_usize()
720        );
721    }
722
723    #[rstest]
724    fn function_value(simple_dfg_hugr: Hugr) {
725        let v = Value::function(simple_dfg_hugr).unwrap();
726
727        let correct_type = Type::new_function(Signature::new_endo(vec![bool_t()]));
728
729        assert_eq!(v.get_type(), correct_type);
730        assert!(v.name().starts_with("const:function:"))
731    }
732
733    #[fixture]
734    fn const_usize() -> Value {
735        ConstUsize::new(257).into()
736    }
737
738    #[fixture]
739    fn const_serialized_usize() -> Value {
740        CustomSerialized::try_from_custom_const(ConstUsize::new(257))
741            .unwrap()
742            .into()
743    }
744
745    #[fixture]
746    fn const_tuple() -> Value {
747        Value::tuple([const_usize(), Value::true_val()])
748    }
749
750    /// Equivalent to [`const_tuple`], but uses a non-resolved opaque op for the usize element.
751    #[fixture]
752    fn const_tuple_serialized() -> Value {
753        Value::tuple([const_serialized_usize(), Value::true_val()])
754    }
755
756    #[fixture]
757    fn const_array_bool() -> Value {
758        ArrayValue::new(bool_t(), [Value::true_val(), Value::false_val()]).into()
759    }
760
761    #[fixture]
762    fn const_value_array_bool() -> Value {
763        VArrayValue::new(bool_t(), [Value::true_val(), Value::false_val()]).into()
764    }
765
766    #[fixture]
767    fn const_array_options() -> Value {
768        let some_true = Value::some([Value::true_val()]);
769        let none = Value::none(vec![bool_t()]);
770        let elem_ty = SumType::new_option(vec![bool_t()]);
771        ArrayValue::new(elem_ty.into(), [some_true, none]).into()
772    }
773
774    #[fixture]
775    fn const_value_array_options() -> Value {
776        let some_true = Value::some([Value::true_val()]);
777        let none = Value::none(vec![bool_t()]);
778        let elem_ty = SumType::new_option(vec![bool_t()]);
779        VArrayValue::new(elem_ty.into(), [some_true, none]).into()
780    }
781
782    #[rstest]
783    #[case(Value::unit(), Type::UNIT, "const:seq:{}")]
784    #[case(const_usize(), usize_t(), "const:custom:ConstUsize(")]
785    #[case(serialized_float(17.4), float64_type(), "const:custom:json:Object")]
786    #[case(const_tuple(), Type::new_tuple(vec![usize_t(), bool_t()]), "const:seq:{")]
787    #[case(const_array_bool(), array_type(2, bool_t()), "const:custom:array")]
788    #[case(
789        const_value_array_bool(),
790        value_array_type(2, bool_t()),
791        "const:custom:value_array"
792    )]
793    #[case(
794        const_array_options(),
795        array_type(2, SumType::new_option(vec![bool_t()]).into()),
796        "const:custom:array"
797    )]
798    #[case(
799        const_value_array_options(),
800        value_array_type(2, SumType::new_option(vec![bool_t()]).into()),
801        "const:custom:value_array"
802    )]
803    fn const_type(
804        #[case] const_value: Value,
805        #[case] expected_type: Type,
806        #[case] name_prefix: &str,
807    ) {
808        assert_eq!(const_value.get_type(), expected_type);
809        let name = const_value.name();
810        assert!(
811            name.starts_with(name_prefix),
812            "{name} does not start with {name_prefix}"
813        );
814    }
815
816    #[rstest]
817    #[case(Value::unit(), Value::unit())]
818    #[case(const_usize(), const_usize())]
819    #[case(const_serialized_usize(), const_usize())]
820    #[case(const_tuple_serialized(), const_tuple())]
821    #[case(const_array_bool(), const_array_bool())]
822    #[case(const_value_array_bool(), const_value_array_bool())]
823    #[case(const_array_options(), const_array_options())]
824    #[case(const_value_array_options(), const_value_array_options())]
825    // Opaque constants don't get resolved into concrete types when running miri,
826    // as the `typetag` machinery is not available.
827    #[cfg_attr(miri, ignore)]
828    fn const_serde_roundtrip(#[case] const_value: Value, #[case] expected_value: Value) {
829        let serialized = serde_json::to_string(&const_value).unwrap();
830        let deserialized: Value = serde_json::from_str(&serialized).unwrap();
831
832        assert_eq!(deserialized, expected_value);
833    }
834
835    #[rstest]
836    fn const_custom_value(const_usize: Value, const_tuple: Value) {
837        assert_eq!(
838            const_usize.get_custom_value::<ConstUsize>(),
839            Some(&ConstUsize::new(257))
840        );
841        assert_eq!(const_usize.get_custom_value::<ConstInt>(), None);
842        assert_eq!(const_tuple.get_custom_value::<ConstUsize>(), None);
843        assert_eq!(const_tuple.get_custom_value::<ConstInt>(), None);
844    }
845
846    #[test]
847    fn test_json_const() {
848        let ex_id: ExtensionId = "my_extension".try_into().unwrap();
849        let typ_int = CustomType::new(
850            "my_type",
851            vec![TypeArg::BoundedNat { n: 8 }],
852            ex_id.clone(),
853            TypeBound::Copyable,
854            // Dummy extension reference.
855            &Weak::default(),
856        );
857        let json_const: Value = CustomSerialized::new(typ_int.clone(), 6.into()).into();
858        let classic_t = Type::new_extension(typ_int.clone());
859        assert_matches!(classic_t.least_upper_bound(), TypeBound::Copyable);
860        assert_eq!(json_const.get_type(), classic_t);
861
862        let typ_qb = CustomType::new(
863            "my_type",
864            vec![],
865            ex_id,
866            TypeBound::Copyable,
867            &Weak::default(),
868        );
869        let t = Type::new_extension(typ_qb.clone());
870        assert_ne!(json_const.get_type(), t);
871    }
872
873    #[rstest]
874    fn hash_tuple(const_tuple: Value) {
875        let vals = [
876            Value::unit(),
877            Value::true_val(),
878            Value::false_val(),
879            ConstUsize::new(13).into(),
880            Value::tuple([ConstUsize::new(13).into()]),
881            Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(14).into()]),
882            Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(15).into()]),
883            const_tuple,
884        ];
885
886        let num_vals = vals.len();
887        let hashes = vals.map(|v| {
888            let mut h = DefaultHasher::new();
889            v.try_hash(&mut h).then_some(()).unwrap();
890            h.finish()
891        });
892        assert_eq!(HashSet::from(hashes).len(), num_vals); // all distinct
893    }
894
895    #[test]
896    fn unhashable_tuple() {
897        let tup = Value::tuple([ConstUsize::new(5).into(), ConstF64::new(4.97).into()]);
898        let mut h1 = DefaultHasher::new();
899        let r = tup.try_hash(&mut h1);
900        assert!(!r);
901
902        // Check that didn't do anything, by checking the hasher behaves
903        // just like one which never saw the tuple
904        h1.write_usize(5);
905        let mut h2 = DefaultHasher::new();
906        h2.write_usize(5);
907        assert_eq!(h1.finish(), h2.finish());
908    }
909
910    mod proptest {
911        use super::super::{OpaqueValue, Sum};
912        use crate::{
913            ops::{constant::CustomSerialized, Value},
914            std_extensions::arithmetic::int_types::ConstInt,
915            std_extensions::collections::list::ListValue,
916            types::{SumType, Type},
917        };
918        use ::proptest::{collection::vec, prelude::*};
919        impl Arbitrary for OpaqueValue {
920            type Parameters = ();
921            type Strategy = BoxedStrategy<Self>;
922            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
923                // We intentionally do not include `ConstF64` because it does not
924                // roundtrip serialize
925                prop_oneof![
926                    any::<ConstInt>().prop_map_into(),
927                    any::<CustomSerialized>().prop_map_into()
928                ]
929                .prop_recursive(
930                    3,  // No more than 3 branch levels deep
931                    32, // Target around 32 total elements
932                    3,  // Each collection is up to 3 elements long
933                    |child_strat| {
934                        (any::<Type>(), vec(child_strat, 0..3)).prop_map(|(typ, children)| {
935                            Self::new(ListValue::new(
936                                typ,
937                                children.into_iter().map(|e| Value::Extension { e }),
938                            ))
939                        })
940                    },
941                )
942                .boxed()
943            }
944        }
945
946        impl Arbitrary for Value {
947            type Parameters = ();
948            type Strategy = BoxedStrategy<Self>;
949            fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
950                use ::proptest::collection::vec;
951                let leaf_strat = prop_oneof![
952                    any::<OpaqueValue>().prop_map(|e| Self::Extension { e }),
953                    crate::proptest::any_hugr().prop_map(|x| Value::function(x).unwrap())
954                ];
955                leaf_strat
956                    .prop_recursive(
957                        3,  // No more than 3 branch levels deep
958                        32, // Target around 32 total elements
959                        3,  // Each collection is up to 3 elements long
960                        |element| {
961                            prop_oneof![
962                                vec(element.clone(), 0..3).prop_map(Self::tuple),
963                                (
964                                    any::<usize>(),
965                                    vec(element.clone(), 0..3),
966                                    any_with::<SumType>(1.into()) // for speed: don't generate large sum types for now
967                                )
968                                    .prop_map(
969                                        |(tag, values, sum_type)| {
970                                            Self::Sum(Sum {
971                                                tag,
972                                                values,
973                                                sum_type,
974                                            })
975                                        }
976                                    ),
977                            ]
978                        },
979                    )
980                    .boxed()
981            }
982        }
983    }
984
985    #[test]
986    fn test_tuple_deserialize() {
987        let json = r#"
988        {
989    "v": "Tuple",
990    "vs": [
991        {
992            "v": "Sum",
993            "tag": 0,
994            "typ": {
995                "t": "Sum",
996                "s": "Unit",
997                "size": 1
998            },
999            "vs": []
1000        },
1001        {
1002            "v": "Sum",
1003            "tag": 1,
1004            "typ": {
1005                "t": "Sum",
1006                "s": "General",
1007                "rows": [
1008                    [
1009                        {
1010                            "t": "Sum",
1011                            "s": "Unit",
1012                            "size": 1
1013                        }
1014                    ],
1015                    [
1016                        {
1017                            "t": "Sum",
1018                            "s": "Unit",
1019                            "size": 2
1020                        }
1021                    ]
1022                ]
1023            },
1024            "vs": [
1025                {
1026                    "v": "Sum",
1027                    "tag": 1,
1028                    "typ": {
1029                        "t": "Sum",
1030                        "s": "Unit",
1031                        "size": 2
1032                    },
1033                    "vs": []
1034                }
1035            ]
1036        }
1037    ]
1038}
1039        "#;
1040
1041        let v: Value = serde_json::from_str(json).unwrap();
1042        assert_eq!(
1043            v,
1044            Value::tuple([
1045                Value::unit(),
1046                Value::sum(
1047                    1,
1048                    [Value::true_val()],
1049                    SumType::new([
1050                        type_row![Type::UNIT],
1051                        vec![Value::true_val().get_type()].into()
1052                    ]),
1053                )
1054                .unwrap()
1055            ])
1056        );
1057    }
1058}