#![doc = concat!("[graph-png]: data:image/png;base64,", include_str!("../images/graph.png.base64"))]
#![feature(portable_simd)]
use std::simd::LaneCount;
use std::simd::Simd;
use std::simd::SupportedLaneCount;
#[macro_use]
mod util;
mod simd;
#[derive(Copy, Clone, Debug)]
pub struct Error;
pub fn decode(data: &[u8]) -> Result<Vec<u8>, Error> {
let mut out = Vec::new();
decode_to(data, &mut out)?;
Ok(out)
}
pub fn encode(data: &[u8]) -> String {
let mut out = Vec::new();
encode_to(data, &mut out);
unsafe { String::from_utf8_unchecked(out) }
}
pub fn decode_to(data: &[u8], out: &mut Vec<u8>) -> Result<(), Error> {
if cfg!(target_feature = "avx2") {
decode_tunable::<32>(data, out)
} else {
decode_tunable::<16>(data, out)
}
}
pub fn encode_to(data: &[u8], out: &mut Vec<u8>) {
encode_tunable::<16>(data, out)
}
fn decode_tunable<const N: usize>(
data: &[u8],
out: &mut Vec<u8>,
) -> Result<(), Error>
where
LaneCount<N>: SupportedLaneCount,
{
assert!(N % 4 == 0);
let data = match data {
[p @ .., b'=', b'='] | [p @ .., b'='] | p => p,
};
if data.is_empty() {
return Ok(());
}
out.reserve(decoded_len(data.len()) + N);
let mut raw_out = out.as_mut_ptr_range().end;
let mut chunks = data.chunks_exact(N);
let mut failed = false;
for chunk in &mut chunks {
let (decoded, ok) = simd::decode(Simd::from_slice(chunk));
failed |= !ok;
unsafe {
raw_out.cast::<Simd<u8, N>>().write_unaligned(decoded);
raw_out = raw_out.add(decoded_len(N));
}
}
let rest = chunks.remainder();
if !rest.is_empty() {
let (decoded, ok) =
simd::decode(unsafe { read_slice_padded::<N, b'A'>(rest) });
failed |= !ok;
unsafe {
raw_out.cast::<Simd<u8, N>>().write_unaligned(decoded);
raw_out = raw_out.add(decoded_len(rest.len()));
}
}
if failed {
return Err(Error);
}
unsafe {
let new_len = raw_out.offset_from(out.as_ptr());
out.set_len(new_len as usize);
}
Ok(())
}
fn encode_tunable<const N: usize>(data: &[u8], out: &mut Vec<u8>)
where
LaneCount<N>: SupportedLaneCount,
{
assert!(N % 4 == 0);
let n3q = N / 4 * 3;
if data.is_empty() {
return;
}
out.reserve(encoded_len(data.len()) + N);
let mut raw_out = out.as_mut_ptr_range().end;
let mut start = data.as_ptr();
let end = unsafe {
if data.len() % n3q >= (N - n3q) {
start.add(data.len() - data.len() % n3q)
} else if data.len() < N {
start
} else {
start.add(data.len() - data.len() % n3q - n3q)
}
};
while start != end {
let chunk = unsafe { std::slice::from_raw_parts(start, N) };
let encoded = simd::encode(Simd::from_slice(chunk));
unsafe {
start = start.add(n3q);
raw_out.cast::<Simd<u8, N>>().write_unaligned(encoded);
raw_out = raw_out.add(N);
}
}
let end = data.as_ptr_range().end;
while start < end {
let chunk = unsafe {
let rest = end.offset_from(start) as usize;
std::slice::from_raw_parts(start, rest.min(n3q))
};
let encoded = simd::encode(unsafe { read_slice_padded::<N, 0>(chunk) });
unsafe {
start = start.add(chunk.len());
raw_out.cast::<Simd<u8, N>>().write_unaligned(encoded);
raw_out = raw_out.add(encoded_len(chunk.len()));
}
}
unsafe {
let new_len = raw_out.offset_from(out.as_ptr());
out.set_len(new_len as usize);
}
match out.len() % 4 {
2 => out.extend_from_slice(b"=="),
3 => out.extend_from_slice(b"="),
_ => {}
}
}
fn decoded_len(input: usize) -> usize {
let mod4 = input % 4;
input / 4 * 3 + (mod4 - mod4 / 2)
}
fn encoded_len(input: usize) -> usize {
let mod3 = input % 3;
input / 3 * 4 + (mod3 + (mod3 + 1) / 2)
}
#[inline(always)]
unsafe fn read_slice_padded<const N: usize, const Z: u8>(
slice: &[u8],
) -> Simd<u8, N>
where
LaneCount<N>: SupportedLaneCount,
{
let mut buf = [Z; N];
let ascii_ptr = buf.as_mut_ptr();
let mut write_at = ascii_ptr;
if slice.len() >= 16 {
for i in 0..slice.len() / 16 {
unsafe {
let word = slice.as_ptr().cast::<u128>().add(i).read_unaligned();
write_at.cast::<u128>().write_unaligned(word);
write_at = write_at.add(16);
}
}
}
unsafe {
let ptr = slice.as_ptr().offset(write_at.offset_from(ascii_ptr));
let len = slice.len() % 16;
if len >= 8 {
let lo = ptr.cast::<u64>().read_unaligned() as u128;
let hi = ptr.add(len - 8).cast::<u64>().read_unaligned() as u128;
let data = lo | (hi << ((len - 8) * 8));
let z = u128::from_ne_bytes([Z; 16]) << (len * 8);
write_at.cast::<u128>().write_unaligned(data | z);
} else if len >= 4 {
let lo = ptr.cast::<u32>().read_unaligned() as u64;
let hi = ptr.add(len - 4).cast::<u32>().read_unaligned() as u64;
let data = lo | (hi << ((len - 4) * 8));
let z = u64::from_ne_bytes([Z; 8]) << (len * 8);
write_at.cast::<u64>().write_unaligned(data | z);
} else if len >= 1 {
let lo = ptr.read() as u32;
let mid = ptr.add(len / 2).read() as u32;
let hi = ptr.add(len - 1).read() as u32;
let data = lo | (mid << ((len / 2) * 8)) | hi << ((len - 1) * 8);
let z = u32::from_ne_bytes([Z; 4]) << (len * 8);
write_at.cast::<u32>().write_unaligned(data | z);
}
}
buf.into()
}
#[cfg(test)]
mod tests {
fn random_tests() -> Vec<(usize, &'static [u8], Vec<u8>)> {
use base64::prelude::*;
include_bytes!("test_vectors.txt")
.split(|&b| b == b'\n')
.enumerate()
.map(|(i, b64)| (i, b64, BASE64_STANDARD.decode(b64).unwrap()))
.collect()
}
fn all_ones_tests() -> Vec<(usize, Vec<u8>, Vec<u8>)> {
use base64::prelude::*;
(0..500)
.map(|i| vec![0xff; i])
.enumerate()
.map(|(i, bin)| (i, BASE64_STANDARD.encode(&bin).into_bytes(), bin))
.collect()
}
#[test]
fn random_decode() {
for (i, enc, dec) in random_tests() {
assert_eq!(crate::decode(enc).unwrap(), dec, "case {i}");
}
}
#[test]
fn random_encode() {
for (i, enc, dec) in random_tests() {
assert_eq!(crate::encode(&dec).as_bytes(), enc, "case {i}");
}
}
#[test]
fn all_ones_decode() {
for (i, enc, dec) in all_ones_tests() {
assert_eq!(crate::decode(&enc).unwrap(), dec, "case {i}");
}
}
#[test]
fn all_ones_encode() {
for (i, enc, dec) in all_ones_tests() {
assert_eq!(crate::encode(&dec).as_bytes(), enc, "case {i}");
}
}
#[test]
fn alphabet() {
for b in 0..255u8 {
let res = crate::decode(&[b, b'=', b'=']);
if b.is_ascii_alphanumeric() || b == b'+' || b == b'/' {
assert!(res.is_ok(), "{b:#04x} is valid data");
} else {
assert!(res.is_err(), "{b:#04x} is not valid data");
}
}
}
#[test]
#[ignore]
fn keep_for_disassembly() {
std::hint::black_box((super::decode as usize, super::encode as usize));
}
}