1#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
8mod avx2;
9mod lut_align64;
10
11use alloc::vec::Vec;
12use core::cmp;
13use core::fmt;
14
15#[must_use]
16struct BlockResult {
17 out_length: u8,
18 first_invalid: Option<u8>,
19}
20
21#[derive(Debug, Clone, Copy)]
23pub enum Error {
24 InvalidLength,
26 InvalidTrailer,
28 InvalidCharacter(usize),
31}
32
33impl fmt::Display for Error {
34 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35 fmt::Debug::fmt(&self, f)
36 }
37}
38
39trait Decoder: Copy {
40 type Block: AsRef<[u8]> + AsMut<[u8]>;
41
42 fn decode_block(self, block: &mut Self::Block) -> BlockResult;
43 fn zero_block() -> Self::Block;
44}
45
46trait Packer: Copy {
47 type Input: AsRef<[u8]> + AsMut<[u8]> + Default;
48 const OUT_BUF_LEN: usize;
49
50 fn pack_block(self, input: &Self::Input, output: &mut [u8]);
52}
53
54#[derive(Copy, Clone)]
55struct Simple;
56
57impl Packer for Simple {
58 type Input = [u8; 4];
59 const OUT_BUF_LEN: usize = 3;
60
61 #[inline]
62 fn pack_block(self, input: &Self::Input, output: &mut [u8]) {
63 output[0] = (input[0] << 2) | (input[1] >> 4);
64 output[1] = (input[1] << 4) | (input[2] >> 2);
65 output[2] = (input[2] << 6) | (input[3] >> 0);
66 }
67}
68
69struct PackState<P: Packer> {
70 packer: P,
71 cache: P::Input,
72 pos: usize,
73}
74
75impl<P: Packer> PackState<P> {
76 fn extend(&mut self, mut input: &[u8], out: &mut Vec<u8>) {
77 while !input.is_empty() {
78 let (_, cache_end) = self.cache.as_mut().split_at_mut(self.pos);
79 let (input_start, input_rest) = input.split_at(cmp::min(input.len(), cache_end.len()));
80 input = input_rest;
81 cache_end[..input_start.len()].copy_from_slice(input_start);
82 if input_start.len() != cache_end.len() {
83 self.pos += input_start.len();
84 } else {
85 let out_start = out.len();
86 out.resize(out_start + P::OUT_BUF_LEN, 0);
87 self.packer.pack_block(&self.cache, &mut out[out_start..]);
88 out.truncate(out_start + (core::mem::size_of::<P::Input>() / 4 * 3));
89 self.pos = 0;
90 }
91 }
92 }
93
94 fn flush(&mut self, out: &mut Vec<u8>, trailer_length: Option<usize>) -> Result<(), Error> {
95 if self.pos % 4 == 1 {
96 return Err(Error::InvalidLength);
97 }
98
99 if let Some(trailer_length) = trailer_length {
100 if (self.pos + trailer_length) % 4 != 0 {
101 return Err(Error::InvalidTrailer);
102 }
103 }
104
105 self.cache.as_mut()[self.pos] = 0;
106 let out_start = out.len();
107 out.resize(out.len() + P::OUT_BUF_LEN, 0);
108 self.packer.pack_block(&self.cache, &mut out[out_start..]);
109 out.truncate(out_start + (self.pos * 3 / 4));
110 Ok(())
111 }
112}
113
114fn decode64<D: Decoder, P: Packer>(input: &[u8], decoder: D, packer: P) -> Result<Vec<u8>, Error> {
115 if input.is_empty() {
116 return Ok(Vec::new());
117 }
118
119 let p_in_len = core::mem::size_of::<P::Input>();
120 let p_out_len = p_in_len / 4 * 3;
121 let cap =
122 crate::misc::div_roundup(input.len(), p_in_len) * p_out_len - p_out_len + P::OUT_BUF_LEN;
123 let mut out = Vec::with_capacity(cap);
124
125 let mut packer = PackState::<P> {
126 packer,
127 cache: P::Input::default(),
128 pos: 0,
129 };
130
131 let mut trailer_length = None;
132 for (chunk, chunk_start) in input
133 .chunks(core::mem::size_of::<D::Block>())
134 .zip((0..).step_by(core::mem::size_of::<D::Block>()))
135 {
136 let mut block = D::zero_block();
137 block.as_mut()[..chunk.len()].copy_from_slice(chunk);
138 let result = decoder.decode_block(&mut block);
139
140 if let Some(idx) = result.first_invalid {
141 let idx = idx as usize;
142 if input[chunk_start + idx] == b'=' {
143 let rest_start = chunk_start + idx + 1;
144 let rest = &input[rest_start..];
145 let mut iter = rest
146 .iter()
147 .enumerate()
148 .filter(|(_, c)| !c.is_ascii_whitespace());
149 trailer_length = match (iter.next(), iter.next()) {
150 (None, _) => Some(1),
151 (Some((_, b'=')), None) => Some(2),
152 (Some((_, b'=')), Some((i, _))) | (Some((i, _)), _) => {
153 return Err(Error::InvalidCharacter(rest_start + i))
154 }
155 };
156 } else {
157 return Err(Error::InvalidCharacter(chunk_start + idx));
158 }
159 }
160
161 packer.extend(&block.as_ref()[..(result.out_length as _)], &mut out);
162
163 if trailer_length.is_some() {
164 break;
165 }
166 }
167
168 packer.flush(&mut out, trailer_length)?;
169
170 Ok(out)
171}
172
173pub(super) fn decode64_arch(input: &[u8]) -> Result<Vec<u8>, Error> {
174 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
175 unsafe {
176 if is_x86_feature_detected!("avx2")
177 && is_x86_feature_detected!("bmi1")
178 && is_x86_feature_detected!("sse4.2")
179 && is_x86_feature_detected!("popcnt")
180 {
181 let avx2 = avx2::Avx2::new();
182 return decode64(input, avx2, avx2);
183 }
184 }
185 decode64(input, lut_align64::LutAlign64, Simple)
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 use crate::test_support::rand_base64_size;
193 use crate::ToBase64;
194
195 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
196 pub(super) fn test_avx2() -> avx2::Avx2 {
197 unsafe { avx2::Avx2::new() }
198 }
199
200 generate_tests![
201 decoders<D>: {
202 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] avx2, test_avx2();
203 lut_align64, lut_align64::LutAlign64;
204 },
205 packers<P>: {
206 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] avx2, test_avx2();
207 simple, Simple;
208 },
209 tests: {
210 decode,
211 decode_equivalency,
212 decode_error,
213 cmp_rand_1kb,
214 whitespace_skipped,
215 all_bytes,
216 wrapping_base64,
217 },
218 ];
219
220 fn decode<D: Decoder, P: Packer>(decoder: D, packer: P) {
221 static DECODE_TESTS: &[(&[u8], &[u8])] = &[
222 (b"", b""),
224 (b"Zg==", b"f"),
225 (b"Zm8=", b"fo"),
226 (b"Zm9v", b"foo"),
227 (b"Zm9vYg==", b"foob"),
228 (b"Zm9vYmE=", b"fooba"),
229 (b"Zm9vYmFy", b"foobar"),
230 (b"Zm9v\r\nYmFy", b"foobar"),
232 (b"Zm9vYg==\r\n", b"foob"),
233 (b"Zm9v\nYmFy", b"foobar"),
234 (b"Zm9vYg==\n", b"foob"),
235 (b"Zm9vYg = = ", b"foob"),
237 ];
238
239 for (input, expected) in DECODE_TESTS {
240 let output = decode64(input, decoder, packer).unwrap();
241 if &output != expected {
242 panic!(
243 "Test failed. Expected specific output. \n\nInput: {}\nOutput: {:02x?}\nExpected output:{:02x?}\n\n",
244 std::str::from_utf8(input).unwrap(),
245 output,
246 expected
247 );
248 }
249 }
250 }
251
252 fn decode_equivalency<D: Decoder, P: Packer>(decoder: D, packer: P) {
253 static DECODE_EQUIVALENCY_TESTS: &[(&[u8], &[u8])] = &[
254 (b"-_8", b"+/8="),
256 ];
257
258 for (input1, input2) in DECODE_EQUIVALENCY_TESTS {
259 let output1 = decode64(input1, decoder, packer).unwrap();
260 let output2 = decode64(input2, decoder, packer).unwrap();
261 if output1 != output2 {
262 panic!(
263 "Test failed. Expected same output.\n\nInput 1: {}\nInput 2: {}\nOutput 1: {:02x?}\nOutput 2:{:02x?}\n\n",
264 std::str::from_utf8(input1).unwrap(),
265 std::str::from_utf8(input2).unwrap(),
266 output1,
267 output2
268 );
269 }
270 }
271 }
272
273 fn decode_error<D: Decoder, P: Packer>(decoder: D, packer: P) {
274 #[rustfmt::skip]
275 static DECODE_ERROR_TESTS: &[&[u8]] = &[
276 b"Zm$=",
278 b"Zg==$",
279 b"Z===",
281 ];
282
283 for input in DECODE_ERROR_TESTS {
284 if decode64(input, decoder, packer).is_ok() {
285 panic!(
286 "Test failed. Expected error.\n\nInput: {}\n\n",
287 std::str::from_utf8(input).unwrap(),
288 );
289 }
290 }
291 }
292
293 fn cmp_rand_1kb<D: Decoder, P: Packer>(decoder: D, packer: P) {
294 let input = rand_base64_size(1024);
295
296 let output1 = decode64(&input, decoder, packer).unwrap();
297 let output2 = decode64(&input, lut_align64::LutAlign64, Simple).unwrap();
298 if output1 != output2 {
299 panic!(
300 "Test failed. Expected same output.\n\nInput: {}\nOutput 1: {:02x?}\nOutput 2:{:02x?}\n\n",
301 std::str::from_utf8(&input).unwrap(),
302 output1,
303 output2
304 );
305 }
306 }
307
308 fn whitespace_skipped<D: Decoder, P: Packer>(decoder: D, packer: P) {
309 let input1 = rand_base64_size(32);
310 use core::iter::once;
311 let input2 = input1
312 .iter()
313 .flat_map(|&c| once(c).chain(once(b' ')))
314 .collect::<Vec<_>>();
315
316 let output1 = decode64(&input1, decoder, packer).unwrap();
317 let output2 = decode64(&input2, decoder, packer).unwrap();
318 if output1 != output2 {
319 panic!(
320 "Test failed. Expected same output.\n\nInput 1: {}\nInput 2: {}\nOutput 1: {:02x?}\nOutput 2:{:02x?}\n\n",
321 std::str::from_utf8(&input1).unwrap(),
322 std::str::from_utf8(&input2).unwrap(),
323 output1,
324 output2
325 );
326 }
327 }
328
329 fn all_bytes<D: Decoder, P: Packer>(decoder: D, packer: P) {
330 let mut set = std::vec![Err(()); 256];
331 for (i, &b) in crate::misc::LUT_STANDARD.iter().enumerate() {
332 set[b as usize] = Ok(Some(i as u8));
333 }
334 set[b'-' as usize] = Ok(Some(62));
336 set[b'_' as usize] = Ok(Some(63));
337 set[b' ' as usize] = Ok(None);
339 set[b'\n' as usize] = Ok(None);
340 set[b'\t' as usize] = Ok(None);
341 set[b'\r' as usize] = Ok(None);
342 set[0x0c] = Ok(None);
343
344 for (i, &expected) in set.iter().enumerate() {
345 let output = match decode64(&[i as u8, i as u8], decoder, packer)
346 .as_ref()
347 .map(|v| &v[..])
348 {
349 Ok(&[]) => Ok(None),
350 Ok(&[v]) => Ok(Some(v >> 2)),
351 Ok(_) => panic!("Result is more than 1 byte long"),
352 Err(_) => Err(()),
353 };
354 assert_eq!(output, expected);
355 }
356 }
357
358 fn wrapping_base64<D: Decoder, P: Packer>(decoder: D, packer: P) {
359 const BASE64_PEM_WRAP: usize = 64;
360
361 static BASE64_PEM: crate::Config = crate::Config {
362 char_set: crate::CharacterSet::Standard,
363 newline: crate::Newline::LF,
364 pad: true,
365 line_length: Some(BASE64_PEM_WRAP),
366 };
367
368 let mut v: Vec<u8> = vec![];
369 let bytes_per_line = BASE64_PEM_WRAP * 3 / 4;
370 for _i in 0..(2 * bytes_per_line) {
371 let encoded = v.to_base64(BASE64_PEM);
372 let decoded = decode64(encoded.as_bytes(), decoder, packer).unwrap();
373 assert_eq!(v, decoded);
374 v.push(0);
375 }
376
377 v = vec![];
378 for _i in 0..1000 {
379 let encoded = v.to_base64(BASE64_PEM);
380 let decoded = decode64(encoded.as_bytes(), decoder, packer).unwrap();
381 assert_eq!(v, decoded);
382 v.push(rand::random::<u8>());
383 }
384 }
385
386 #[test]
387 fn display_errors() {
388 println!("Invalid length is {}", Error::InvalidLength);
389 println!("Invalid trailer is {}", Error::InvalidTrailer);
390 println!("Invalid character is {}", Error::InvalidCharacter(0));
391 }
392}
393
394#[cfg(all(test, feature = "nightly"))]
395mod benches {
396 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
397 use super::tests::test_avx2;
398 use super::*;
399
400 use test::Bencher;
401
402 use crate::test_support::rand_base64_size;
403
404 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
405 #[bench]
406 fn avx2_1mb(b: &mut Bencher) {
407 let input = rand_base64_size(1024 * 1024);
408 b.iter(|| {
409 let ret = decode64(&input, test_avx2(), test_avx2()).unwrap();
410 std::hint::black_box(ret);
411 });
412 }
413
414 #[bench]
415 fn lut_align64_1mb(b: &mut Bencher) {
416 let input = rand_base64_size(1024 * 1024);
417 b.iter(|| {
418 let ret = decode64(&input, lut_align64::LutAlign64, Simple).unwrap();
419 std::hint::black_box(ret);
420 });
421 }
422
423 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
424 #[bench]
425 fn avx2_1kb(b: &mut Bencher) {
426 let input = rand_base64_size(1024);
427 b.iter(|| {
428 let ret = decode64(&input, test_avx2(), test_avx2()).unwrap();
429 std::hint::black_box(ret);
430 });
431 }
432
433 #[bench]
434 fn lut_align64_1kb(b: &mut Bencher) {
435 let input = rand_base64_size(1024);
436 b.iter(|| {
437 let ret = decode64(&input, lut_align64::LutAlign64, Simple).unwrap();
438 std::hint::black_box(ret);
439 });
440 }
441}