1#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
12#![warn(
13 missing_debug_implementations,
14 clippy::explicit_iter_loop,
15 clippy::use_self,
16 clippy::clone_on_ref_ptr,
17 clippy::future_not_send
18)]
19
20#[doc(hidden)]
21pub mod private {
22 pub use base64;
24
25 use base64::Engine;
26 use base64::engine::DecodePaddingMode;
27 use base64::engine::{GeneralPurpose, GeneralPurposeConfig};
28 use serde::Deserialize;
29 use serde::de::Visitor;
30 use std::borrow::Cow;
31 use std::str::FromStr;
32
33 #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
35 pub struct NumberDeserialize<T>(pub T);
36
37 #[derive(Deserialize)]
38 #[serde(untagged)]
39 enum Content<'a, T> {
40 #[serde(borrow)]
41 Str(Cow<'a, str>),
42 Number(T),
43 }
44
45 impl<'de, T> serde::Deserialize<'de> for NumberDeserialize<T>
46 where
47 T: FromStr + serde::Deserialize<'de>,
48 <T as FromStr>::Err: std::error::Error,
49 {
50 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
51 where
52 D: serde::Deserializer<'de>,
53 {
54 let content = Content::deserialize(deserializer)?;
55 Ok(Self(match content {
56 Content::Str(v) => v.parse().map_err(serde::de::Error::custom)?,
57 Content::Number(v) => v,
58 }))
59 }
60 }
61
62 struct Base64Visitor;
63
64 impl<'de> Visitor<'de> for Base64Visitor {
65 type Value = Vec<u8>;
66
67 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 formatter.write_str("a base64 string")
69 }
70
71 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
72 where
73 E: serde::de::Error,
74 {
75 const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
76 .with_decode_padding_mode(DecodePaddingMode::Indifferent);
77 const STANDARD_INDIFFERENT_PAD: GeneralPurpose =
78 GeneralPurpose::new(&base64::alphabet::STANDARD, INDIFFERENT_PAD);
79 const URL_SAFE_INDIFFERENT_PAD: GeneralPurpose =
80 GeneralPurpose::new(&base64::alphabet::URL_SAFE, INDIFFERENT_PAD);
81
82 let decoded = STANDARD_INDIFFERENT_PAD
83 .decode(s)
84 .or_else(|e| match e {
85 base64::DecodeError::InvalidByte(_, c) if c == b'-' || c == b'_' => {
92 URL_SAFE_INDIFFERENT_PAD.decode(s)
93 }
94 _ => Err(e),
95 })
96 .map_err(serde::de::Error::custom)?;
97 Ok(decoded)
98 }
99 }
100
101 #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
102 pub struct BytesDeserialize<T>(pub T);
103
104 impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
105 where
106 T: From<Vec<u8>>,
107 {
108 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
109 where
110 D: serde::Deserializer<'de>,
111 {
112 Ok(Self(deserializer.deserialize_str(Base64Visitor)?.into()))
113 }
114 }
115
116 #[cfg(test)]
117 mod tests {
118 use super::*;
119 use base64::Engine;
120 use bytes::Bytes;
121 use rand::prelude::*;
122 use serde::de::value::{BorrowedStrDeserializer, Error};
123
124 #[test]
125 fn test_bytes() {
126 for _ in 0..20 {
127 let mut rng = rand::rng();
128 let len = rng.random_range(50..100);
129 let raw: Vec<_> = std::iter::from_fn(|| Some(rng.random()))
130 .take(len)
131 .collect();
132
133 for config in [
134 base64::engine::general_purpose::STANDARD,
135 base64::engine::general_purpose::STANDARD_NO_PAD,
136 base64::engine::general_purpose::URL_SAFE,
137 base64::engine::general_purpose::URL_SAFE_NO_PAD,
138 ] {
139 let encoded = config.encode(&raw);
140
141 let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
142 let a: Bytes = BytesDeserialize::deserialize(deserializer).unwrap().0;
143 let b: Vec<u8> = BytesDeserialize::deserialize(deserializer).unwrap().0;
144
145 assert_eq!(raw.as_slice(), &a);
146 assert_eq!(raw.as_slice(), &b);
147 }
148 }
149 }
150 }
151}