pbjson/
lib.rs

1//! `pbjson` is a set of crates to automatically generate [`serde::Serialize`] and
2//! [`serde::Deserialize`] implementations for [prost][1] generated structs that
3//! are compliant with the [protobuf JSON mapping][2]
4//!
5//! See [pbjson-build][3] for usage instructions
6//!
7//! [1]: https://github.com/tokio-rs/prost
8//! [2]: https://developers.google.com/protocol-buffers/docs/proto3#json
9//! [3]: https://docs.rs/pbjson-build
10//!
11#![no_std]
12#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
13#![warn(
14    missing_debug_implementations,
15    clippy::explicit_iter_loop,
16    clippy::use_self,
17    clippy::clone_on_ref_ptr,
18    clippy::future_not_send
19)]
20
21extern crate alloc;
22
23#[doc(hidden)]
24pub mod private {
25    /// Re-export base64
26    pub use base64;
27
28    use alloc::borrow::Cow;
29    use alloc::str::FromStr;
30    use alloc::vec::Vec;
31
32    use base64::engine::DecodePaddingMode;
33    use base64::engine::{GeneralPurpose, GeneralPurposeConfig};
34    use base64::Engine;
35    use serde::de::Visitor;
36    use serde::Deserialize;
37
38    /// Used to parse a number from either a string or its raw representation
39    #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
40    pub struct NumberDeserialize<T>(pub T);
41
42    #[derive(Deserialize)]
43    #[serde(untagged)]
44    enum Content<'a, T> {
45        #[serde(borrow)]
46        Str(Cow<'a, str>),
47        Number(T),
48    }
49
50    impl<'de, T> serde::Deserialize<'de> for NumberDeserialize<T>
51    where
52        T: FromStr + serde::Deserialize<'de>,
53        <T as FromStr>::Err: core::fmt::Display, // std::error::Error,
54    {
55        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
56        where
57            D: serde::Deserializer<'de>,
58        {
59            let content = Content::deserialize(deserializer)?;
60            Ok(Self(match content {
61                Content::Str(v) => v.parse().map_err(serde::de::Error::custom)?,
62                Content::Number(v) => v,
63            }))
64        }
65    }
66
67    struct Base64Visitor;
68
69    impl<'de> Visitor<'de> for Base64Visitor {
70        type Value = Vec<u8>;
71
72        fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
73            formatter.write_str("a base64 string")
74        }
75
76        fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
77        where
78            E: serde::de::Error,
79        {
80            const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
81                .with_decode_padding_mode(DecodePaddingMode::Indifferent);
82            const STANDARD_INDIFFERENT_PAD: GeneralPurpose =
83                GeneralPurpose::new(&base64::alphabet::STANDARD, INDIFFERENT_PAD);
84            const URL_SAFE_INDIFFERENT_PAD: GeneralPurpose =
85                GeneralPurpose::new(&base64::alphabet::URL_SAFE, INDIFFERENT_PAD);
86
87            let decoded = STANDARD_INDIFFERENT_PAD
88                .decode(s)
89                .or_else(|e| match e {
90                    // Either standard or URL-safe base64 encoding are accepted
91                    //
92                    // The difference being URL-safe uses `-` and `_` instead of `+` and `/`
93                    //
94                    // Therefore if we error out on those characters, try again with
95                    // the URL-safe character set
96                    base64::DecodeError::InvalidByte(_, c) if c == b'-' || c == b'_' => {
97                        URL_SAFE_INDIFFERENT_PAD.decode(s)
98                    }
99                    _ => Err(e),
100                })
101                .map_err(serde::de::Error::custom)?;
102            Ok(decoded)
103        }
104    }
105
106    #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
107    pub struct BytesDeserialize<T>(pub T);
108
109    impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
110    where
111        T: From<Vec<u8>>,
112    {
113        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
114        where
115            D: serde::Deserializer<'de>,
116        {
117            Ok(Self(deserializer.deserialize_str(Base64Visitor)?.into()))
118        }
119    }
120
121    #[cfg(test)]
122    mod tests {
123        use super::*;
124        use base64::Engine;
125        use bytes::Bytes;
126        use rand::prelude::*;
127        use serde::de::value::{BorrowedStrDeserializer, Error};
128
129        #[test]
130        fn test_bytes() {
131            for _ in 0..20 {
132                let mut rng = thread_rng();
133                let len = rng.gen_range(50..100);
134                let raw: Vec<_> = core::iter::from_fn(|| Some(rng.gen())).take(len).collect();
135
136                for config in [
137                    base64::engine::general_purpose::STANDARD,
138                    base64::engine::general_purpose::STANDARD_NO_PAD,
139                    base64::engine::general_purpose::URL_SAFE,
140                    base64::engine::general_purpose::URL_SAFE_NO_PAD,
141                ] {
142                    let encoded = config.encode(&raw);
143
144                    let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
145                    let a: Bytes = BytesDeserialize::deserialize(deserializer).unwrap().0;
146                    let b: Vec<u8> = BytesDeserialize::deserialize(deserializer).unwrap().0;
147
148                    assert_eq!(raw.as_slice(), &a);
149                    assert_eq!(raw.as_slice(), &b);
150                }
151            }
152        }
153    }
154}