e_utils/algorithm/base64/
decode.rs

1use super::{tables, PAD_BYTE};
2use super::{Config, STANDARD};
3use std::fmt;
4
5/// decode
6pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError> {
7  decode_config(input, STANDARD)
8}
9/// decode_config
10pub fn decode_config<T: AsRef<[u8]>>(input: T, config: Config) -> Result<Vec<u8>, DecodeError> {
11  let mut buffer = Vec::<u8>::with_capacity(input.as_ref().len() * 4 / 3);
12  decode_config_buf(input, config, &mut buffer).map(|_| buffer)
13}
14
15/// DecodeError
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub enum DecodeError {
18  /// An invalid byte was found in the input. The offset and offending byte are provided.
19  InvalidByte(usize, u8),
20  /// The length of the input is invalid.
21  /// A typical cause of this is stray trailing whitespace or other separator bytes.
22  /// In the case where excess trailing bytes have produced an invalid length *and* the last byte
23  /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte`
24  /// will be emitted instead of `InvalidLength` to make the issue easier to debug.
25  InvalidLength,
26  /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
27  /// This is indicative of corrupted or truncated Base64.
28  /// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for
29  /// symbols that are in the alphabet but represent nonsensical encodings.
30  InvalidLastSymbol(usize, u8),
31}
32
33impl fmt::Display for DecodeError {
34  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35    match *self {
36      DecodeError::InvalidByte(index, byte) => {
37        write!(f, "Invalid byte {}, offset {}.", byte, index)
38      }
39      DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
40      DecodeError::InvalidLastSymbol(index, byte) => {
41        write!(f, "Invalid last symbol {}, offset {}.", byte, index)
42      }
43    }
44  }
45}
46/// decode_config_buf
47pub fn decode_config_buf<T: AsRef<[u8]>>(
48  input: T,
49  config: Config,
50  buffer: &mut Vec<u8>,
51) -> Result<(), DecodeError> {
52  let input_bytes = input.as_ref();
53
54  let starting_output_len = buffer.len();
55
56  let num_chunks = num_chunks(input_bytes);
57  let decoded_len_estimate = num_chunks
58    .checked_mul(DECODED_CHUNK_LEN)
59    .and_then(|p| p.checked_add(starting_output_len))
60    .expect("Overflow when calculating output buffer length");
61  buffer.resize(decoded_len_estimate, 0);
62
63  let bytes_written;
64  {
65    let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
66    bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?;
67  }
68
69  buffer.truncate(starting_output_len + bytes_written);
70
71  Ok(())
72}
73// decode logic operates on chunks of 8 input bytes without padding
74const INPUT_CHUNK_LEN: usize = 8;
75const DECODED_CHUNK_LEN: usize = 6;
76/// Return the number of input chunks (including a possibly partial final chunk) in the input
77fn num_chunks(input: &[u8]) -> usize {
78  input
79    .len()
80    .checked_add(INPUT_CHUNK_LEN - 1)
81    .expect("Overflow when calculating number of chunks in input")
82    / INPUT_CHUNK_LEN
83}
84// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
85// 2 bytes of any output u64 should not be counted as written to (but must be available in a
86// slice).
87const DECODED_CHUNK_SUFFIX: usize = 2;
88// how many u64's of input to handle at a time
89const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
90const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
91// includes the trailing 2 bytes for the final u64 write
92const DECODED_BLOCK_LEN: usize =
93  CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
94
95/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
96/// Returns the number of bytes written, or an error.
97// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
98// inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
99// but this is fragile and the best setting changes with only minor code modifications.
100#[inline]
101fn decode_helper(
102  input: &[u8],
103  num_chunks: usize,
104  config: Config,
105  output: &mut [u8],
106) -> Result<usize, DecodeError> {
107  let char_set = config.char_set;
108  let decode_table = char_set.decode_table();
109
110  let remainder_len = input.len() % INPUT_CHUNK_LEN;
111
112  // Because the fast decode loop writes in groups of 8 bytes (unrolled to
113  // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
114  // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
115  // soon enough that there will always be 2 more bytes of valid data written after that loop.
116  let trailing_bytes_to_skip = match remainder_len {
117    // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
118    // and the fast decode logic cannot handle padding
119    0 => INPUT_CHUNK_LEN,
120    // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
121    1 | 5 => {
122      // trailing whitespace is so common that it's worth it to check the last byte to
123      // possibly return a better error message
124      if let Some(b) = input.last() {
125        if *b != PAD_BYTE && decode_table[*b as usize] == tables::INVALID_VALUE {
126          return Err(DecodeError::InvalidByte(input.len() - 1, *b));
127        }
128      }
129
130      return Err(DecodeError::InvalidLength);
131    }
132    // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
133    // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
134    // previous chunk.
135    2 => INPUT_CHUNK_LEN + 2,
136    // If this is 3 unpadded chars, then it would actually decode to 2 bytes. However, if this
137    // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
138    // with an error, not panic from going past the bounds of the output slice, so we let it
139    // use stage 3 + 4.
140    3 => INPUT_CHUNK_LEN + 3,
141    // This can also decode to one output byte because it may be 2 input chars + 2 padding
142    // chars, which would decode to 1 byte.
143    4 => INPUT_CHUNK_LEN + 4,
144    // Everything else is a legal decode len (given that we don't require padding), and will
145    // decode to at least 2 bytes of output.
146    _ => remainder_len,
147  };
148
149  // rounded up to include partial chunks
150  let mut remaining_chunks = num_chunks;
151
152  let mut input_index = 0;
153  let mut output_index = 0;
154
155  {
156    let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
157
158    // Fast loop, stage 1
159    // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
160    if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
161      while input_index <= max_start_index {
162        let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
163        let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
164
165        decode_chunk(
166          &input_slice[0..],
167          input_index,
168          decode_table,
169          &mut output_slice[0..],
170        )?;
171        decode_chunk(
172          &input_slice[8..],
173          input_index + 8,
174          decode_table,
175          &mut output_slice[6..],
176        )?;
177        decode_chunk(
178          &input_slice[16..],
179          input_index + 16,
180          decode_table,
181          &mut output_slice[12..],
182        )?;
183        decode_chunk(
184          &input_slice[24..],
185          input_index + 24,
186          decode_table,
187          &mut output_slice[18..],
188        )?;
189
190        input_index += INPUT_BLOCK_LEN;
191        output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
192        remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
193      }
194    }
195
196    // Fast loop, stage 2 (aka still pretty fast loop)
197    // 8 bytes at a time for whatever we didn't do in stage 1.
198    if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
199      while input_index < max_start_index {
200        decode_chunk(
201          &input[input_index..(input_index + INPUT_CHUNK_LEN)],
202          input_index,
203          decode_table,
204          &mut output[output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
205        )?;
206
207        output_index += DECODED_CHUNK_LEN;
208        input_index += INPUT_CHUNK_LEN;
209        remaining_chunks -= 1;
210      }
211    }
212  }
213
214  // Stage 3
215  // If input length was such that a chunk had to be deferred until after the fast loop
216  // because decoding it would have produced 2 trailing bytes that wouldn't then be
217  // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
218  // trailing bytes.
219  // However, we still need to avoid the last chunk (partial or complete) because it could
220  // have padding, so we always do 1 fewer to avoid the last chunk.
221  for _ in 1..remaining_chunks {
222    decode_chunk_precise(
223      &input[input_index..],
224      input_index,
225      decode_table,
226      &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
227    )?;
228
229    input_index += INPUT_CHUNK_LEN;
230    output_index += DECODED_CHUNK_LEN;
231  }
232
233  // always have one more (possibly partial) block of 8 input
234  debug_assert!(input.len() - input_index > 1 || input.is_empty());
235  debug_assert!(input.len() - input_index <= 8);
236
237  // Stage 4
238  // Finally, decode any leftovers that aren't a complete input block of 8 bytes.
239  // Use a u64 as a stack-resident 8 byte buffer.
240  let mut leftover_bits: u64 = 0;
241  let mut morsels_in_leftover = 0;
242  let mut padding_bytes = 0;
243  let mut first_padding_index: usize = 0;
244  let mut last_symbol = 0_u8;
245  let start_of_leftovers = input_index;
246  for (i, b) in input[start_of_leftovers..].iter().enumerate() {
247    // '=' padding
248    if *b == PAD_BYTE {
249      // There can be bad padding in a few ways:
250      // 1 - Padding with non-padding characters after it
251      // 2 - Padding after zero or one non-padding characters before it
252      //     in the current quad.
253      // 3 - More than two characters of padding. If 3 or 4 padding chars
254      //     are in the same quad, that implies it will be caught by #2.
255      //     If it spreads from one quad to another, it will be caught by
256      //     #2 in the second quad.
257
258      if i % 4 < 2 {
259        // Check for case #2.
260        let bad_padding_index = start_of_leftovers
261          + if padding_bytes > 0 {
262            // If we've already seen padding, report the first padding index.
263            // This is to be consistent with the faster logic above: it will report an
264            // error on the first padding character (since it doesn't expect to see
265            // anything but actual encoded data).
266            first_padding_index
267          } else {
268            // haven't seen padding before, just use where we are now
269            i
270          };
271        return Err(DecodeError::InvalidByte(bad_padding_index, *b));
272      }
273
274      if padding_bytes == 0 {
275        first_padding_index = i;
276      }
277
278      padding_bytes += 1;
279      continue;
280    }
281
282    // Check for case #1.
283    // To make '=' handling consistent with the main loop, don't allow
284    // non-suffix '=' in trailing chunk either. Report error as first
285    // erroneous padding.
286    if padding_bytes > 0 {
287      return Err(DecodeError::InvalidByte(
288        start_of_leftovers + first_padding_index,
289        PAD_BYTE,
290      ));
291    }
292    last_symbol = *b;
293
294    // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
295    // To minimize shifts, pack the leftovers from left to right.
296    let shift = 64 - (morsels_in_leftover + 1) * 6;
297    // tables are all 256 elements, lookup with a u8 index always succeeds
298    let morsel = decode_table[*b as usize];
299    if morsel == tables::INVALID_VALUE {
300      return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
301    }
302
303    leftover_bits |= (morsel as u64) << shift;
304    morsels_in_leftover += 1;
305  }
306
307  let leftover_bits_ready_to_append = match morsels_in_leftover {
308    0 => 0,
309    2 => 8,
310    3 => 16,
311    4 => 24,
312    6 => 32,
313    7 => 40,
314    8 => 48,
315    _ => unreachable!(
316      "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
317    ),
318  };
319
320  // if there are bits set outside the bits we care about, last symbol encodes trailing bits that
321  // will not be included in the output
322  let mask = !0 >> leftover_bits_ready_to_append;
323  if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
324    // last morsel is at `morsels_in_leftover` - 1
325    return Err(DecodeError::InvalidLastSymbol(
326      start_of_leftovers + morsels_in_leftover - 1,
327      last_symbol,
328    ));
329  }
330
331  let mut leftover_bits_appended_to_buf = 0;
332  while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
333    // `as` simply truncates the higher bits, which is what we want here
334    let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
335    output[output_index] = selected_bits;
336    output_index += 1;
337
338    leftover_bits_appended_to_buf += 8;
339  }
340
341  Ok(output_index)
342}
343
344/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
345/// first 6 of those contain meaningful data.
346///
347/// `input` is the bytes to decode, of which the first 8 bytes will be processed.
348/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
349/// accurately)
350/// `decode_table` is the lookup table for the particular base64 alphabet.
351/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
352/// data.
353// yes, really inline (worth 30-50% speedup)
354#[inline(always)]
355fn decode_chunk(
356  input: &[u8],
357  index_at_start_of_input: usize,
358  decode_table: &[u8; 256],
359  output: &mut [u8],
360) -> Result<(), DecodeError> {
361  let mut accum: u64;
362
363  let morsel = decode_table[input[0] as usize];
364  if morsel == tables::INVALID_VALUE {
365    return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
366  }
367  accum = (morsel as u64) << 58;
368
369  let morsel = decode_table[input[1] as usize];
370  if morsel == tables::INVALID_VALUE {
371    return Err(DecodeError::InvalidByte(
372      index_at_start_of_input + 1,
373      input[1],
374    ));
375  }
376  accum |= (morsel as u64) << 52;
377
378  let morsel = decode_table[input[2] as usize];
379  if morsel == tables::INVALID_VALUE {
380    return Err(DecodeError::InvalidByte(
381      index_at_start_of_input + 2,
382      input[2],
383    ));
384  }
385  accum |= (morsel as u64) << 46;
386
387  let morsel = decode_table[input[3] as usize];
388  if morsel == tables::INVALID_VALUE {
389    return Err(DecodeError::InvalidByte(
390      index_at_start_of_input + 3,
391      input[3],
392    ));
393  }
394  accum |= (morsel as u64) << 40;
395
396  let morsel = decode_table[input[4] as usize];
397  if morsel == tables::INVALID_VALUE {
398    return Err(DecodeError::InvalidByte(
399      index_at_start_of_input + 4,
400      input[4],
401    ));
402  }
403  accum |= (morsel as u64) << 34;
404
405  let morsel = decode_table[input[5] as usize];
406  if morsel == tables::INVALID_VALUE {
407    return Err(DecodeError::InvalidByte(
408      index_at_start_of_input + 5,
409      input[5],
410    ));
411  }
412  accum |= (morsel as u64) << 28;
413
414  let morsel = decode_table[input[6] as usize];
415  if morsel == tables::INVALID_VALUE {
416    return Err(DecodeError::InvalidByte(
417      index_at_start_of_input + 6,
418      input[6],
419    ));
420  }
421  accum |= (morsel as u64) << 22;
422
423  let morsel = decode_table[input[7] as usize];
424  if morsel == tables::INVALID_VALUE {
425    return Err(DecodeError::InvalidByte(
426      index_at_start_of_input + 7,
427      input[7],
428    ));
429  }
430  accum |= (morsel as u64) << 16;
431
432  write_u64(output, accum);
433
434  Ok(())
435}
436
437#[inline]
438fn write_u64(output: &mut [u8], value: u64) {
439  output[..8].copy_from_slice(&value.to_be_bytes());
440}
441
442/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
443/// trailing garbage bytes.
444#[inline]
445fn decode_chunk_precise(
446  input: &[u8],
447  index_at_start_of_input: usize,
448  decode_table: &[u8; 256],
449  output: &mut [u8],
450) -> Result<(), DecodeError> {
451  let mut tmp_buf = [0_u8; 8];
452
453  decode_chunk(
454    input,
455    index_at_start_of_input,
456    decode_table,
457    &mut tmp_buf[..],
458  )?;
459
460  output[0..6].copy_from_slice(&tmp_buf[0..6]);
461
462  Ok(())
463}