base64 0.22.1

encodes and decodes base64 as bytes or utf8
Documentation
use crate::{
    engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
    DecodeError, DecodeSliceError, PAD_BYTE,
};

#[doc(hidden)]
pub struct GeneralPurposeEstimate {
    /// input len % 4
    rem: usize,
    conservative_decoded_len: usize,
}

impl GeneralPurposeEstimate {
    pub(crate) fn new(encoded_len: usize) -> Self {
        let rem = encoded_len % 4;
        Self {
            rem,
            conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
        }
    }
}

impl DecodeEstimate for GeneralPurposeEstimate {
    fn decoded_len_estimate(&self) -> usize {
        self.conservative_decoded_len
    }
}

/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
/// Returns the decode metadata, or an error.
// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
// inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
// but this is fragile and the best setting changes with only minor code modifications.
#[inline]
pub(crate) fn decode_helper(
    input: &[u8],
    estimate: GeneralPurposeEstimate,
    output: &mut [u8],
    decode_table: &[u8; 256],
    decode_allow_trailing_bits: bool,
    padding_mode: DecodePaddingMode,
) -> Result<DecodeMetadata, DecodeSliceError> {
    let input_complete_nonterminal_quads_len =
        complete_quads_len(input, estimate.rem, output.len(), decode_table)?;

    const UNROLLED_INPUT_CHUNK_SIZE: usize = 32;
    const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3;

    let input_complete_quads_after_unrolled_chunks_len =
        input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE;

    let input_unrolled_loop_len =
        input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len;

    // chunks of 32 bytes
    for (chunk_index, chunk) in input[..input_unrolled_loop_len]
        .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE)
        .enumerate()
    {
        let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE;
        let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE
            ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE];

        decode_chunk_8(
            &chunk[0..8],
            input_index,
            decode_table,
            &mut chunk_output[0..6],
        )?;
        decode_chunk_8(
            &chunk[8..16],
            input_index + 8,
            decode_table,
            &mut chunk_output[6..12],
        )?;
        decode_chunk_8(
            &chunk[16..24],
            input_index + 16,
            decode_table,
            &mut chunk_output[12..18],
        )?;
        decode_chunk_8(
            &chunk[24..32],
            input_index + 24,
            decode_table,
            &mut chunk_output[18..24],
        )?;
    }

    // remaining quads, except for the last possibly partial one, as it may have padding
    let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3;
    let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3;
    {
        let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len];

        for (chunk_index, chunk) in input
            [input_unrolled_loop_len..input_complete_nonterminal_quads_len]
            .chunks_exact(4)
            .enumerate()
        {
            let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3];

            decode_chunk_4(
                chunk,
                input_unrolled_loop_len + chunk_index * 4,
                decode_table,
                chunk_output,
            )?;
        }
    }

    super::decode_suffix::decode_suffix(
        input,
        input_complete_nonterminal_quads_len,
        output,
        output_complete_quad_len,
        decode_table,
        decode_allow_trailing_bits,
        padding_mode,
    )
}

/// Returns the length of complete quads, except for the last one, even if it is complete.
///
/// Returns an error if the output len is not big enough for decoding those complete quads, or if
/// the input % 4 == 1, and that last byte is an invalid value other than a pad byte.
///
/// - `input` is the base64 input
/// - `input_len_rem` is input len % 4
/// - `output_len` is the length of the output slice
pub(crate) fn complete_quads_len(
    input: &[u8],
    input_len_rem: usize,
    output_len: usize,
    decode_table: &[u8; 256],
) -> Result<usize, DecodeSliceError> {
    debug_assert!(input.len() % 4 == input_len_rem);

    // detect a trailing invalid byte, like a newline, as a user convenience
    if input_len_rem == 1 {
        let last_byte = input[input.len() - 1];
        // exclude pad bytes; might be part of padding that extends from earlier in the input
        if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
            return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into());
        }
    };

    // skip last quad, even if it's complete, as it may have padding
    let input_complete_nonterminal_quads_len = input
        .len()
        .saturating_sub(input_len_rem)
        // if rem was 0, subtract 4 to avoid padding
        .saturating_sub((input_len_rem == 0) as usize * 4);
    debug_assert!(
        input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
    );

    // check that everything except the last quad handled by decode_suffix will fit
    if output_len < input_complete_nonterminal_quads_len / 4 * 3 {
        return Err(DecodeSliceError::OutputSliceTooSmall);
    };
    Ok(input_complete_nonterminal_quads_len)
}

/// Decode 8 bytes of input into 6 bytes of output.
///
/// `input` is the 8 bytes to decode.
/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
/// accurately)
/// `decode_table` is the lookup table for the particular base64 alphabet.
/// `output` will have its first 6 bytes overwritten
// yes, really inline (worth 30-50% speedup)
#[inline(always)]
fn decode_chunk_8(
    input: &[u8],
    index_at_start_of_input: usize,
    decode_table: &[u8; 256],
    output: &mut [u8],
) -> Result<(), DecodeError> {
    let morsel = decode_table[usize::from(input[0])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
    }
    let mut accum = u64::from(morsel) << 58;

    let morsel = decode_table[usize::from(input[1])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 1,
            input[1],
        ));
    }
    accum |= u64::from(morsel) << 52;

    let morsel = decode_table[usize::from(input[2])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 2,
            input[2],
        ));
    }
    accum |= u64::from(morsel) << 46;

    let morsel = decode_table[usize::from(input[3])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 3,
            input[3],
        ));
    }
    accum |= u64::from(morsel) << 40;

    let morsel = decode_table[usize::from(input[4])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 4,
            input[4],
        ));
    }
    accum |= u64::from(morsel) << 34;

    let morsel = decode_table[usize::from(input[5])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 5,
            input[5],
        ));
    }
    accum |= u64::from(morsel) << 28;

    let morsel = decode_table[usize::from(input[6])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 6,
            input[6],
        ));
    }
    accum |= u64::from(morsel) << 22;

    let morsel = decode_table[usize::from(input[7])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 7,
            input[7],
        ));
    }
    accum |= u64::from(morsel) << 16;

    output[..6].copy_from_slice(&accum.to_be_bytes()[..6]);

    Ok(())
}

/// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output.
#[inline(always)]
fn decode_chunk_4(
    input: &[u8],
    index_at_start_of_input: usize,
    decode_table: &[u8; 256],
    output: &mut [u8],
) -> Result<(), DecodeError> {
    let morsel = decode_table[usize::from(input[0])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
    }
    let mut accum = u32::from(morsel) << 26;

    let morsel = decode_table[usize::from(input[1])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 1,
            input[1],
        ));
    }
    accum |= u32::from(morsel) << 20;

    let morsel = decode_table[usize::from(input[2])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 2,
            input[2],
        ));
    }
    accum |= u32::from(morsel) << 14;

    let morsel = decode_table[usize::from(input[3])];
    if morsel == INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 3,
            input[3],
        ));
    }
    accum |= u32::from(morsel) << 8;

    output[..3].copy_from_slice(&accum.to_be_bytes()[..3]);

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::engine::general_purpose::STANDARD;

    #[test]
    fn decode_chunk_8_writes_only_6_bytes() {
        let input = b"Zm9vYmFy"; // "foobar"
        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];

        decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
    }

    #[test]
    fn decode_chunk_4_writes_only_3_bytes() {
        let input = b"Zm9v"; // "foobar"
        let mut output = [0_u8, 1, 2, 3];

        decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
        assert_eq!(&vec![b'f', b'o', b'o', 3], &output);
    }

    #[test]
    fn estimate_short_lengths() {
        for (range, decoded_len_estimate) in [
            (0..=0, 0),
            (1..=4, 3),
            (5..=8, 6),
            (9..=12, 9),
            (13..=16, 12),
            (17..=20, 15),
        ] {
            for encoded_len in range {
                let estimate = GeneralPurposeEstimate::new(encoded_len);
                assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate());
            }
        }
    }

    #[test]
    fn estimate_via_u128_inflation() {
        // cover both ends of usize
        (0..1000)
            .chain(usize::MAX - 1000..=usize::MAX)
            .for_each(|encoded_len| {
                // inflate to 128 bit type to be able to safely use the easy formulas
                let len_128 = encoded_len as u128;

                let estimate = GeneralPurposeEstimate::new(encoded_len);
                assert_eq!(
                    (len_128 + 3) / 4 * 3,
                    estimate.conservative_decoded_len as u128
                );
            })
    }
}