base64-simd 0.5.0

SIMD-accelerated base64 encoding and decoding
Documentation
#![allow(missing_docs)]

use crate::utils::{empty_slice_mut, read, write};
use crate::{Base64, Base64Kind, Error, OutBuf, ERROR};

pub(crate) const STANDARD_CHARSET: &[u8; 64] =
    b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

pub(crate) const URL_SAFE_CHARSET: &[u8; 64] =
    b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";

const fn decode_table(charset: &'static [u8; 64]) -> [u8; 256] {
    let mut table = [0xff; 256];
    let mut i = 0;
    while i < charset.len() {
        table[charset[i] as usize] = i as u8;
        i += 1;
    }
    table
}

pub(crate) const STANDARD_DECODE_TABLE: &[u8; 256] = &decode_table(STANDARD_CHARSET);
pub(crate) const URL_SAFE_DECODE_TABLE: &[u8; 256] = &decode_table(URL_SAFE_CHARSET);

#[inline]
pub fn encode<'s, 'd>(
    base64: &'_ Base64,
    src: &'s [u8],
    dst: OutBuf<'d, u8>,
) -> Result<&'d mut [u8], Error> {
    unsafe {
        if src.is_empty() {
            return Ok(empty_slice_mut(dst.as_mut_ptr()));
        }

        let n = src.len();
        let m = Base64::encoded_length_unchecked(n, base64.padding);

        if dst.len() < m {
            return Err(ERROR);
        }

        let charset = match base64.kind {
            Base64Kind::Standard => STANDARD_CHARSET.as_ptr(),
            Base64Kind::UrlSafe => URL_SAFE_CHARSET.as_ptr(),
        };
        let padding = base64.padding;

        {
            let mut src = src.as_ptr();
            let mut dst = dst.as_mut_ptr();

            let dst_end = dst.add(n / 3 * 4);

            const UNROLL: usize = 4;
            if n / 3 * 3 >= (UNROLL * 6 + 2) {
                let src_end = src.add(n / 3 * 3 - (UNROLL * 6 + 2));
                while src <= src_end {
                    for _ in 0..UNROLL {
                        let x = u64::from_be_bytes(src.cast::<[u8; 8]>().read());
                        for i in 0..8 {
                            let y = read(charset, ((x >> (58 - i * 6)) & 0x3f) as usize);
                            write(dst, i, y)
                        }
                        src = src.add(6);
                        dst = dst.add(8);
                    }
                }
            }

            while dst < dst_end {
                let x = u32::from_be_bytes([0, read(src, 0), read(src, 1), read(src, 2)]);
                for i in 0..4 {
                    let y = read(charset, ((x >> (18 - i * 6)) & 0x3f) as usize);
                    write(dst, i, y);
                }
                src = src.add(3);
                dst = dst.add(4);
            }

            encode_extra(n % 3, src, dst, charset, padding)
        }

        Ok(core::slice::from_raw_parts_mut(dst.as_mut_ptr(), m))
    }
}

pub(crate) unsafe fn encode_extra(
    extra: usize,
    src: *const u8,
    dst: *mut u8,
    charset: *const u8,
    padding: bool,
) {
    match extra {
        0 => {}
        1 => {
            let x = read(src, 0);
            let y1 = read(charset, (x >> 2) as usize);
            let y2 = read(charset, ((x << 6) >> 2) as usize);
            write(dst, 0, y1);
            write(dst, 1, y2);
            if padding {
                write(dst, 2, Base64::PAD);
                write(dst, 3, Base64::PAD);
            }
        }
        2 => {
            let x1 = read(src, 0);
            let x2 = read(src, 1);
            let y1 = read(charset, (x1 >> 2) as usize);
            let y2 = read(charset, (((x1 << 6) >> 2) | (x2 >> 4)) as usize);
            let y3 = read(charset, ((x2 << 4) >> 2) as usize);
            write(dst, 0, y1);
            write(dst, 1, y2);
            write(dst, 2, y3);
            if padding {
                write(dst, 3, Base64::PAD);
            }
        }
        _ => core::hint::unreachable_unchecked(),
    }
}

#[inline]
pub fn decode<'s, 'd>(
    base64: &'_ Base64,
    src: &'s [u8],
    dst: OutBuf<'d, u8>,
) -> Result<&'d mut [u8], Error> {
    unsafe {
        if src.is_empty() {
            return Ok(empty_slice_mut(dst.as_mut_ptr()));
        }

        let (n, m) = Base64::decoded_length_unchecked(src, base64.padding)?;

        if dst.len() < m {
            return Err(ERROR);
        }

        let table = match base64.kind {
            Base64Kind::Standard => STANDARD_DECODE_TABLE.as_ptr(),
            Base64Kind::UrlSafe => URL_SAFE_DECODE_TABLE.as_ptr(),
        };

        {
            let mut src = src.as_ptr();
            let mut dst = dst.as_mut_ptr();

            let src_end = src.add(n / 4 * 4);

            const UNROLL: usize = 4;
            if m >= (UNROLL * 6 + 2) {
                let end = dst.add(m - (UNROLL * 6 + 2));
                while dst <= end {
                    for _ in 0..UNROLL {
                        let mut x = u64::from_le_bytes(src.cast::<[u8; 8]>().read());
                        let mut y: u64 = 0;
                        let mut flag = 0;
                        for i in 0..8 {
                            let bits = read(table, (x & 0xff) as usize);
                            flag |= bits;
                            x >>= 8;
                            y |= (bits as u64) << (58 - i * 6);
                        }
                        if flag == 0xff {
                            return Err(ERROR);
                        }
                        #[cfg(target_endian = "little")]
                        {
                            y = y.swap_bytes();
                        }
                        dst.cast::<u64>().write_unaligned(y);

                        src = src.add(8);
                        dst = dst.add(6);
                    }
                }
            }

            while src < src_end {
                let mut x = u32::from_le_bytes(src.cast::<[u8; 4]>().read());
                let mut y: u32 = 0;
                let mut flag = 0;
                for i in 0..4 {
                    let bits = read(table, (x & 0xff) as usize);
                    flag |= bits;
                    x >>= 8;
                    y |= (bits as u32) << (18 - i * 6);
                }
                if flag == 0xff {
                    return Err(ERROR);
                }
                let y = y.to_be_bytes();
                write(dst, 0, y[1]);
                write(dst, 1, y[2]);
                write(dst, 2, y[3]);
                src = src.add(4);
                dst = dst.add(3);
            }

            decode_extra(n % 4, src, dst, table)?;
        }
        Ok(core::slice::from_raw_parts_mut(dst.as_mut_ptr(), m))
    }
}

pub(crate) unsafe fn decode_extra(
    extra: usize,
    src: *const u8,
    dst: *mut u8,
    table: *const u8,
) -> Result<(), Error> {
    match extra {
        0 => {}
        1 => core::hint::unreachable_unchecked(),
        2 => {
            let [x1, x2] = src.cast::<[u8; 2]>().read();
            let y1 = read(table, x1 as usize);
            let y2 = read(table, x2 as usize);
            if (y2 & 0x0f) != 0 {
                return Err(ERROR);
            }
            if (y1 | y2) == 0xff {
                return Err(ERROR);
            }
            write(dst, 0, (y1 << 2) | (y2 >> 4));
        }
        3 => {
            let [x1, x2, x3] = src.cast::<[u8; 3]>().read();
            let y1 = read(table, x1 as usize);
            let y2 = read(table, x2 as usize);
            let y3 = read(table, x3 as usize);
            if (y3 & 0x03) != 0 {
                return Err(ERROR);
            }
            if (y1 | y2 | y3) == 0xff {
                return Err(ERROR);
            }
            write(dst, 0, (y1 << 2) | (y2 >> 4));
            write(dst, 1, (y2 << 4) | (y3 >> 2));
        }
        _ => core::hint::unreachable_unchecked(),
    }
    Ok(())
}

#[test]
fn test() {
    crate::tests::test(encode, decode);
}