1use std::{error, fmt, str};
2
3use byteorder::{BigEndian, ByteOrder};
4
5use ::{Config, STANDARD};
6use tables;
7
8const INPUT_CHUNK_LEN: usize = 8;
10const DECODED_CHUNK_LEN: usize = 6;
11const DECODED_CHUNK_SUFFIX: usize = 2;
15
16const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
18const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
19const DECODED_BLOCK_LEN: usize =
21 CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
22
23#[derive(Clone, Debug, PartialEq, Eq)]
25pub enum DecodeError {
26 InvalidByte(usize, u8),
28 InvalidLength,
30 InvalidLastSymbol(usize, u8),
35}
36
37impl fmt::Display for DecodeError {
38 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39 match *self {
40 DecodeError::InvalidByte(index, byte) => {
41 write!(f, "Invalid byte {}, offset {}.", byte, index)
42 }
43 DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
44 DecodeError::InvalidLastSymbol(index, byte) => {
45 write!(f, "Invalid last symbol {}, offset {}.", byte, index)
46 }
47 }
48 }
49}
50
51impl error::Error for DecodeError {
52 fn description(&self) -> &str {
53 match *self {
54 DecodeError::InvalidByte(_, _) => "invalid byte",
55 DecodeError::InvalidLength => "invalid length",
56 DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol",
57 }
58 }
59
60 fn cause(&self) -> Option<&error::Error> {
61 None
62 }
63}
64
65pub fn decode<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<Vec<u8>, DecodeError> {
80 decode_config(input, STANDARD)
81}
82
83pub fn decode_config<T: ?Sized + AsRef<[u8]>>(
100 input: &T,
101 config: Config,
102) -> Result<Vec<u8>, DecodeError> {
103 let mut buffer = Vec::<u8>::with_capacity(input.as_ref().len() * 4 / 3);
104
105 decode_config_buf(input, config, &mut buffer).map(|_| buffer)
106}
107
108pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(
130 input: &T,
131 config: Config,
132 buffer: &mut Vec<u8>,
133) -> Result<(), DecodeError> {
134 let input_bytes = input.as_ref();
135
136 let starting_output_len = buffer.len();
137
138 let num_chunks = num_chunks(input_bytes);
139 let decoded_len_estimate = num_chunks
140 .checked_mul(DECODED_CHUNK_LEN)
141 .and_then(|p| p.checked_add(starting_output_len))
142 .expect("Overflow when calculating output buffer length");
143 buffer.resize(decoded_len_estimate, 0);
144
145 let bytes_written;
146 {
147 let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
148 bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?;
149 }
150
151 buffer.truncate(starting_output_len + bytes_written);
152
153 Ok(())
154}
155
156pub fn decode_config_slice<T: ?Sized + AsRef<[u8]>>(
166 input: &T,
167 config: Config,
168 output: &mut [u8],
169) -> Result<usize, DecodeError> {
170 let input_bytes = input.as_ref();
171
172 decode_helper(input_bytes, num_chunks(input_bytes), config, output)
173}
174
175fn num_chunks(input: &[u8]) -> usize {
177 input
178 .len()
179 .checked_add(INPUT_CHUNK_LEN - 1)
180 .expect("Overflow when calculating number of chunks in input")
181 / INPUT_CHUNK_LEN
182}
183
184#[inline]
190fn decode_helper(
191 input: &[u8],
192 num_chunks: usize,
193 config: Config,
194 output: &mut [u8],
195) -> Result<usize, DecodeError> {
196 let char_set = config.char_set;
197 let decode_table = char_set.decode_table();
198
199 let remainder_len = input.len() % INPUT_CHUNK_LEN;
200
201 let trailing_bytes_to_skip = match remainder_len {
206 0 => INPUT_CHUNK_LEN,
209 1 | 5 => return Err(DecodeError::InvalidLength),
211 2 => INPUT_CHUNK_LEN + 2,
215 3 => INPUT_CHUNK_LEN + 3,
220 4 => INPUT_CHUNK_LEN + 4,
223 _ => remainder_len,
226 };
227
228 let mut remaining_chunks = num_chunks;
230
231 let mut input_index = 0;
232 let mut output_index = 0;
233
234 {
235 let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
236
237 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
240 while input_index <= max_start_index {
241 let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
242 let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
243
244 decode_chunk(
245 &input_slice[0..],
246 input_index,
247 decode_table,
248 &mut output_slice[0..],
249 )?;
250 decode_chunk(
251 &input_slice[8..],
252 input_index + 8,
253 decode_table,
254 &mut output_slice[6..],
255 )?;
256 decode_chunk(
257 &input_slice[16..],
258 input_index + 16,
259 decode_table,
260 &mut output_slice[12..],
261 )?;
262 decode_chunk(
263 &input_slice[24..],
264 input_index + 24,
265 decode_table,
266 &mut output_slice[18..],
267 )?;
268
269 input_index += INPUT_BLOCK_LEN;
270 output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
271 remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
272 }
273 }
274
275 if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
278 while input_index < max_start_index {
279 decode_chunk(
280 &input[input_index..(input_index + INPUT_CHUNK_LEN)],
281 input_index,
282 decode_table,
283 &mut output
284 [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
285 )?;
286
287 output_index += DECODED_CHUNK_LEN;
288 input_index += INPUT_CHUNK_LEN;
289 remaining_chunks -= 1;
290 }
291 }
292 }
293
294 for _ in 1..remaining_chunks {
302 decode_chunk_precise(
303 &input[input_index..],
304 input_index,
305 decode_table,
306 &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
307 )?;
308
309 input_index += INPUT_CHUNK_LEN;
310 output_index += DECODED_CHUNK_LEN;
311 }
312
313 debug_assert!(input.len() - input_index > 1 || input.is_empty());
315 debug_assert!(input.len() - input_index <= 8);
316
317 let mut leftover_bits: u64 = 0;
321 let mut morsels_in_leftover = 0;
322 let mut padding_bytes = 0;
323 let mut first_padding_index: usize = 0;
324 let mut last_symbol = 0_u8;
325 let start_of_leftovers = input_index;
326 for (i, b) in input[start_of_leftovers..].iter().enumerate() {
327 if *b == 0x3D {
329 if i % 4 < 2 {
339 let bad_padding_index = start_of_leftovers
341 + if padding_bytes > 0 {
342 first_padding_index
347 } else {
348 i
350 };
351 return Err(DecodeError::InvalidByte(bad_padding_index, *b));
352 }
353
354 if padding_bytes == 0 {
355 first_padding_index = i;
356 }
357
358 padding_bytes += 1;
359 continue;
360 }
361
362 if padding_bytes > 0 {
367 return Err(DecodeError::InvalidByte(
368 start_of_leftovers + first_padding_index,
369 0x3D,
370 ));
371 }
372 last_symbol = *b;
373
374 let shift = 64 - (morsels_in_leftover + 1) * 6;
377 let morsel = decode_table[*b as usize];
379 if morsel == tables::INVALID_VALUE {
380 return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
381 }
382
383 leftover_bits |= (morsel as u64) << shift;
384 morsels_in_leftover += 1;
385 }
386
387 let leftover_bits_ready_to_append = match morsels_in_leftover {
388 0 => 0,
389 2 => 8,
390 3 => 16,
391 4 => 24,
392 6 => 32,
393 7 => 40,
394 8 => 48,
395 _ => unreachable!(
396 "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
397 ),
398 };
399
400 let mask = !0 >> leftover_bits_ready_to_append;
403 if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
404 return Err(DecodeError::InvalidLastSymbol(
406 start_of_leftovers + morsels_in_leftover - 1,
407 last_symbol,
408 ));
409 }
410
411 let mut leftover_bits_appended_to_buf = 0;
412 while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
413 let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
415 output[output_index] = selected_bits;
416 output_index += 1;
417
418 leftover_bits_appended_to_buf += 8;
419 }
420
421 Ok(output_index)
422}
423
424#[inline(always)]
435fn decode_chunk(
436 input: &[u8],
437 index_at_start_of_input: usize,
438 decode_table: &[u8; 256],
439 output: &mut [u8],
440) -> Result<(), DecodeError> {
441 let mut accum: u64;
442
443 let morsel = decode_table[input[0] as usize];
444 if morsel == tables::INVALID_VALUE {
445 return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
446 }
447 accum = (morsel as u64) << 58;
448
449 let morsel = decode_table[input[1] as usize];
450 if morsel == tables::INVALID_VALUE {
451 return Err(DecodeError::InvalidByte(
452 index_at_start_of_input + 1,
453 input[1],
454 ));
455 }
456 accum |= (morsel as u64) << 52;
457
458 let morsel = decode_table[input[2] as usize];
459 if morsel == tables::INVALID_VALUE {
460 return Err(DecodeError::InvalidByte(
461 index_at_start_of_input + 2,
462 input[2],
463 ));
464 }
465 accum |= (morsel as u64) << 46;
466
467 let morsel = decode_table[input[3] as usize];
468 if morsel == tables::INVALID_VALUE {
469 return Err(DecodeError::InvalidByte(
470 index_at_start_of_input + 3,
471 input[3],
472 ));
473 }
474 accum |= (morsel as u64) << 40;
475
476 let morsel = decode_table[input[4] as usize];
477 if morsel == tables::INVALID_VALUE {
478 return Err(DecodeError::InvalidByte(
479 index_at_start_of_input + 4,
480 input[4],
481 ));
482 }
483 accum |= (morsel as u64) << 34;
484
485 let morsel = decode_table[input[5] as usize];
486 if morsel == tables::INVALID_VALUE {
487 return Err(DecodeError::InvalidByte(
488 index_at_start_of_input + 5,
489 input[5],
490 ));
491 }
492 accum |= (morsel as u64) << 28;
493
494 let morsel = decode_table[input[6] as usize];
495 if morsel == tables::INVALID_VALUE {
496 return Err(DecodeError::InvalidByte(
497 index_at_start_of_input + 6,
498 input[6],
499 ));
500 }
501 accum |= (morsel as u64) << 22;
502
503 let morsel = decode_table[input[7] as usize];
504 if morsel == tables::INVALID_VALUE {
505 return Err(DecodeError::InvalidByte(
506 index_at_start_of_input + 7,
507 input[7],
508 ));
509 }
510 accum |= (morsel as u64) << 16;
511
512 BigEndian::write_u64(output, accum);
513
514 Ok(())
515}
516
517#[inline]
520fn decode_chunk_precise(
521 input: &[u8],
522 index_at_start_of_input: usize,
523 decode_table: &[u8; 256],
524 output: &mut [u8],
525) -> Result<(), DecodeError> {
526 let mut tmp_buf = [0_u8; 8];
527
528 decode_chunk(
529 input,
530 index_at_start_of_input,
531 decode_table,
532 &mut tmp_buf[..],
533 )?;
534
535 output[0..6].copy_from_slice(&tmp_buf[0..6]);
536
537 Ok(())
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 use rand::{FromEntropy, Rng};
545 use rand::distributions::{Distribution, Uniform};
546
547 use encode::encode_config_buf;
548 use encode::encode_config_slice;
549 use tests::{assert_encode_sanity, random_config};
550
551 #[test]
552 fn decode_chunk_precise_writes_only_6_bytes() {
553 let input = b"Zm9vYmFy"; let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
555 decode_chunk_precise(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
556 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
557 }
558
559 #[test]
560 fn decode_chunk_writes_8_bytes() {
561 let input = b"Zm9vYmFy"; let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
563 decode_chunk(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
564 assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
565 }
566
567 #[test]
568 fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() {
569 let mut orig_data = Vec::new();
570 let mut encoded_data = String::new();
571 let mut decoded_with_prefix = Vec::new();
572 let mut decoded_without_prefix = Vec::new();
573 let mut prefix = Vec::new();
574
575 let prefix_len_range = Uniform::new(0, 1000);
576 let input_len_range = Uniform::new(0, 1000);
577
578 let mut rng = rand::rngs::SmallRng::from_entropy();
579
580 for _ in 0..10_000 {
581 orig_data.clear();
582 encoded_data.clear();
583 decoded_with_prefix.clear();
584 decoded_without_prefix.clear();
585 prefix.clear();
586
587 let input_len = input_len_range.sample(&mut rng);
588
589 for _ in 0..input_len {
590 orig_data.push(rng.gen());
591 }
592
593 let config = random_config(&mut rng);
594 encode_config_buf(&orig_data, config, &mut encoded_data);
595 assert_encode_sanity(&encoded_data, config, input_len);
596
597 let prefix_len = prefix_len_range.sample(&mut rng);
598
599 for _ in 0..prefix_len {
601 prefix.push(rng.gen());
602 }
603
604 decoded_with_prefix.resize(prefix_len, 0);
605 decoded_with_prefix.copy_from_slice(&prefix);
606
607 decode_config_buf(&encoded_data, config, &mut decoded_with_prefix).unwrap();
609 decode_config_buf(&encoded_data, config, &mut decoded_without_prefix).unwrap();
611
612 assert_eq!(
613 prefix_len + decoded_without_prefix.len(),
614 decoded_with_prefix.len()
615 );
616 assert_eq!(orig_data, decoded_without_prefix);
617
618 prefix.append(&mut decoded_without_prefix);
620
621 assert_eq!(prefix, decoded_with_prefix);
622 }
623 }
624
625 #[test]
626 fn decode_into_slice_doesnt_clobber_existing_prefix_or_suffix() {
627 let mut orig_data = Vec::new();
628 let mut encoded_data = String::new();
629 let mut decode_buf = Vec::new();
630 let mut decode_buf_copy: Vec<u8> = Vec::new();
631
632 let input_len_range = Uniform::new(0, 1000);
633
634 let mut rng = rand::rngs::SmallRng::from_entropy();
635
636 for _ in 0..10_000 {
637 orig_data.clear();
638 encoded_data.clear();
639 decode_buf.clear();
640 decode_buf_copy.clear();
641
642 let input_len = input_len_range.sample(&mut rng);
643
644 for _ in 0..input_len {
645 orig_data.push(rng.gen());
646 }
647
648 let config = random_config(&mut rng);
649 encode_config_buf(&orig_data, config, &mut encoded_data);
650 assert_encode_sanity(&encoded_data, config, input_len);
651
652 for _ in 0..5000 {
654 decode_buf.push(rng.gen());
655 }
656
657 decode_buf_copy.extend(decode_buf.iter());
659
660 let offset = 1000;
661
662 let decode_bytes_written =
664 decode_config_slice(&encoded_data, config, &mut decode_buf[offset..]).unwrap();
665
666 assert_eq!(orig_data.len(), decode_bytes_written);
667 assert_eq!(
668 orig_data,
669 &decode_buf[offset..(offset + decode_bytes_written)]
670 );
671 assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]);
672 assert_eq!(
673 &decode_buf_copy[offset + decode_bytes_written..],
674 &decode_buf[offset + decode_bytes_written..]
675 );
676 }
677 }
678
679 #[test]
680 fn decode_into_slice_fits_in_precisely_sized_slice() {
681 let mut orig_data = Vec::new();
682 let mut encoded_data = String::new();
683 let mut decode_buf = Vec::new();
684
685 let input_len_range = Uniform::new(0, 1000);
686
687 let mut rng = rand::rngs::SmallRng::from_entropy();
688
689 for _ in 0..10_000 {
690 orig_data.clear();
691 encoded_data.clear();
692 decode_buf.clear();
693
694 let input_len = input_len_range.sample(&mut rng);
695
696 for _ in 0..input_len {
697 orig_data.push(rng.gen());
698 }
699
700 let config = random_config(&mut rng);
701 encode_config_buf(&orig_data, config, &mut encoded_data);
702 assert_encode_sanity(&encoded_data, config, input_len);
703
704 decode_buf.resize(input_len, 0);
705
706 let decode_bytes_written =
708 decode_config_slice(&encoded_data, config, &mut decode_buf[..]).unwrap();
709
710 assert_eq!(orig_data.len(), decode_bytes_written);
711 assert_eq!(orig_data, decode_buf);
712 }
713 }
714
715 #[test]
716 fn detect_invalid_last_symbol_two_bytes() {
717 let decode =
718 |input, forgiving| decode_config(input, STANDARD.decode_allow_trailing_bits(forgiving));
719
720 assert!(decode("iYU=", false).is_ok());
722 assert_eq!(
724 Err(DecodeError::InvalidLastSymbol(2, b'V')),
725 decode("iYV=", false)
726 );
727 assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
728 assert_eq!(
730 Err(DecodeError::InvalidLastSymbol(2, b'W')),
731 decode("iYW=", false)
732 );
733 assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
734 assert_eq!(
736 Err(DecodeError::InvalidLastSymbol(2, b'X')),
737 decode("iYX=", false)
738 );
739 assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
740
741 assert_eq!(
743 Err(DecodeError::InvalidLastSymbol(6, b'X')),
744 decode("AAAAiYX=", false)
745 );
746 assert_eq!(Ok(vec![0, 0, 0, 137, 133]), decode("AAAAiYX=", true));
747 }
748
749 #[test]
750 fn detect_invalid_last_symbol_one_byte() {
751 assert!(decode("/w==").is_ok());
754 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x=="));
756 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z=="));
757 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0=="));
758 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9=="));
759 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+=="));
760 assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//=="));
761
762 assert_eq!(
764 Err(DecodeError::InvalidLastSymbol(5, b'x')),
765 decode("AAAA/x==")
766 );
767 }
768
769 #[test]
770 fn detect_invalid_last_symbol_every_possible_three_symbols() {
771 let mut base64_to_bytes = ::std::collections::HashMap::new();
772
773 let mut bytes = [0_u8; 2];
774 for b1 in 0_u16..256 {
775 bytes[0] = b1 as u8;
776 for b2 in 0_u16..256 {
777 bytes[1] = b2 as u8;
778 let mut b64 = vec![0_u8; 4];
779 assert_eq!(4, encode_config_slice(&bytes, STANDARD, &mut b64[..]));
780 let mut v = ::std::vec::Vec::with_capacity(2);
781 v.extend_from_slice(&bytes[..]);
782
783 assert!(base64_to_bytes.insert(b64, v).is_none());
784 }
785 }
786
787 let mut symbols = [0_u8; 4];
790 for &s1 in STANDARD.char_set.encode_table().iter() {
791 symbols[0] = s1;
792 for &s2 in STANDARD.char_set.encode_table().iter() {
793 symbols[1] = s2;
794 for &s3 in STANDARD.char_set.encode_table().iter() {
795 symbols[2] = s3;
796 symbols[3] = b'=';
797
798 match base64_to_bytes.get(&symbols[..]) {
799 Some(bytes) => {
800 assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
801 }
802 None => assert_eq!(
803 Err(DecodeError::InvalidLastSymbol(2, s3)),
804 decode_config(&symbols[..], STANDARD)
805 ),
806 }
807 }
808 }
809 }
810 }
811
812 #[test]
813 fn detect_invalid_last_symbol_every_possible_two_symbols() {
814 let mut base64_to_bytes = ::std::collections::HashMap::new();
815
816 for b in 0_u16..256 {
817 let mut b64 = vec![0_u8; 4];
818 assert_eq!(4, encode_config_slice(&[b as u8], STANDARD, &mut b64[..]));
819 let mut v = ::std::vec::Vec::with_capacity(1);
820 v.push(b as u8);
821
822 assert!(base64_to_bytes.insert(b64, v).is_none());
823 }
824
825 let mut symbols = [0_u8; 4];
828 for &s1 in STANDARD.char_set.encode_table().iter() {
829 symbols[0] = s1;
830 for &s2 in STANDARD.char_set.encode_table().iter() {
831 symbols[1] = s2;
832 symbols[2] = b'=';
833 symbols[3] = b'=';
834
835 match base64_to_bytes.get(&symbols[..]) {
836 Some(bytes) => {
837 assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
838 }
839 None => assert_eq!(
840 Err(DecodeError::InvalidLastSymbol(1, s2)),
841 decode_config(&symbols[..], STANDARD)
842 ),
843 }
844 }
845 }
846 }
847}