hugr_core/ops/constant/
custom.rs

1//! Representation of custom constant values.
2//!
3//! These can be used as [`Const`] operations in HUGRs.
4//!
5//! [`Const`]: crate::ops::Const
6
7use std::any::Any;
8use std::hash::{Hash, Hasher};
9
10use downcast_rs::{impl_downcast, Downcast};
11use thiserror::Error;
12
13use crate::extension::resolution::{
14    resolve_type_extensions, ExtensionResolutionError, WeakExtensionRegistry,
15};
16use crate::extension::ExtensionSet;
17use crate::macros::impl_box_clone;
18use crate::types::{CustomCheckFailure, Type};
19use crate::IncomingPort;
20
21use super::{Value, ValueName};
22
23/// Extensible constant values.
24///
25/// We use [typetag] to provide an `impl Serialize for dyn CustomConst`, and
26/// similarly [serde::Deserialize]. When implementing this trait, include the
27/// [`#[typetag::serde]`](typetag) attribute to enable serialization.
28///
29/// Note that when serializing through the [`dyn CustomConst`] a dictionary will
30/// be serialized with two attributes, `"c"`  the tag and `"v"` the
31/// `CustomConst`:
32///
33#[cfg_attr(not(miri), doc = "```")] // this doctest depends on typetag, so fails with miri
34#[cfg_attr(miri, doc = "```ignore")]
35/// use serde::{Serialize,Deserialize};
36/// use hugr::{
37///   types::Type,ops::constant::{OpaqueValue, ValueName, CustomConst},
38///   extension::ExtensionSet, std_extensions::arithmetic::int_types};
39/// use serde_json::json;
40///
41/// #[derive(std::fmt::Debug, Clone, Hash, Serialize,Deserialize)]
42/// struct CC(i64);
43///
44/// #[typetag::serde]
45/// impl CustomConst for CC {
46///   fn name(&self) -> ValueName { "CC".into() }
47///   fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::singleton(int_types::EXTENSION_ID) }
48///   fn get_type(&self) -> Type { int_types::INT_TYPES[5].clone() }
49/// }
50///
51/// assert_eq!(serde_json::to_value(CC(2)).unwrap(), json!(2));
52/// assert_eq!(serde_json::to_value(&CC(2) as &dyn CustomConst).unwrap(), json!({
53///   "c": "CC",
54///   "v": 2
55/// }));
56/// ```
57#[typetag::serde(tag = "c", content = "v")]
58pub trait CustomConst:
59    Send + Sync + std::fmt::Debug + TryHash + CustomConstBoxClone + Any + Downcast
60{
61    /// An identifier for the constant.
62    fn name(&self) -> ValueName;
63
64    /// The extension(s) defining the custom constant
65    /// (a set to allow, say, a [List] of [USize])
66    ///
67    /// [List]: crate::std_extensions::collections::list::LIST_TYPENAME
68    /// [USize]: crate::extension::prelude::usize_t
69    fn extension_reqs(&self) -> ExtensionSet;
70
71    /// Check the value.
72    fn validate(&self) -> Result<(), CustomCheckFailure> {
73        Ok(())
74    }
75
76    /// Compare two constants for equality, using downcasting and comparing the definitions.
77    ///
78    /// If the type implements `PartialEq`, use [`downcast_equal_consts`] to compare the values.
79    ///
80    /// Note that this does not require any equivalent of [Eq]: it is permissible to return
81    /// `false` if in doubt, and in particular, there is no requirement for reflexivity
82    /// (i.e. `x.equal_consts(x)` can be `false`). However, we do expect both
83    /// symmetry (`x.equal_consts(y) == y.equal_consts(x)`) and transitivity
84    /// (if `x.equal_consts(y) && y.equal_consts(z)` then `x.equal_consts(z)`).
85    fn equal_consts(&self, _other: &dyn CustomConst) -> bool {
86        // false unless overridden
87        false
88    }
89
90    /// Update the extensions associated with the internal values.
91    ///
92    /// This is used to ensure that any extension reference [`CustomConst::get_type`] remains
93    /// valid when serializing and deserializing the constant.
94    ///
95    /// See the helper methods in [`crate::extension::resolution`].
96    fn update_extensions(
97        &mut self,
98        _extensions: &WeakExtensionRegistry,
99    ) -> Result<(), ExtensionResolutionError> {
100        Ok(())
101    }
102
103    /// Report the type.
104    fn get_type(&self) -> Type;
105}
106
107/// Fallible hash function.
108///
109/// Prerequisite for `CustomConst`. Allows to declare a custom hash function,
110/// but the easiest options are either to `impl TryHash for ... {}` to indicate
111/// "not hashable", or else to implement/derive [Hash].
112pub trait TryHash {
113    /// Hashes the value, if possible; else return `false` without mutating the `Hasher`.
114    /// This relates with [CustomConst::equal_consts] just like [Hash] with [Eq]:
115    /// * if `x.equal_consts(y)` ==> `x.try_hash(s)` behaves equivalently to `y.try_hash(s)`
116    /// * if `x.hash(s)` behaves differently from `y.hash(s)` ==> `x.equal_consts(y) == false`
117    ///
118    /// As with [Hash], these requirements can trivially be satisfied by either
119    /// * `equal_consts` always returning `false`, or
120    /// * `try_hash` always behaving the same (e.g. returning `false`, as it does by default)
121    ///
122    /// Note: uses `dyn` rather than being parametrized by `<H: Hasher>` to be object-safe.
123    fn try_hash(&self, _state: &mut dyn Hasher) -> bool {
124        false
125    }
126}
127
128impl<T: Hash> TryHash for T {
129    fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
130        Hash::hash(self, &mut st);
131        true
132    }
133}
134
135impl PartialEq for dyn CustomConst {
136    fn eq(&self, other: &Self) -> bool {
137        (*self).equal_consts(other)
138    }
139}
140
141/// Const equality for types that have PartialEq
142pub fn downcast_equal_consts<T: CustomConst + PartialEq>(
143    constant: &T,
144    other: &dyn CustomConst,
145) -> bool {
146    if let Some(other) = other.as_any().downcast_ref::<T>() {
147        constant == other
148    } else {
149        false
150    }
151}
152
153/// Serialize any CustomConst using the `impl Serialize for &dyn CustomConst`.
154fn serialize_custom_const(cc: &dyn CustomConst) -> Result<serde_json::Value, serde_json::Error> {
155    serde_json::to_value(cc)
156}
157
158/// Deserialize a `Box<&dyn CustomConst>` and attempt to downcast it to `CC`;
159/// propagating failure.
160fn deserialize_custom_const<CC: CustomConst>(
161    value: serde_json::Value,
162) -> Result<CC, serde_json::Error> {
163    match deserialize_dyn_custom_const(value)?.downcast::<CC>() {
164        Ok(cc) => Ok(*cc),
165        Err(dyn_cc) => Err(<serde_json::Error as serde::de::Error>::custom(format!(
166            "Failed to deserialize [{}]: {:?}",
167            std::any::type_name::<CC>(),
168            dyn_cc
169        ))),
170    }
171}
172
173/// Deserialize a `Box<&dyn CustomConst>`.
174fn deserialize_dyn_custom_const(
175    value: serde_json::Value,
176) -> Result<Box<dyn CustomConst>, serde_json::Error> {
177    serde_json::from_value(value)
178}
179
180impl_downcast!(CustomConst);
181impl_box_clone!(CustomConst, CustomConstBoxClone);
182
183#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
184/// A constant value stored as a serialized blob that can report its own type.
185pub struct CustomSerialized {
186    typ: Type,
187    value: serde_json::Value,
188    extensions: ExtensionSet,
189}
190
191#[derive(Debug, Error)]
192#[error("Error serializing value into CustomSerialized: err: {err}, value: {payload:?}")]
193pub struct SerializeError {
194    #[source]
195    err: serde_json::Error,
196    payload: Box<dyn CustomConst>,
197}
198
199#[derive(Debug, Error)]
200#[error("Error deserializing value from CustomSerialized: err: {err}, value: {payload:?}")]
201pub struct DeserializeError {
202    #[source]
203    err: serde_json::Error,
204    payload: serde_json::Value,
205}
206
207impl CustomSerialized {
208    /// Creates a new [`CustomSerialized`].
209    pub fn new(
210        typ: impl Into<Type>,
211        value: serde_json::Value,
212        exts: impl Into<ExtensionSet>,
213    ) -> Self {
214        Self {
215            typ: typ.into(),
216            value,
217            extensions: exts.into(),
218        }
219    }
220
221    /// Returns the inner value.
222    pub fn value(&self) -> &serde_json::Value {
223        &self.value
224    }
225
226    /// If `cc` is a [Self], returns a clone of `cc` coerced to [Self].
227    /// Otherwise, returns a [Self] with `cc` serialized in it's value.
228    pub fn try_from_custom_const_ref(cc: &impl CustomConst) -> Result<Self, SerializeError> {
229        Self::try_from_dyn_custom_const(cc)
230    }
231
232    /// If `cc` is a [Self], returns a clone of `cc` coerced to [Self].
233    /// Otherwise, returns a [Self] with `cc` serialized in it's value.
234    pub fn try_from_dyn_custom_const(cc: &dyn CustomConst) -> Result<Self, SerializeError> {
235        Ok(match cc.as_any().downcast_ref::<Self>() {
236            Some(cs) => cs.clone(),
237            None => Self::new(
238                cc.get_type(),
239                serialize_custom_const(cc).map_err(|err| SerializeError {
240                    err,
241                    payload: cc.clone_box(),
242                })?,
243                cc.extension_reqs(),
244            ),
245        })
246    }
247
248    /// If `cc` is a [Self], return `cc` coerced to [Self]. Otherwise,
249    /// returns a [Self] with `cc` serialized in it's value.
250    /// Never clones `cc` outside of error paths.
251    pub fn try_from_custom_const(cc: impl CustomConst) -> Result<Self, SerializeError> {
252        Self::try_from_custom_const_box(Box::new(cc))
253    }
254
255    /// If `cc` is a [Self], return `cc` coerced to [Self]. Otherwise,
256    /// returns a [Self] with `cc` serialized in it's value.
257    /// Never clones `cc` outside of error paths.
258    pub fn try_from_custom_const_box(cc: Box<dyn CustomConst>) -> Result<Self, SerializeError> {
259        match cc.downcast::<Self>() {
260            Ok(x) => Ok(*x),
261            Err(cc) => {
262                let (typ, extension_reqs) = (cc.get_type(), cc.extension_reqs());
263                let value = serialize_custom_const(cc.as_ref())
264                    .map_err(|err| SerializeError { err, payload: cc })?;
265                Ok(Self::new(typ, value, extension_reqs))
266            }
267        }
268    }
269
270    /// Attempts to deserialize the value in self into a `Box<dyn CustomConst>`.
271    /// This can fail, in particular when the `impl CustomConst` for the trait
272    /// is not linked into the running executable.
273    /// If deserialization fails, returns self in a box.
274    ///
275    /// Note that if the inner value is a [Self] we do not recursively
276    /// deserialize it.
277    pub fn into_custom_const_box(self) -> Box<dyn CustomConst> {
278        // ideally we would not have to clone, but serde_json does not allow us
279        // to recover the value from the error
280        deserialize_dyn_custom_const(self.value.clone()).unwrap_or_else(|_| Box::new(self))
281    }
282
283    /// Attempts to deserialize the value in self into a `CC`. Propagates failure.
284    ///
285    /// Note that if the inner value is a [Self] we do not recursively
286    /// deserialize it. In particular if that inner value were a [Self] whose
287    /// inner value were a `CC`, then we would still fail.
288    pub fn try_into_custom_const<CC: CustomConst>(self) -> Result<CC, DeserializeError> {
289        // ideally we would not have to clone, but serde_json does not allow us
290        // to recover the value from the error
291        deserialize_custom_const(self.value.clone()).map_err(|err| DeserializeError {
292            err,
293            payload: self.value,
294        })
295    }
296}
297
298impl TryHash for CustomSerialized {
299    fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
300        // Consistent with equality, same serialization <=> same hash.
301        self.value.to_string().hash(&mut st);
302        true
303    }
304}
305
306#[typetag::serde]
307impl CustomConst for CustomSerialized {
308    fn name(&self) -> ValueName {
309        format!("json:{:?}", self.value).into()
310    }
311
312    fn equal_consts(&self, other: &dyn CustomConst) -> bool {
313        Some(self) == other.downcast_ref()
314    }
315
316    fn extension_reqs(&self) -> ExtensionSet {
317        self.extensions.clone()
318    }
319    fn update_extensions(
320        &mut self,
321        extensions: &WeakExtensionRegistry,
322    ) -> Result<(), ExtensionResolutionError> {
323        resolve_type_extensions(&mut self.typ, extensions)
324    }
325    fn get_type(&self) -> Type {
326        self.typ.clone()
327    }
328}
329
330/// This module is used by the serde annotations on `super::OpaqueValue`
331pub(super) mod serde_extension_value {
332    use serde::{Deserializer, Serializer};
333
334    use super::{CustomConst, CustomSerialized};
335
336    pub fn deserialize<'de, D: Deserializer<'de>>(
337        deserializer: D,
338    ) -> Result<Box<dyn CustomConst>, D::Error> {
339        use serde::Deserialize;
340        // We deserialize a CustomSerialized, i.e. not a dyn CustomConst.
341        let cs = CustomSerialized::deserialize(deserializer)?;
342        // We return the inner serialized CustomConst if we can, otherwise the
343        // CustomSerialized itself.
344        Ok(cs.into_custom_const_box())
345    }
346
347    pub fn serialize<S: Serializer>(
348        konst: impl AsRef<dyn CustomConst>,
349        serializer: S,
350    ) -> Result<S::Ok, S::Error> {
351        use serde::Serialize;
352        // we create a CustomSerialized, then serialize it. Note we do not
353        // serialize it as a dyn CustomConst.
354        let cs = CustomSerialized::try_from_dyn_custom_const(konst.as_ref())
355            .map_err(<S::Error as serde::ser::Error>::custom)?;
356        cs.serialize(serializer)
357    }
358}
359
360/// Given a singleton list of constant operations, return the value.
361pub fn get_single_input_value<T: CustomConst>(consts: &[(IncomingPort, Value)]) -> Option<&T> {
362    let [(_, c)] = consts else {
363        return None;
364    };
365    c.get_custom_value()
366}
367
368/// Given a list of two constant operations, return the values.
369pub fn get_pair_of_input_values<T: CustomConst>(
370    consts: &[(IncomingPort, Value)],
371) -> Option<(&T, &T)> {
372    let [(_, c0), (_, c1)] = consts else {
373        return None;
374    };
375    Some((c0.get_custom_value()?, c1.get_custom_value()?))
376}
377
378// these tests depend on the `typetag` crate.
379#[cfg(all(test, not(miri)))]
380mod test {
381
382    use rstest::rstest;
383
384    use crate::{
385        extension::prelude::{usize_t, ConstUsize},
386        ops::{constant::custom::serialize_custom_const, Value},
387        std_extensions::collections::list::ListValue,
388    };
389
390    use super::{super::OpaqueValue, CustomConst, CustomConstBoxClone, CustomSerialized};
391
392    struct SerializeCustomConstExample<CC: CustomConst + serde::Serialize + 'static> {
393        cc: CC,
394        tag: &'static str,
395        json: serde_json::Value,
396    }
397
398    impl<CC: CustomConst + serde::Serialize + 'static> SerializeCustomConstExample<CC> {
399        fn new(cc: CC, tag: &'static str) -> Self {
400            let json = serde_json::to_value(&cc).unwrap();
401            Self { cc, tag, json }
402        }
403    }
404
405    fn scce_usize() -> SerializeCustomConstExample<ConstUsize> {
406        SerializeCustomConstExample::new(ConstUsize::new(12), "ConstUsize")
407    }
408
409    fn scce_list() -> SerializeCustomConstExample<ListValue> {
410        let cc = ListValue::new(
411            usize_t(),
412            [ConstUsize::new(1), ConstUsize::new(2)]
413                .into_iter()
414                .map(Value::extension),
415        );
416        SerializeCustomConstExample::new(cc, "ListValue")
417    }
418
419    #[rstest]
420    #[cfg_attr(miri, ignore = "miri is incompatible with the typetag crate")]
421    #[case(scce_usize())]
422    #[case(scce_list())]
423    fn test_custom_serialized_try_from<
424        CC: CustomConst + serde::Serialize + Clone + PartialEq + 'static + Sized,
425    >(
426        #[case] example: SerializeCustomConstExample<CC>,
427    ) {
428        assert_eq!(example.json, serde_json::to_value(&example.cc).unwrap()); // sanity check
429        let expected_json: serde_json::Value = [
430            ("c".into(), example.tag.into()),
431            ("v".into(), example.json.clone()),
432        ]
433        .into_iter()
434        .collect::<serde_json::Map<String, serde_json::Value>>()
435        .into();
436
437        // check serialize_custom_const
438        assert_eq!(expected_json, serialize_custom_const(&example.cc).unwrap());
439
440        let expected_custom_serialized = CustomSerialized::new(
441            example.cc.get_type(),
442            expected_json,
443            example.cc.extension_reqs(),
444        );
445
446        // check all the try_from/try_into/into variations
447        assert_eq!(
448            &expected_custom_serialized,
449            &CustomSerialized::try_from_custom_const(example.cc.clone()).unwrap()
450        );
451        assert_eq!(
452            &expected_custom_serialized,
453            &CustomSerialized::try_from_custom_const_ref(&example.cc).unwrap()
454        );
455        assert_eq!(
456            &expected_custom_serialized,
457            &CustomSerialized::try_from_custom_const_box(example.cc.clone_box()).unwrap()
458        );
459        assert_eq!(
460            &expected_custom_serialized,
461            &CustomSerialized::try_from_dyn_custom_const(example.cc.clone_box().as_ref()).unwrap()
462        );
463        assert_eq!(
464            &example.cc.clone_box(),
465            &expected_custom_serialized.clone().into_custom_const_box()
466        );
467        assert_eq!(
468            &example.cc,
469            &expected_custom_serialized
470                .clone()
471                .try_into_custom_const()
472                .unwrap()
473        );
474
475        // check OpaqueValue serializes/deserializes as a CustomSerialized
476        let ev: OpaqueValue = example.cc.clone().into();
477        let ev_val = serde_json::to_value(&ev).unwrap();
478        assert_eq!(
479            &ev_val,
480            &serde_json::to_value(&expected_custom_serialized).unwrap()
481        );
482        assert_eq!(ev, serde_json::from_value(ev_val).unwrap());
483    }
484
485    fn example_custom_serialized() -> (ConstUsize, CustomSerialized) {
486        let inner = scce_usize().cc;
487        (
488            inner.clone(),
489            CustomSerialized::try_from_custom_const(inner).unwrap(),
490        )
491    }
492
493    fn example_nested_custom_serialized() -> (CustomSerialized, CustomSerialized) {
494        let inner = example_custom_serialized().1;
495        (
496            inner.clone(),
497            CustomSerialized::new(
498                inner.get_type(),
499                serialize_custom_const(&inner).unwrap(),
500                inner.extension_reqs(),
501            ),
502        )
503    }
504
505    #[rstest]
506    #[cfg_attr(miri, ignore = "miri is incompatible with the typetag crate")]
507    #[case(example_custom_serialized())]
508    #[case(example_nested_custom_serialized())]
509    fn test_try_from_custom_serialized_recursive<CC: CustomConst + PartialEq>(
510        #[case] example: (CC, CustomSerialized),
511    ) {
512        let (inner, cs) = example;
513        // check all the try_from/try_into/into variations
514
515        assert_eq!(
516            &cs,
517            &CustomSerialized::try_from_custom_const(cs.clone()).unwrap()
518        );
519        assert_eq!(
520            &cs,
521            &CustomSerialized::try_from_custom_const_ref(&cs).unwrap()
522        );
523        assert_eq!(
524            &cs,
525            &CustomSerialized::try_from_custom_const_box(cs.clone_box()).unwrap()
526        );
527        assert_eq!(
528            &cs,
529            &CustomSerialized::try_from_dyn_custom_const(cs.clone_box().as_ref()).unwrap()
530        );
531        assert_eq!(&inner.clone_box(), &cs.clone().into_custom_const_box());
532        assert_eq!(&inner, &cs.clone().try_into_custom_const().unwrap());
533
534        let ev: OpaqueValue = cs.clone().into();
535        // A serialization round-trip results in an OpaqueValue with the value of inner
536        assert_eq!(
537            OpaqueValue::new(inner),
538            serde_json::from_value(serde_json::to_value(&ev).unwrap()).unwrap()
539        );
540    }
541}
542
543#[cfg(test)]
544mod proptest {
545    use ::proptest::prelude::*;
546
547    use crate::{
548        extension::ExtensionSet,
549        ops::constant::CustomSerialized,
550        proptest::{any_serde_json_value, any_string},
551        types::Type,
552    };
553
554    impl Arbitrary for CustomSerialized {
555        type Parameters = ();
556        type Strategy = BoxedStrategy<Self>;
557        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
558            let typ = any::<Type>();
559            let extensions = any::<ExtensionSet>();
560            // here we manually construct a serialized `dyn CustomConst`.
561            // The "c" and "v" come from the `typetag::serde` annotation on
562            // `trait CustomConst`.
563            // TODO This is not ideal, if we were to accidentally
564            // generate a valid tag(e.g. "ConstInt") then things will
565            // go wrong: the serde::Deserialize impl for that type will
566            // interpret "v" and fail.
567            let value = (any_serde_json_value(), any_string()).prop_map(|(content, tag)| {
568                [("c".into(), tag.into()), ("v".into(), content)]
569                    .into_iter()
570                    .collect::<serde_json::Map<String, _>>()
571                    .into()
572            });
573            (typ, value, extensions)
574                .prop_map(|(typ, value, extensions)| CustomSerialized {
575                    typ,
576                    value,
577                    extensions,
578                })
579                .boxed()
580        }
581    }
582}