base64-ng 1.1.0

no_std-first Base64 encoding and decoding with strict APIs and a security-heavy release process
Documentation
use super::*;

fn fill_pattern(output: &mut [u8], seed: usize) {
    for (index, byte) in output.iter_mut().enumerate() {
        let value = (index * 73 + seed * 19) % 256;
        *byte = u8::try_from(value).unwrap();
    }
}

fn assert_encode_backend_matches_scalar<A, const PAD: bool>(input: &[u8])
where
    A: Alphabet,
{
    let engine = Engine::<A, PAD>::new();
    let mut dispatched = [0x55; 256];
    let mut scalar = [0xaa; 256];

    let dispatched_result = engine.encode_slice(input, &mut dispatched);
    let scalar_result = scalar::scalar_reference_encode_slice::<A, PAD>(input, &mut scalar);

    assert_eq!(dispatched_result, scalar_result);
    if let Ok(written) = dispatched_result {
        assert_eq!(&dispatched[..written], &scalar[..written]);
    }

    let required = checked_encoded_len(input.len(), PAD).unwrap();
    if required > 0 {
        let mut dispatched_short = [0x55; 256];
        let mut scalar_short = [0xaa; 256];
        let available = required - 1;

        assert_eq!(
            engine.encode_slice(input, &mut dispatched_short[..available]),
            scalar::scalar_reference_encode_slice::<A, PAD>(input, &mut scalar_short[..available],)
        );
    }
}

fn assert_decode_backend_matches_scalar<A, const PAD: bool>(input: &[u8])
where
    A: Alphabet,
{
    let engine = Engine::<A, PAD>::new();
    let mut dispatched = [0x55; 128];
    let mut scalar = [0xaa; 128];

    let dispatched_result = engine.decode_slice(input, &mut dispatched);
    let scalar_result = scalar::scalar_reference_decode_slice::<A, PAD>(input, &mut scalar);

    assert_eq!(dispatched_result, scalar_result);
    if let Ok(written) = dispatched_result {
        assert_eq!(&dispatched[..written], &scalar[..written]);

        if written > 0 {
            let mut dispatched_short = [0x55; 128];
            let mut scalar_short = [0xaa; 128];
            let available = written - 1;

            assert_eq!(
                engine.decode_slice(input, &mut dispatched_short[..available]),
                scalar::scalar_reference_decode_slice::<A, PAD>(
                    input,
                    &mut scalar_short[..available],
                )
            );
        }
    }
}

fn assert_backend_round_trip_matches_scalar<A, const PAD: bool>(input: &[u8])
where
    A: Alphabet,
{
    assert_encode_backend_matches_scalar::<A, PAD>(input);

    let mut encoded = [0; 256];
    let encoded_len = scalar::scalar_reference_encode_slice::<A, PAD>(input, &mut encoded).unwrap();
    assert_decode_backend_matches_scalar::<A, PAD>(&encoded[..encoded_len]);
}

fn assert_standard_decode_chunk_matches_input(input: &[u8]) {
    let mut encoded = [0u8; 4];
    let encoded_len = STANDARD.encode_slice(input, &mut encoded).unwrap();
    assert_eq!(encoded_len, 4);

    let chunk = [encoded[0], encoded[1], encoded[2], encoded[3]];
    let mut decoded = [0u8; 3];
    let decoded_len = decode_chunk::<Standard, true>(chunk, &mut decoded).unwrap();

    assert_eq!(decoded_len, input.len());
    assert_eq!(&decoded[..decoded_len], input);
}

#[test]
fn backend_dispatch_matches_scalar_reference_for_canonical_inputs() {
    let mut input = [0; 128];

    for input_len in 0..=input.len() {
        fill_pattern(&mut input[..input_len], input_len);
        let input = &input[..input_len];

        assert_backend_round_trip_matches_scalar::<Standard, true>(input);
        assert_backend_round_trip_matches_scalar::<Standard, false>(input);
        assert_backend_round_trip_matches_scalar::<UrlSafe, true>(input);
        assert_backend_round_trip_matches_scalar::<UrlSafe, false>(input);
    }
}

#[test]
fn backend_dispatch_matches_scalar_reference_for_malformed_inputs() {
    for input in [
        &b"Z"[..],
        b"====",
        b"AA=A",
        b"Zh==",
        b"Zm9=",
        b"Zm9v$g==",
        b"Zm9vZh==",
    ] {
        assert_decode_backend_matches_scalar::<Standard, true>(input);
    }

    for input in [&b"Z"[..], b"AA=A", b"Zh", b"Zm9", b"Zm9vYg$"] {
        assert_decode_backend_matches_scalar::<Standard, false>(input);
    }

    assert_decode_backend_matches_scalar::<UrlSafe, true>(b"AA+A");
    assert_decode_backend_matches_scalar::<UrlSafe, false>(b"AA/A");
    assert_decode_backend_matches_scalar::<Standard, true>(b"AA-A");
    assert_decode_backend_matches_scalar::<Standard, false>(b"AA_A");
}

#[test]
fn decode_chunk_bit_packing_matches_exhaustive_small_inputs() {
    for byte in u8::MIN..=u8::MAX {
        assert_standard_decode_chunk_matches_input(&[byte]);
    }

    for first in u8::MIN..=u8::MAX {
        for second in u8::MIN..=u8::MAX {
            assert_standard_decode_chunk_matches_input(&[first, second]);
        }
    }
}

#[test]
fn decode_chunk_bit_packing_matches_representative_full_quanta() {
    const SAMPLES: [u8; 16] = [
        0, 1, 2, 15, 16, 31, 32, 63, 64, 95, 127, 128, 191, 192, 254, 255,
    ];

    for first in SAMPLES {
        for second in SAMPLES {
            for third in SAMPLES {
                assert_standard_decode_chunk_matches_input(&[first, second, third]);
            }
        }
    }
}

#[test]
fn ct_padded_final_quantum_fails_closed_for_invalid_padding_count() {
    let (_, invalid_byte, invalid_padding, written) =
        ct_padded_final_quantum::<Standard>(*b"ABCD", 3);

    assert_ne!(invalid_byte, 0);
    assert_ne!(invalid_padding, 0);
    assert_eq!(written, 0);
    assert_eq!(
        report_ct_error(invalid_byte, invalid_padding),
        Err(DecodeError::InvalidInput)
    );
}

#[cfg(feature = "simd")]
#[test]
fn simd_dispatch_scaffold_keeps_scalar_active() {
    assert_eq!(simd::active_backend(), simd::ActiveBackend::Scalar);
    let _candidate = simd::detected_candidate();
}

#[test]
fn encodes_standard_vectors() {
    let vectors = [
        (&b""[..], &b""[..]),
        (&b"f"[..], &b"Zg=="[..]),
        (&b"fo"[..], &b"Zm8="[..]),
        (&b"foo"[..], &b"Zm9v"[..]),
        (&b"foob"[..], &b"Zm9vYg=="[..]),
        (&b"fooba"[..], &b"Zm9vYmE="[..]),
        (&b"foobar"[..], &b"Zm9vYmFy"[..]),
    ];
    for (input, expected) in vectors {
        let mut output = [0u8; 16];
        let written = STANDARD.encode_slice(input, &mut output).unwrap();
        assert_eq!(&output[..written], expected);
    }
}

#[test]
fn decodes_standard_vectors() {
    let vectors = [
        (&b""[..], &b""[..]),
        (&b"Zg=="[..], &b"f"[..]),
        (&b"Zm8="[..], &b"fo"[..]),
        (&b"Zm9v"[..], &b"foo"[..]),
        (&b"Zm9vYg=="[..], &b"foob"[..]),
        (&b"Zm9vYmE="[..], &b"fooba"[..]),
        (&b"Zm9vYmFy"[..], &b"foobar"[..]),
    ];
    for (input, expected) in vectors {
        let mut output = [0u8; 16];
        let written = STANDARD.decode_slice(input, &mut output).unwrap();
        assert_eq!(&output[..written], expected);
    }
}

#[test]
fn supports_unpadded_url_safe() {
    let mut encoded = [0u8; 16];
    let written = URL_SAFE_NO_PAD
        .encode_slice(b"\xfb\xff", &mut encoded)
        .unwrap();
    assert_eq!(&encoded[..written], b"-_8");

    let mut decoded = [0u8; 2];
    let written = URL_SAFE_NO_PAD
        .decode_slice(&encoded[..written], &mut decoded)
        .unwrap();
    assert_eq!(&decoded[..written], b"\xfb\xff");
}

#[test]
fn decodes_in_place() {
    let mut buffer = *b"Zm9vYmFy";
    let decoded = STANDARD_NO_PAD.decode_in_place(&mut buffer).unwrap();
    assert_eq!(decoded, b"foobar");
}

#[test]
fn rejects_non_canonical_padding_bits() {
    let mut output = [0u8; 4];
    assert_eq!(
        STANDARD.decode_slice(b"Zh==", &mut output),
        Err(DecodeError::InvalidPadding { index: 1 })
    );
    assert_eq!(
        STANDARD.decode_slice(b"Zm9=", &mut output),
        Err(DecodeError::InvalidPadding { index: 2 })
    );
}