broxus_util/
serde_helpers.rs

1use std::borrow::Cow;
2use std::convert::TryInto;
3use std::fmt;
4use std::str::FromStr;
5use std::time::Duration;
6
7use serde::de::{Error, Visitor};
8use serde::{Deserialize, Serialize};
9
10macro_rules! declare_const_helpers {
11    ($($name:ident => $ty:ty),*$(,)?) => {
12        $(pub const fn $name<const N: $ty>() -> $ty {
13            N
14        })*
15    };
16}
17
18declare_const_helpers!(
19    const_bool => bool,
20    const_usize => usize,
21    const_i8 => i8,
22    const_u8 => u8,
23    const_i16 => i16,
24    const_u16 => u16,
25    const_i32 => i32,
26    const_u32 => u32,
27    const_i64 => i64,
28    const_u64 => u64,
29    const_i128 => i128,
30    const_u128 => u128,
31);
32
33pub const fn const_duration_sec<const N: u64>() -> Duration {
34    Duration::from_secs(N)
35}
36
37pub const fn const_duration_ms<const N: u64>() -> Duration {
38    Duration::from_millis(N)
39}
40
41pub trait JsonNumberRepr {
42    #[inline(always)]
43    fn fits_into_number(&self) -> bool {
44        true
45    }
46}
47impl<T: JsonNumberRepr> JsonNumberRepr for &T {
48    #[inline(always)]
49    fn fits_into_number(&self) -> bool {
50        <T as JsonNumberRepr>::fits_into_number(*self)
51    }
52}
53
54impl JsonNumberRepr for u8 {}
55impl JsonNumberRepr for i8 {}
56impl JsonNumberRepr for u16 {}
57impl JsonNumberRepr for i16 {}
58impl JsonNumberRepr for u32 {}
59impl JsonNumberRepr for i32 {}
60impl JsonNumberRepr for u64 {
61    #[inline(always)]
62    fn fits_into_number(&self) -> bool {
63        *self <= 0x1fffffffffffffu64
64    }
65}
66impl JsonNumberRepr for i64 {
67    #[inline(always)]
68    fn fits_into_number(&self) -> bool {
69        *self <= 0x1fffffffffffffi64
70    }
71}
72impl JsonNumberRepr for u128 {
73    #[inline(always)]
74    fn fits_into_number(&self) -> bool {
75        *self <= 0x1fffffffffffffu128
76    }
77}
78impl JsonNumberRepr for i128 {
79    #[inline(always)]
80    fn fits_into_number(&self) -> bool {
81        *self <= 0x1fffffffffffffi128
82    }
83}
84
85struct StringOrNumber<T>(T);
86
87impl<T> Serialize for StringOrNumber<T>
88where
89    T: JsonNumberRepr + Serialize + fmt::Display,
90{
91    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92    where
93        S: serde::Serializer,
94    {
95        if !serializer.is_human_readable() || self.0.fits_into_number() {
96            self.0.serialize(serializer)
97        } else {
98            serializer.serialize_str(&self.0.to_string())
99        }
100    }
101}
102
103impl<'de, T> Deserialize<'de> for StringOrNumber<T>
104where
105    T: FromStr + Deserialize<'de>,
106{
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        #[derive(Deserialize)]
112        #[serde(untagged)]
113        enum Value<'a, T> {
114            String(#[serde(borrow)] Cow<'a, str>),
115            Number(T),
116        }
117
118        if deserializer.is_human_readable() {
119            match Value::deserialize(deserializer)? {
120                Value::String(str) => T::from_str(str.as_ref())
121                    .map(Self)
122                    .map_err(|_| Error::custom("Invalid number")),
123                Value::Number(value) => Ok(Self(value)),
124            }
125        } else {
126            T::deserialize(deserializer).map(StringOrNumber)
127        }
128    }
129}
130
131pub mod serde_string_or_number {
132    use super::*;
133
134    pub fn serialize<S, T>(data: &T, serializer: S) -> Result<S::Ok, S::Error>
135    where
136        S: serde::Serializer,
137        T: JsonNumberRepr + Serialize + fmt::Display,
138    {
139        StringOrNumber(data).serialize(serializer)
140    }
141
142    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
143    where
144        D: serde::Deserializer<'de>,
145        T: FromStr + Deserialize<'de>,
146    {
147        StringOrNumber::<T>::deserialize(deserializer).map(|StringOrNumber(x)| x)
148    }
149}
150
151pub mod serde_optional_string_or_number {
152    use super::*;
153
154    pub fn serialize<S, T>(data: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
155    where
156        S: serde::Serializer,
157        T: JsonNumberRepr + Serialize + fmt::Display,
158    {
159        match data {
160            Some(data) => serializer.serialize_some(&StringOrNumber(data)),
161            None => serializer.serialize_none(),
162        }
163    }
164
165    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
166    where
167        D: serde::Deserializer<'de>,
168        T: FromStr + Deserialize<'de>,
169    {
170        Option::<StringOrNumber<T>>::deserialize(deserializer).map(|x| x.map(|StringOrNumber(x)| x))
171    }
172}
173
174pub mod serde_duration_sec {
175    use super::*;
176
177    pub fn serialize<S>(data: &Duration, serializer: S) -> Result<S::Ok, S::Error>
178    where
179        S: serde::Serializer,
180    {
181        StringOrNumber(data.as_secs()).serialize(serializer)
182    }
183
184    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
185    where
186        D: serde::Deserializer<'de>,
187    {
188        StringOrNumber::deserialize(deserializer).map(|StringOrNumber(x)| Duration::from_secs(x))
189    }
190}
191
192pub mod serde_duration_ms {
193    use super::*;
194
195    pub fn serialize<S>(data: &Duration, serializer: S) -> Result<S::Ok, S::Error>
196    where
197        S: serde::Serializer,
198    {
199        StringOrNumber(data.as_millis() as u64).serialize(serializer)
200    }
201
202    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
203    where
204        D: serde::Deserializer<'de>,
205    {
206        StringOrNumber::deserialize(deserializer).map(|StringOrNumber(x)| Duration::from_millis(x))
207    }
208}
209
210pub mod serde_base64_array {
211    use super::*;
212
213    pub fn serialize<S>(data: &dyn AsRef<[u8]>, serializer: S) -> Result<S::Ok, S::Error>
214    where
215        S: serde::Serializer,
216    {
217        serde_base64_bytes::serialize(data, serializer)
218    }
219
220    pub fn deserialize<'de, D, const N: usize>(deserializer: D) -> Result<[u8; N], D::Error>
221    where
222        D: serde::Deserializer<'de>,
223    {
224        let data = serde_base64_bytes::deserialize(deserializer)?;
225        data.try_into()
226            .map_err(|_| Error::custom(format!("Invalid array length, expected: {N}")))
227    }
228}
229
230pub mod serde_hex_array {
231    use super::*;
232
233    pub fn serialize<S>(data: &dyn AsRef<[u8]>, serializer: S) -> Result<S::Ok, S::Error>
234    where
235        S: serde::Serializer,
236    {
237        serde_hex_bytes::serialize(data, serializer)
238    }
239
240    pub fn deserialize<'de, D, const N: usize>(deserializer: D) -> Result<[u8; N], D::Error>
241    where
242        D: serde::Deserializer<'de>,
243    {
244        let data = serde_hex_bytes::deserialize(deserializer)?;
245        data.try_into()
246            .map_err(|_| Error::custom(format!("Invalid array length, expected: {N}")))
247    }
248}
249
250pub mod serde_optional_hex_array {
251    use super::*;
252
253    pub fn serialize<S, T>(data: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
254    where
255        T: AsRef<[u8]> + Sized,
256        S: serde::Serializer,
257    {
258        #[derive(Serialize)]
259        struct Wrapper<'a>(#[serde(with = "serde_hex_bytes")] &'a [u8]);
260
261        match data {
262            Some(data) => serializer.serialize_some(&Wrapper(data.as_ref())),
263            None => serializer.serialize_none(),
264        }
265    }
266
267    pub fn deserialize<'de, D, const N: usize>(deserializer: D) -> Result<Option<[u8; N]>, D::Error>
268    where
269        D: serde::Deserializer<'de>,
270    {
271        #[derive(Deserialize)]
272        struct Wrapper(#[serde(with = "serde_hex_bytes")] Vec<u8>);
273
274        let data = Option::<Wrapper>::deserialize(deserializer)?;
275        Ok(match data {
276            Some(data) => Some(
277                data.0
278                    .try_into()
279                    .map_err(|_| Error::custom(format!("Invalid array length, expected: {}", N)))?,
280            ),
281            None => None,
282        })
283    }
284}
285
286pub mod serde_string {
287    use super::*;
288
289    pub fn serialize<S>(data: &dyn fmt::Display, serializer: S) -> Result<S::Ok, S::Error>
290    where
291        S: serde::Serializer,
292    {
293        data.to_string().serialize(serializer)
294    }
295
296    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
297    where
298        D: serde::Deserializer<'de>,
299        T: FromStr,
300        T::Err: fmt::Display,
301    {
302        <BorrowedStr>::deserialize(deserializer)
303            .and_then(|data| T::from_str(data.0.as_ref()).map_err(Error::custom))
304    }
305}
306
307pub mod serde_optional_string {
308    use super::*;
309
310    pub fn serialize<S, T>(data: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
311    where
312        S: serde::Serializer,
313        T: fmt::Display,
314    {
315        data.as_ref().map(ToString::to_string).serialize(serializer)
316    }
317
318    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
319    where
320        D: serde::Deserializer<'de>,
321        T: FromStr,
322        T::Err: fmt::Display,
323    {
324        Option::<BorrowedStr>::deserialize(deserializer).and_then(|data| {
325            data.map(|data| T::from_str(data.0.as_ref()).map_err(Error::custom))
326                .transpose()
327        })
328    }
329}
330
331pub mod serde_string_array {
332    use super::*;
333
334    pub fn serialize<S, T>(data: &[T], serializer: S) -> Result<S::Ok, S::Error>
335    where
336        S: serde::Serializer,
337        T: fmt::Display,
338    {
339        data.iter()
340            .map(ToString::to_string)
341            .collect::<Vec<_>>()
342            .join(",")
343            .serialize(serializer)
344    }
345
346    pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
347    where
348        T: Deserialize<'de> + FromStr,
349        D: serde::Deserializer<'de>,
350        <T as FromStr>::Err: fmt::Display,
351    {
352        let BorrowedStr(s) = <_>::deserialize(deserializer)?;
353        if s.contains(',') {
354            let mut v = Vec::new();
355            for url in s.split(',') {
356                v.push(T::from_str(url).map_err(Error::custom)?);
357            }
358            Ok(v)
359        } else {
360            Ok(vec![T::from_str(s.as_ref()).map_err(Error::custom)?])
361        }
362    }
363}
364
365struct BytesVisitor;
366
367impl<'de> Visitor<'de> for BytesVisitor {
368    type Value = Vec<u8>;
369
370    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
371        formatter.write_str("byte array")
372    }
373
374    fn visit_bytes<E: Error>(self, value: &[u8]) -> Result<Self::Value, E> {
375        Ok(value.to_vec())
376    }
377}
378
379pub mod serde_hex_bytes {
380    use std::fmt;
381
382    use serde::de::Unexpected;
383
384    use super::*;
385
386    pub fn serialize<S>(data: &dyn AsRef<[u8]>, serializer: S) -> Result<S::Ok, S::Error>
387    where
388        S: serde::Serializer,
389    {
390        if serializer.is_human_readable() {
391            serializer.serialize_str(hex::encode(data).as_str())
392        } else {
393            serializer.serialize_bytes(data.as_ref())
394        }
395    }
396
397    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
398    where
399        D: serde::Deserializer<'de>,
400    {
401        struct HexVisitor;
402
403        impl<'de> Visitor<'de> for HexVisitor {
404            type Value = Vec<u8>;
405
406            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
407                formatter.write_str("hex-encoded byte array")
408            }
409
410            fn visit_str<E: Error>(self, value: &str) -> Result<Self::Value, E> {
411                hex::decode(value).map_err(|_| E::invalid_type(Unexpected::Str(value), &self))
412            }
413
414            // See the `deserializing_flattened_field` test for an example why this is needed.
415            fn visit_bytes<E: Error>(self, value: &[u8]) -> Result<Self::Value, E> {
416                Ok(value.to_vec())
417            }
418        }
419
420        if deserializer.is_human_readable() {
421            deserializer.deserialize_str(HexVisitor)
422        } else {
423            deserializer.deserialize_bytes(BytesVisitor)
424        }
425    }
426}
427
428pub mod serde_optional_hex_bytes {
429    use super::*;
430
431    pub fn serialize<S, T>(data: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
432    where
433        S: serde::Serializer,
434        T: AsRef<[u8]>,
435    {
436        #[derive(serde::Serialize)]
437        #[serde(transparent)]
438        struct Wrapper<'a>(#[serde(with = "serde_hex_bytes")] &'a [u8]);
439
440        match data {
441            Some(data) => serializer.serialize_some(&Wrapper(data.as_ref())),
442            None => serializer.serialize_none(),
443        }
444    }
445
446    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
447    where
448        D: serde::Deserializer<'de>,
449    {
450        #[derive(serde::Deserialize)]
451        #[serde(transparent)]
452        struct Wrapper(#[serde(with = "serde_hex_bytes")] Vec<u8>);
453
454        Option::<Wrapper>::deserialize(deserializer).map(|wrapper| wrapper.map(|data| data.0))
455    }
456}
457
458pub mod serde_base64_bytes {
459    use serde::de::Unexpected;
460
461    use super::*;
462
463    pub fn serialize<S>(data: &dyn AsRef<[u8]>, serializer: S) -> Result<S::Ok, S::Error>
464    where
465        S: serde::Serializer,
466    {
467        if serializer.is_human_readable() {
468            serializer.serialize_str(base64::encode(data).as_str())
469        } else {
470            serializer.serialize_bytes(data.as_ref())
471        }
472    }
473
474    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
475    where
476        D: serde::Deserializer<'de>,
477    {
478        struct Base64Visitor;
479
480        impl<'de> Visitor<'de> for Base64Visitor {
481            type Value = Vec<u8>;
482
483            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
484                formatter.write_str("base64-encoded byte array")
485            }
486
487            fn visit_str<E: Error>(self, value: &str) -> Result<Self::Value, E> {
488                base64::decode(value).map_err(|_| E::invalid_type(Unexpected::Str(value), &self))
489            }
490
491            // See the `deserializing_flattened_field` test for an example why this is needed.
492            fn visit_bytes<E: Error>(self, value: &[u8]) -> Result<Self::Value, E> {
493                Ok(value.to_vec())
494            }
495        }
496
497        if deserializer.is_human_readable() {
498            deserializer.deserialize_str(Base64Visitor)
499        } else {
500            deserializer.deserialize_bytes(BytesVisitor)
501        }
502    }
503}
504
505pub mod serde_optional_base64_bytes {
506    use super::*;
507
508    pub fn serialize<S, T>(data: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
509    where
510        S: serde::Serializer,
511        T: AsRef<[u8]>,
512    {
513        #[derive(serde::Serialize)]
514        #[serde(transparent)]
515        struct Wrapper<'a>(#[serde(with = "serde_base64_bytes")] &'a [u8]);
516
517        match data {
518            Some(data) => serializer.serialize_some(&Wrapper(data.as_ref())),
519            None => serializer.serialize_none(),
520        }
521    }
522
523    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
524    where
525        D: serde::Deserializer<'de>,
526    {
527        #[derive(serde::Deserialize)]
528        #[serde(transparent)]
529        struct Wrapper(#[serde(with = "serde_base64_bytes")] Vec<u8>);
530
531        Option::<Wrapper>::deserialize(deserializer).map(|wrapper| wrapper.map(|data| data.0))
532    }
533}
534
535pub mod serde_iter {
536    pub fn serialize<S, T, V>(iter: &T, serializer: S) -> Result<S::Ok, S::Error>
537    where
538        S: serde::Serializer,
539        T: IntoIterator<Item = V> + Clone,
540        V: serde::Serialize,
541    {
542        use serde::ser::SerializeSeq;
543
544        let iter = iter.clone().into_iter();
545        let mut seq = serializer.serialize_seq(Some(iter.size_hint().0))?;
546        for value in iter {
547            seq.serialize_element(&value)?;
548        }
549        seq.end()
550    }
551}
552
553#[derive(Deserialize)]
554struct BorrowedStr<'a>(#[serde(borrow)] Cow<'a, str>);
555
556#[cfg(test)]
557mod test {
558    use super::*;
559    use serde::{Deserialize, Serialize};
560
561    #[test]
562    fn test_string_or_number() {
563        #[derive(Serialize, Deserialize, Eq, PartialEq, Debug)]
564        struct Test {
565            #[serde(with = "serde_string_or_number")]
566            value: u64,
567        }
568
569        let test = Test { value: 123123 };
570        let data = serde_json::to_string(&test).unwrap();
571        assert_eq!(data, r#"{"value":123123}"#);
572        assert_eq!(serde_json::from_str::<Test>(&data).unwrap(), test);
573
574        let data = r#"{"value":"123123"}"#;
575        assert_eq!(serde_json::from_str::<Test>(data).unwrap(), test);
576
577        let test = Test {
578            value: 0xffffffffffffffff,
579        };
580        let data = serde_json::to_string(&test).unwrap();
581        assert_eq!(data, r#"{"value":"18446744073709551615"}"#);
582    }
583
584    #[test]
585    fn test_changed_string() {
586        #[derive(Debug, Serialize, Deserialize)]
587        struct Test {
588            #[serde(with = "serde_string_array")]
589            value: Vec<String>,
590        }
591
592        let test: Test = serde_json::from_str("{\"value\":\"\\\"\"}").unwrap();
593        println!("{test:?}");
594    }
595
596    #[test]
597    fn test_hex() {
598        #[derive(Serialize, Deserialize, Eq, PartialEq, Debug)]
599        struct Test {
600            #[serde(with = "serde_hex_array")]
601            key: [u8; 32],
602        }
603        let test = Test { key: [1; 32] };
604        let data = serde_json::to_string(&test).unwrap();
605        assert_eq!(
606            data,
607            r#"{"key":"0101010101010101010101010101010101010101010101010101010101010101"}"#
608        );
609        assert_eq!(serde_json::from_str::<Test>(&data).unwrap(), test);
610        let data = bincode::serialize(&test).unwrap();
611        assert!(data.len() < 64);
612        assert_eq!(bincode::deserialize::<Test>(&data).unwrap(), test);
613    }
614
615    #[test]
616    fn test_optional() {
617        #[derive(Serialize, Deserialize, Eq, PartialEq, Debug)]
618        struct Test {
619            #[serde(with = "serde_optional_base64_bytes")]
620            key: Option<Vec<u8>>,
621        }
622
623        let data = Test {
624            key: Some(vec![1; 32]),
625        };
626        let res = serde_json::to_string(&data).unwrap();
627        assert_eq!(
628            r#"{"key":"AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE="}"#,
629            res
630        );
631        assert_eq!(data, serde_json::from_str(&res).unwrap());
632
633        let data = Test { key: None };
634        let res = serde_json::to_string(&data).unwrap();
635        assert_eq!(r#"{"key":null}"#, res);
636        assert_eq!(data, serde_json::from_str(&res).unwrap())
637    }
638
639    #[test]
640    fn test_optional_hex_array() {
641        #[derive(Serialize, Deserialize)]
642        struct Test {
643            #[serde(with = "serde_optional_hex_array")]
644            field: Option<[u8; 32]>,
645        }
646
647        let target: [u8; 32] =
648            hex::decode("0101010101010101010101010101010101010101010101010101010101010101")
649                .unwrap()
650                .try_into()
651                .unwrap();
652
653        let serialized = serde_json::to_string(&Test {
654            field: Some(target),
655        })
656        .unwrap();
657        let deserialized: Test = serde_json::from_str(&serialized).unwrap();
658        assert_eq!(deserialized.field, Some(target));
659
660        let serialized = serde_json::to_string(&Test { field: None }).unwrap();
661        let deserialized: Test = serde_json::from_str(&serialized).unwrap();
662        assert_eq!(deserialized.field, None);
663    }
664}