use std::io::{self, Read, Write};
use base64_simd::AsOut;
const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
#[inline]
fn num_cpus() -> usize {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}
const NOWRAP_CHUNK: usize = 8 * 1024 * 1024 - (8 * 1024 * 1024 % 3);
const PARALLEL_NOWRAP_THRESHOLD: usize = 16 * 1024 * 1024;
const PARALLEL_WRAPPED_THRESHOLD: usize = 12 * 1024 * 1024;
const PARALLEL_DECODE_THRESHOLD: usize = 1024 * 1024;
#[cfg(target_os = "linux")]
fn hint_hugepage(buf: &mut Vec<u8>) {
if buf.capacity() >= 2 * 1024 * 1024 {
unsafe {
libc::madvise(
buf.as_mut_ptr() as *mut libc::c_void,
buf.capacity(),
libc::MADV_HUGEPAGE,
);
}
}
}
pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
if wrap_col == 0 {
return encode_no_wrap(data, out);
}
encode_wrapped(data, wrap_col, out)
}
fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
if data.len() >= PARALLEL_NOWRAP_THRESHOLD && num_cpus() > 1 {
return encode_no_wrap_parallel(data, out);
}
let enc_len = BASE64_ENGINE.encoded_length(data.len().min(NOWRAP_CHUNK));
let mut buf: Vec<u8> = Vec::with_capacity(enc_len);
#[allow(clippy::uninit_vec)]
unsafe {
buf.set_len(enc_len);
}
for chunk in data.chunks(NOWRAP_CHUNK) {
let clen = BASE64_ENGINE.encoded_length(chunk.len());
let encoded = BASE64_ENGINE.encode(chunk, buf[..clen].as_out());
out.write_all(encoded)?;
}
Ok(())
}
fn encode_no_wrap_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
let num_threads = num_cpus().max(1);
let raw_chunk = data.len() / num_threads;
let chunk_size = ((raw_chunk + 2) / 3) * 3;
let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(3)).collect();
let mut offsets: Vec<usize> = Vec::with_capacity(chunks.len() + 1);
let mut total_out = 0usize;
for chunk in &chunks {
offsets.push(total_out);
total_out += BASE64_ENGINE.encoded_length(chunk.len());
}
let mut output: Vec<u8> = Vec::with_capacity(total_out);
#[allow(clippy::uninit_vec)]
unsafe {
output.set_len(total_out);
}
#[cfg(target_os = "linux")]
hint_hugepage(&mut output);
let output_base = output.as_mut_ptr() as usize;
rayon::scope(|s| {
for (i, chunk) in chunks.iter().enumerate() {
let out_off = offsets[i];
let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
let base = output_base;
s.spawn(move |_| {
let dest =
unsafe { std::slice::from_raw_parts_mut((base + out_off) as *mut u8, enc_len) };
let _ = BASE64_ENGINE.encode(chunk, dest.as_out());
});
}
});
out.write_all(&output[..total_out])
}
fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
let bytes_per_line = wrap_col * 3 / 4;
if bytes_per_line == 0 {
return encode_wrapped_small(data, wrap_col, out);
}
if data.len() >= PARALLEL_WRAPPED_THRESHOLD && bytes_per_line.is_multiple_of(3) {
return encode_wrapped_parallel(data, wrap_col, bytes_per_line, out);
}
if bytes_per_line.is_multiple_of(3) && data.len() > 1024 * 1024 {
let lines_per_chunk = (1024 * 1024) / bytes_per_line;
if lines_per_chunk > 0 {
return encode_wrapped_chunked(data, wrap_col, bytes_per_line, out);
}
}
if bytes_per_line.is_multiple_of(3) {
return encode_wrapped_expand(data, wrap_col, bytes_per_line, out);
}
let enc_max = BASE64_ENGINE.encoded_length(data.len());
let num_full = enc_max / wrap_col;
let rem = enc_max % wrap_col;
let out_len = num_full * (wrap_col + 1) + if rem > 0 { rem + 1 } else { 0 };
let mut enc_buf: Vec<u8> = Vec::with_capacity(enc_max);
#[allow(clippy::uninit_vec)]
unsafe {
enc_buf.set_len(enc_max);
}
let _ = BASE64_ENGINE.encode(data, enc_buf[..enc_max].as_out());
let mut out_buf: Vec<u8> = Vec::with_capacity(out_len);
#[allow(clippy::uninit_vec)]
unsafe {
out_buf.set_len(out_len);
}
let n = fuse_wrap(&enc_buf, wrap_col, &mut out_buf);
out.write_all(&out_buf[..n])
}
fn encode_wrapped_chunked(
data: &[u8],
wrap_col: usize,
bytes_per_line: usize,
out: &mut impl Write,
) -> io::Result<()> {
debug_assert!(bytes_per_line.is_multiple_of(3));
let lines_per_chunk = (1024 * 1024) / bytes_per_line;
let chunk_input = lines_per_chunk * bytes_per_line;
let line_out = wrap_col + 1;
let max_chunk_out =
lines_per_chunk * line_out + BASE64_ENGINE.encoded_length(bytes_per_line) + 2;
let mut out_buf: Vec<u8> = Vec::with_capacity(max_chunk_out);
#[allow(clippy::uninit_vec)]
unsafe {
out_buf.set_len(max_chunk_out);
}
let mut pos = 0;
while pos < data.len() {
let remaining = data.len() - pos;
let chunk_len = remaining.min(chunk_input);
let chunk = &data[pos..pos + chunk_len];
let full_lines = chunk_len / bytes_per_line;
let remainder = chunk_len % bytes_per_line;
let dst = out_buf.as_mut_ptr();
let mut line_idx = 0;
while line_idx + 4 <= full_lines {
let in_base = line_idx * bytes_per_line;
let out_base = line_idx * line_out;
unsafe {
let s0 = std::slice::from_raw_parts_mut(dst.add(out_base), wrap_col);
let _ =
BASE64_ENGINE.encode(&chunk[in_base..in_base + bytes_per_line], s0.as_out());
*dst.add(out_base + wrap_col) = b'\n';
let s1 = std::slice::from_raw_parts_mut(dst.add(out_base + line_out), wrap_col);
let _ = BASE64_ENGINE.encode(
&chunk[in_base + bytes_per_line..in_base + 2 * bytes_per_line],
s1.as_out(),
);
*dst.add(out_base + line_out + wrap_col) = b'\n';
let s2 = std::slice::from_raw_parts_mut(dst.add(out_base + 2 * line_out), wrap_col);
let _ = BASE64_ENGINE.encode(
&chunk[in_base + 2 * bytes_per_line..in_base + 3 * bytes_per_line],
s2.as_out(),
);
*dst.add(out_base + 2 * line_out + wrap_col) = b'\n';
let s3 = std::slice::from_raw_parts_mut(dst.add(out_base + 3 * line_out), wrap_col);
let _ = BASE64_ENGINE.encode(
&chunk[in_base + 3 * bytes_per_line..in_base + 4 * bytes_per_line],
s3.as_out(),
);
*dst.add(out_base + 3 * line_out + wrap_col) = b'\n';
}
line_idx += 4;
}
while line_idx < full_lines {
let in_off = line_idx * bytes_per_line;
let out_off = line_idx * line_out;
unsafe {
let s = std::slice::from_raw_parts_mut(dst.add(out_off), wrap_col);
let _ = BASE64_ENGINE.encode(&chunk[in_off..in_off + bytes_per_line], s.as_out());
*dst.add(out_off + wrap_col) = b'\n';
}
line_idx += 1;
}
let mut total_out = full_lines * line_out;
if remainder > 0 {
let in_off = full_lines * bytes_per_line;
let enc_len = BASE64_ENGINE.encoded_length(remainder);
unsafe {
let s = std::slice::from_raw_parts_mut(dst.add(total_out), enc_len);
let _ = BASE64_ENGINE.encode(&chunk[in_off..in_off + remainder], s.as_out());
*dst.add(total_out + enc_len) = b'\n';
}
total_out += enc_len + 1;
}
out.write_all(&out_buf[..total_out])?;
pos += chunk_len;
}
Ok(())
}
fn encode_wrapped_expand(
data: &[u8],
wrap_col: usize,
bytes_per_line: usize,
out: &mut impl Write,
) -> io::Result<()> {
debug_assert!(bytes_per_line.is_multiple_of(3));
let enc_len = BASE64_ENGINE.encoded_length(data.len());
if enc_len == 0 {
return Ok(());
}
let num_full = enc_len / wrap_col;
let rem = enc_len % wrap_col;
let out_len = num_full * (wrap_col + 1) + if rem > 0 { rem + 1 } else { 0 };
let mut buf: Vec<u8> = Vec::with_capacity(out_len);
#[allow(clippy::uninit_vec)]
unsafe {
buf.set_len(out_len);
}
#[cfg(target_os = "linux")]
hint_hugepage(&mut buf);
let encoded = BASE64_ENGINE.encode(data, buf[..enc_len].as_out());
debug_assert_eq!(encoded.len(), enc_len, "encode wrote unexpected length");
expand_backward(buf.as_mut_ptr(), enc_len, out_len, wrap_col);
out.write_all(&buf[..out_len])
}
#[allow(dead_code)]
fn encode_wrapped_scatter(
data: &[u8],
wrap_col: usize,
bytes_per_line: usize,
out: &mut impl Write,
) -> io::Result<()> {
let enc_len = BASE64_ENGINE.encoded_length(data.len());
if enc_len == 0 {
return Ok(());
}
let num_full = enc_len / wrap_col;
let rem = enc_len % wrap_col;
let out_len = num_full * (wrap_col + 1) + if rem > 0 { rem + 1 } else { 0 };
let mut buf: Vec<u8> = Vec::with_capacity(out_len);
#[allow(clippy::uninit_vec)]
unsafe {
buf.set_len(out_len);
}
#[cfg(target_os = "linux")]
hint_hugepage(&mut buf);
const GROUP_LINES: usize = 256;
let group_input = GROUP_LINES * bytes_per_line;
let temp_size = GROUP_LINES * wrap_col;
let mut temp: Vec<u8> = Vec::with_capacity(temp_size);
#[allow(clippy::uninit_vec)]
unsafe {
temp.set_len(temp_size);
}
let line_out = wrap_col + 1;
let mut wp = 0usize;
for chunk in data.chunks(group_input) {
let clen = BASE64_ENGINE.encoded_length(chunk.len());
let _ = BASE64_ENGINE.encode(chunk, temp[..clen].as_out());
let lines = clen / wrap_col;
let chunk_rem = clen % wrap_col;
let mut i = 0;
while i + 8 <= lines {
unsafe {
let src = temp.as_ptr().add(i * wrap_col);
let dst = buf.as_mut_ptr().add(wp);
std::ptr::copy_nonoverlapping(src, dst, wrap_col);
*dst.add(wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
*dst.add(line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(2 * wrap_col),
dst.add(2 * line_out),
wrap_col,
);
*dst.add(2 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(3 * wrap_col),
dst.add(3 * line_out),
wrap_col,
);
*dst.add(3 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(4 * wrap_col),
dst.add(4 * line_out),
wrap_col,
);
*dst.add(4 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(5 * wrap_col),
dst.add(5 * line_out),
wrap_col,
);
*dst.add(5 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(6 * wrap_col),
dst.add(6 * line_out),
wrap_col,
);
*dst.add(6 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(7 * wrap_col),
dst.add(7 * line_out),
wrap_col,
);
*dst.add(7 * line_out + wrap_col) = b'\n';
}
wp += 8 * line_out;
i += 8;
}
while i < lines {
unsafe {
std::ptr::copy_nonoverlapping(
temp.as_ptr().add(i * wrap_col),
buf.as_mut_ptr().add(wp),
wrap_col,
);
*buf.as_mut_ptr().add(wp + wrap_col) = b'\n';
}
wp += line_out;
i += 1;
}
if chunk_rem > 0 {
unsafe {
std::ptr::copy_nonoverlapping(
temp.as_ptr().add(lines * wrap_col),
buf.as_mut_ptr().add(wp),
chunk_rem,
);
*buf.as_mut_ptr().add(wp + chunk_rem) = b'\n';
}
wp += chunk_rem + 1;
}
}
out.write_all(&buf[..wp])
}
#[inline]
#[allow(dead_code)]
fn scatter_lines(
temp: &[u8],
buf: &mut [u8],
line_start: usize,
count: usize,
wrap_col: usize,
line_out: usize,
) {
unsafe {
let src = temp.as_ptr();
let dst = buf.as_mut_ptr();
for i in 0..count {
let s_off = i * wrap_col;
let d_off = (line_start + i) * line_out;
std::ptr::copy_nonoverlapping(src.add(s_off), dst.add(d_off), wrap_col);
*dst.add(d_off + wrap_col) = b'\n';
}
}
}
#[inline]
fn expand_backward(ptr: *mut u8, enc_len: usize, out_len: usize, wrap_col: usize) {
let num_full = enc_len / wrap_col;
let rem = enc_len % wrap_col;
unsafe {
let mut rp = enc_len;
let mut wp = out_len;
if rem > 0 {
wp -= 1;
*ptr.add(wp) = b'\n';
wp -= rem;
rp -= rem;
if rp != wp {
std::ptr::copy(ptr.add(rp), ptr.add(wp), rem);
}
}
let mut lines_left = num_full;
while lines_left >= 8 {
wp -= 1;
*ptr.add(wp) = b'\n';
rp -= wrap_col;
wp -= wrap_col;
std::ptr::copy(ptr.add(rp), ptr.add(wp), wrap_col);
wp -= 1;
*ptr.add(wp) = b'\n';
rp -= wrap_col;
wp -= wrap_col;
std::ptr::copy(ptr.add(rp), ptr.add(wp), wrap_col);
wp -= 1;
*ptr.add(wp) = b'\n';
rp -= wrap_col;
wp -= wrap_col;
std::ptr::copy(ptr.add(rp), ptr.add(wp), wrap_col);
wp -= 1;
*ptr.add(wp) = b'\n';
rp -= wrap_col;
wp -= wrap_col;
std::ptr::copy(ptr.add(rp), ptr.add(wp), wrap_col);
wp -= 1;
*ptr.add(wp) = b'\n';
rp -= wrap_col;
wp -= wrap_col;
std::ptr::copy(ptr.add(rp), ptr.add(wp), wrap_col);
wp -= 1;
*ptr.add(wp) = b'\n';
rp -= wrap_col;
wp -= wrap_col;
std::ptr::copy(ptr.add(rp), ptr.add(wp), wrap_col);
wp -= 1;
*ptr.add(wp) = b'\n';
rp -= wrap_col;
wp -= wrap_col;
std::ptr::copy(ptr.add(rp), ptr.add(wp), wrap_col);
wp -= 1;
*ptr.add(wp) = b'\n';
rp -= wrap_col;
wp -= wrap_col;
std::ptr::copy(ptr.add(rp), ptr.add(wp), wrap_col);
lines_left -= 8;
}
while lines_left > 0 {
wp -= 1;
*ptr.add(wp) = b'\n';
rp -= wrap_col;
wp -= wrap_col;
if rp != wp {
std::ptr::copy(ptr.add(rp), ptr.add(wp), wrap_col);
}
lines_left -= 1;
}
}
}
static NEWLINE: [u8; 1] = [b'\n'];
#[inline]
#[allow(dead_code)]
fn write_wrapped_iov(encoded: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
const MAX_IOV: usize = 1024;
let num_full_lines = encoded.len() / wrap_col;
let remainder = encoded.len() % wrap_col;
let total_iov = num_full_lines * 2 + if remainder > 0 { 2 } else { 0 };
if total_iov <= MAX_IOV {
let mut iov: Vec<io::IoSlice> = Vec::with_capacity(total_iov);
let mut pos = 0;
for _ in 0..num_full_lines {
iov.push(io::IoSlice::new(&encoded[pos..pos + wrap_col]));
iov.push(io::IoSlice::new(&NEWLINE));
pos += wrap_col;
}
if remainder > 0 {
iov.push(io::IoSlice::new(&encoded[pos..pos + remainder]));
iov.push(io::IoSlice::new(&NEWLINE));
}
return write_all_vectored(out, &iov);
}
let line_out = wrap_col + 1;
const BATCH_LINES: usize = 512;
let batch_fused_size = BATCH_LINES * line_out;
let mut fused: Vec<u8> = Vec::with_capacity(batch_fused_size);
#[allow(clippy::uninit_vec)]
unsafe {
fused.set_len(batch_fused_size);
}
let mut rp = 0;
let mut lines_done = 0;
while lines_done + BATCH_LINES <= num_full_lines {
let n = fuse_wrap(
&encoded[rp..rp + BATCH_LINES * wrap_col],
wrap_col,
&mut fused,
);
out.write_all(&fused[..n])?;
rp += BATCH_LINES * wrap_col;
lines_done += BATCH_LINES;
}
let remaining_lines = num_full_lines - lines_done;
if remaining_lines > 0 {
let n = fuse_wrap(
&encoded[rp..rp + remaining_lines * wrap_col],
wrap_col,
&mut fused,
);
out.write_all(&fused[..n])?;
rp += remaining_lines * wrap_col;
}
if remainder > 0 {
out.write_all(&encoded[rp..rp + remainder])?;
out.write_all(b"\n")?;
}
Ok(())
}
#[inline]
fn write_wrapped_iov_streaming(
encoded: &[u8],
wrap_col: usize,
col: &mut usize,
out: &mut impl Write,
) -> io::Result<()> {
const MAX_IOV: usize = 1024;
let mut iov: Vec<io::IoSlice> = Vec::with_capacity(MAX_IOV);
let mut rp = 0;
while rp < encoded.len() {
let space = wrap_col - *col;
let avail = encoded.len() - rp;
if avail <= space {
iov.push(io::IoSlice::new(&encoded[rp..rp + avail]));
*col += avail;
if *col == wrap_col {
iov.push(io::IoSlice::new(&NEWLINE));
*col = 0;
}
break;
} else {
iov.push(io::IoSlice::new(&encoded[rp..rp + space]));
iov.push(io::IoSlice::new(&NEWLINE));
rp += space;
*col = 0;
}
if iov.len() >= MAX_IOV - 1 {
write_all_vectored(out, &iov)?;
iov.clear();
}
}
if !iov.is_empty() {
write_all_vectored(out, &iov)?;
}
Ok(())
}
fn encode_wrapped_parallel(
data: &[u8],
wrap_col: usize,
bytes_per_line: usize,
out: &mut impl Write,
) -> io::Result<()> {
let num_threads = num_cpus().max(1);
let lines_per_chunk = ((data.len() / bytes_per_line) / num_threads).max(1);
let chunk_input = lines_per_chunk * bytes_per_line;
let chunks: Vec<&[u8]> = data.chunks(chunk_input.max(bytes_per_line)).collect();
let mut offsets: Vec<usize> = Vec::with_capacity(chunks.len() + 1);
let mut total_out = 0usize;
for chunk in &chunks {
offsets.push(total_out);
let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
let full_lines = enc_len / wrap_col;
let remainder = enc_len % wrap_col;
total_out += full_lines * (wrap_col + 1) + if remainder > 0 { remainder + 1 } else { 0 };
}
let mut output: Vec<u8> = Vec::with_capacity(total_out);
#[allow(clippy::uninit_vec)]
unsafe {
output.set_len(total_out);
}
#[cfg(target_os = "linux")]
hint_hugepage(&mut output);
let output_base = output.as_mut_ptr() as usize;
rayon::scope(|s| {
for (i, chunk) in chunks.iter().enumerate() {
let out_off = offsets[i];
let out_end = if i + 1 < offsets.len() {
offsets[i + 1]
} else {
total_out
};
let out_size = out_end - out_off;
let base = output_base;
s.spawn(move |_| {
let out_slice = unsafe {
std::slice::from_raw_parts_mut((base + out_off) as *mut u8, out_size)
};
encode_chunk_l1_scatter_into(chunk, out_slice, wrap_col, bytes_per_line);
});
}
});
out.write_all(&output[..total_out])
}
fn encode_chunk_l1_scatter_into(
data: &[u8],
output: &mut [u8],
wrap_col: usize,
bytes_per_line: usize,
) {
const GROUP_LINES: usize = 256;
let group_input = GROUP_LINES * bytes_per_line;
let temp_size = GROUP_LINES * wrap_col;
let mut temp: Vec<u8> = Vec::with_capacity(temp_size);
#[allow(clippy::uninit_vec)]
unsafe {
temp.set_len(temp_size);
}
let line_out = wrap_col + 1;
let mut wp = 0usize;
for chunk in data.chunks(group_input) {
let clen = BASE64_ENGINE.encoded_length(chunk.len());
let _ = BASE64_ENGINE.encode(chunk, temp[..clen].as_out());
let lines = clen / wrap_col;
let chunk_rem = clen % wrap_col;
let mut i = 0;
while i + 8 <= lines {
unsafe {
let src = temp.as_ptr().add(i * wrap_col);
let dst = output.as_mut_ptr().add(wp);
std::ptr::copy_nonoverlapping(src, dst, wrap_col);
*dst.add(wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
*dst.add(line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(2 * wrap_col),
dst.add(2 * line_out),
wrap_col,
);
*dst.add(2 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(3 * wrap_col),
dst.add(3 * line_out),
wrap_col,
);
*dst.add(3 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(4 * wrap_col),
dst.add(4 * line_out),
wrap_col,
);
*dst.add(4 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(5 * wrap_col),
dst.add(5 * line_out),
wrap_col,
);
*dst.add(5 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(6 * wrap_col),
dst.add(6 * line_out),
wrap_col,
);
*dst.add(6 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(
src.add(7 * wrap_col),
dst.add(7 * line_out),
wrap_col,
);
*dst.add(7 * line_out + wrap_col) = b'\n';
}
wp += 8 * line_out;
i += 8;
}
while i < lines {
unsafe {
std::ptr::copy_nonoverlapping(
temp.as_ptr().add(i * wrap_col),
output.as_mut_ptr().add(wp),
wrap_col,
);
*output.as_mut_ptr().add(wp + wrap_col) = b'\n';
}
wp += line_out;
i += 1;
}
if chunk_rem > 0 {
unsafe {
std::ptr::copy_nonoverlapping(
temp.as_ptr().add(lines * wrap_col),
output.as_mut_ptr().add(wp),
chunk_rem,
);
*output.as_mut_ptr().add(wp + chunk_rem) = b'\n';
}
wp += chunk_rem + 1;
}
}
}
#[inline]
fn fuse_wrap(encoded: &[u8], wrap_col: usize, out_buf: &mut [u8]) -> usize {
let line_out = wrap_col + 1; let mut rp = 0;
let mut wp = 0;
while rp + 8 * wrap_col <= encoded.len() {
unsafe {
let src = encoded.as_ptr().add(rp);
let dst = out_buf.as_mut_ptr().add(wp);
std::ptr::copy_nonoverlapping(src, dst, wrap_col);
*dst.add(wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
*dst.add(line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
*dst.add(2 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
*dst.add(3 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(4 * wrap_col), dst.add(4 * line_out), wrap_col);
*dst.add(4 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(5 * wrap_col), dst.add(5 * line_out), wrap_col);
*dst.add(5 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(6 * wrap_col), dst.add(6 * line_out), wrap_col);
*dst.add(6 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(7 * wrap_col), dst.add(7 * line_out), wrap_col);
*dst.add(7 * line_out + wrap_col) = b'\n';
}
rp += 8 * wrap_col;
wp += 8 * line_out;
}
while rp + 4 * wrap_col <= encoded.len() {
unsafe {
let src = encoded.as_ptr().add(rp);
let dst = out_buf.as_mut_ptr().add(wp);
std::ptr::copy_nonoverlapping(src, dst, wrap_col);
*dst.add(wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
*dst.add(line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
*dst.add(2 * line_out + wrap_col) = b'\n';
std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
*dst.add(3 * line_out + wrap_col) = b'\n';
}
rp += 4 * wrap_col;
wp += 4 * line_out;
}
while rp + wrap_col <= encoded.len() {
unsafe {
std::ptr::copy_nonoverlapping(
encoded.as_ptr().add(rp),
out_buf.as_mut_ptr().add(wp),
wrap_col,
);
*out_buf.as_mut_ptr().add(wp + wrap_col) = b'\n';
}
rp += wrap_col;
wp += line_out;
}
if rp < encoded.len() {
let remaining = encoded.len() - rp;
unsafe {
std::ptr::copy_nonoverlapping(
encoded.as_ptr().add(rp),
out_buf.as_mut_ptr().add(wp),
remaining,
);
}
wp += remaining;
out_buf[wp] = b'\n';
wp += 1;
}
wp
}
fn encode_wrapped_small(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
let enc_max = BASE64_ENGINE.encoded_length(data.len());
let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
#[allow(clippy::uninit_vec)]
unsafe {
buf.set_len(enc_max);
}
let encoded = BASE64_ENGINE.encode(data, buf[..enc_max].as_out());
let wc = wrap_col.max(1);
for line in encoded.chunks(wc) {
out.write_all(line)?;
out.write_all(b"\n")?;
}
Ok(())
}
pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
if ignore_garbage {
let mut cleaned = strip_non_base64(data);
return decode_clean_slice(&mut cleaned, out);
}
if data.len() < 512 * 1024 && data.len() >= 77 {
if let Some(result) = try_line_decode(data, out) {
return result;
}
}
decode_stripping_whitespace(data, out)
}
pub fn decode_mmap_inplace(
data: &mut [u8],
ignore_garbage: bool,
out: &mut impl Write,
) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
if !ignore_garbage && data.len() >= 77 && data.len() < 512 * 1024 {
if let Some(result) = try_line_decode(data, out) {
return result;
}
}
if ignore_garbage {
let ptr = data.as_mut_ptr();
let len = data.len();
let mut wp = 0;
for rp in 0..len {
let b = unsafe { *ptr.add(rp) };
if is_base64_char(b) {
unsafe { *ptr.add(wp) = b };
wp += 1;
}
}
let r = decode_inplace_with_padding(&mut data[..wp], out);
return r;
}
if data.len() >= 77 {
if let Some(result) = try_decode_uniform_lines(data, out) {
return result;
}
}
if memchr::memchr2(b'\n', b'\r', data).is_none() {
if !data
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c)
{
return decode_inplace_with_padding(data, out);
}
let ptr = data.as_mut_ptr();
let len = data.len();
let mut wp = 0;
for rp in 0..len {
let b = unsafe { *ptr.add(rp) };
if NOT_WHITESPACE[b as usize] {
unsafe { *ptr.add(wp) = b };
wp += 1;
}
}
return decode_inplace_with_padding(&mut data[..wp], out);
}
let ptr = data.as_mut_ptr();
let len = data.len();
let mut wp = 0usize;
let mut gap_start = 0usize;
let mut has_rare_ws = false;
for pos in memchr::memchr2_iter(b'\n', b'\r', data) {
let gap_len = pos - gap_start;
if gap_len > 0 {
if !has_rare_ws {
has_rare_ws = unsafe {
std::slice::from_raw_parts(ptr.add(gap_start), gap_len)
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c)
};
}
if wp != gap_start {
unsafe { std::ptr::copy(ptr.add(gap_start), ptr.add(wp), gap_len) };
}
wp += gap_len;
}
gap_start = pos + 1;
}
let tail_len = len - gap_start;
if tail_len > 0 {
if !has_rare_ws {
has_rare_ws = unsafe {
std::slice::from_raw_parts(ptr.add(gap_start), tail_len)
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c)
};
}
if wp != gap_start {
unsafe { std::ptr::copy(ptr.add(gap_start), ptr.add(wp), tail_len) };
}
wp += tail_len;
}
if has_rare_ws {
let mut rp = 0;
let mut cwp = 0;
while rp < wp {
let b = unsafe { *ptr.add(rp) };
if NOT_WHITESPACE[b as usize] {
unsafe { *ptr.add(cwp) = b };
cwp += 1;
}
rp += 1;
}
wp = cwp;
}
if wp >= PARALLEL_DECODE_THRESHOLD {
return decode_borrowed_clean_parallel(out, &data[..wp]);
}
decode_inplace_with_padding(&mut data[..wp], out)
}
pub fn decode_owned(
data: &mut Vec<u8>,
ignore_garbage: bool,
out: &mut impl Write,
) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
if ignore_garbage {
data.retain(|&b| is_base64_char(b));
} else {
strip_whitespace_inplace(data);
}
decode_clean_slice(data, out)
}
fn strip_whitespace_inplace(data: &mut Vec<u8>) {
if memchr::memchr2(b'\n', b'\r', data).is_none() {
if data
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c)
{
data.retain(|&b| NOT_WHITESPACE[b as usize]);
}
return;
}
let ptr = data.as_mut_ptr();
let len = data.len();
let mut wp = 0usize;
let mut gap_start = 0usize;
let mut has_rare_ws = false;
for pos in memchr::memchr2_iter(b'\n', b'\r', data.as_slice()) {
let gap_len = pos - gap_start;
if gap_len > 0 {
if !has_rare_ws {
has_rare_ws = data[gap_start..pos]
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
}
if wp != gap_start {
unsafe {
std::ptr::copy(ptr.add(gap_start), ptr.add(wp), gap_len);
}
}
wp += gap_len;
}
gap_start = pos + 1;
}
let tail_len = len - gap_start;
if tail_len > 0 {
if !has_rare_ws {
has_rare_ws = data[gap_start..]
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
}
if wp != gap_start {
unsafe {
std::ptr::copy(ptr.add(gap_start), ptr.add(wp), tail_len);
}
}
wp += tail_len;
}
data.truncate(wp);
if has_rare_ws {
let ptr = data.as_mut_ptr();
let len = data.len();
let mut rp = 0;
let mut cwp = 0;
while rp < len {
let b = unsafe { *ptr.add(rp) };
if NOT_WHITESPACE[b as usize] {
unsafe { *ptr.add(cwp) = b };
cwp += 1;
}
rp += 1;
}
data.truncate(cwp);
}
}
static NOT_WHITESPACE: [bool; 256] = {
let mut table = [true; 256];
table[b' ' as usize] = false;
table[b'\t' as usize] = false;
table[b'\n' as usize] = false;
table[b'\r' as usize] = false;
table[0x0b] = false; table[0x0c] = false; table
};
fn try_decode_uniform_lines(data: &[u8], out: &mut impl Write) -> Option<io::Result<()>> {
let first_nl = memchr::memchr(b'\n', data)?;
let line_len = first_nl;
if line_len == 0 || line_len % 4 != 0 {
return None;
}
let stride = line_len + 1;
let check_lines = 4.min(data.len() / stride);
for i in 1..check_lines {
let expected_nl = i * stride - 1;
if expected_nl >= data.len() || data[expected_nl] != b'\n' {
return None;
}
}
let full_lines = if data.len() >= stride {
let candidate = data.len() / stride;
if candidate > 0 && data[candidate * stride - 1] != b'\n' {
return None;
}
candidate
} else {
0
};
let remainder_start = full_lines * stride;
let remainder = &data[remainder_start..];
let rem_clean = if remainder.last() == Some(&b'\n') {
&remainder[..remainder.len() - 1]
} else {
remainder
};
let decoded_per_line = line_len * 3 / 4;
let rem_decoded_size = if rem_clean.is_empty() {
0
} else {
let pad = rem_clean
.iter()
.rev()
.take(2)
.filter(|&&b| b == b'=')
.count();
rem_clean.len() * 3 / 4 - pad
};
let total_decoded = full_lines * decoded_per_line + rem_decoded_size;
let clean_len = full_lines * line_len;
if clean_len >= PARALLEL_DECODE_THRESHOLD && num_cpus() > 1 {
let mut output: Vec<u8> = Vec::with_capacity(total_decoded);
#[allow(clippy::uninit_vec)]
unsafe {
output.set_len(total_decoded);
}
#[cfg(target_os = "linux")]
hint_hugepage(&mut output);
let out_ptr = output.as_mut_ptr() as usize;
let src_ptr = data.as_ptr() as usize;
let num_threads = num_cpus().max(1);
let lines_per_thread = (full_lines + num_threads - 1) / num_threads;
let lines_per_sub = (512 * 1024 / line_len).max(1);
let err_flag = std::sync::atomic::AtomicBool::new(false);
rayon::scope(|s| {
for t in 0..num_threads {
let err_flag = &err_flag;
s.spawn(move |_| {
let start_line = t * lines_per_thread;
if start_line >= full_lines {
return;
}
let end_line = (start_line + lines_per_thread).min(full_lines);
let chunk_lines = end_line - start_line;
let sub_buf_size = lines_per_sub.min(chunk_lines) * line_len;
let mut local_buf: Vec<u8> = Vec::with_capacity(sub_buf_size);
#[allow(clippy::uninit_vec)]
unsafe {
local_buf.set_len(sub_buf_size);
}
let src = src_ptr as *const u8;
let out_base = out_ptr as *mut u8;
let local_dst = local_buf.as_mut_ptr();
let mut sub_start = 0usize;
while sub_start < chunk_lines {
if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
let sub_count = (chunk_lines - sub_start).min(lines_per_sub);
let sub_clean = sub_count * line_len;
for i in 0..sub_count {
unsafe {
std::ptr::copy_nonoverlapping(
src.add((start_line + sub_start + i) * stride),
local_dst.add(i * line_len),
line_len,
);
}
}
let out_offset = (start_line + sub_start) * decoded_per_line;
let out_size = sub_count * decoded_per_line;
let out_slice = unsafe {
std::slice::from_raw_parts_mut(out_base.add(out_offset), out_size)
};
if BASE64_ENGINE
.decode(&local_buf[..sub_clean], out_slice.as_out())
.is_err()
{
err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
return;
}
sub_start += sub_count;
}
});
}
});
let result: Result<(), io::Error> = if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input"))
} else {
Ok(())
};
if let Err(e) = result {
return Some(Err(e));
}
if !rem_clean.is_empty() {
let rem_out = &mut output[full_lines * decoded_per_line..total_decoded];
match BASE64_ENGINE.decode(rem_clean, rem_out.as_out()) {
Ok(_) => {}
Err(_) => return Some(decode_error()),
}
}
return Some(out.write_all(&output[..total_decoded]));
}
let lines_per_sub = (256 * 1024 / line_len).max(1);
let sub_buf_size = lines_per_sub * line_len;
let mut local_buf: Vec<u8> = Vec::with_capacity(sub_buf_size);
#[allow(clippy::uninit_vec)]
unsafe {
local_buf.set_len(sub_buf_size);
}
let src = data.as_ptr();
let local_dst = local_buf.as_mut_ptr();
let mut line_idx = 0usize;
while line_idx < full_lines {
let sub_count = (full_lines - line_idx).min(lines_per_sub);
let sub_clean = sub_count * line_len;
for i in 0..sub_count {
unsafe {
std::ptr::copy_nonoverlapping(
src.add((line_idx + i) * stride),
local_dst.add(i * line_len),
line_len,
);
}
}
match BASE64_ENGINE.decode_inplace(&mut local_buf[..sub_clean]) {
Ok(decoded) => {
if let Err(e) = out.write_all(decoded) {
return Some(Err(e));
}
}
Err(_) => return Some(decode_error()),
}
line_idx += sub_count;
}
if !rem_clean.is_empty() {
let mut rem_buf = rem_clean.to_vec();
match BASE64_ENGINE.decode_inplace(&mut rem_buf) {
Ok(decoded) => {
if let Err(e) = out.write_all(decoded) {
return Some(Err(e));
}
}
Err(_) => return Some(decode_error()),
}
}
Some(Ok(()))
}
fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
if data.len() >= 77 {
if let Some(result) = try_decode_uniform_lines(data, out) {
return result;
}
}
if memchr::memchr2(b'\n', b'\r', data).is_none() {
if !data
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c)
{
return decode_borrowed_clean(out, data);
}
let mut cleaned: Vec<u8> = Vec::with_capacity(data.len());
for &b in data {
if NOT_WHITESPACE[b as usize] {
cleaned.push(b);
}
}
return decode_clean_slice(&mut cleaned, out);
}
let mut clean: Vec<u8> = Vec::with_capacity(data.len());
let dst = clean.as_mut_ptr();
let mut wp = 0usize;
let mut gap_start = 0usize;
let mut has_rare_ws = false;
for pos in memchr::memchr2_iter(b'\n', b'\r', data) {
let gap_len = pos - gap_start;
if gap_len > 0 {
if !has_rare_ws {
has_rare_ws = data[gap_start..pos]
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
}
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), gap_len);
}
wp += gap_len;
}
gap_start = pos + 1;
}
let tail_len = data.len() - gap_start;
if tail_len > 0 {
if !has_rare_ws {
has_rare_ws = data[gap_start..]
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
}
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), tail_len);
}
wp += tail_len;
}
unsafe {
clean.set_len(wp);
}
if has_rare_ws {
let ptr = clean.as_mut_ptr();
let len = clean.len();
let mut rp = 0;
let mut cwp = 0;
while rp < len {
let b = unsafe { *ptr.add(rp) };
if NOT_WHITESPACE[b as usize] {
unsafe { *ptr.add(cwp) = b };
cwp += 1;
}
rp += 1;
}
clean.truncate(cwp);
}
if clean.len() >= PARALLEL_DECODE_THRESHOLD {
decode_borrowed_clean_parallel(out, &clean)
} else {
decode_clean_slice(&mut clean, out)
}
}
fn try_line_decode(data: &[u8], out: &mut impl Write) -> Option<io::Result<()>> {
let first_nl = memchr::memchr(b'\n', data)?;
let line_len = first_nl;
if line_len == 0 || line_len % 4 != 0 {
return None;
}
let line_stride = line_len + 1; let decoded_per_line = line_len * 3 / 4;
let check_lines = 4.min(data.len() / line_stride);
for i in 1..check_lines {
let expected_nl = i * line_stride - 1;
if expected_nl >= data.len() {
break;
}
if data[expected_nl] != b'\n' {
return None; }
}
let full_lines = if data.len() >= line_stride {
let candidate = data.len() / line_stride;
if candidate > 0 && data[candidate * line_stride - 1] != b'\n' {
return None; }
candidate
} else {
0
};
let remainder_start = full_lines * line_stride;
let remainder = &data[remainder_start..];
let remainder_clean_len = if remainder.is_empty() {
0
} else {
let rem = if remainder.last() == Some(&b'\n') {
&remainder[..remainder.len() - 1]
} else {
remainder
};
if rem.is_empty() {
0
} else {
let pad = rem.iter().rev().take(2).filter(|&&b| b == b'=').count();
if rem.len() % 4 != 0 {
return None; }
rem.len() * 3 / 4 - pad
}
};
let total_decoded = full_lines * decoded_per_line + remainder_clean_len;
let mut out_buf: Vec<u8> = Vec::with_capacity(total_decoded);
#[allow(clippy::uninit_vec)]
unsafe {
out_buf.set_len(total_decoded);
}
let dst = out_buf.as_mut_ptr();
if data.len() >= PARALLEL_DECODE_THRESHOLD && full_lines >= 64 {
let out_addr = dst as usize;
let num_threads = num_cpus().max(1);
let lines_per_chunk = (full_lines / num_threads).max(1);
let mut tasks: Vec<(usize, usize)> = Vec::new();
let mut line_off = 0;
while line_off < full_lines {
let end = (line_off + lines_per_chunk).min(full_lines);
tasks.push((line_off, end));
line_off = end;
}
let decode_err = std::sync::atomic::AtomicBool::new(false);
rayon::scope(|s| {
for &(start_line, end_line) in &tasks {
let decode_err = &decode_err;
s.spawn(move |_| {
let out_ptr = out_addr as *mut u8;
let mut i = start_line;
while i + 4 <= end_line {
if decode_err.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
let in_base = i * line_stride;
let ob = i * decoded_per_line;
unsafe {
let s0 =
std::slice::from_raw_parts_mut(out_ptr.add(ob), decoded_per_line);
if BASE64_ENGINE
.decode(&data[in_base..in_base + line_len], s0.as_out())
.is_err()
{
decode_err.store(true, std::sync::atomic::Ordering::Relaxed);
return;
}
let s1 = std::slice::from_raw_parts_mut(
out_ptr.add(ob + decoded_per_line),
decoded_per_line,
);
if BASE64_ENGINE
.decode(
&data[in_base + line_stride..in_base + line_stride + line_len],
s1.as_out(),
)
.is_err()
{
decode_err.store(true, std::sync::atomic::Ordering::Relaxed);
return;
}
let s2 = std::slice::from_raw_parts_mut(
out_ptr.add(ob + 2 * decoded_per_line),
decoded_per_line,
);
if BASE64_ENGINE
.decode(
&data[in_base + 2 * line_stride
..in_base + 2 * line_stride + line_len],
s2.as_out(),
)
.is_err()
{
decode_err.store(true, std::sync::atomic::Ordering::Relaxed);
return;
}
let s3 = std::slice::from_raw_parts_mut(
out_ptr.add(ob + 3 * decoded_per_line),
decoded_per_line,
);
if BASE64_ENGINE
.decode(
&data[in_base + 3 * line_stride
..in_base + 3 * line_stride + line_len],
s3.as_out(),
)
.is_err()
{
decode_err.store(true, std::sync::atomic::Ordering::Relaxed);
return;
}
}
i += 4;
}
while i < end_line {
if decode_err.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
let in_start = i * line_stride;
let out_off = i * decoded_per_line;
let out_slice = unsafe {
std::slice::from_raw_parts_mut(out_ptr.add(out_off), decoded_per_line)
};
if BASE64_ENGINE
.decode(&data[in_start..in_start + line_len], out_slice.as_out())
.is_err()
{
decode_err.store(true, std::sync::atomic::Ordering::Relaxed);
return;
}
i += 1;
}
});
}
});
if decode_err.load(std::sync::atomic::Ordering::Relaxed) {
return Some(decode_error());
}
} else {
let mut i = 0;
while i + 4 <= full_lines {
let in_base = i * line_stride;
let out_base = i * decoded_per_line;
unsafe {
let s0 = std::slice::from_raw_parts_mut(dst.add(out_base), decoded_per_line);
if BASE64_ENGINE
.decode(&data[in_base..in_base + line_len], s0.as_out())
.is_err()
{
return Some(decode_error());
}
let s1 = std::slice::from_raw_parts_mut(
dst.add(out_base + decoded_per_line),
decoded_per_line,
);
if BASE64_ENGINE
.decode(
&data[in_base + line_stride..in_base + line_stride + line_len],
s1.as_out(),
)
.is_err()
{
return Some(decode_error());
}
let s2 = std::slice::from_raw_parts_mut(
dst.add(out_base + 2 * decoded_per_line),
decoded_per_line,
);
if BASE64_ENGINE
.decode(
&data[in_base + 2 * line_stride..in_base + 2 * line_stride + line_len],
s2.as_out(),
)
.is_err()
{
return Some(decode_error());
}
let s3 = std::slice::from_raw_parts_mut(
dst.add(out_base + 3 * decoded_per_line),
decoded_per_line,
);
if BASE64_ENGINE
.decode(
&data[in_base + 3 * line_stride..in_base + 3 * line_stride + line_len],
s3.as_out(),
)
.is_err()
{
return Some(decode_error());
}
}
i += 4;
}
while i < full_lines {
let in_start = i * line_stride;
let in_end = in_start + line_len;
let out_off = i * decoded_per_line;
let out_slice =
unsafe { std::slice::from_raw_parts_mut(dst.add(out_off), decoded_per_line) };
match BASE64_ENGINE.decode(&data[in_start..in_end], out_slice.as_out()) {
Ok(_) => {}
Err(_) => return Some(decode_error()),
}
i += 1;
}
}
if remainder_clean_len > 0 {
let rem = if remainder.last() == Some(&b'\n') {
&remainder[..remainder.len() - 1]
} else {
remainder
};
let out_off = full_lines * decoded_per_line;
let out_slice =
unsafe { std::slice::from_raw_parts_mut(dst.add(out_off), remainder_clean_len) };
match BASE64_ENGINE.decode(rem, out_slice.as_out()) {
Ok(_) => {}
Err(_) => return Some(decode_error()),
}
}
Some(out.write_all(&out_buf[..total_decoded]))
}
fn decode_clean_slice(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
decode_inplace_with_padding(data, out)
}
#[cold]
#[inline(never)]
fn decode_error() -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input"))
}
fn decode_inplace_with_padding(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
match BASE64_ENGINE.decode_inplace(data) {
Ok(decoded) => out.write_all(decoded),
Err(_) => {
let remainder = data.len() % 4;
if remainder == 2 || remainder == 3 {
let has_existing_padding = memchr::memchr(b'=', data).is_some();
let mut padded = Vec::with_capacity(data.len() + (4 - remainder));
padded.extend_from_slice(data);
padded.extend(std::iter::repeat_n(b'=', 4 - remainder));
if let Ok(decoded) = BASE64_ENGINE.decode_inplace(&mut padded) {
out.write_all(decoded)?;
if has_existing_padding {
return decode_error();
}
return Ok(());
}
}
decode_error()
}
}
}
fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
if data.len() >= PARALLEL_DECODE_THRESHOLD {
return decode_borrowed_clean_parallel(out, data);
}
let remainder = data.len() % 4;
if remainder == 2 || remainder == 3 {
let has_existing_padding = memchr::memchr(b'=', data).is_some();
let mut padded = Vec::with_capacity(data.len() + (4 - remainder));
padded.extend_from_slice(data);
padded.extend(std::iter::repeat_n(b'=', 4 - remainder));
let result = decode_borrowed_clean(out, &padded);
if has_existing_padding && result.is_ok() {
return decode_error();
}
return result;
}
let pad = data.iter().rev().take(2).filter(|&&b| b == b'=').count();
let decoded_size = data.len() * 3 / 4 - pad;
let mut buf: Vec<u8> = Vec::with_capacity(decoded_size);
#[allow(clippy::uninit_vec)]
unsafe {
buf.set_len(decoded_size);
}
match BASE64_ENGINE.decode(data, buf[..decoded_size].as_out()) {
Ok(decoded) => {
out.write_all(decoded)?;
Ok(())
}
Err(_) => decode_error(),
}
}
fn decode_borrowed_clean_parallel(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
let num_threads = num_cpus().max(1);
let raw_chunk = data.len() / num_threads;
let chunk_size = ((raw_chunk + 3) / 4) * 4;
let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(4)).collect();
let mut offsets: Vec<usize> = Vec::with_capacity(chunks.len() + 1);
offsets.push(0);
let mut total_decoded = 0usize;
for (i, chunk) in chunks.iter().enumerate() {
let decoded_size = if i == chunks.len() - 1 {
let pad = chunk.iter().rev().take(2).filter(|&&b| b == b'=').count();
chunk.len() * 3 / 4 - pad
} else {
chunk.len() * 3 / 4
};
total_decoded += decoded_size;
offsets.push(total_decoded);
}
let mut output_buf: Vec<u8> = Vec::with_capacity(total_decoded);
#[allow(clippy::uninit_vec)]
unsafe {
output_buf.set_len(total_decoded);
}
#[cfg(target_os = "linux")]
hint_hugepage(&mut output_buf);
let out_addr = output_buf.as_mut_ptr() as usize;
let err_flag = std::sync::atomic::AtomicBool::new(false);
rayon::scope(|s| {
for (i, chunk) in chunks.iter().enumerate() {
let offset = offsets[i];
let expected_size = offsets[i + 1] - offset;
let err_flag = &err_flag;
s.spawn(move |_| {
if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
let out_slice = unsafe {
std::slice::from_raw_parts_mut((out_addr as *mut u8).add(offset), expected_size)
};
if BASE64_ENGINE.decode(chunk, out_slice.as_out()).is_err() {
err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
}
});
}
});
if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input"));
}
out.write_all(&output_buf[..total_decoded])
}
fn strip_non_base64(data: &[u8]) -> Vec<u8> {
data.iter()
.copied()
.filter(|&b| is_base64_char(b))
.collect()
}
#[inline]
fn is_base64_char(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
}
pub fn encode_stream(
reader: &mut impl Read,
wrap_col: usize,
writer: &mut impl Write,
) -> io::Result<()> {
if wrap_col == 0 {
return encode_stream_nowrap(reader, writer);
}
encode_stream_wrapped(reader, wrap_col, writer)
}
fn encode_stream_nowrap(reader: &mut impl Read, writer: &mut impl Write) -> io::Result<()> {
const NOWRAP_READ: usize = 24 * 1024 * 1024;
let mut buf: Vec<u8> = Vec::with_capacity(NOWRAP_READ);
#[allow(clippy::uninit_vec)]
unsafe {
buf.set_len(NOWRAP_READ);
}
let encode_buf_size = BASE64_ENGINE.encoded_length(NOWRAP_READ);
let mut encode_buf: Vec<u8> = Vec::with_capacity(encode_buf_size);
#[allow(clippy::uninit_vec)]
unsafe {
encode_buf.set_len(encode_buf_size);
}
loop {
let n = read_full(reader, &mut buf)?;
if n == 0 {
break;
}
let enc_len = BASE64_ENGINE.encoded_length(n);
let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
writer.write_all(encoded)?;
}
Ok(())
}
fn encode_stream_wrapped(
reader: &mut impl Read,
wrap_col: usize,
writer: &mut impl Write,
) -> io::Result<()> {
let bytes_per_line = wrap_col * 3 / 4;
if bytes_per_line > 0 && bytes_per_line.is_multiple_of(3) {
return encode_stream_wrapped_fused(reader, wrap_col, bytes_per_line, writer);
}
const STREAM_READ: usize = 12 * 1024 * 1024;
let mut buf: Vec<u8> = Vec::with_capacity(STREAM_READ);
#[allow(clippy::uninit_vec)]
unsafe {
buf.set_len(STREAM_READ);
}
let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_READ);
let mut encode_buf: Vec<u8> = Vec::with_capacity(encode_buf_size);
#[allow(clippy::uninit_vec)]
unsafe {
encode_buf.set_len(encode_buf_size);
}
let mut col = 0usize;
loop {
let n = read_full(reader, &mut buf)?;
if n == 0 {
break;
}
let enc_len = BASE64_ENGINE.encoded_length(n);
let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
write_wrapped_iov_streaming(encoded, wrap_col, &mut col, writer)?;
}
if col > 0 {
writer.write_all(b"\n")?;
}
Ok(())
}
fn encode_stream_wrapped_fused(
reader: &mut impl Read,
wrap_col: usize,
bytes_per_line: usize,
writer: &mut impl Write,
) -> io::Result<()> {
let lines_per_chunk = (12 * 1024 * 1024) / bytes_per_line;
let read_size = lines_per_chunk * bytes_per_line;
let line_out = wrap_col + 1;
let mut buf: Vec<u8> = Vec::with_capacity(read_size);
#[allow(clippy::uninit_vec)]
unsafe {
buf.set_len(read_size);
}
let max_output = lines_per_chunk * line_out + BASE64_ENGINE.encoded_length(bytes_per_line) + 2;
let mut out_buf: Vec<u8> = Vec::with_capacity(max_output);
#[allow(clippy::uninit_vec)]
unsafe {
out_buf.set_len(max_output);
}
loop {
let n = read_full(reader, &mut buf)?;
if n == 0 {
break;
}
let full_lines = n / bytes_per_line;
let remainder = n % bytes_per_line;
let dst = out_buf.as_mut_ptr();
let mut line_idx = 0;
while line_idx + 4 <= full_lines {
let in_base = line_idx * bytes_per_line;
let out_base = line_idx * line_out;
unsafe {
let s0 = std::slice::from_raw_parts_mut(dst.add(out_base), wrap_col);
let _ = BASE64_ENGINE.encode(&buf[in_base..in_base + bytes_per_line], s0.as_out());
*dst.add(out_base + wrap_col) = b'\n';
let s1 = std::slice::from_raw_parts_mut(dst.add(out_base + line_out), wrap_col);
let _ = BASE64_ENGINE.encode(
&buf[in_base + bytes_per_line..in_base + 2 * bytes_per_line],
s1.as_out(),
);
*dst.add(out_base + line_out + wrap_col) = b'\n';
let s2 = std::slice::from_raw_parts_mut(dst.add(out_base + 2 * line_out), wrap_col);
let _ = BASE64_ENGINE.encode(
&buf[in_base + 2 * bytes_per_line..in_base + 3 * bytes_per_line],
s2.as_out(),
);
*dst.add(out_base + 2 * line_out + wrap_col) = b'\n';
let s3 = std::slice::from_raw_parts_mut(dst.add(out_base + 3 * line_out), wrap_col);
let _ = BASE64_ENGINE.encode(
&buf[in_base + 3 * bytes_per_line..in_base + 4 * bytes_per_line],
s3.as_out(),
);
*dst.add(out_base + 3 * line_out + wrap_col) = b'\n';
}
line_idx += 4;
}
while line_idx < full_lines {
let in_base = line_idx * bytes_per_line;
let out_base = line_idx * line_out;
unsafe {
let s = std::slice::from_raw_parts_mut(dst.add(out_base), wrap_col);
let _ = BASE64_ENGINE.encode(&buf[in_base..in_base + bytes_per_line], s.as_out());
*dst.add(out_base + wrap_col) = b'\n';
}
line_idx += 1;
}
let mut wp = full_lines * line_out;
if remainder > 0 {
let enc_len = BASE64_ENGINE.encoded_length(remainder);
let line_input = &buf[full_lines * bytes_per_line..n];
unsafe {
let s = std::slice::from_raw_parts_mut(dst.add(wp), enc_len);
let _ = BASE64_ENGINE.encode(line_input, s.as_out());
*dst.add(wp + enc_len) = b'\n';
}
wp += enc_len + 1;
}
writer.write_all(&out_buf[..wp])?;
}
Ok(())
}
pub fn decode_stream(
reader: &mut impl Read,
ignore_garbage: bool,
writer: &mut impl Write,
) -> io::Result<()> {
const READ_CHUNK: usize = 32 * 1024 * 1024;
let mut buf: Vec<u8> = Vec::with_capacity(READ_CHUNK + 4);
#[allow(clippy::uninit_vec)]
unsafe {
buf.set_len(READ_CHUNK + 4);
}
let mut carry = [0u8; 4];
let mut carry_len = 0usize;
loop {
if carry_len > 0 {
unsafe {
std::ptr::copy_nonoverlapping(carry.as_ptr(), buf.as_mut_ptr(), carry_len);
}
}
let n = read_full(reader, &mut buf[carry_len..carry_len + READ_CHUNK])?;
if n == 0 {
break;
}
let total_raw = carry_len + n;
let clean_len = if ignore_garbage {
let ptr = buf.as_mut_ptr();
let mut wp = 0usize;
for i in 0..total_raw {
let b = unsafe { *ptr.add(i) };
if is_base64_char(b) {
unsafe { *ptr.add(wp) = b };
wp += 1;
}
}
wp
} else {
let ptr = buf.as_mut_ptr();
let data = &buf[..total_raw];
let mut wp = 0usize;
let mut gap_start = 0usize;
let mut has_rare_ws = false;
for pos in memchr::memchr2_iter(b'\n', b'\r', data) {
let gap_len = pos - gap_start;
if gap_len > 0 {
if !has_rare_ws {
has_rare_ws = data[gap_start..pos]
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
}
if wp != gap_start {
unsafe {
std::ptr::copy(ptr.add(gap_start), ptr.add(wp), gap_len);
}
}
wp += gap_len;
}
gap_start = pos + 1;
}
let tail_len = total_raw - gap_start;
if tail_len > 0 {
if !has_rare_ws {
has_rare_ws = data[gap_start..total_raw]
.iter()
.any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
}
if wp != gap_start {
unsafe {
std::ptr::copy(ptr.add(gap_start), ptr.add(wp), tail_len);
}
}
wp += tail_len;
}
if has_rare_ws {
let mut rp = 0;
let mut cwp = 0;
while rp < wp {
let b = unsafe { *ptr.add(rp) };
if NOT_WHITESPACE[b as usize] {
unsafe { *ptr.add(cwp) = b };
cwp += 1;
}
rp += 1;
}
cwp
} else {
wp
}
};
carry_len = 0;
let is_last = n < READ_CHUNK;
if is_last {
decode_clean_slice(&mut buf[..clean_len], writer)?;
} else {
let decode_len = (clean_len / 4) * 4;
let leftover = clean_len - decode_len;
if leftover > 0 {
unsafe {
std::ptr::copy_nonoverlapping(
buf.as_ptr().add(decode_len),
carry.as_mut_ptr(),
leftover,
);
}
carry_len = leftover;
}
if decode_len > 0 {
decode_clean_slice(&mut buf[..decode_len], writer)?;
}
}
}
if carry_len > 0 {
let mut carry_buf = carry[..carry_len].to_vec();
decode_clean_slice(&mut carry_buf, writer)?;
}
Ok(())
}
#[inline(always)]
fn write_all_vectored(out: &mut impl Write, slices: &[io::IoSlice]) -> io::Result<()> {
if slices.is_empty() {
return Ok(());
}
let total: usize = slices.iter().map(|s| s.len()).sum();
let written = out.write_vectored(slices)?;
if written >= total {
return Ok(());
}
if written == 0 {
return Err(io::Error::new(io::ErrorKind::WriteZero, "write zero"));
}
write_all_vectored_slow(out, slices, written)
}
#[cold]
#[inline(never)]
fn write_all_vectored_slow(
out: &mut impl Write,
slices: &[io::IoSlice],
mut skip: usize,
) -> io::Result<()> {
for slice in slices {
let len = slice.len();
if skip >= len {
skip -= len;
continue;
}
out.write_all(&slice[skip..])?;
skip = 0;
}
Ok(())
}
#[inline]
fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
let n = reader.read(buf)?;
if n == buf.len() || n == 0 {
return Ok(n);
}
let mut total = n;
while total < buf.len() {
match reader.read(&mut buf[total..]) {
Ok(0) => break,
Ok(n) => total += n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(total)
}