base64/
decode.rs

1use std::{error, fmt, str};
2
3use byteorder::{BigEndian, ByteOrder};
4
5use ::{Config, STANDARD};
6use tables;
7
8// decode logic operates on chunks of 8 input bytes without padding
9const INPUT_CHUNK_LEN: usize = 8;
10const DECODED_CHUNK_LEN: usize = 6;
11// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
12// 2 bytes of any output u64 should not be counted as written to (but must be available in a
13// slice).
14const DECODED_CHUNK_SUFFIX: usize = 2;
15
16// how many u64's of input to handle at a time
17const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
18const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
19// includes the trailing 2 bytes for the final u64 write
20const DECODED_BLOCK_LEN: usize =
21    CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
22
23/// Errors that can occur while decoding.
24#[derive(Clone, Debug, PartialEq, Eq)]
25pub enum DecodeError {
26    /// An invalid byte was found in the input. The offset and offending byte are provided.
27    InvalidByte(usize, u8),
28    /// The length of the input is invalid.
29    InvalidLength,
30    /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
31    /// This is indicative of corrupted or truncated Base64.
32    /// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for
33    /// symbols that are in the alphabet but represent nonsensical encodings.
34    InvalidLastSymbol(usize, u8),
35}
36
37impl fmt::Display for DecodeError {
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        match *self {
40            DecodeError::InvalidByte(index, byte) => {
41                write!(f, "Invalid byte {}, offset {}.", byte, index)
42            }
43            DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
44            DecodeError::InvalidLastSymbol(index, byte) => {
45                write!(f, "Invalid last symbol {}, offset {}.", byte, index)
46            }
47        }
48    }
49}
50
51impl error::Error for DecodeError {
52    fn description(&self) -> &str {
53        match *self {
54            DecodeError::InvalidByte(_, _) => "invalid byte",
55            DecodeError::InvalidLength => "invalid length",
56            DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol",
57        }
58    }
59
60    fn cause(&self) -> Option<&error::Error> {
61        None
62    }
63}
64
65///Decode from string reference as octets.
66///Returns a Result containing a Vec<u8>.
67///Convenience `decode_config(input, base64::STANDARD);`.
68///
69///# Example
70///
71///```rust
72///extern crate base64;
73///
74///fn main() {
75///    let bytes = base64::decode("aGVsbG8gd29ybGQ=").unwrap();
76///    println!("{:?}", bytes);
77///}
78///```
79pub fn decode<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<Vec<u8>, DecodeError> {
80    decode_config(input, STANDARD)
81}
82
83///Decode from string reference as octets.
84///Returns a Result containing a Vec<u8>.
85///
86///# Example
87///
88///```rust
89///extern crate base64;
90///
91///fn main() {
92///    let bytes = base64::decode_config("aGVsbG8gd29ybGR+Cg==", base64::STANDARD).unwrap();
93///    println!("{:?}", bytes);
94///
95///    let bytes_url = base64::decode_config("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE).unwrap();
96///    println!("{:?}", bytes_url);
97///}
98///```
99pub fn decode_config<T: ?Sized + AsRef<[u8]>>(
100    input: &T,
101    config: Config,
102) -> Result<Vec<u8>, DecodeError> {
103    let mut buffer = Vec::<u8>::with_capacity(input.as_ref().len() * 4 / 3);
104
105    decode_config_buf(input, config, &mut buffer).map(|_| buffer)
106}
107
108///Decode from string reference as octets.
109///Writes into the supplied buffer to avoid allocation.
110///Returns a Result containing an empty tuple, aka ().
111///
112///# Example
113///
114///```rust
115///extern crate base64;
116///
117///fn main() {
118///    let mut buffer = Vec::<u8>::new();
119///    base64::decode_config_buf("aGVsbG8gd29ybGR+Cg==", base64::STANDARD, &mut buffer).unwrap();
120///    println!("{:?}", buffer);
121///
122///    buffer.clear();
123///
124///    base64::decode_config_buf("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE, &mut buffer)
125///        .unwrap();
126///    println!("{:?}", buffer);
127///}
128///```
129pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(
130    input: &T,
131    config: Config,
132    buffer: &mut Vec<u8>,
133) -> Result<(), DecodeError> {
134    let input_bytes = input.as_ref();
135
136    let starting_output_len = buffer.len();
137
138    let num_chunks = num_chunks(input_bytes);
139    let decoded_len_estimate = num_chunks
140        .checked_mul(DECODED_CHUNK_LEN)
141        .and_then(|p| p.checked_add(starting_output_len))
142        .expect("Overflow when calculating output buffer length");
143    buffer.resize(decoded_len_estimate, 0);
144
145    let bytes_written;
146    {
147        let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
148        bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?;
149    }
150
151    buffer.truncate(starting_output_len + bytes_written);
152
153    Ok(())
154}
155
156/// Decode the input into the provided output slice.
157///
158/// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end).
159///
160/// If you don't know ahead of time what the decoded length should be, size your buffer with a
161/// conservative estimate for the decoded length of an input: 3 bytes of output for every 4 bytes of
162/// input, rounded up, or in other words `(input_len + 3) / 4 * 3`.
163///
164/// If the slice is not large enough, this will panic.
165pub fn decode_config_slice<T: ?Sized + AsRef<[u8]>>(
166    input: &T,
167    config: Config,
168    output: &mut [u8],
169) -> Result<usize, DecodeError> {
170    let input_bytes = input.as_ref();
171
172    decode_helper(input_bytes, num_chunks(input_bytes), config, output)
173}
174
175/// Return the number of input chunks (including a possibly partial final chunk) in the input
176fn num_chunks(input: &[u8]) -> usize {
177    input
178        .len()
179        .checked_add(INPUT_CHUNK_LEN - 1)
180        .expect("Overflow when calculating number of chunks in input")
181        / INPUT_CHUNK_LEN
182}
183
184/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
185/// Returns the number of bytes written, or an error.
186// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
187// inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
188// but this is fragile and the best setting changes with only minor code modifications.
189#[inline]
190fn decode_helper(
191    input: &[u8],
192    num_chunks: usize,
193    config: Config,
194    output: &mut [u8],
195) -> Result<usize, DecodeError> {
196    let char_set = config.char_set;
197    let decode_table = char_set.decode_table();
198
199    let remainder_len = input.len() % INPUT_CHUNK_LEN;
200
201    // Because the fast decode loop writes in groups of 8 bytes (unrolled to
202    // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
203    // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
204    // soon enough that there will always be 2 more bytes of valid data written after that loop.
205    let trailing_bytes_to_skip = match remainder_len {
206        // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
207        // and the fast decode logic cannot handle padding
208        0 => INPUT_CHUNK_LEN,
209        // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
210        1 | 5 => return Err(DecodeError::InvalidLength),
211        // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
212        // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
213        // previous chunk.
214        2 => INPUT_CHUNK_LEN + 2,
215        // If this is 3 unpadded chars, then it would actually decode to 2 bytes. However, if this
216        // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
217        // with an error, not panic from going past the bounds of the output slice, so we let it
218        // use stage 3 + 4.
219        3 => INPUT_CHUNK_LEN + 3,
220        // This can also decode to one output byte because it may be 2 input chars + 2 padding
221        // chars, which would decode to 1 byte.
222        4 => INPUT_CHUNK_LEN + 4,
223        // Everything else is a legal decode len (given that we don't require padding), and will
224        // decode to at least 2 bytes of output.
225        _ => remainder_len,
226    };
227
228    // rounded up to include partial chunks
229    let mut remaining_chunks = num_chunks;
230
231    let mut input_index = 0;
232    let mut output_index = 0;
233
234    {
235        let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
236
237        // Fast loop, stage 1
238        // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
239        if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
240            while input_index <= max_start_index {
241                let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
242                let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
243
244                decode_chunk(
245                    &input_slice[0..],
246                    input_index,
247                    decode_table,
248                    &mut output_slice[0..],
249                )?;
250                decode_chunk(
251                    &input_slice[8..],
252                    input_index + 8,
253                    decode_table,
254                    &mut output_slice[6..],
255                )?;
256                decode_chunk(
257                    &input_slice[16..],
258                    input_index + 16,
259                    decode_table,
260                    &mut output_slice[12..],
261                )?;
262                decode_chunk(
263                    &input_slice[24..],
264                    input_index + 24,
265                    decode_table,
266                    &mut output_slice[18..],
267                )?;
268
269                input_index += INPUT_BLOCK_LEN;
270                output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
271                remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
272            }
273        }
274
275        // Fast loop, stage 2 (aka still pretty fast loop)
276        // 8 bytes at a time for whatever we didn't do in stage 1.
277        if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
278            while input_index < max_start_index {
279                decode_chunk(
280                    &input[input_index..(input_index + INPUT_CHUNK_LEN)],
281                    input_index,
282                    decode_table,
283                    &mut output
284                        [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
285                )?;
286
287                output_index += DECODED_CHUNK_LEN;
288                input_index += INPUT_CHUNK_LEN;
289                remaining_chunks -= 1;
290            }
291        }
292    }
293
294    // Stage 3
295    // If input length was such that a chunk had to be deferred until after the fast loop
296    // because decoding it would have produced 2 trailing bytes that wouldn't then be
297    // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
298    // trailing bytes.
299    // However, we still need to avoid the last chunk (partial or complete) because it could
300    // have padding, so we always do 1 fewer to avoid the last chunk.
301    for _ in 1..remaining_chunks {
302        decode_chunk_precise(
303            &input[input_index..],
304            input_index,
305            decode_table,
306            &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
307        )?;
308
309        input_index += INPUT_CHUNK_LEN;
310        output_index += DECODED_CHUNK_LEN;
311    }
312
313    // always have one more (possibly partial) block of 8 input
314    debug_assert!(input.len() - input_index > 1 || input.is_empty());
315    debug_assert!(input.len() - input_index <= 8);
316
317    // Stage 4
318    // Finally, decode any leftovers that aren't a complete input block of 8 bytes.
319    // Use a u64 as a stack-resident 8 byte buffer.
320    let mut leftover_bits: u64 = 0;
321    let mut morsels_in_leftover = 0;
322    let mut padding_bytes = 0;
323    let mut first_padding_index: usize = 0;
324    let mut last_symbol = 0_u8;
325    let start_of_leftovers = input_index;
326    for (i, b) in input[start_of_leftovers..].iter().enumerate() {
327        // '=' padding
328        if *b == 0x3D {
329            // There can be bad padding in a few ways:
330            // 1 - Padding with non-padding characters after it
331            // 2 - Padding after zero or one non-padding characters before it
332            //     in the current quad.
333            // 3 - More than two characters of padding. If 3 or 4 padding chars
334            //     are in the same quad, that implies it will be caught by #2.
335            //     If it spreads from one quad to another, it will be caught by
336            //     #2 in the second quad.
337
338            if i % 4 < 2 {
339                // Check for case #2.
340                let bad_padding_index = start_of_leftovers
341                    + if padding_bytes > 0 {
342                        // If we've already seen padding, report the first padding index.
343                        // This is to be consistent with the faster logic above: it will report an
344                        // error on the first padding character (since it doesn't expect to see
345                        // anything but actual encoded data).
346                        first_padding_index
347                    } else {
348                        // haven't seen padding before, just use where we are now
349                        i
350                    };
351                return Err(DecodeError::InvalidByte(bad_padding_index, *b));
352            }
353
354            if padding_bytes == 0 {
355                first_padding_index = i;
356            }
357
358            padding_bytes += 1;
359            continue;
360        }
361
362        // Check for case #1.
363        // To make '=' handling consistent with the main loop, don't allow
364        // non-suffix '=' in trailing chunk either. Report error as first
365        // erroneous padding.
366        if padding_bytes > 0 {
367            return Err(DecodeError::InvalidByte(
368                start_of_leftovers + first_padding_index,
369                0x3D,
370            ));
371        }
372        last_symbol = *b;
373
374        // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
375        // To minimize shifts, pack the leftovers from left to right.
376        let shift = 64 - (morsels_in_leftover + 1) * 6;
377        // tables are all 256 elements, lookup with a u8 index always succeeds
378        let morsel = decode_table[*b as usize];
379        if morsel == tables::INVALID_VALUE {
380            return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
381        }
382
383        leftover_bits |= (morsel as u64) << shift;
384        morsels_in_leftover += 1;
385    }
386
387    let leftover_bits_ready_to_append = match morsels_in_leftover {
388        0 => 0,
389        2 => 8,
390        3 => 16,
391        4 => 24,
392        6 => 32,
393        7 => 40,
394        8 => 48,
395        _ => unreachable!(
396            "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
397        ),
398    };
399
400    // if there are bits set outside the bits we care about, last symbol encodes trailing bits that
401    // will not be included in the output
402    let mask = !0 >> leftover_bits_ready_to_append;
403    if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
404        // last morsel is at `morsels_in_leftover` - 1
405        return Err(DecodeError::InvalidLastSymbol(
406            start_of_leftovers + morsels_in_leftover - 1,
407            last_symbol,
408        ));
409    }
410
411    let mut leftover_bits_appended_to_buf = 0;
412    while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
413        // `as` simply truncates the higher bits, which is what we want here
414        let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
415        output[output_index] = selected_bits;
416        output_index += 1;
417
418        leftover_bits_appended_to_buf += 8;
419    }
420
421    Ok(output_index)
422}
423
424/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
425/// first 6 of those contain meaningful data.
426///
427/// `input` is the bytes to decode, of which the first 8 bytes will be processed.
428/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
429/// accurately)
430/// `decode_table` is the lookup table for the particular base64 alphabet.
431/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
432/// data.
433// yes, really inline (worth 30-50% speedup)
434#[inline(always)]
435fn decode_chunk(
436    input: &[u8],
437    index_at_start_of_input: usize,
438    decode_table: &[u8; 256],
439    output: &mut [u8],
440) -> Result<(), DecodeError> {
441    let mut accum: u64;
442
443    let morsel = decode_table[input[0] as usize];
444    if morsel == tables::INVALID_VALUE {
445        return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
446    }
447    accum = (morsel as u64) << 58;
448
449    let morsel = decode_table[input[1] as usize];
450    if morsel == tables::INVALID_VALUE {
451        return Err(DecodeError::InvalidByte(
452            index_at_start_of_input + 1,
453            input[1],
454        ));
455    }
456    accum |= (morsel as u64) << 52;
457
458    let morsel = decode_table[input[2] as usize];
459    if morsel == tables::INVALID_VALUE {
460        return Err(DecodeError::InvalidByte(
461            index_at_start_of_input + 2,
462            input[2],
463        ));
464    }
465    accum |= (morsel as u64) << 46;
466
467    let morsel = decode_table[input[3] as usize];
468    if morsel == tables::INVALID_VALUE {
469        return Err(DecodeError::InvalidByte(
470            index_at_start_of_input + 3,
471            input[3],
472        ));
473    }
474    accum |= (morsel as u64) << 40;
475
476    let morsel = decode_table[input[4] as usize];
477    if morsel == tables::INVALID_VALUE {
478        return Err(DecodeError::InvalidByte(
479            index_at_start_of_input + 4,
480            input[4],
481        ));
482    }
483    accum |= (morsel as u64) << 34;
484
485    let morsel = decode_table[input[5] as usize];
486    if morsel == tables::INVALID_VALUE {
487        return Err(DecodeError::InvalidByte(
488            index_at_start_of_input + 5,
489            input[5],
490        ));
491    }
492    accum |= (morsel as u64) << 28;
493
494    let morsel = decode_table[input[6] as usize];
495    if morsel == tables::INVALID_VALUE {
496        return Err(DecodeError::InvalidByte(
497            index_at_start_of_input + 6,
498            input[6],
499        ));
500    }
501    accum |= (morsel as u64) << 22;
502
503    let morsel = decode_table[input[7] as usize];
504    if morsel == tables::INVALID_VALUE {
505        return Err(DecodeError::InvalidByte(
506            index_at_start_of_input + 7,
507            input[7],
508        ));
509    }
510    accum |= (morsel as u64) << 16;
511
512    BigEndian::write_u64(output, accum);
513
514    Ok(())
515}
516
517/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
518/// trailing garbage bytes.
519#[inline]
520fn decode_chunk_precise(
521    input: &[u8],
522    index_at_start_of_input: usize,
523    decode_table: &[u8; 256],
524    output: &mut [u8],
525) -> Result<(), DecodeError> {
526    let mut tmp_buf = [0_u8; 8];
527
528    decode_chunk(
529        input,
530        index_at_start_of_input,
531        decode_table,
532        &mut tmp_buf[..],
533    )?;
534
535    output[0..6].copy_from_slice(&tmp_buf[0..6]);
536
537    Ok(())
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    use rand::{FromEntropy, Rng};
545    use rand::distributions::{Distribution, Uniform};
546
547    use encode::encode_config_buf;
548    use encode::encode_config_slice;
549    use tests::{assert_encode_sanity, random_config};
550
551    #[test]
552    fn decode_chunk_precise_writes_only_6_bytes() {
553        let input = b"Zm9vYmFy"; // "foobar"
554        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
555        decode_chunk_precise(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
556        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
557    }
558
559    #[test]
560    fn decode_chunk_writes_8_bytes() {
561        let input = b"Zm9vYmFy"; // "foobar"
562        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
563        decode_chunk(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
564        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
565    }
566
567    #[test]
568    fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() {
569        let mut orig_data = Vec::new();
570        let mut encoded_data = String::new();
571        let mut decoded_with_prefix = Vec::new();
572        let mut decoded_without_prefix = Vec::new();
573        let mut prefix = Vec::new();
574
575        let prefix_len_range = Uniform::new(0, 1000);
576        let input_len_range = Uniform::new(0, 1000);
577
578        let mut rng = rand::rngs::SmallRng::from_entropy();
579
580        for _ in 0..10_000 {
581            orig_data.clear();
582            encoded_data.clear();
583            decoded_with_prefix.clear();
584            decoded_without_prefix.clear();
585            prefix.clear();
586
587            let input_len = input_len_range.sample(&mut rng);
588
589            for _ in 0..input_len {
590                orig_data.push(rng.gen());
591            }
592
593            let config = random_config(&mut rng);
594            encode_config_buf(&orig_data, config, &mut encoded_data);
595            assert_encode_sanity(&encoded_data, config, input_len);
596
597            let prefix_len = prefix_len_range.sample(&mut rng);
598
599            // fill the buf with a prefix
600            for _ in 0..prefix_len {
601                prefix.push(rng.gen());
602            }
603
604            decoded_with_prefix.resize(prefix_len, 0);
605            decoded_with_prefix.copy_from_slice(&prefix);
606
607            // decode into the non-empty buf
608            decode_config_buf(&encoded_data, config, &mut decoded_with_prefix).unwrap();
609            // also decode into the empty buf
610            decode_config_buf(&encoded_data, config, &mut decoded_without_prefix).unwrap();
611
612            assert_eq!(
613                prefix_len + decoded_without_prefix.len(),
614                decoded_with_prefix.len()
615            );
616            assert_eq!(orig_data, decoded_without_prefix);
617
618            // append plain decode onto prefix
619            prefix.append(&mut decoded_without_prefix);
620
621            assert_eq!(prefix, decoded_with_prefix);
622        }
623    }
624
625    #[test]
626    fn decode_into_slice_doesnt_clobber_existing_prefix_or_suffix() {
627        let mut orig_data = Vec::new();
628        let mut encoded_data = String::new();
629        let mut decode_buf = Vec::new();
630        let mut decode_buf_copy: Vec<u8> = Vec::new();
631
632        let input_len_range = Uniform::new(0, 1000);
633
634        let mut rng = rand::rngs::SmallRng::from_entropy();
635
636        for _ in 0..10_000 {
637            orig_data.clear();
638            encoded_data.clear();
639            decode_buf.clear();
640            decode_buf_copy.clear();
641
642            let input_len = input_len_range.sample(&mut rng);
643
644            for _ in 0..input_len {
645                orig_data.push(rng.gen());
646            }
647
648            let config = random_config(&mut rng);
649            encode_config_buf(&orig_data, config, &mut encoded_data);
650            assert_encode_sanity(&encoded_data, config, input_len);
651
652            // fill the buffer with random garbage, long enough to have some room before and after
653            for _ in 0..5000 {
654                decode_buf.push(rng.gen());
655            }
656
657            // keep a copy for later comparison
658            decode_buf_copy.extend(decode_buf.iter());
659
660            let offset = 1000;
661
662            // decode into the non-empty buf
663            let decode_bytes_written =
664                decode_config_slice(&encoded_data, config, &mut decode_buf[offset..]).unwrap();
665
666            assert_eq!(orig_data.len(), decode_bytes_written);
667            assert_eq!(
668                orig_data,
669                &decode_buf[offset..(offset + decode_bytes_written)]
670            );
671            assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]);
672            assert_eq!(
673                &decode_buf_copy[offset + decode_bytes_written..],
674                &decode_buf[offset + decode_bytes_written..]
675            );
676        }
677    }
678
679    #[test]
680    fn decode_into_slice_fits_in_precisely_sized_slice() {
681        let mut orig_data = Vec::new();
682        let mut encoded_data = String::new();
683        let mut decode_buf = Vec::new();
684
685        let input_len_range = Uniform::new(0, 1000);
686
687        let mut rng = rand::rngs::SmallRng::from_entropy();
688
689        for _ in 0..10_000 {
690            orig_data.clear();
691            encoded_data.clear();
692            decode_buf.clear();
693
694            let input_len = input_len_range.sample(&mut rng);
695
696            for _ in 0..input_len {
697                orig_data.push(rng.gen());
698            }
699
700            let config = random_config(&mut rng);
701            encode_config_buf(&orig_data, config, &mut encoded_data);
702            assert_encode_sanity(&encoded_data, config, input_len);
703
704            decode_buf.resize(input_len, 0);
705
706            // decode into the non-empty buf
707            let decode_bytes_written =
708                decode_config_slice(&encoded_data, config, &mut decode_buf[..]).unwrap();
709
710            assert_eq!(orig_data.len(), decode_bytes_written);
711            assert_eq!(orig_data, decode_buf);
712        }
713    }
714
715    #[test]
716    fn detect_invalid_last_symbol_two_bytes() {
717        let decode =
718            |input, forgiving| decode_config(input, STANDARD.decode_allow_trailing_bits(forgiving));
719
720        // example from https://github.com/alicemaz/rust-base64/issues/75
721        assert!(decode("iYU=", false).is_ok());
722        // trailing 01
723        assert_eq!(
724            Err(DecodeError::InvalidLastSymbol(2, b'V')),
725            decode("iYV=", false)
726        );
727        assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
728        // trailing 10
729        assert_eq!(
730            Err(DecodeError::InvalidLastSymbol(2, b'W')),
731            decode("iYW=", false)
732        );
733        assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
734        // trailing 11
735        assert_eq!(
736            Err(DecodeError::InvalidLastSymbol(2, b'X')),
737            decode("iYX=", false)
738        );
739        assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
740
741        // also works when there are 2 quads in the last block
742        assert_eq!(
743            Err(DecodeError::InvalidLastSymbol(6, b'X')),
744            decode("AAAAiYX=", false)
745        );
746        assert_eq!(Ok(vec![0, 0, 0, 137, 133]), decode("AAAAiYX=", true));
747    }
748
749    #[test]
750    fn detect_invalid_last_symbol_one_byte() {
751        // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol
752
753        assert!(decode("/w==").is_ok());
754        // trailing 01
755        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x=="));
756        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z=="));
757        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0=="));
758        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9=="));
759        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+=="));
760        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//=="));
761
762        // also works when there are 2 quads in the last block
763        assert_eq!(
764            Err(DecodeError::InvalidLastSymbol(5, b'x')),
765            decode("AAAA/x==")
766        );
767    }
768
769    #[test]
770    fn detect_invalid_last_symbol_every_possible_three_symbols() {
771        let mut base64_to_bytes = ::std::collections::HashMap::new();
772
773        let mut bytes = [0_u8; 2];
774        for b1 in 0_u16..256 {
775            bytes[0] = b1 as u8;
776            for b2 in 0_u16..256 {
777                bytes[1] = b2 as u8;
778                let mut b64 = vec![0_u8; 4];
779                assert_eq!(4, encode_config_slice(&bytes, STANDARD, &mut b64[..]));
780                let mut v = ::std::vec::Vec::with_capacity(2);
781                v.extend_from_slice(&bytes[..]);
782
783                assert!(base64_to_bytes.insert(b64, v).is_none());
784            }
785        }
786
787        // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol
788
789        let mut symbols = [0_u8; 4];
790        for &s1 in STANDARD.char_set.encode_table().iter() {
791            symbols[0] = s1;
792            for &s2 in STANDARD.char_set.encode_table().iter() {
793                symbols[1] = s2;
794                for &s3 in STANDARD.char_set.encode_table().iter() {
795                    symbols[2] = s3;
796                    symbols[3] = b'=';
797
798                    match base64_to_bytes.get(&symbols[..]) {
799                        Some(bytes) => {
800                            assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
801                        }
802                        None => assert_eq!(
803                            Err(DecodeError::InvalidLastSymbol(2, s3)),
804                            decode_config(&symbols[..], STANDARD)
805                        ),
806                    }
807                }
808            }
809        }
810    }
811
812    #[test]
813    fn detect_invalid_last_symbol_every_possible_two_symbols() {
814        let mut base64_to_bytes = ::std::collections::HashMap::new();
815
816        for b in 0_u16..256 {
817            let mut b64 = vec![0_u8; 4];
818            assert_eq!(4, encode_config_slice(&[b as u8], STANDARD, &mut b64[..]));
819            let mut v = ::std::vec::Vec::with_capacity(1);
820            v.push(b as u8);
821
822            assert!(base64_to_bytes.insert(b64, v).is_none());
823        }
824
825        // every possible combination of symbols must either decode to 1 byte or get InvalidLastSymbol
826
827        let mut symbols = [0_u8; 4];
828        for &s1 in STANDARD.char_set.encode_table().iter() {
829            symbols[0] = s1;
830            for &s2 in STANDARD.char_set.encode_table().iter() {
831                symbols[1] = s2;
832                symbols[2] = b'=';
833                symbols[3] = b'=';
834
835                match base64_to_bytes.get(&symbols[..]) {
836                    Some(bytes) => {
837                        assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
838                    }
839                    None => assert_eq!(
840                        Err(DecodeError::InvalidLastSymbol(1, s2)),
841                        decode_config(&symbols[..], STANDARD)
842                    ),
843                }
844            }
845        }
846    }
847}