1const ENCODING_TABLE: &[u8; 1 << 6] =
27 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
28const DECODING_TABLE: [u8; 1 << 8] = {
29 let mut table = [u8::MAX; 1 << 8];
30 let mut i = 0;
31 while i < ENCODING_TABLE.len() {
32 table[ENCODING_TABLE[i] as usize] = i as u8;
33 i += 1;
34 }
35 table
36};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Base64Error {
40 MalformedByte { byte: u8, position: usize },
41 InsufficientPadding,
42}
43
44impl std::fmt::Display for Base64Error {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 write!(f, "{self:?}")
47 }
48}
49
50impl std::error::Error for Base64Error {}
51
52#[derive(Clone, PartialEq, Eq, Hash)]
54pub struct Base64Binary {
55 bin: Vec<u8>,
56}
57
58impl Base64Binary {
59 pub fn encode(iter: impl IntoIterator<Item = u8>) -> Self {
69 iter.into_iter().collect()
70 }
71
72 pub fn decode(&self) -> impl Iterator<Item = u8> + '_ {
83 assert!(self.bin.len() % 4 == 0);
84 self.bin.chunks_exact(4).flat_map(|chunk| {
85 let b0 = DECODING_TABLE[chunk[0] as usize];
86 let b1 = DECODING_TABLE[chunk[1] as usize];
87 let b2 = DECODING_TABLE[chunk[2] as usize];
88 let b3 = DECODING_TABLE[chunk[3] as usize];
89 let mut r0 = Some((b0 << 2) | (b1 >> 4));
90 let mut r1 = (b2 != u8::MAX).then_some(b1.wrapping_shl(4) | (b2 >> 2));
91 let mut r2 = (b3 != u8::MAX).then_some(b2.wrapping_shl(6) | b3);
92 std::iter::from_fn(move || r0.take().or_else(|| r1.take()).or_else(|| r2.take()))
93 })
94 }
95
96 pub fn from_encoded(
103 iter: impl IntoIterator<Item = u8>,
104 allow_whitespace: bool,
105 ) -> Result<Self, Base64Error> {
106 let mut bin = vec![];
107 let mut pad = None;
108 for (position, byte) in iter.into_iter().enumerate() {
109 if allow_whitespace && byte.is_ascii_whitespace() {
110 continue;
111 }
112 if byte == b'=' {
113 pad.get_or_insert((bin.len(), position));
114 } else if DECODING_TABLE[byte as usize] == u8::MAX {
115 return Err(Base64Error::MalformedByte { byte, position });
116 }
117 bin.push(byte);
118 }
119
120 if bin.len() % 4 != 0 {
121 return Err(Base64Error::InsufficientPadding);
122 }
123
124 if let Some((pad, position)) = pad
125 && (bin.len() - pad > 2 || bin[pad..].iter().any(|&b| b != b'='))
126 {
127 return Err(Base64Error::MalformedByte {
128 byte: b'=',
129 position,
130 });
131 }
132
133 Ok(Base64Binary { bin })
134 }
135}
136
137impl FromIterator<u8> for Base64Binary {
138 fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
139 let mut iter = iter.into_iter();
140 let mut bin = vec![];
141 while let Some(b0) = iter.next() {
142 bin.push(ENCODING_TABLE[(b0 >> 2) as usize]);
143 match iter.next() {
144 Some(b1) => {
145 bin.push(ENCODING_TABLE[(((b0 & 0x3) << 4) | (b1 >> 4)) as usize]);
146 match iter.next() {
147 Some(b2) => {
148 bin.push(ENCODING_TABLE[(((b1 & 0xF) << 2) | (b2 >> 6)) as usize]);
149 bin.push(ENCODING_TABLE[(b2 & 0x3F) as usize]);
150 }
151 None => {
152 bin.push(ENCODING_TABLE[((b1 & 0xF) << 2) as usize]);
153 bin.push(b'=');
154 }
155 }
156 }
157 None => {
158 bin.push(ENCODING_TABLE[((b0 & 0x3) << 4) as usize]);
159 bin.push(b'=');
160 bin.push(b'=');
161 }
162 }
163 }
164 Base64Binary { bin }
165 }
166}
167
168impl From<&str> for Base64Binary {
169 fn from(value: &str) -> Self {
170 value.bytes().collect()
171 }
172}
173
174macro_rules! impl_from_str_for_base64_binary {
175 ( $( $t:ty ),* ) => {
176 $(
177 impl From<$t> for Base64Binary {
178 fn from(value: $t) -> Self {
179 value.bytes().collect()
180 }
181 }
182 impl From<&$t> for Base64Binary {
183 fn from(value: &$t) -> Self {
184 value.bytes().collect()
185 }
186 }
187 )*
188 };
189}
190impl_from_str_for_base64_binary!(
191 String,
192 Box<str>,
193 std::rc::Rc<str>,
194 std::sync::Arc<str>,
195 std::borrow::Cow<'_, str>
196);
197
198impl From<&[u8]> for Base64Binary {
199 fn from(value: &[u8]) -> Self {
200 value.iter().copied().collect()
201 }
202}
203macro_rules! impl_from_bytes_for_base64_binary {
204 ( $( $t:ty ),* ) => {
205 $(
206 impl From<$t> for Base64Binary {
207 fn from(value: $t) -> Self {
208 value.iter().copied().collect()
209 }
210 }
211 impl From<&$t> for Base64Binary {
212 fn from(value: &$t) -> Self {
213 value.iter().copied().collect()
214 }
215 }
216 )*
217 };
218}
219impl_from_bytes_for_base64_binary!(
220 Vec<u8>,
221 Box<[u8]>,
222 std::rc::Rc<[u8]>,
223 std::sync::Arc<[u8]>,
224 std::borrow::Cow<'_, [u8]>
225);
226
227impl std::fmt::Debug for Base64Binary {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 write!(f, "{}", self)
230 }
231}
232
233impl std::fmt::Display for Base64Binary {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 unsafe {
236 write!(f, "{}", std::str::from_utf8_unchecked(&self.bin))
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use std::hash::{BuildHasher, Hasher, RandomState};
247
248 use super::*;
249
250 fn xor_shift32(seed: u64) -> impl Iterator<Item = u32> {
251 let mut random = seed as u32;
252
253 std::iter::repeat_with(move || {
254 random ^= random << 13;
255 random ^= random >> 17;
256 random ^= random << 5;
257 random
258 })
259 }
260
261 fn bytes(seed: u64) -> impl Iterator<Item = u8> {
262 let mut generator = xor_shift32(seed);
263 let mut counter = 0;
264 let mut buf = [0u8; 4];
265 std::iter::from_fn(move || {
266 if counter == 4 {
267 let val = generator.next().unwrap();
268 buf = val.to_le_bytes();
269 counter = 0;
270 }
271 let ret = buf[counter];
272 counter += 1;
273 Some(ret)
274 })
275 }
276
277 #[test]
278 fn regression_tests() {
279 let state = RandomState::new().build_hasher();
280 let seed = state.finish();
281 let mut bytes = bytes(seed);
282 for _ in 0..10000 {
283 let len = bytes.next().unwrap() as usize;
284 let bytes = (0..len).map(|_| bytes.next().unwrap()).collect::<Vec<_>>();
285
286 let encoded = bytes.iter().copied().collect::<Base64Binary>();
287 let decoded = encoded.decode().collect::<Vec<_>>();
288
289 assert_eq!(bytes, decoded, "len: {},{}", bytes.len(), decoded.len());
290 let pad = encoded.bin.iter().filter(|c| **c == b'=').count();
291 match encoded.bin.as_slice() {
292 [.., b'=', b'='] => assert_eq!(pad, 2),
293 [.., b'='] => assert_eq!(pad, 1),
294 [..] => assert_eq!(pad, 0),
295 }
296
297 let encoded2 = Base64Binary::from_encoded(encoded.to_string().bytes(), false).unwrap();
298 assert_eq!(encoded, encoded2);
299 }
300 }
301
302 #[test]
303 fn encoded_bytes_tests() {
304 assert!(Base64Binary::from_encoded(*b"", false).is_ok());
305 let state = RandomState::new().build_hasher();
306 let seed = state.finish();
307 let mut bytes = bytes(seed);
308 for _ in 0..10000 {
309 let len = bytes.next().unwrap() as usize;
310 let len = len.div_ceil(4) * 4;
311 let bytes = bytes
312 .by_ref()
313 .filter(|b| DECODING_TABLE[*b as usize] != u8::MAX)
314 .take(len)
315 .collect::<Vec<_>>();
316
317 let encoded = Base64Binary::from_encoded(bytes, false);
318 assert!(encoded.is_ok());
319 }
320 }
321
322 #[test]
323 fn erroneous_encoded_bytes_tests() {
324 assert!(Base64Binary::from_encoded(*b"a", false).is_err());
325 assert!(Base64Binary::from_encoded(*b"aa", false).is_err());
326 assert!(Base64Binary::from_encoded(*b"aaa", false).is_err());
327 assert!(Base64Binary::from_encoded(*b"aaaaa", false).is_err());
328 assert!(Base64Binary::from_encoded(*b"aaaaaa", false).is_err());
329 assert!(Base64Binary::from_encoded(*b"aaaaaaa", false).is_err());
330
331 assert!(Base64Binary::from_encoded(*b"=", false).is_err());
332 assert!(Base64Binary::from_encoded(*b"==", false).is_err());
333 assert!(Base64Binary::from_encoded(*b"===", false).is_err());
334 assert!(Base64Binary::from_encoded(*b"====", false).is_err());
335
336 assert!(Base64Binary::from_encoded(*b"a=", false).is_err());
337 assert!(Base64Binary::from_encoded(*b"a==", false).is_err());
338 assert!(Base64Binary::from_encoded(*b"a===", false).is_err());
339 }
340}