base64-ng 1.2.0

no_std-first Base64 encoding and decoding with strict APIs and a security-heavy release process
Documentation
use super::equality::{ct_accumulate_u8, ct_mask_eq_u8, ct_mask_nonzero_u8, report_ct_error};
use super::padded::{ct_decode_padded, ct_decode_padded_in_place, ct_padded_final_quantum};
use super::unpadded::{ct_decode_unpadded, ct_decode_unpadded_in_place};
use crate::{Alphabet, DecodeError, read_quad, wipe_bytes, wipe_tail};

pub(super) fn ct_decode_slice<A: Alphabet, const PAD: bool>(
    input: &[u8],
    output: &mut [u8],
) -> Result<usize, DecodeError> {
    if input.is_empty() {
        return Ok(0);
    }

    if PAD {
        ct_decode_padded::<A>(input, output)
    } else {
        ct_decode_unpadded::<A>(input, output)
    }
}

pub(super) fn ct_decode_slice_staged_clear_tail<A: Alphabet, const PAD: bool>(
    input: &[u8],
    output: &mut [u8],
    staging: &mut [u8],
) -> Result<usize, DecodeError> {
    let required = match ct_decoded_len::<A, PAD>(input) {
        Ok(required) => required,
        Err(err) => {
            wipe_bytes(output);
            wipe_bytes(staging);
            return Err(err);
        }
    };

    if output.len() < required {
        wipe_bytes(output);
        wipe_bytes(staging);
        return Err(DecodeError::OutputTooSmall {
            required,
            available: output.len(),
        });
    }

    if staging.len() < required {
        wipe_bytes(output);
        wipe_bytes(staging);
        return Err(DecodeError::StagingTooSmall {
            required,
            available: staging.len(),
        });
    }

    let written = match ct_decode_slice::<A, PAD>(input, &mut staging[..required]) {
        Ok(written) => written,
        Err(err) => {
            wipe_bytes(output);
            wipe_bytes(staging);
            return Err(err);
        }
    };

    output[..written].copy_from_slice(&staging[..written]);
    wipe_bytes(staging);
    wipe_tail(output, written);
    Ok(written)
}

pub(super) fn ct_decode_in_place<A: Alphabet, const PAD: bool>(
    buffer: &mut [u8],
) -> Result<usize, DecodeError> {
    if buffer.is_empty() {
        return Ok(0);
    }

    if PAD {
        ct_decode_padded_in_place::<A>(buffer)
    } else {
        ct_decode_unpadded_in_place::<A>(buffer)
    }
}

pub(super) fn ct_validate_decode<A: Alphabet, const PAD: bool>(
    input: &[u8],
) -> Result<(), DecodeError> {
    if input.is_empty() {
        return Ok(());
    }

    if PAD {
        ct_validate_padded::<A>(input)
    } else {
        ct_validate_unpadded::<A>(input)
    }
}

pub(super) fn ct_decoded_len<A: Alphabet, const PAD: bool>(
    input: &[u8],
) -> Result<usize, DecodeError> {
    ct_validate_decode::<A, PAD>(input)?;
    if input.is_empty() {
        return Ok(0);
    }

    if PAD {
        Ok(input.len() / 4 * 3 - ct_padding_len(input))
    } else {
        let full_quads = input.len() / 4 * 3;
        match input.len() % 4 {
            0 => Ok(full_quads),
            2 => Ok(full_quads + 1),
            3 => Ok(full_quads + 2),
            _ => Err(DecodeError::InvalidLength),
        }
    }
}

fn ct_validate_padded<A: Alphabet>(input: &[u8]) -> Result<(), DecodeError> {
    if !input.len().is_multiple_of(4) {
        return Err(DecodeError::InvalidLength);
    }

    let padding = ct_padding_len(input);
    let mut invalid_byte = 0u8;
    let mut invalid_padding = 0u8;
    let mut read = 0;

    while read + 4 < input.len() {
        let [b0, b1, b2, b3] =
            read_quad_or_mark_invalid(input, read, &mut invalid_byte, &mut invalid_padding);
        let (_, valid0) = ct_decode_alphabet_byte::<A>(b0);
        let (_, valid1) = ct_decode_alphabet_byte::<A>(b1);
        let (_, valid2) = ct_decode_alphabet_byte::<A>(b2);
        let (_, valid3) = ct_decode_alphabet_byte::<A>(b3);

        invalid_byte |= !valid0;
        invalid_byte |= !valid1;
        invalid_byte |= !valid2;
        invalid_byte |= !valid3;
        invalid_padding |= ct_mask_eq_u8(b2, b'=');
        invalid_padding |= ct_mask_eq_u8(b3, b'=');
        read += 4;
    }

    let final_chunk =
        read_quad_or_mark_invalid(input, read, &mut invalid_byte, &mut invalid_padding);
    let (_, final_invalid_byte, final_invalid_padding, _) =
        ct_padded_final_quantum::<A>(final_chunk, padding);
    invalid_byte |= final_invalid_byte;
    invalid_padding |= final_invalid_padding;

    report_ct_error(invalid_byte, invalid_padding)
}

fn ct_validate_unpadded<A: Alphabet>(input: &[u8]) -> Result<(), DecodeError> {
    if input.len() % 4 == 1 {
        return Err(DecodeError::InvalidLength);
    }

    let mut invalid_byte = 0u8;
    let mut invalid_padding = 0u8;
    let mut read = 0;

    while read + 4 <= input.len() {
        let [b0, b1, b2, b3] =
            read_quad_or_mark_invalid(input, read, &mut invalid_byte, &mut invalid_padding);
        let (_, valid0) = ct_decode_alphabet_byte::<A>(b0);
        let (_, valid1) = ct_decode_alphabet_byte::<A>(b1);
        let (_, valid2) = ct_decode_alphabet_byte::<A>(b2);
        let (_, valid3) = ct_decode_alphabet_byte::<A>(b3);

        invalid_byte |= !valid0;
        invalid_byte |= !valid1;
        invalid_byte |= !valid2;
        invalid_byte |= !valid3;
        invalid_padding |= ct_mask_eq_u8(b0, b'=');
        invalid_padding |= ct_mask_eq_u8(b1, b'=');
        invalid_padding |= ct_mask_eq_u8(b2, b'=');
        invalid_padding |= ct_mask_eq_u8(b3, b'=');

        read += 4;
    }

    match read_tail_or_mark_invalid(input, read, &mut invalid_byte, &mut invalid_padding) {
        [] => {}
        [b0, b1] => {
            let (_, valid0) = ct_decode_alphabet_byte::<A>(*b0);
            let (v1, valid1) = ct_decode_alphabet_byte::<A>(*b1);
            invalid_byte |= !valid0;
            invalid_byte |= !valid1;
            invalid_padding |= ct_mask_eq_u8(*b0, b'=');
            invalid_padding |= ct_mask_eq_u8(*b1, b'=');
            invalid_padding |= ct_mask_nonzero_u8(v1 & 0b0000_1111);
        }
        [b0, b1, b2] => {
            let (_, valid0) = ct_decode_alphabet_byte::<A>(*b0);
            let (_, valid1) = ct_decode_alphabet_byte::<A>(*b1);
            let (v2, valid2) = ct_decode_alphabet_byte::<A>(*b2);
            invalid_byte |= !valid0;
            invalid_byte |= !valid1;
            invalid_byte |= !valid2;
            invalid_padding |= ct_mask_eq_u8(*b0, b'=');
            invalid_padding |= ct_mask_eq_u8(*b1, b'=');
            invalid_padding |= ct_mask_eq_u8(*b2, b'=');
            invalid_padding |= ct_mask_nonzero_u8(v2 & 0b0000_0011);
        }
        _ => {
            invalid_byte = 0xff;
            invalid_padding = 0xff;
        }
    }

    report_ct_error(invalid_byte, invalid_padding)
}

fn read_tail(input: &[u8], offset: usize) -> Result<&[u8], DecodeError> {
    input.get(offset..).ok_or(DecodeError::InvalidLength)
}

pub(super) fn read_quad_or_mark_invalid(
    input: &[u8],
    offset: usize,
    invalid_byte: &mut u8,
    invalid_padding: &mut u8,
) -> [u8; 4] {
    if let Ok(quad) = read_quad(input, offset) {
        quad
    } else {
        debug_assert!(
            false,
            "read_quad failed inside length-validated constant-time decode loop"
        );
        *invalid_byte = 0xff;
        *invalid_padding = 0xff;
        [0; 4]
    }
}

pub(super) fn read_tail_or_mark_invalid<'a>(
    input: &'a [u8],
    offset: usize,
    invalid_byte: &mut u8,
    invalid_padding: &mut u8,
) -> &'a [u8] {
    if let Ok(tail) = read_tail(input, offset) {
        tail
    } else {
        debug_assert!(
            false,
            "read_tail failed inside length-validated constant-time decode loop"
        );
        *invalid_byte = 0xff;
        *invalid_padding = 0xff;
        &[]
    }
}

#[inline(never)]
#[allow(unsafe_code)]
pub(super) fn ct_decode_alphabet_byte<A: Alphabet>(byte: u8) -> (u8, u8) {
    let mut decoded = 0u8;
    let mut valid = 0u8;
    let mut candidate = 0u8;

    while candidate < 64 {
        let matches = core::hint::black_box(ct_mask_eq_u8(
            core::hint::black_box(byte),
            core::hint::black_box(A::ENCODE[candidate as usize]),
        ));
        decoded = ct_accumulate_u8(decoded, candidate & matches);
        valid = ct_accumulate_u8(valid, matches);
        candidate += 1;
    }

    (decoded, valid)
}

pub(super) fn ct_padding_len(input: &[u8]) -> usize {
    let Some((&last, before_last_prefix)) = input.split_last() else {
        return 0;
    };
    let Some(&before_last) = before_last_prefix.last() else {
        return 0;
    };
    usize::from(ct_mask_eq_u8(last, b'=') & 1) + usize::from(ct_mask_eq_u8(before_last, b'=') & 1)
}