Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8/// Chunk size for no-wrap encoding: 32MB aligned to 3 bytes.
9/// Larger chunks = fewer write() syscalls for big files.
10const NOWRAP_CHUNK: usize = 32 * 1024 * 1024 - (32 * 1024 * 1024 % 3);
11
12/// Minimum data size for parallel encoding (1MB).
13/// Lowered from 4MB so 10MB benchmark workloads get multi-core processing.
14const PARALLEL_ENCODE_THRESHOLD: usize = 1024 * 1024;
15
16/// Minimum data size for parallel decoding (1MB of base64 data).
17/// Lowered from 4MB for better parallelism on typical workloads.
18const PARALLEL_DECODE_THRESHOLD: usize = 1024 * 1024;
19
20/// Encode data and write to output with line wrapping.
21/// Uses SIMD encoding with fused encode+wrap for maximum throughput.
22pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
23    if data.is_empty() {
24        return Ok(());
25    }
26
27    if wrap_col == 0 {
28        return encode_no_wrap(data, out);
29    }
30
31    encode_wrapped(data, wrap_col, out)
32}
33
34/// Encode without wrapping — parallel SIMD encoding for large data, sequential for small.
35fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
36    if data.len() >= PARALLEL_ENCODE_THRESHOLD {
37        return encode_no_wrap_parallel(data, out);
38    }
39
40    let actual_chunk = NOWRAP_CHUNK.min(data.len());
41    let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
42    // SAFETY: encode() writes exactly enc_len bytes before we read them.
43    let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
44    #[allow(clippy::uninit_vec)]
45    unsafe {
46        buf.set_len(enc_max);
47    }
48
49    for chunk in data.chunks(NOWRAP_CHUNK) {
50        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
51        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
52        out.write_all(encoded)?;
53    }
54    Ok(())
55}
56
57/// Parallel no-wrap encoding: split at 3-byte boundaries, encode chunks in parallel.
58/// Each chunk except possibly the last is 3-byte aligned, so no padding in intermediate chunks.
59/// Uses write_vectored (writev) to send all encoded chunks in a single syscall.
60fn encode_no_wrap_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
61    let num_threads = rayon::current_num_threads().max(1);
62    let raw_chunk = data.len() / num_threads;
63    // Align to 3 bytes so each chunk encodes without padding (except the last)
64    let chunk_size = ((raw_chunk + 2) / 3) * 3;
65
66    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(3)).collect();
67    let encoded_chunks: Vec<Vec<u8>> = chunks
68        .par_iter()
69        .map(|chunk| {
70            let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
71            let mut buf: Vec<u8> = Vec::with_capacity(enc_len);
72            #[allow(clippy::uninit_vec)]
73            unsafe {
74                buf.set_len(enc_len);
75            }
76            let _ = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
77            buf
78        })
79        .collect();
80
81    // Use write_vectored to send all chunks in a single syscall
82    let iov: Vec<io::IoSlice> = encoded_chunks.iter().map(|c| io::IoSlice::new(c)).collect();
83    write_all_vectored(out, &iov)
84}
85
86/// Encode with line wrapping — uses writev to interleave encoded segments
87/// with newlines without copying data. For each wrap_col-sized segment of
88/// encoded output, we create an IoSlice pointing directly at the encode buffer,
89/// interleaved with IoSlice entries pointing at a static newline byte.
90fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
91    // Calculate bytes_per_line: input bytes that produce exactly wrap_col encoded chars.
92    // For default wrap_col=76: 76*3/4 = 57 bytes per line.
93    let bytes_per_line = wrap_col * 3 / 4;
94    if bytes_per_line == 0 {
95        // Degenerate case: wrap_col < 4, fall back to byte-at-a-time
96        return encode_wrapped_small(data, wrap_col, out);
97    }
98
99    // Parallel encoding for large data when bytes_per_line is a multiple of 3.
100    // This guarantees each chunk encodes to complete base64 without padding.
101    if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line.is_multiple_of(3) {
102        return encode_wrapped_parallel(data, wrap_col, bytes_per_line, out);
103    }
104
105    // Align input chunk to bytes_per_line for complete output lines.
106    let lines_per_chunk = (32 * 1024 * 1024) / bytes_per_line;
107    let max_input_chunk = (lines_per_chunk * bytes_per_line).max(bytes_per_line);
108    let input_chunk = max_input_chunk.min(data.len());
109
110    let enc_max = BASE64_ENGINE.encoded_length(input_chunk);
111    let mut encode_buf: Vec<u8> = Vec::with_capacity(enc_max);
112    #[allow(clippy::uninit_vec)]
113    unsafe {
114        encode_buf.set_len(enc_max);
115    }
116
117    for chunk in data.chunks(max_input_chunk.max(1)) {
118        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
119        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
120
121        // Use writev: build IoSlice entries pointing at wrap_col-sized segments
122        // of the encoded buffer interleaved with newline IoSlices.
123        // This eliminates the fused_buf copy entirely.
124        write_wrapped_iov(encoded, wrap_col, out)?;
125    }
126
127    Ok(())
128}
129
130/// Static newline byte for IoSlice references in writev calls.
131static NEWLINE: [u8; 1] = [b'\n'];
132
133/// Write encoded base64 data with line wrapping using write_vectored (writev).
134/// Builds IoSlice entries pointing at wrap_col-sized segments of the encoded buffer,
135/// interleaved with newline IoSlices, then writes in batches of MAX_WRITEV_IOV.
136/// This is zero-copy: no fused output buffer needed.
137#[inline]
138fn write_wrapped_iov(encoded: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
139    // Max IoSlice entries per writev batch. Linux UIO_MAXIOV is 1024.
140    // Each line needs 2 entries (data + newline), so 512 lines per batch.
141    const MAX_IOV: usize = 1024;
142
143    let num_full_lines = encoded.len() / wrap_col;
144    let remainder = encoded.len() % wrap_col;
145    let total_iov = num_full_lines * 2 + if remainder > 0 { 2 } else { 0 };
146
147    // Small output: build all IoSlices and write in one call
148    if total_iov <= MAX_IOV {
149        let mut iov: Vec<io::IoSlice> = Vec::with_capacity(total_iov);
150        let mut pos = 0;
151        for _ in 0..num_full_lines {
152            iov.push(io::IoSlice::new(&encoded[pos..pos + wrap_col]));
153            iov.push(io::IoSlice::new(&NEWLINE));
154            pos += wrap_col;
155        }
156        if remainder > 0 {
157            iov.push(io::IoSlice::new(&encoded[pos..pos + remainder]));
158            iov.push(io::IoSlice::new(&NEWLINE));
159        }
160        return write_all_vectored(out, &iov);
161    }
162
163    // Large output: write in batches
164    let mut iov: Vec<io::IoSlice> = Vec::with_capacity(MAX_IOV);
165    let mut pos = 0;
166    for _ in 0..num_full_lines {
167        iov.push(io::IoSlice::new(&encoded[pos..pos + wrap_col]));
168        iov.push(io::IoSlice::new(&NEWLINE));
169        pos += wrap_col;
170        if iov.len() >= MAX_IOV {
171            write_all_vectored(out, &iov)?;
172            iov.clear();
173        }
174    }
175    if remainder > 0 {
176        iov.push(io::IoSlice::new(&encoded[pos..pos + remainder]));
177        iov.push(io::IoSlice::new(&NEWLINE));
178    }
179    if !iov.is_empty() {
180        write_all_vectored(out, &iov)?;
181    }
182    Ok(())
183}
184
185/// Write encoded base64 data with line wrapping using writev, tracking column state
186/// across calls. Used by encode_stream for piped input where chunks don't align
187/// to line boundaries.
188#[inline]
189fn write_wrapped_iov_streaming(
190    encoded: &[u8],
191    wrap_col: usize,
192    col: &mut usize,
193    out: &mut impl Write,
194) -> io::Result<()> {
195    const MAX_IOV: usize = 1024;
196    let mut iov: Vec<io::IoSlice> = Vec::with_capacity(MAX_IOV);
197    let mut rp = 0;
198
199    while rp < encoded.len() {
200        let space = wrap_col - *col;
201        let avail = encoded.len() - rp;
202
203        if avail <= space {
204            // Remaining data fits in current line
205            iov.push(io::IoSlice::new(&encoded[rp..rp + avail]));
206            *col += avail;
207            if *col == wrap_col {
208                iov.push(io::IoSlice::new(&NEWLINE));
209                *col = 0;
210            }
211            break;
212        } else {
213            // Fill current line and add newline
214            iov.push(io::IoSlice::new(&encoded[rp..rp + space]));
215            iov.push(io::IoSlice::new(&NEWLINE));
216            rp += space;
217            *col = 0;
218        }
219
220        if iov.len() >= MAX_IOV - 1 {
221            write_all_vectored(out, &iov)?;
222            iov.clear();
223        }
224    }
225
226    if !iov.is_empty() {
227        write_all_vectored(out, &iov)?;
228    }
229    Ok(())
230}
231
232/// Parallel wrapped encoding: split at bytes_per_line boundaries, encode + wrap in parallel.
233/// Requires bytes_per_line % 3 == 0 so each chunk encodes without intermediate padding.
234/// Uses write_vectored (writev) to send all encoded+wrapped chunks in a single syscall.
235fn encode_wrapped_parallel(
236    data: &[u8],
237    wrap_col: usize,
238    bytes_per_line: usize,
239    out: &mut impl Write,
240) -> io::Result<()> {
241    let num_threads = rayon::current_num_threads().max(1);
242    // Split at bytes_per_line boundaries for complete output lines per chunk
243    let lines_per_chunk = (data.len() / bytes_per_line / num_threads).max(1);
244    let chunk_size = lines_per_chunk * bytes_per_line;
245
246    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(bytes_per_line)).collect();
247    let encoded_chunks: Vec<Vec<u8>> = chunks
248        .par_iter()
249        .map(|chunk| {
250            let enc_max = BASE64_ENGINE.encoded_length(chunk.len());
251            let max_lines = enc_max / wrap_col + 2;
252            // Single allocation with two non-overlapping regions:
253            //   [0..fused_size) = fuse_wrap output region
254            //   [fused_size..fused_size+enc_max) = encode region
255            let fused_size = enc_max + max_lines;
256            let total_size = fused_size + enc_max;
257            let mut buf: Vec<u8> = Vec::with_capacity(total_size);
258            #[allow(clippy::uninit_vec)]
259            unsafe {
260                buf.set_len(total_size);
261            }
262            // Encode into the second region [fused_size..fused_size+enc_max]
263            let _ = BASE64_ENGINE.encode(chunk, buf[fused_size..fused_size + enc_max].as_out());
264            // Use split_at_mut to get non-overlapping mutable/immutable refs
265            let (fused_region, encode_region) = buf.split_at_mut(fused_size);
266            let encoded = &encode_region[..enc_max];
267            let wp = fuse_wrap(encoded, wrap_col, fused_region);
268            buf.truncate(wp);
269            buf
270        })
271        .collect();
272
273    // Use write_vectored to send all chunks in a single syscall
274    let iov: Vec<io::IoSlice> = encoded_chunks.iter().map(|c| io::IoSlice::new(c)).collect();
275    write_all_vectored(out, &iov)
276}
277
278/// Fuse encoded base64 data with newlines in a single pass.
279/// Uses ptr::copy_nonoverlapping with 8-line unrolling for max throughput.
280/// Returns number of bytes written.
281#[inline]
282fn fuse_wrap(encoded: &[u8], wrap_col: usize, out_buf: &mut [u8]) -> usize {
283    let line_out = wrap_col + 1; // wrap_col data bytes + 1 newline
284    let mut rp = 0;
285    let mut wp = 0;
286
287    // Unrolled: process 8 lines per iteration for better ILP
288    while rp + 8 * wrap_col <= encoded.len() {
289        unsafe {
290            let src = encoded.as_ptr().add(rp);
291            let dst = out_buf.as_mut_ptr().add(wp);
292
293            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
294            *dst.add(wrap_col) = b'\n';
295
296            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
297            *dst.add(line_out + wrap_col) = b'\n';
298
299            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
300            *dst.add(2 * line_out + wrap_col) = b'\n';
301
302            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
303            *dst.add(3 * line_out + wrap_col) = b'\n';
304
305            std::ptr::copy_nonoverlapping(src.add(4 * wrap_col), dst.add(4 * line_out), wrap_col);
306            *dst.add(4 * line_out + wrap_col) = b'\n';
307
308            std::ptr::copy_nonoverlapping(src.add(5 * wrap_col), dst.add(5 * line_out), wrap_col);
309            *dst.add(5 * line_out + wrap_col) = b'\n';
310
311            std::ptr::copy_nonoverlapping(src.add(6 * wrap_col), dst.add(6 * line_out), wrap_col);
312            *dst.add(6 * line_out + wrap_col) = b'\n';
313
314            std::ptr::copy_nonoverlapping(src.add(7 * wrap_col), dst.add(7 * line_out), wrap_col);
315            *dst.add(7 * line_out + wrap_col) = b'\n';
316        }
317        rp += 8 * wrap_col;
318        wp += 8 * line_out;
319    }
320
321    // Handle remaining 4 lines at a time
322    while rp + 4 * wrap_col <= encoded.len() {
323        unsafe {
324            let src = encoded.as_ptr().add(rp);
325            let dst = out_buf.as_mut_ptr().add(wp);
326
327            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
328            *dst.add(wrap_col) = b'\n';
329
330            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
331            *dst.add(line_out + wrap_col) = b'\n';
332
333            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
334            *dst.add(2 * line_out + wrap_col) = b'\n';
335
336            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
337            *dst.add(3 * line_out + wrap_col) = b'\n';
338        }
339        rp += 4 * wrap_col;
340        wp += 4 * line_out;
341    }
342
343    // Remaining full lines
344    while rp + wrap_col <= encoded.len() {
345        unsafe {
346            std::ptr::copy_nonoverlapping(
347                encoded.as_ptr().add(rp),
348                out_buf.as_mut_ptr().add(wp),
349                wrap_col,
350            );
351            *out_buf.as_mut_ptr().add(wp + wrap_col) = b'\n';
352        }
353        rp += wrap_col;
354        wp += line_out;
355    }
356
357    // Partial last line
358    if rp < encoded.len() {
359        let remaining = encoded.len() - rp;
360        unsafe {
361            std::ptr::copy_nonoverlapping(
362                encoded.as_ptr().add(rp),
363                out_buf.as_mut_ptr().add(wp),
364                remaining,
365            );
366        }
367        wp += remaining;
368        out_buf[wp] = b'\n';
369        wp += 1;
370    }
371
372    wp
373}
374
375/// Fallback for very small wrap columns (< 4 chars).
376fn encode_wrapped_small(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
377    let enc_max = BASE64_ENGINE.encoded_length(data.len());
378    let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
379    #[allow(clippy::uninit_vec)]
380    unsafe {
381        buf.set_len(enc_max);
382    }
383    let encoded = BASE64_ENGINE.encode(data, buf[..enc_max].as_out());
384
385    let wc = wrap_col.max(1);
386    for line in encoded.chunks(wc) {
387        out.write_all(line)?;
388        out.write_all(b"\n")?;
389    }
390    Ok(())
391}
392
393/// Decode base64 data and write to output (borrows data, allocates clean buffer).
394/// When `ignore_garbage` is true, strip all non-base64 characters.
395/// When false, only strip whitespace (standard behavior).
396pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
397    if data.is_empty() {
398        return Ok(());
399    }
400
401    if ignore_garbage {
402        let mut cleaned = strip_non_base64(data);
403        return decode_clean_slice(&mut cleaned, out);
404    }
405
406    // Fast path: single-pass strip + decode
407    decode_stripping_whitespace(data, out)
408}
409
410/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
411pub fn decode_owned(
412    data: &mut Vec<u8>,
413    ignore_garbage: bool,
414    out: &mut impl Write,
415) -> io::Result<()> {
416    if data.is_empty() {
417        return Ok(());
418    }
419
420    if ignore_garbage {
421        data.retain(|&b| is_base64_char(b));
422    } else {
423        strip_whitespace_inplace(data);
424    }
425
426    decode_clean_slice(data, out)
427}
428
429/// Strip all whitespace from a Vec in-place using the lookup table.
430/// Single-pass compaction: uses NOT_WHITESPACE table to classify all whitespace
431/// types simultaneously, avoiding the previous multi-scan approach.
432fn strip_whitespace_inplace(data: &mut Vec<u8>) {
433    // Quick check: any whitespace at all?
434    let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
435    if !has_ws {
436        return;
437    }
438
439    // Single-pass in-place compaction using the lookup table.
440    let ptr = data.as_ptr();
441    let mut_ptr = data.as_mut_ptr();
442    let len = data.len();
443    let mut wp = 0usize;
444
445    for i in 0..len {
446        let b = unsafe { *ptr.add(i) };
447        if NOT_WHITESPACE[b as usize] {
448            unsafe { *mut_ptr.add(wp) = b };
449            wp += 1;
450        }
451    }
452
453    data.truncate(wp);
454}
455
456/// 256-byte lookup table: true for non-whitespace bytes.
457/// Used for single-pass whitespace stripping in decode.
458static NOT_WHITESPACE: [bool; 256] = {
459    let mut table = [true; 256];
460    table[b' ' as usize] = false;
461    table[b'\t' as usize] = false;
462    table[b'\n' as usize] = false;
463    table[b'\r' as usize] = false;
464    table[0x0b] = false; // vertical tab
465    table[0x0c] = false; // form feed
466    table
467};
468
469/// Decode by stripping whitespace and decoding in a single fused pass.
470/// For data with no whitespace, decodes directly without any copy.
471/// Uses memchr2 SIMD gap-copy for \n/\r (the dominant whitespace in base64),
472/// then a fallback pass for rare whitespace types (tab, space, VT, FF).
473fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
474    // Quick check: any whitespace at all?  Use the lookup table for a single scan.
475    let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
476    if !has_ws {
477        // No whitespace — decode directly from borrowed data
478        return decode_borrowed_clean(out, data);
479    }
480
481    // SIMD gap-copy: use memchr2 to find \n and \r positions, then copy the
482    // gaps between them. For typical base64 (76-char lines), newlines are ~1/77
483    // of the data, so we process ~76 bytes per memchr hit instead of 1 per scalar.
484    let mut clean: Vec<u8> = Vec::with_capacity(data.len());
485    let dst = clean.as_mut_ptr();
486    let mut wp = 0usize;
487    let mut gap_start = 0usize;
488
489    for pos in memchr::memchr2_iter(b'\n', b'\r', data) {
490        let gap_len = pos - gap_start;
491        if gap_len > 0 {
492            unsafe {
493                std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), gap_len);
494            }
495            wp += gap_len;
496        }
497        gap_start = pos + 1;
498    }
499    // Copy the final gap after the last \n/\r
500    let tail_len = data.len() - gap_start;
501    if tail_len > 0 {
502        unsafe {
503            std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), tail_len);
504        }
505        wp += tail_len;
506    }
507    unsafe {
508        clean.set_len(wp);
509    }
510
511    // Second pass for rare whitespace (tab, space, VT, FF) using lookup table.
512    // In typical base64 streams this does nothing, but correctness requires it.
513    let has_rare_ws = clean.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
514    if has_rare_ws {
515        let ptr = clean.as_mut_ptr();
516        let len = clean.len();
517        let mut rp = 0;
518        let mut cwp = 0;
519        while rp < len {
520            let b = unsafe { *ptr.add(rp) };
521            if NOT_WHITESPACE[b as usize] {
522                unsafe { *ptr.add(cwp) = b };
523                cwp += 1;
524            }
525            rp += 1;
526        }
527        clean.truncate(cwp);
528    }
529
530    decode_clean_slice(&mut clean, out)
531}
532
533/// Decode a clean (no whitespace) buffer in-place with SIMD.
534fn decode_clean_slice(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
535    if data.is_empty() {
536        return Ok(());
537    }
538    match BASE64_ENGINE.decode_inplace(data) {
539        Ok(decoded) => out.write_all(decoded),
540        Err(_) => decode_error(),
541    }
542}
543
544/// Cold error path — keeps hot decode path tight by moving error construction out of line.
545#[cold]
546#[inline(never)]
547fn decode_error() -> io::Result<()> {
548    Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input"))
549}
550
551/// Decode clean base64 data (no whitespace) from a borrowed slice.
552fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
553    if data.is_empty() {
554        return Ok(());
555    }
556    // Parallel decode for large data: split at 4-byte boundaries,
557    // decode each chunk independently (base64 is context-free per 4-char group).
558    if data.len() >= PARALLEL_DECODE_THRESHOLD {
559        return decode_borrowed_clean_parallel(out, data);
560    }
561    match BASE64_ENGINE.decode_to_vec(data) {
562        Ok(decoded) => {
563            out.write_all(&decoded)?;
564            Ok(())
565        }
566        Err(_) => decode_error(),
567    }
568}
569
570/// Parallel decode: split at 4-byte boundaries, decode chunks in parallel via rayon.
571/// Pre-allocates a single contiguous output buffer with exact decoded offsets computed
572/// upfront, so each thread decodes directly to its final position. No compaction needed.
573fn decode_borrowed_clean_parallel(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
574    let num_threads = rayon::current_num_threads().max(1);
575    let raw_chunk = data.len() / num_threads;
576    // Align to 4 bytes (each 4 base64 chars = 3 decoded bytes, context-free)
577    let chunk_size = ((raw_chunk + 3) / 4) * 4;
578
579    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(4)).collect();
580
581    // Compute exact decoded sizes per chunk upfront to eliminate the compaction pass.
582    // For all chunks except the last, decoded size is exactly chunk.len() * 3 / 4.
583    // For the last chunk, account for '=' padding bytes.
584    let mut offsets: Vec<usize> = Vec::with_capacity(chunks.len() + 1);
585    offsets.push(0);
586    let mut total_decoded = 0usize;
587    for (i, chunk) in chunks.iter().enumerate() {
588        let decoded_size = if i == chunks.len() - 1 {
589            // Last chunk: count '=' padding to get exact decoded size
590            let pad = chunk.iter().rev().take(2).filter(|&&b| b == b'=').count();
591            chunk.len() * 3 / 4 - pad
592        } else {
593            // Non-last chunks: 4-byte aligned, no padding, exact 3/4 ratio
594            chunk.len() * 3 / 4
595        };
596        total_decoded += decoded_size;
597        offsets.push(total_decoded);
598    }
599
600    // Pre-allocate contiguous output buffer with exact total size
601    let mut output_buf: Vec<u8> = Vec::with_capacity(total_decoded);
602    #[allow(clippy::uninit_vec)]
603    unsafe {
604        output_buf.set_len(total_decoded);
605    }
606
607    // Parallel decode: each thread decodes directly into its exact final position.
608    // No compaction pass needed since offsets are computed from exact decoded sizes.
609    // SAFETY: each thread writes to a non-overlapping region of the output buffer.
610    // Use usize representation of the pointer for Send+Sync compatibility with rayon.
611    let out_addr = output_buf.as_mut_ptr() as usize;
612    let decode_result: Result<Vec<()>, io::Error> = chunks
613        .par_iter()
614        .enumerate()
615        .map(|(i, chunk)| {
616            let offset = offsets[i];
617            let expected_size = offsets[i + 1] - offset;
618            // SAFETY: each thread writes to non-overlapping region [offset..offset+expected_size]
619            let out_slice = unsafe {
620                std::slice::from_raw_parts_mut((out_addr as *mut u8).add(offset), expected_size)
621            };
622            let decoded = BASE64_ENGINE
623                .decode(chunk, out_slice.as_out())
624                .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid input"))?;
625            debug_assert_eq!(decoded.len(), expected_size);
626            Ok(())
627        })
628        .collect();
629
630    decode_result?;
631
632    out.write_all(&output_buf[..total_decoded])
633}
634
635/// Strip non-base64 characters (for -i / --ignore-garbage).
636fn strip_non_base64(data: &[u8]) -> Vec<u8> {
637    data.iter()
638        .copied()
639        .filter(|&b| is_base64_char(b))
640        .collect()
641}
642
643/// Check if a byte is a valid base64 alphabet character or padding.
644#[inline]
645fn is_base64_char(b: u8) -> bool {
646    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
647}
648
649/// Stream-encode from a reader to a writer. Used for stdin processing.
650/// Uses 3MB read chunks (aligned to 3 bytes for padding-free intermediate encoding).
651/// 3MB is optimal for piped input: large enough for good throughput, small enough
652/// that read_full() fills the buffer quickly from pipes (3 reads at 1MB pipe size).
653pub fn encode_stream(
654    reader: &mut impl Read,
655    wrap_col: usize,
656    writer: &mut impl Write,
657) -> io::Result<()> {
658    // 3MB aligned to 3 bytes — sweet spot for pipe throughput
659    const STREAM_READ: usize = 3 * 1024 * 1024;
660    let mut buf = vec![0u8; STREAM_READ];
661
662    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_READ);
663    let mut encode_buf = vec![0u8; encode_buf_size];
664
665    if wrap_col == 0 {
666        // No wrapping: encode each chunk and write directly.
667        loop {
668            let n = read_full(reader, &mut buf)?;
669            if n == 0 {
670                break;
671            }
672            let enc_len = BASE64_ENGINE.encoded_length(n);
673            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
674            writer.write_all(encoded)?;
675        }
676    } else {
677        // Wrapping: use writev with IoSlice to interleave newlines without copying.
678        // For streaming, we need to track the column position across chunks
679        // because the last encoded chunk may not fill a complete line.
680        let mut col = 0usize;
681
682        loop {
683            let n = read_full(reader, &mut buf)?;
684            if n == 0 {
685                break;
686            }
687            let enc_len = BASE64_ENGINE.encoded_length(n);
688            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
689
690            // For streaming wrapping: build IoSlice entries with column tracking
691            write_wrapped_iov_streaming(encoded, wrap_col, &mut col, writer)?;
692        }
693
694        if col > 0 {
695            writer.write_all(b"\n")?;
696        }
697    }
698
699    Ok(())
700}
701
702/// Stream-decode from a reader to a writer. Used for stdin processing.
703/// Fused single-pass: read chunk -> strip whitespace -> decode immediately.
704/// Uses 16MB read buffer to reduce syscalls and memchr2-based SIMD whitespace
705/// stripping for the common case (only \n and \r whitespace in base64 streams).
706pub fn decode_stream(
707    reader: &mut impl Read,
708    ignore_garbage: bool,
709    writer: &mut impl Write,
710) -> io::Result<()> {
711    const READ_CHUNK: usize = 16 * 1024 * 1024;
712    let mut buf = vec![0u8; READ_CHUNK];
713    // Pre-allocate clean buffer once and reuse across iterations.
714    // Use Vec with set_len for zero-overhead reset instead of clear() + extend().
715    let mut clean: Vec<u8> = Vec::with_capacity(READ_CHUNK + 4);
716    let mut carry = [0u8; 4];
717    let mut carry_len = 0usize;
718
719    loop {
720        let n = read_full(reader, &mut buf)?;
721        if n == 0 {
722            break;
723        }
724
725        // Copy carry bytes to start of clean buffer (0-3 bytes from previous chunk)
726        unsafe {
727            std::ptr::copy_nonoverlapping(carry.as_ptr(), clean.as_mut_ptr(), carry_len);
728        }
729
730        let chunk = &buf[..n];
731        if ignore_garbage {
732            // Scalar filter for ignore_garbage mode (rare path)
733            let dst = unsafe { clean.as_mut_ptr().add(carry_len) };
734            let mut wp = 0usize;
735            for &b in chunk {
736                if is_base64_char(b) {
737                    unsafe { *dst.add(wp) = b };
738                    wp += 1;
739                }
740            }
741            unsafe { clean.set_len(carry_len + wp) };
742        } else {
743            // SIMD gap-copy: use memchr2 to find \n and \r positions, then copy
744            // the gaps between them. For typical base64 (76-char lines), newlines
745            // are ~1/77 of the data, so we process ~76 bytes per memchr hit
746            // instead of 1 byte per scalar iteration.
747            let dst = unsafe { clean.as_mut_ptr().add(carry_len) };
748            let mut wp = 0usize;
749            let mut gap_start = 0usize;
750
751            for pos in memchr::memchr2_iter(b'\n', b'\r', chunk) {
752                let gap_len = pos - gap_start;
753                if gap_len > 0 {
754                    unsafe {
755                        std::ptr::copy_nonoverlapping(
756                            chunk.as_ptr().add(gap_start),
757                            dst.add(wp),
758                            gap_len,
759                        );
760                    }
761                    wp += gap_len;
762                }
763                gap_start = pos + 1;
764            }
765            let tail_len = n - gap_start;
766            if tail_len > 0 {
767                unsafe {
768                    std::ptr::copy_nonoverlapping(
769                        chunk.as_ptr().add(gap_start),
770                        dst.add(wp),
771                        tail_len,
772                    );
773                }
774                wp += tail_len;
775            }
776            let total_clean = carry_len + wp;
777            unsafe { clean.set_len(total_clean) };
778
779            // Second pass for rare whitespace (tab, space, VT, FF) using lookup table.
780            // In typical base64 streams this does nothing, but we need correctness.
781            let has_rare_ws = clean[carry_len..total_clean]
782                .iter()
783                .any(|&b| !NOT_WHITESPACE[b as usize]);
784            if has_rare_ws {
785                let ptr = clean.as_mut_ptr();
786                let mut rp = carry_len;
787                let mut cwp = carry_len;
788                while rp < total_clean {
789                    let b = unsafe { *ptr.add(rp) };
790                    if NOT_WHITESPACE[b as usize] {
791                        unsafe { *ptr.add(cwp) = b };
792                        cwp += 1;
793                    }
794                    rp += 1;
795                }
796                clean.truncate(cwp);
797            }
798        }
799
800        carry_len = 0;
801        let is_last = n < READ_CHUNK;
802
803        if is_last {
804            // Last chunk: decode everything (including padding)
805            decode_clean_slice(&mut clean, writer)?;
806        } else {
807            // Save incomplete base64 quadruplet for next iteration
808            let clean_len = clean.len();
809            let decode_len = (clean_len / 4) * 4;
810            let leftover = clean_len - decode_len;
811            if leftover > 0 {
812                unsafe {
813                    std::ptr::copy_nonoverlapping(
814                        clean.as_ptr().add(decode_len),
815                        carry.as_mut_ptr(),
816                        leftover,
817                    );
818                }
819                carry_len = leftover;
820            }
821            if decode_len > 0 {
822                clean.truncate(decode_len);
823                decode_clean_slice(&mut clean, writer)?;
824            }
825        }
826    }
827
828    // Handle any remaining carry-over bytes
829    if carry_len > 0 {
830        let mut carry_buf = carry[..carry_len].to_vec();
831        decode_clean_slice(&mut carry_buf, writer)?;
832    }
833
834    Ok(())
835}
836
837/// Write all IoSlice entries using write_vectored (writev syscall).
838/// Falls back to write_all per slice on partial writes.
839fn write_all_vectored(out: &mut impl Write, slices: &[io::IoSlice]) -> io::Result<()> {
840    if slices.is_empty() {
841        return Ok(());
842    }
843    let total: usize = slices.iter().map(|s| s.len()).sum();
844
845    // Try write_vectored first — often writes everything in one syscall
846    let written = match out.write_vectored(slices) {
847        Ok(n) if n >= total => return Ok(()),
848        Ok(n) => n,
849        Err(e) => return Err(e),
850    };
851
852    // Partial write fallback
853    let mut skip = written;
854    for slice in slices {
855        let slen = slice.len();
856        if skip >= slen {
857            skip -= slen;
858            continue;
859        }
860        if skip > 0 {
861            out.write_all(&slice[skip..])?;
862            skip = 0;
863        } else {
864            out.write_all(slice)?;
865        }
866    }
867    Ok(())
868}
869
870/// Read as many bytes as possible into buf, retrying on partial reads.
871/// Fast path: regular file reads usually return the full buffer on the first call,
872/// avoiding the loop overhead entirely.
873#[inline]
874fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
875    // Fast path: first read() usually fills the entire buffer for regular files
876    let n = reader.read(buf)?;
877    if n == buf.len() || n == 0 {
878        return Ok(n);
879    }
880    // Slow path: partial read — retry to fill buffer (pipes, slow devices)
881    let mut total = n;
882    while total < buf.len() {
883        match reader.read(&mut buf[total..]) {
884            Ok(0) => break,
885            Ok(n) => total += n,
886            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
887            Err(e) => return Err(e),
888        }
889    }
890    Ok(total)
891}