1const ENCODING_TABLE: &[u8; 1 << 6] =
27 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
28const DECODING_TABLE: [u8; 1 << 8] = {
29 let mut table = [u8::MAX; 1 << 8];
30 let mut i = 0;
31 while i < ENCODING_TABLE.len() {
32 table[ENCODING_TABLE[i] as usize] = i as u8;
33 i += 1;
34 }
35 table
36};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Base64Error {
40 MalformedByte { byte: u8, position: usize },
41 InsufficientPadding,
42}
43
44#[derive(Clone, PartialEq, Eq, Hash)]
46pub struct Base64Binary {
47 bin: Vec<u8>,
48}
49
50impl Base64Binary {
51 pub fn encode(iter: impl IntoIterator<Item = u8>) -> Self {
61 iter.into_iter().collect()
62 }
63
64 pub fn decode(&self) -> impl Iterator<Item = u8> + '_ {
75 assert!(self.bin.len() % 4 == 0);
76 self.bin.chunks_exact(4).flat_map(|chunk| {
77 let b0 = DECODING_TABLE[chunk[0] as usize];
78 let b1 = DECODING_TABLE[chunk[1] as usize];
79 let b2 = DECODING_TABLE[chunk[2] as usize];
80 let b3 = DECODING_TABLE[chunk[3] as usize];
81 let mut r0 = Some((b0 << 2) | (b1 >> 4));
82 let mut r1 = (b2 != u8::MAX).then_some(b1.wrapping_shl(4) | (b2 >> 2));
83 let mut r2 = (b3 != u8::MAX).then_some(b2.wrapping_shl(6) | b3);
84 std::iter::from_fn(move || r0.take().or_else(|| r1.take()).or_else(|| r2.take()))
85 })
86 }
87
88 pub fn from_encoded(
95 iter: impl IntoIterator<Item = u8>,
96 allow_whitespace: bool,
97 ) -> Result<Self, Base64Error> {
98 let mut bin = vec![];
99 let mut pad = None;
100 for (position, byte) in iter.into_iter().enumerate() {
101 if allow_whitespace && byte.is_ascii_whitespace() {
102 continue;
103 }
104 if byte == b'=' {
105 pad.get_or_insert((bin.len(), position));
106 } else if DECODING_TABLE[byte as usize] == u8::MAX {
107 return Err(Base64Error::MalformedByte { byte, position });
108 }
109 bin.push(byte);
110 }
111
112 if bin.len() % 4 != 0 {
113 return Err(Base64Error::InsufficientPadding);
114 }
115
116 if let Some((pad, position)) = pad
117 && (bin.len() - pad > 2 || bin[pad..].iter().any(|&b| b != b'='))
118 {
119 return Err(Base64Error::MalformedByte {
120 byte: b'=',
121 position,
122 });
123 }
124
125 Ok(Base64Binary { bin })
126 }
127}
128
129impl FromIterator<u8> for Base64Binary {
130 fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
131 let mut iter = iter.into_iter();
132 let mut bin = vec![];
133 while let Some(b0) = iter.next() {
134 bin.push(ENCODING_TABLE[(b0 >> 2) as usize]);
135 match iter.next() {
136 Some(b1) => {
137 bin.push(ENCODING_TABLE[(((b0 & 0x3) << 4) | (b1 >> 4)) as usize]);
138 match iter.next() {
139 Some(b2) => {
140 bin.push(ENCODING_TABLE[(((b1 & 0xF) << 2) | (b2 >> 6)) as usize]);
141 bin.push(ENCODING_TABLE[(b2 & 0x3F) as usize]);
142 }
143 None => {
144 bin.push(ENCODING_TABLE[((b1 & 0xF) << 2) as usize]);
145 bin.push(b'=');
146 }
147 }
148 }
149 None => {
150 bin.push(ENCODING_TABLE[((b0 & 0x3) << 4) as usize]);
151 bin.push(b'=');
152 bin.push(b'=');
153 }
154 }
155 }
156 Base64Binary { bin }
157 }
158}
159
160impl From<&str> for Base64Binary {
161 fn from(value: &str) -> Self {
162 value.bytes().collect()
163 }
164}
165
166macro_rules! impl_from_str_for_base64_binary {
167 ( $( $t:ty ),* ) => {
168 $(
169 impl From<$t> for Base64Binary {
170 fn from(value: $t) -> Self {
171 value.bytes().collect()
172 }
173 }
174 impl From<&$t> for Base64Binary {
175 fn from(value: &$t) -> Self {
176 value.bytes().collect()
177 }
178 }
179 )*
180 };
181}
182impl_from_str_for_base64_binary!(
183 String,
184 Box<str>,
185 std::rc::Rc<str>,
186 std::sync::Arc<str>,
187 std::borrow::Cow<'_, str>
188);
189
190impl From<&[u8]> for Base64Binary {
191 fn from(value: &[u8]) -> Self {
192 value.iter().copied().collect()
193 }
194}
195macro_rules! impl_from_bytes_for_base64_binary {
196 ( $( $t:ty ),* ) => {
197 $(
198 impl From<$t> for Base64Binary {
199 fn from(value: $t) -> Self {
200 value.iter().copied().collect()
201 }
202 }
203 impl From<&$t> for Base64Binary {
204 fn from(value: &$t) -> Self {
205 value.iter().copied().collect()
206 }
207 }
208 )*
209 };
210}
211impl_from_bytes_for_base64_binary!(
212 Vec<u8>,
213 Box<[u8]>,
214 std::rc::Rc<[u8]>,
215 std::sync::Arc<[u8]>,
216 std::borrow::Cow<'_, [u8]>
217);
218
219impl std::fmt::Debug for Base64Binary {
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 write!(f, "{}", self)
222 }
223}
224
225impl std::fmt::Display for Base64Binary {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 unsafe {
228 write!(f, "{}", std::str::from_utf8_unchecked(&self.bin))
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use std::hash::{BuildHasher, Hasher, RandomState};
239
240 use super::*;
241
242 fn xor_shift32(seed: u64) -> impl Iterator<Item = u32> {
243 let mut random = seed as u32;
244
245 std::iter::repeat_with(move || {
246 random ^= random << 13;
247 random ^= random >> 17;
248 random ^= random << 5;
249 random
250 })
251 }
252
253 fn bytes(seed: u64) -> impl Iterator<Item = u8> {
254 let mut generator = xor_shift32(seed);
255 let mut counter = 0;
256 let mut buf = [0u8; 4];
257 std::iter::from_fn(move || {
258 if counter == 4 {
259 let val = generator.next().unwrap();
260 buf = val.to_le_bytes();
261 counter = 0;
262 }
263 let ret = buf[counter];
264 counter += 1;
265 Some(ret)
266 })
267 }
268
269 #[test]
270 fn regression_tests() {
271 let state = RandomState::new().build_hasher();
272 let seed = state.finish();
273 let mut bytes = bytes(seed);
274 for _ in 0..10000 {
275 let len = bytes.next().unwrap() as usize;
276 let bytes = (0..len).map(|_| bytes.next().unwrap()).collect::<Vec<_>>();
277
278 let encoded = bytes.iter().copied().collect::<Base64Binary>();
279 let decoded = encoded.decode().collect::<Vec<_>>();
280
281 assert_eq!(bytes, decoded, "len: {},{}", bytes.len(), decoded.len());
282 let pad = encoded.bin.iter().filter(|c| **c == b'=').count();
283 match encoded.bin.as_slice() {
284 [.., b'=', b'='] => assert_eq!(pad, 2),
285 [.., b'='] => assert_eq!(pad, 1),
286 [..] => assert_eq!(pad, 0),
287 }
288
289 let encoded2 = Base64Binary::from_encoded(encoded.to_string().bytes(), false).unwrap();
290 assert_eq!(encoded, encoded2);
291 }
292 }
293
294 #[test]
295 fn encoded_bytes_tests() {
296 assert!(Base64Binary::from_encoded(*b"", false).is_ok());
297 let state = RandomState::new().build_hasher();
298 let seed = state.finish();
299 let mut bytes = bytes(seed);
300 for _ in 0..10000 {
301 let len = bytes.next().unwrap() as usize;
302 let len = len.div_ceil(4) * 4;
303 let bytes = bytes
304 .by_ref()
305 .filter(|b| DECODING_TABLE[*b as usize] != u8::MAX)
306 .take(len)
307 .collect::<Vec<_>>();
308
309 let encoded = Base64Binary::from_encoded(bytes, false);
310 assert!(encoded.is_ok());
311 }
312 }
313
314 #[test]
315 fn erroneous_encoded_bytes_tests() {
316 assert!(Base64Binary::from_encoded(*b"a", false).is_err());
317 assert!(Base64Binary::from_encoded(*b"aa", false).is_err());
318 assert!(Base64Binary::from_encoded(*b"aaa", false).is_err());
319 assert!(Base64Binary::from_encoded(*b"aaaaa", false).is_err());
320 assert!(Base64Binary::from_encoded(*b"aaaaaa", false).is_err());
321 assert!(Base64Binary::from_encoded(*b"aaaaaaa", false).is_err());
322
323 assert!(Base64Binary::from_encoded(*b"=", false).is_err());
324 assert!(Base64Binary::from_encoded(*b"==", false).is_err());
325 assert!(Base64Binary::from_encoded(*b"===", false).is_err());
326 assert!(Base64Binary::from_encoded(*b"====", false).is_err());
327
328 assert!(Base64Binary::from_encoded(*b"a=", false).is_err());
329 assert!(Base64Binary::from_encoded(*b"a==", false).is_err());
330 assert!(Base64Binary::from_encoded(*b"a===", false).is_err());
331 }
332}