1#![no_std]
12#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)]
13#![warn(
14 missing_debug_implementations,
15 clippy::explicit_iter_loop,
16 clippy::use_self,
17 clippy::clone_on_ref_ptr,
18 clippy::future_not_send
19)]
20
21extern crate alloc;
22
23#[doc(hidden)]
24pub mod private {
25 pub use base64;
27
28 use alloc::borrow::Cow;
29 use alloc::str::FromStr;
30 use alloc::vec::Vec;
31
32 use base64::engine::DecodePaddingMode;
33 use base64::engine::{GeneralPurpose, GeneralPurposeConfig};
34 use base64::Engine;
35 use serde::de::Visitor;
36 use serde::Deserialize;
37
38 #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
40 pub struct NumberDeserialize<T>(pub T);
41
42 #[derive(Deserialize)]
43 #[serde(untagged)]
44 enum Content<'a, T> {
45 #[serde(borrow)]
46 Str(Cow<'a, str>),
47 Number(T),
48 }
49
50 impl<'de, T> serde::Deserialize<'de> for NumberDeserialize<T>
51 where
52 T: FromStr + serde::Deserialize<'de>,
53 <T as FromStr>::Err: core::fmt::Display, {
55 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
56 where
57 D: serde::Deserializer<'de>,
58 {
59 let content = Content::deserialize(deserializer)?;
60 Ok(Self(match content {
61 Content::Str(v) => v.parse().map_err(serde::de::Error::custom)?,
62 Content::Number(v) => v,
63 }))
64 }
65 }
66
67 struct Base64Visitor;
68
69 impl<'de> Visitor<'de> for Base64Visitor {
70 type Value = Vec<u8>;
71
72 fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
73 formatter.write_str("a base64 string")
74 }
75
76 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
77 where
78 E: serde::de::Error,
79 {
80 const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
81 .with_decode_padding_mode(DecodePaddingMode::Indifferent);
82 const STANDARD_INDIFFERENT_PAD: GeneralPurpose =
83 GeneralPurpose::new(&base64::alphabet::STANDARD, INDIFFERENT_PAD);
84 const URL_SAFE_INDIFFERENT_PAD: GeneralPurpose =
85 GeneralPurpose::new(&base64::alphabet::URL_SAFE, INDIFFERENT_PAD);
86
87 let decoded = STANDARD_INDIFFERENT_PAD
88 .decode(s)
89 .or_else(|e| match e {
90 base64::DecodeError::InvalidByte(_, c) if c == b'-' || c == b'_' => {
97 URL_SAFE_INDIFFERENT_PAD.decode(s)
98 }
99 _ => Err(e),
100 })
101 .map_err(serde::de::Error::custom)?;
102 Ok(decoded)
103 }
104 }
105
106 #[derive(Debug, Copy, Clone, PartialOrd, PartialEq, Hash, Ord, Eq)]
107 pub struct BytesDeserialize<T>(pub T);
108
109 impl<'de, T> Deserialize<'de> for BytesDeserialize<T>
110 where
111 T: From<Vec<u8>>,
112 {
113 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
114 where
115 D: serde::Deserializer<'de>,
116 {
117 Ok(Self(deserializer.deserialize_str(Base64Visitor)?.into()))
118 }
119 }
120
121 #[cfg(test)]
122 mod tests {
123 use super::*;
124 use base64::Engine;
125 use bytes::Bytes;
126 use rand::prelude::*;
127 use serde::de::value::{BorrowedStrDeserializer, Error};
128
129 #[test]
130 fn test_bytes() {
131 for _ in 0..20 {
132 let mut rng = thread_rng();
133 let len = rng.gen_range(50..100);
134 let raw: Vec<_> = core::iter::from_fn(|| Some(rng.gen())).take(len).collect();
135
136 for config in [
137 base64::engine::general_purpose::STANDARD,
138 base64::engine::general_purpose::STANDARD_NO_PAD,
139 base64::engine::general_purpose::URL_SAFE,
140 base64::engine::general_purpose::URL_SAFE_NO_PAD,
141 ] {
142 let encoded = config.encode(&raw);
143
144 let deserializer = BorrowedStrDeserializer::<'_, Error>::new(&encoded);
145 let a: Bytes = BytesDeserialize::deserialize(deserializer).unwrap().0;
146 let b: Vec<u8> = BytesDeserialize::deserialize(deserializer).unwrap().0;
147
148 assert_eq!(raw.as_slice(), &a);
149 assert_eq!(raw.as_slice(), &b);
150 }
151 }
152 }
153 }
154}