hugr_core/ops/constant/
custom.rs

1//! Representation of custom constant values.
2//!
3//! These can be used as [`Const`] operations in HUGRs.
4//!
5//! [`Const`]: crate::ops::Const
6
7use std::any::Any;
8use std::hash::{Hash, Hasher};
9
10use downcast_rs::{Downcast, impl_downcast};
11use thiserror::Error;
12
13use crate::IncomingPort;
14use crate::extension::resolution::{
15    ExtensionResolutionError, WeakExtensionRegistry, resolve_type_extensions,
16};
17use crate::macros::impl_box_clone;
18use crate::types::{CustomCheckFailure, Type};
19
20use super::{Value, ValueName};
21
22/// Extensible constant values.
23///
24/// We use [typetag] to provide an `impl Serialize for dyn CustomConst`, and
25/// similarly [serde::Deserialize]. When implementing this trait, include the
26/// [`#[typetag::serde]`](typetag) attribute to enable serialization.
27///
28/// Note that when serializing through the [`dyn CustomConst`] a dictionary will
29/// be serialized with two attributes, `"c"`  the tag and `"v"` the
30/// `CustomConst`:
31///
32#[cfg_attr(not(miri), doc = "```")] // this doctest depends on typetag, so fails with miri
33#[cfg_attr(miri, doc = "```ignore")]
34/// use serde::{Serialize,Deserialize};
35/// use hugr::{
36///   types::Type,ops::constant::{OpaqueValue, ValueName, CustomConst},
37///   extension::ExtensionSet, std_extensions::arithmetic::int_types};
38/// use serde_json::json;
39///
40/// #[derive(std::fmt::Debug, Clone, Hash, Serialize,Deserialize)]
41/// struct CC(i64);
42///
43/// #[typetag::serde]
44/// impl CustomConst for CC {
45///   fn name(&self) -> ValueName { "CC".into() }
46///   fn get_type(&self) -> Type { int_types::INT_TYPES[5].clone() }
47/// }
48///
49/// assert_eq!(serde_json::to_value(CC(2)).unwrap(), json!(2));
50/// assert_eq!(serde_json::to_value(&CC(2) as &dyn CustomConst).unwrap(), json!({
51///   "c": "CC",
52///   "v": 2
53/// }));
54/// ```
55#[typetag::serde(tag = "c", content = "v")]
56pub trait CustomConst:
57    Send + Sync + std::fmt::Debug + TryHash + CustomConstBoxClone + Any + Downcast
58{
59    /// An identifier for the constant.
60    fn name(&self) -> ValueName;
61
62    /// Check the value.
63    fn validate(&self) -> Result<(), CustomCheckFailure> {
64        Ok(())
65    }
66
67    /// Compare two constants for equality, using downcasting and comparing the definitions.
68    ///
69    /// If the type implements `PartialEq`, use [`downcast_equal_consts`] to compare the values.
70    ///
71    /// Note that this does not require any equivalent of [Eq]: it is permissible to return
72    /// `false` if in doubt, and in particular, there is no requirement for reflexivity
73    /// (i.e. `x.equal_consts(x)` can be `false`). However, we do expect both
74    /// symmetry (`x.equal_consts(y) == y.equal_consts(x)`) and transitivity
75    /// (if `x.equal_consts(y) && y.equal_consts(z)` then `x.equal_consts(z)`).
76    fn equal_consts(&self, _other: &dyn CustomConst) -> bool {
77        // false unless overridden
78        false
79    }
80
81    /// Update the extensions associated with the internal values.
82    ///
83    /// This is used to ensure that any extension reference [`CustomConst::get_type`] remains
84    /// valid when serializing and deserializing the constant.
85    ///
86    /// See the helper methods in [`crate::extension::resolution`].
87    fn update_extensions(
88        &mut self,
89        _extensions: &WeakExtensionRegistry,
90    ) -> Result<(), ExtensionResolutionError> {
91        Ok(())
92    }
93
94    /// Report the type.
95    fn get_type(&self) -> Type;
96}
97
98/// Fallible hash function.
99///
100/// Prerequisite for `CustomConst`. Allows to declare a custom hash function,
101/// but the easiest options are either to `impl TryHash for ... {}` to indicate
102/// "not hashable", or else to implement/derive [Hash].
103pub trait TryHash {
104    /// Hashes the value, if possible; else return `false` without mutating the `Hasher`.
105    /// This relates with [`CustomConst::equal_consts`] just like [Hash] with [Eq]:
106    /// * if `x.equal_consts(y)` ==> `x.try_hash(s)` behaves equivalently to `y.try_hash(s)`
107    /// * if `x.hash(s)` behaves differently from `y.hash(s)` ==> `x.equal_consts(y) == false`
108    ///
109    /// As with [Hash], these requirements can trivially be satisfied by either
110    /// * `equal_consts` always returning `false`, or
111    /// * `try_hash` always behaving the same (e.g. returning `false`, as it does by default)
112    ///
113    /// Note: uses `dyn` rather than being parametrized by `<H: Hasher>` to be object-safe.
114    fn try_hash(&self, _state: &mut dyn Hasher) -> bool {
115        false
116    }
117}
118
119impl<T: Hash> TryHash for T {
120    fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
121        Hash::hash(self, &mut st);
122        true
123    }
124}
125
126impl PartialEq for dyn CustomConst {
127    fn eq(&self, other: &Self) -> bool {
128        (*self).equal_consts(other)
129    }
130}
131
132/// Const equality for types that have `PartialEq`
133pub fn downcast_equal_consts<T: CustomConst + PartialEq>(
134    constant: &T,
135    other: &dyn CustomConst,
136) -> bool {
137    if let Some(other) = other.as_any().downcast_ref::<T>() {
138        constant == other
139    } else {
140        false
141    }
142}
143
144/// Serialize any `CustomConst` using the `impl Serialize for &dyn CustomConst`.
145fn serialize_custom_const(cc: &dyn CustomConst) -> Result<serde_json::Value, serde_json::Error> {
146    serde_json::to_value(cc)
147}
148
149/// Deserialize a `Box<&dyn CustomConst>` and attempt to downcast it to `CC`;
150/// propagating failure.
151fn deserialize_custom_const<CC: CustomConst>(
152    value: serde_json::Value,
153) -> Result<CC, serde_json::Error> {
154    match deserialize_dyn_custom_const(value)?.downcast::<CC>() {
155        Ok(cc) => Ok(*cc),
156        Err(dyn_cc) => Err(<serde_json::Error as serde::de::Error>::custom(format!(
157            "Failed to deserialize [{}]: {:?}",
158            std::any::type_name::<CC>(),
159            dyn_cc
160        ))),
161    }
162}
163
164/// Deserialize a `Box<&dyn CustomConst>`.
165fn deserialize_dyn_custom_const(
166    value: serde_json::Value,
167) -> Result<Box<dyn CustomConst>, serde_json::Error> {
168    serde_json::from_value(value)
169}
170
171impl_downcast!(CustomConst);
172impl_box_clone!(CustomConst, CustomConstBoxClone);
173
174#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
175/// A constant value stored as a serialized blob that can report its own type.
176pub struct CustomSerialized {
177    typ: Type,
178    value: serde_json::Value,
179}
180
181#[derive(Debug, Error)]
182#[error("Error serializing value into CustomSerialized: err: {err}, value: {payload:?}")]
183pub struct SerializeError {
184    #[source]
185    err: serde_json::Error,
186    payload: Box<dyn CustomConst>,
187}
188
189#[derive(Debug, Error)]
190#[error("Error deserializing value from CustomSerialized: err: {err}, value: {payload:?}")]
191pub struct DeserializeError {
192    #[source]
193    err: serde_json::Error,
194    payload: serde_json::Value,
195}
196
197impl CustomSerialized {
198    /// Creates a new [`CustomSerialized`].
199    pub fn new(typ: impl Into<Type>, value: serde_json::Value) -> Self {
200        Self {
201            typ: typ.into(),
202            value,
203        }
204    }
205
206    /// Returns the inner value.
207    #[must_use]
208    pub fn value(&self) -> &serde_json::Value {
209        &self.value
210    }
211
212    /// If `cc` is a [Self], returns a clone of `cc` coerced to [Self].
213    /// Otherwise, returns a [Self] with `cc` serialized in it's value.
214    pub fn try_from_custom_const_ref(cc: &impl CustomConst) -> Result<Self, SerializeError> {
215        Self::try_from_dyn_custom_const(cc)
216    }
217
218    /// If `cc` is a [Self], returns a clone of `cc` coerced to [Self].
219    /// Otherwise, returns a [Self] with `cc` serialized in it's value.
220    pub fn try_from_dyn_custom_const(cc: &dyn CustomConst) -> Result<Self, SerializeError> {
221        Ok(match cc.as_any().downcast_ref::<Self>() {
222            Some(cs) => cs.clone(),
223            None => Self::new(
224                cc.get_type(),
225                serialize_custom_const(cc).map_err(|err| SerializeError {
226                    err,
227                    payload: cc.clone_box(),
228                })?,
229            ),
230        })
231    }
232
233    /// If `cc` is a [Self], return `cc` coerced to [Self]. Otherwise,
234    /// returns a [Self] with `cc` serialized in it's value.
235    /// Never clones `cc` outside of error paths.
236    pub fn try_from_custom_const(cc: impl CustomConst) -> Result<Self, SerializeError> {
237        Self::try_from_custom_const_box(Box::new(cc))
238    }
239
240    /// If `cc` is a [Self], return `cc` coerced to [Self]. Otherwise,
241    /// returns a [Self] with `cc` serialized in it's value.
242    /// Never clones `cc` outside of error paths.
243    pub fn try_from_custom_const_box(cc: Box<dyn CustomConst>) -> Result<Self, SerializeError> {
244        match cc.downcast::<Self>() {
245            Ok(x) => Ok(*x),
246            Err(cc) => {
247                let typ = cc.get_type();
248                let value = serialize_custom_const(cc.as_ref())
249                    .map_err(|err| SerializeError { err, payload: cc })?;
250                Ok(Self::new(typ, value))
251            }
252        }
253    }
254
255    /// Attempts to deserialize the value in self into a `Box<dyn CustomConst>`.
256    /// This can fail, in particular when the `impl CustomConst` for the trait
257    /// is not linked into the running executable.
258    /// If deserialization fails, returns self in a box.
259    ///
260    /// Note that if the inner value is a [Self] we do not recursively
261    /// deserialize it.
262    #[must_use]
263    pub fn into_custom_const_box(self) -> Box<dyn CustomConst> {
264        // ideally we would not have to clone, but serde_json does not allow us
265        // to recover the value from the error
266        deserialize_dyn_custom_const(self.value.clone()).unwrap_or_else(|_| Box::new(self))
267    }
268
269    /// Attempts to deserialize the value in self into a `CC`. Propagates failure.
270    ///
271    /// Note that if the inner value is a [Self] we do not recursively
272    /// deserialize it. In particular if that inner value were a [Self] whose
273    /// inner value were a `CC`, then we would still fail.
274    pub fn try_into_custom_const<CC: CustomConst>(self) -> Result<CC, DeserializeError> {
275        // ideally we would not have to clone, but serde_json does not allow us
276        // to recover the value from the error
277        deserialize_custom_const(self.value.clone()).map_err(|err| DeserializeError {
278            err,
279            payload: self.value,
280        })
281    }
282}
283
284impl TryHash for CustomSerialized {
285    fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
286        // Consistent with equality, same serialization <=> same hash.
287        self.value.to_string().hash(&mut st);
288        true
289    }
290}
291
292#[typetag::serde]
293impl CustomConst for CustomSerialized {
294    fn name(&self) -> ValueName {
295        format!("json:{:?}", self.value).into()
296    }
297
298    fn equal_consts(&self, other: &dyn CustomConst) -> bool {
299        Some(self) == other.downcast_ref()
300    }
301
302    fn update_extensions(
303        &mut self,
304        extensions: &WeakExtensionRegistry,
305    ) -> Result<(), ExtensionResolutionError> {
306        resolve_type_extensions(&mut self.typ, extensions)
307    }
308    fn get_type(&self) -> Type {
309        self.typ.clone()
310    }
311}
312
313/// This module is used by the serde annotations on `super::OpaqueValue`
314pub(super) mod serde_extension_value {
315    use serde::{Deserializer, Serializer};
316
317    use super::{CustomConst, CustomSerialized};
318
319    pub fn deserialize<'de, D: Deserializer<'de>>(
320        deserializer: D,
321    ) -> Result<Box<dyn CustomConst>, D::Error> {
322        use serde::Deserialize;
323        // We deserialize a CustomSerialized, i.e. not a dyn CustomConst.
324        let cs = CustomSerialized::deserialize(deserializer)?;
325        // We return the inner serialized CustomConst if we can, otherwise the
326        // CustomSerialized itself.
327        Ok(cs.into_custom_const_box())
328    }
329
330    pub fn serialize<S: Serializer>(
331        konst: impl AsRef<dyn CustomConst>,
332        serializer: S,
333    ) -> Result<S::Ok, S::Error> {
334        use serde::Serialize;
335        // we create a CustomSerialized, then serialize it. Note we do not
336        // serialize it as a dyn CustomConst.
337        let cs = CustomSerialized::try_from_dyn_custom_const(konst.as_ref())
338            .map_err(<S::Error as serde::ser::Error>::custom)?;
339        cs.serialize(serializer)
340    }
341}
342
343/// Given a singleton list of constant operations, return the value.
344#[must_use]
345pub fn get_single_input_value<T: CustomConst>(consts: &[(IncomingPort, Value)]) -> Option<&T> {
346    let [(_, c)] = consts else {
347        return None;
348    };
349    c.get_custom_value()
350}
351
352/// Given a list of two constant operations, return the values.
353#[must_use]
354pub fn get_pair_of_input_values<T: CustomConst>(
355    consts: &[(IncomingPort, Value)],
356) -> Option<(&T, &T)> {
357    let [(_, c0), (_, c1)] = consts else {
358        return None;
359    };
360    Some((c0.get_custom_value()?, c1.get_custom_value()?))
361}
362
363// these tests depend on the `typetag` crate.
364#[cfg(all(test, not(miri)))]
365mod test {
366
367    use rstest::rstest;
368
369    use crate::{
370        extension::prelude::{ConstUsize, usize_t},
371        ops::{Value, constant::custom::serialize_custom_const},
372        std_extensions::collections::list::ListValue,
373    };
374
375    use super::{super::OpaqueValue, CustomConst, CustomConstBoxClone, CustomSerialized};
376
377    struct SerializeCustomConstExample<CC: CustomConst + serde::Serialize + 'static> {
378        cc: CC,
379        tag: &'static str,
380        json: serde_json::Value,
381    }
382
383    impl<CC: CustomConst + serde::Serialize + 'static> SerializeCustomConstExample<CC> {
384        fn new(cc: CC, tag: &'static str) -> Self {
385            let json = serde_json::to_value(&cc).unwrap();
386            Self { cc, tag, json }
387        }
388    }
389
390    fn scce_usize() -> SerializeCustomConstExample<ConstUsize> {
391        SerializeCustomConstExample::new(ConstUsize::new(12), "ConstUsize")
392    }
393
394    fn scce_list() -> SerializeCustomConstExample<ListValue> {
395        let cc = ListValue::new(
396            usize_t(),
397            [ConstUsize::new(1), ConstUsize::new(2)]
398                .into_iter()
399                .map(Value::extension),
400        );
401        SerializeCustomConstExample::new(cc, "ListValue")
402    }
403
404    #[rstest]
405    #[cfg_attr(miri, ignore = "miri is incompatible with the typetag crate")]
406    #[case(scce_usize())]
407    #[case(scce_list())]
408    fn test_custom_serialized_try_from<
409        CC: CustomConst + serde::Serialize + Clone + PartialEq + 'static + Sized,
410    >(
411        #[case] example: SerializeCustomConstExample<CC>,
412    ) {
413        assert_eq!(example.json, serde_json::to_value(&example.cc).unwrap()); // sanity check
414        let expected_json: serde_json::Value = [
415            ("c".into(), example.tag.into()),
416            ("v".into(), example.json.clone()),
417        ]
418        .into_iter()
419        .collect::<serde_json::Map<String, serde_json::Value>>()
420        .into();
421
422        // check serialize_custom_const
423        assert_eq!(expected_json, serialize_custom_const(&example.cc).unwrap());
424
425        let expected_custom_serialized =
426            CustomSerialized::new(example.cc.get_type(), expected_json);
427
428        // check all the try_from/try_into/into variations
429        assert_eq!(
430            &expected_custom_serialized,
431            &CustomSerialized::try_from_custom_const(example.cc.clone()).unwrap()
432        );
433        assert_eq!(
434            &expected_custom_serialized,
435            &CustomSerialized::try_from_custom_const_ref(&example.cc).unwrap()
436        );
437        assert_eq!(
438            &expected_custom_serialized,
439            &CustomSerialized::try_from_custom_const_box(example.cc.clone_box()).unwrap()
440        );
441        assert_eq!(
442            &expected_custom_serialized,
443            &CustomSerialized::try_from_dyn_custom_const(example.cc.clone_box().as_ref()).unwrap()
444        );
445        assert_eq!(
446            &example.cc.clone_box(),
447            &expected_custom_serialized.clone().into_custom_const_box()
448        );
449        assert_eq!(
450            &example.cc,
451            &expected_custom_serialized
452                .clone()
453                .try_into_custom_const()
454                .unwrap()
455        );
456
457        // check OpaqueValue serializes/deserializes as a CustomSerialized
458        let ev: OpaqueValue = example.cc.clone().into();
459        let ev_val = serde_json::to_value(&ev).unwrap();
460        assert_eq!(
461            &ev_val,
462            &serde_json::to_value(&expected_custom_serialized).unwrap()
463        );
464        assert_eq!(ev, serde_json::from_value(ev_val).unwrap());
465    }
466
467    fn example_custom_serialized() -> (ConstUsize, CustomSerialized) {
468        let inner = scce_usize().cc;
469        (
470            inner.clone(),
471            CustomSerialized::try_from_custom_const(inner).unwrap(),
472        )
473    }
474
475    fn example_nested_custom_serialized() -> (CustomSerialized, CustomSerialized) {
476        let inner = example_custom_serialized().1;
477        (
478            inner.clone(),
479            CustomSerialized::new(inner.get_type(), serialize_custom_const(&inner).unwrap()),
480        )
481    }
482
483    #[rstest]
484    #[cfg_attr(miri, ignore = "miri is incompatible with the typetag crate")]
485    #[case(example_custom_serialized())]
486    #[case(example_nested_custom_serialized())]
487    fn test_try_from_custom_serialized_recursive<CC: CustomConst + PartialEq>(
488        #[case] example: (CC, CustomSerialized),
489    ) {
490        let (inner, cs) = example;
491        // check all the try_from/try_into/into variations
492
493        assert_eq!(
494            &cs,
495            &CustomSerialized::try_from_custom_const(cs.clone()).unwrap()
496        );
497        assert_eq!(
498            &cs,
499            &CustomSerialized::try_from_custom_const_ref(&cs).unwrap()
500        );
501        assert_eq!(
502            &cs,
503            &CustomSerialized::try_from_custom_const_box(cs.clone_box()).unwrap()
504        );
505        assert_eq!(
506            &cs,
507            &CustomSerialized::try_from_dyn_custom_const(cs.clone_box().as_ref()).unwrap()
508        );
509        assert_eq!(&inner.clone_box(), &cs.clone().into_custom_const_box());
510        assert_eq!(&inner, &cs.clone().try_into_custom_const().unwrap());
511
512        let ev: OpaqueValue = cs.clone().into();
513        // A serialization round-trip results in an OpaqueValue with the value of inner
514        assert_eq!(
515            OpaqueValue::new(inner),
516            serde_json::from_value(serde_json::to_value(&ev).unwrap()).unwrap()
517        );
518    }
519}
520
521#[cfg(test)]
522mod proptest {
523    use ::proptest::prelude::*;
524
525    use crate::{
526        ops::constant::CustomSerialized,
527        proptest::{any_serde_json_value, any_string},
528        types::Type,
529    };
530
531    impl Arbitrary for CustomSerialized {
532        type Parameters = ();
533        type Strategy = BoxedStrategy<Self>;
534        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
535            let typ = any::<Type>();
536            // here we manually construct a serialized `dyn CustomConst`.
537            // The "c" and "v" come from the `typetag::serde` annotation on
538            // `trait CustomConst`.
539            // TODO This is not ideal, if we were to accidentally
540            // generate a valid tag(e.g. "ConstInt") then things will
541            // go wrong: the serde::Deserialize impl for that type will
542            // interpret "v" and fail.
543            let value = (any_serde_json_value(), any_string()).prop_map(|(content, tag)| {
544                [("c".into(), tag.into()), ("v".into(), content)]
545                    .into_iter()
546                    .collect::<serde_json::Map<String, _>>()
547                    .into()
548            });
549            (typ, value)
550                .prop_map(|(typ, value)| CustomSerialized { typ, value })
551                .boxed()
552        }
553    }
554}