p2panda_rs/serde/
u64_str.rs1use std::convert::{TryFrom, TryInto};
4use std::marker::PhantomData;
5use std::str::FromStr;
6
7use serde::de::Visitor;
8use serde::Deserialize;
9
10#[derive(Debug, Default)]
12pub struct StringOrU64<T>(PhantomData<T>);
13
14impl<T> StringOrU64<T> {
15 pub fn new() -> Self {
17 Self(PhantomData::<T>)
18 }
19}
20
21impl<'de, T> Visitor<'de> for StringOrU64<T>
22where
23 T: Deserialize<'de> + FromStr + TryFrom<u64>,
24{
25 type Value = T;
26
27 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
28 formatter.write_str("string or u64 integer")
29 }
30
31 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
32 where
33 E: serde::de::Error,
34 {
35 let result = FromStr::from_str(value)
36 .map_err(|_| serde::de::Error::custom("Invalid string value"))?;
37
38 Ok(result)
39 }
40
41 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
42 where
43 E: serde::de::Error,
44 {
45 let result = TryInto::<Self::Value>::try_into(value)
46 .map_err(|_| serde::de::Error::custom("Invalid u64 value"))?;
47
48 Ok(result)
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use std::str::FromStr;
55
56 use serde::Deserialize;
57
58 use super::StringOrU64;
59
60 #[test]
61 fn deserialize_str_and_u64() {
62 #[derive(PartialEq, Eq, Debug)]
63 struct Test(u64);
64
65 impl From<u64> for Test {
66 fn from(value: u64) -> Self {
67 Self(value)
68 }
69 }
70
71 impl FromStr for Test {
72 type Err = Box<dyn std::error::Error>;
73
74 fn from_str(s: &str) -> Result<Self, Self::Err> {
75 Ok(Test(u64::from_str(s)?))
76 }
77 }
78
79 impl<'de> Deserialize<'de> for Test {
80 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
81 where
82 D: serde::Deserializer<'de>,
83 {
84 deserializer.deserialize_any(StringOrU64::<Test>::new())
85 }
86 }
87
88 let mut cbor_bytes = Vec::new();
89 ciborium::ser::into_writer("12", &mut cbor_bytes).unwrap();
90 let result: Test = ciborium::de::from_reader(&cbor_bytes[..]).unwrap();
91 assert_eq!(result, Test(12));
92 }
93}