pbjson_any/
lib.rs

1//! `pbjson` is a set of crates to automatically generate [`serde::Serialize`] and
2//! [`serde::Deserialize`] implementations for [prost][1] generated structs that
3//! are compliant with the [protobuf JSON mapping][2]
4//!
5//! See [pbjson-build][3] for usage instructions
6//!
7//! [1]: https://github.com/tokio-rs/prost
8//! [2]: https://developers.google.com/protocol-buffers/docs/proto3#json
9//! [3]: https://docs.rs/pbjson-build
10//!
11#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
12#![warn(
13    missing_debug_implementations,
14    clippy::explicit_iter_loop,
15    clippy::use_self,
16    clippy::clone_on_ref_ptr,
17    clippy::future_not_send
18)]
19
20#[doc(hidden)]
21pub use prost_wkt;
22#[doc(hidden)]
23pub use typetag::serde as typetag_serde;
24#[doc(hidden)]
25pub use typetag;
26
27#[doc(hidden)]
28pub mod private {
29    /// Re-export base64
30    pub use base64;
31
32    use serde::Deserialize;
33    use std::str::FromStr;
34
35    /// Used to parse a number from either a string or its raw representation
36    #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
37    pub struct NumberDeserialize<T>(pub T);
38
39    #[derive(Deserialize)]
40    #[serde(untagged)]
41    enum Content<'a, T> {
42        Str(&'a str),
43        Number(T),
44    }
45
46    impl<'de, T> serde::Deserialize<'de> for NumberDeserialize<T>
47    where
48        T: FromStr + serde::Deserialize<'de>,
49        <T as FromStr>::Err: std::error::Error,
50    {
51        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
52        where
53            D: serde::Deserializer<'de>,
54        {
55            let content = Content::deserialize(deserializer)?;
56            Ok(Self(match content {
57                Content::Str(v) => v.parse().map_err(serde::de::Error::custom)?,
58                Content::Number(v) => v,
59            }))
60        }
61    }
62
63    #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
64    pub struct BytesDeserialize<T>(pub T);
65
66    impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
67    where
68        T: From<Vec<u8>>,
69    {
70        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
71        where
72            D: serde::Deserializer<'de>,
73        {
74            let s: &str = Deserialize::deserialize(deserializer)?;
75
76            let decoded = base64::decode_config(s, base64::STANDARD)
77                .or_else(|e| match e {
78                    // Either standard or URL-safe base64 encoding are accepted
79                    //
80                    // The difference being URL-safe uses `-` and `_` instead of `+` and `/`
81                    //
82                    // Therefore if we error out on those characters, try again with
83                    // the URL-safe character set
84                    base64::DecodeError::InvalidByte(_, c) if c == b'-' || c == b'_' => {
85                        base64::decode_config(s, base64::URL_SAFE)
86                    }
87                    _ => Err(e),
88                })
89                .map_err(serde::de::Error::custom)?;
90
91            Ok(Self(decoded.into()))
92        }
93    }
94
95    #[cfg(test)]
96    mod tests {
97        use super::*;
98        use bytes::Bytes;
99        use rand::prelude::*;
100        use serde::de::value::{BorrowedStrDeserializer, Error};
101
102        #[test]
103        fn test_bytes() {
104            for _ in 0..20 {
105                let mut rng = thread_rng();
106                let len = rng.gen_range(50..100);
107                let raw: Vec<_> = std::iter::from_fn(|| Some(rng.gen())).take(len).collect();
108
109                for config in [
110                    base64::STANDARD,
111                    base64::STANDARD_NO_PAD,
112                    base64::URL_SAFE,
113                    base64::URL_SAFE_NO_PAD,
114                ] {
115                    let encoded = base64::encode_config(&raw, config);
116
117                    let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
118                    let a: Bytes = BytesDeserialize::deserialize(deserializer).unwrap().0;
119                    let b: Vec<u8> = BytesDeserialize::deserialize(deserializer).unwrap().0;
120
121                    assert_eq!(raw.as_slice(), &a);
122                    assert_eq!(raw.as_slice(), &b);
123                }
124            }
125        }
126    }
127}