pbjson/
lib.rs

1//! `pbjson` is a set of crates to automatically generate [`serde::Serialize`] and
2//! [`serde::Deserialize`] implementations for [prost][1] generated structs that
3//! are compliant with the [protobuf JSON mapping][2]
4//!
5//! See [pbjson-build][3] for usage instructions
6//!
7//! [1]: https://github.com/tokio-rs/prost
8//! [2]: https://developers.google.com/protocol-buffers/docs/proto3#json
9//! [3]: https://docs.rs/pbjson-build
10//!
11#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
12#![warn(
13    missing_debug_implementations,
14    clippy::explicit_iter_loop,
15    clippy::use_self,
16    clippy::clone_on_ref_ptr,
17    clippy::future_not_send
18)]
19
20#[doc(hidden)]
21pub mod private {
22    /// Re-export base64
23    pub use base64;
24
25    use base64::Engine;
26    use base64::engine::DecodePaddingMode;
27    use base64::engine::{GeneralPurpose, GeneralPurposeConfig};
28    use serde::Deserialize;
29    use serde::de::Visitor;
30    use std::borrow::Cow;
31    use std::str::FromStr;
32
33    /// Used to parse a number from either a string or its raw representation
34    #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
35    pub struct NumberDeserialize<T>(pub T);
36
37    #[derive(Deserialize)]
38    #[serde(untagged)]
39    enum Content<'a, T> {
40        #[serde(borrow)]
41        Str(Cow<'a, str>),
42        Number(T),
43    }
44
45    impl<'de, T> serde::Deserialize<'de> for NumberDeserialize<T>
46    where
47        T: FromStr + serde::Deserialize<'de>,
48        <T as FromStr>::Err: std::error::Error,
49    {
50        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51        where
52            D: serde::Deserializer<'de>,
53        {
54            let content = Content::deserialize(deserializer)?;
55            Ok(Self(match content {
56                Content::Str(v) => v.parse().map_err(serde::de::Error::custom)?,
57                Content::Number(v) => v,
58            }))
59        }
60    }
61
62    struct Base64Visitor;
63
64    impl<'de> Visitor<'de> for Base64Visitor {
65        type Value = Vec<u8>;
66
67        fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68            formatter.write_str("a base64 string")
69        }
70
71        fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
72        where
73            E: serde::de::Error,
74        {
75            const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
76                .with_decode_padding_mode(DecodePaddingMode::Indifferent);
77            const STANDARD_INDIFFERENT_PAD: GeneralPurpose =
78                GeneralPurpose::new(&base64::alphabet::STANDARD, INDIFFERENT_PAD);
79            const URL_SAFE_INDIFFERENT_PAD: GeneralPurpose =
80                GeneralPurpose::new(&base64::alphabet::URL_SAFE, INDIFFERENT_PAD);
81
82            let decoded = STANDARD_INDIFFERENT_PAD
83                .decode(s)
84                .or_else(|e| match e {
85                    // Either standard or URL-safe base64 encoding are accepted
86                    //
87                    // The difference being URL-safe uses `-` and `_` instead of `+` and `/`
88                    //
89                    // Therefore if we error out on those characters, try again with
90                    // the URL-safe character set
91                    base64::DecodeError::InvalidByte(_, c) if c == b'-' || c == b'_' => {
92                        URL_SAFE_INDIFFERENT_PAD.decode(s)
93                    }
94                    _ => Err(e),
95                })
96                .map_err(serde::de::Error::custom)?;
97            Ok(decoded)
98        }
99    }
100
101    #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
102    pub struct BytesDeserialize<T>(pub T);
103
104    impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
105    where
106        T: From<Vec<u8>>,
107    {
108        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
109        where
110            D: serde::Deserializer<'de>,
111        {
112            Ok(Self(deserializer.deserialize_str(Base64Visitor)?.into()))
113        }
114    }
115
116    #[cfg(test)]
117    mod tests {
118        use super::*;
119        use base64::Engine;
120        use bytes::Bytes;
121        use rand::prelude::*;
122        use serde::de::value::{BorrowedStrDeserializer, Error};
123
124        #[test]
125        fn test_bytes() {
126            for _ in 0..20 {
127                let mut rng = rand::rng();
128                let len = rng.random_range(50..100);
129                let raw: Vec<_> = std::iter::from_fn(|| Some(rng.random()))
130                    .take(len)
131                    .collect();
132
133                for config in [
134                    base64::engine::general_purpose::STANDARD,
135                    base64::engine::general_purpose::STANDARD_NO_PAD,
136                    base64::engine::general_purpose::URL_SAFE,
137                    base64::engine::general_purpose::URL_SAFE_NO_PAD,
138                ] {
139                    let encoded = config.encode(&raw);
140
141                    let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
142                    let a: Bytes = BytesDeserialize::deserialize(deserializer).unwrap().0;
143                    let b: Vec<u8> = BytesDeserialize::deserialize(deserializer).unwrap().0;
144
145                    assert_eq!(raw.as_slice(), &a);
146                    assert_eq!(raw.as_slice(), &b);
147                }
148            }
149        }
150    }
151}