p2panda_rs/serde/
u64_str.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2
3use std::convert::{TryFrom, TryInto};
4use std::marker::PhantomData;
5use std::str::FromStr;
6
7use serde::de::Visitor;
8use serde::Deserialize;
9
10/// Visitor which can be used to deserialize a `String` or `u64` integer to a type T.
11#[derive(Debug, Default)]
12pub struct StringOrU64<T>(PhantomData<T>);
13
14impl<T> StringOrU64<T> {
15    /// Returns temporary type to deserialize a string or u64 into T.
16    pub fn new() -> Self {
17        Self(PhantomData::<T>)
18    }
19}
20
21impl<'de, T> Visitor<'de> for StringOrU64<T>
22where
23    T: Deserialize<'de> + FromStr + TryFrom<u64>,
24{
25    type Value = T;
26
27    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
28        formatter.write_str("string or u64 integer")
29    }
30
31    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
32    where
33        E: serde::de::Error,
34    {
35        let result = FromStr::from_str(value)
36            .map_err(|_| serde::de::Error::custom("Invalid string value"))?;
37
38        Ok(result)
39    }
40
41    fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
42    where
43        E: serde::de::Error,
44    {
45        let result = TryInto::<Self::Value>::try_into(value)
46            .map_err(|_| serde::de::Error::custom("Invalid u64 value"))?;
47
48        Ok(result)
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use std::str::FromStr;
55
56    use serde::Deserialize;
57
58    use super::StringOrU64;
59
60    #[test]
61    fn deserialize_str_and_u64() {
62        #[derive(PartialEq, Eq, Debug)]
63        struct Test(u64);
64
65        impl From<u64> for Test {
66            fn from(value: u64) -> Self {
67                Self(value)
68            }
69        }
70
71        impl FromStr for Test {
72            type Err = Box<dyn std::error::Error>;
73
74            fn from_str(s: &str) -> Result<Self, Self::Err> {
75                Ok(Test(u64::from_str(s)?))
76            }
77        }
78
79        impl<'de> Deserialize<'de> for Test {
80            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
81            where
82                D: serde::Deserializer<'de>,
83            {
84                deserializer.deserialize_any(StringOrU64::<Test>::new())
85            }
86        }
87
88        let mut cbor_bytes = Vec::new();
89        ciborium::ser::into_writer("12", &mut cbor_bytes).unwrap();
90        let result: Test = ciborium::de::from_reader(&cbor_bytes[..]).unwrap();
91        assert_eq!(result, Test(12));
92    }
93}