Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8/// Chunk size for no-wrap encoding: 32MB aligned to 3 bytes.
9/// Larger chunks = fewer write() syscalls for big files.
10const NOWRAP_CHUNK: usize = 32 * 1024 * 1024 - (32 * 1024 * 1024 % 3);
11
12/// Minimum data size for parallel encoding (1MB).
13/// Lowered from 4MB so 10MB benchmark workloads get multi-core processing.
14const PARALLEL_ENCODE_THRESHOLD: usize = 1024 * 1024;
15
16/// Minimum data size for parallel decoding (1MB of base64 data).
17/// Lowered from 4MB for better parallelism on typical workloads.
18const PARALLEL_DECODE_THRESHOLD: usize = 1024 * 1024;
19
20/// Encode data and write to output with line wrapping.
21/// Uses SIMD encoding with fused encode+wrap for maximum throughput.
22pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
23    if data.is_empty() {
24        return Ok(());
25    }
26
27    if wrap_col == 0 {
28        return encode_no_wrap(data, out);
29    }
30
31    encode_wrapped(data, wrap_col, out)
32}
33
34/// Encode without wrapping — parallel SIMD encoding for large data, sequential for small.
35fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
36    if data.len() >= PARALLEL_ENCODE_THRESHOLD {
37        return encode_no_wrap_parallel(data, out);
38    }
39
40    let actual_chunk = NOWRAP_CHUNK.min(data.len());
41    let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
42    // SAFETY: encode() writes exactly enc_len bytes before we read them.
43    let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
44    #[allow(clippy::uninit_vec)]
45    unsafe {
46        buf.set_len(enc_max);
47    }
48
49    for chunk in data.chunks(NOWRAP_CHUNK) {
50        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
51        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
52        out.write_all(encoded)?;
53    }
54    Ok(())
55}
56
57/// Parallel no-wrap encoding: split at 3-byte boundaries, encode chunks in parallel.
58/// Each chunk except possibly the last is 3-byte aligned, so no padding in intermediate chunks.
59/// Uses write_vectored (writev) to send all encoded chunks in a single syscall.
60fn encode_no_wrap_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
61    let num_threads = rayon::current_num_threads().max(1);
62    let raw_chunk = data.len() / num_threads;
63    // Align to 3 bytes so each chunk encodes without padding (except the last)
64    let chunk_size = ((raw_chunk + 2) / 3) * 3;
65
66    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(3)).collect();
67    let encoded_chunks: Vec<Vec<u8>> = chunks
68        .par_iter()
69        .map(|chunk| {
70            let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
71            let mut buf: Vec<u8> = Vec::with_capacity(enc_len);
72            #[allow(clippy::uninit_vec)]
73            unsafe {
74                buf.set_len(enc_len);
75            }
76            let _ = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
77            buf
78        })
79        .collect();
80
81    // Use write_vectored to send all chunks in a single syscall
82    let iov: Vec<io::IoSlice> = encoded_chunks.iter().map(|c| io::IoSlice::new(c)).collect();
83    write_all_vectored(out, &iov)
84}
85
86/// Encode with line wrapping — uses writev to interleave encoded segments
87/// with newlines without copying data. For each wrap_col-sized segment of
88/// encoded output, we create an IoSlice pointing directly at the encode buffer,
89/// interleaved with IoSlice entries pointing at a static newline byte.
90fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
91    // Calculate bytes_per_line: input bytes that produce exactly wrap_col encoded chars.
92    // For default wrap_col=76: 76*3/4 = 57 bytes per line.
93    let bytes_per_line = wrap_col * 3 / 4;
94    if bytes_per_line == 0 {
95        // Degenerate case: wrap_col < 4, fall back to byte-at-a-time
96        return encode_wrapped_small(data, wrap_col, out);
97    }
98
99    // Parallel encoding for large data when bytes_per_line is a multiple of 3.
100    // This guarantees each chunk encodes to complete base64 without padding.
101    if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line.is_multiple_of(3) {
102        return encode_wrapped_parallel(data, wrap_col, bytes_per_line, out);
103    }
104
105    // Align input chunk to bytes_per_line for complete output lines.
106    let lines_per_chunk = (32 * 1024 * 1024) / bytes_per_line;
107    let max_input_chunk = (lines_per_chunk * bytes_per_line).max(bytes_per_line);
108    let input_chunk = max_input_chunk.min(data.len());
109
110    let enc_max = BASE64_ENGINE.encoded_length(input_chunk);
111    let mut encode_buf: Vec<u8> = Vec::with_capacity(enc_max);
112    #[allow(clippy::uninit_vec)]
113    unsafe {
114        encode_buf.set_len(enc_max);
115    }
116
117    for chunk in data.chunks(max_input_chunk.max(1)) {
118        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
119        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
120
121        // Use writev: build IoSlice entries pointing at wrap_col-sized segments
122        // of the encoded buffer interleaved with newline IoSlices.
123        // This eliminates the fused_buf copy entirely.
124        write_wrapped_iov(encoded, wrap_col, out)?;
125    }
126
127    Ok(())
128}
129
130/// Static newline byte for IoSlice references in writev calls.
131static NEWLINE: [u8; 1] = [b'\n'];
132
133/// Write encoded base64 data with line wrapping using write_vectored (writev).
134/// Builds IoSlice entries pointing at wrap_col-sized segments of the encoded buffer,
135/// interleaved with newline IoSlices, then writes in batches of MAX_WRITEV_IOV.
136/// This is zero-copy: no fused output buffer needed.
137#[inline]
138fn write_wrapped_iov(encoded: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
139    // Max IoSlice entries per writev batch. Linux UIO_MAXIOV is 1024.
140    // Each line needs 2 entries (data + newline), so 512 lines per batch.
141    const MAX_IOV: usize = 1024;
142
143    let num_full_lines = encoded.len() / wrap_col;
144    let remainder = encoded.len() % wrap_col;
145    let total_iov = num_full_lines * 2 + if remainder > 0 { 2 } else { 0 };
146
147    // Small output: build all IoSlices and write in one call
148    if total_iov <= MAX_IOV {
149        let mut iov: Vec<io::IoSlice> = Vec::with_capacity(total_iov);
150        let mut pos = 0;
151        for _ in 0..num_full_lines {
152            iov.push(io::IoSlice::new(&encoded[pos..pos + wrap_col]));
153            iov.push(io::IoSlice::new(&NEWLINE));
154            pos += wrap_col;
155        }
156        if remainder > 0 {
157            iov.push(io::IoSlice::new(&encoded[pos..pos + remainder]));
158            iov.push(io::IoSlice::new(&NEWLINE));
159        }
160        return write_all_vectored(out, &iov);
161    }
162
163    // Large output: write in batches
164    let mut iov: Vec<io::IoSlice> = Vec::with_capacity(MAX_IOV);
165    let mut pos = 0;
166    for _ in 0..num_full_lines {
167        iov.push(io::IoSlice::new(&encoded[pos..pos + wrap_col]));
168        iov.push(io::IoSlice::new(&NEWLINE));
169        pos += wrap_col;
170        if iov.len() >= MAX_IOV {
171            write_all_vectored(out, &iov)?;
172            iov.clear();
173        }
174    }
175    if remainder > 0 {
176        iov.push(io::IoSlice::new(&encoded[pos..pos + remainder]));
177        iov.push(io::IoSlice::new(&NEWLINE));
178    }
179    if !iov.is_empty() {
180        write_all_vectored(out, &iov)?;
181    }
182    Ok(())
183}
184
185/// Write encoded base64 data with line wrapping using writev, tracking column state
186/// across calls. Used by encode_stream for piped input where chunks don't align
187/// to line boundaries.
188#[inline]
189fn write_wrapped_iov_streaming(
190    encoded: &[u8],
191    wrap_col: usize,
192    col: &mut usize,
193    out: &mut impl Write,
194) -> io::Result<()> {
195    const MAX_IOV: usize = 1024;
196    let mut iov: Vec<io::IoSlice> = Vec::with_capacity(MAX_IOV);
197    let mut rp = 0;
198
199    while rp < encoded.len() {
200        let space = wrap_col - *col;
201        let avail = encoded.len() - rp;
202
203        if avail <= space {
204            // Remaining data fits in current line
205            iov.push(io::IoSlice::new(&encoded[rp..rp + avail]));
206            *col += avail;
207            if *col == wrap_col {
208                iov.push(io::IoSlice::new(&NEWLINE));
209                *col = 0;
210            }
211            break;
212        } else {
213            // Fill current line and add newline
214            iov.push(io::IoSlice::new(&encoded[rp..rp + space]));
215            iov.push(io::IoSlice::new(&NEWLINE));
216            rp += space;
217            *col = 0;
218        }
219
220        if iov.len() >= MAX_IOV - 1 {
221            write_all_vectored(out, &iov)?;
222            iov.clear();
223        }
224    }
225
226    if !iov.is_empty() {
227        write_all_vectored(out, &iov)?;
228    }
229    Ok(())
230}
231
232/// Parallel wrapped encoding: split at bytes_per_line boundaries, encode + wrap in parallel.
233/// Requires bytes_per_line % 3 == 0 so each chunk encodes without intermediate padding.
234/// Uses write_vectored (writev) to send all encoded+wrapped chunks in a single syscall.
235fn encode_wrapped_parallel(
236    data: &[u8],
237    wrap_col: usize,
238    bytes_per_line: usize,
239    out: &mut impl Write,
240) -> io::Result<()> {
241    let num_threads = rayon::current_num_threads().max(1);
242    // Split at bytes_per_line boundaries for complete output lines per chunk
243    let lines_per_chunk = (data.len() / bytes_per_line / num_threads).max(1);
244    let chunk_size = lines_per_chunk * bytes_per_line;
245
246    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(bytes_per_line)).collect();
247    let encoded_chunks: Vec<Vec<u8>> = chunks
248        .par_iter()
249        .map(|chunk| {
250            let enc_max = BASE64_ENGINE.encoded_length(chunk.len());
251            let max_lines = enc_max / wrap_col + 2;
252            // Single allocation with two non-overlapping regions:
253            //   [0..fused_size) = fuse_wrap output region
254            //   [fused_size..fused_size+enc_max) = encode region
255            let fused_size = enc_max + max_lines;
256            let total_size = fused_size + enc_max;
257            let mut buf: Vec<u8> = Vec::with_capacity(total_size);
258            #[allow(clippy::uninit_vec)]
259            unsafe {
260                buf.set_len(total_size);
261            }
262            // Encode into the second region [fused_size..fused_size+enc_max]
263            let _ = BASE64_ENGINE.encode(chunk, buf[fused_size..fused_size + enc_max].as_out());
264            // Use split_at_mut to get non-overlapping mutable/immutable refs
265            let (fused_region, encode_region) = buf.split_at_mut(fused_size);
266            let encoded = &encode_region[..enc_max];
267            let wp = fuse_wrap(encoded, wrap_col, fused_region);
268            buf.truncate(wp);
269            buf
270        })
271        .collect();
272
273    // Use write_vectored to send all chunks in a single syscall
274    let iov: Vec<io::IoSlice> = encoded_chunks.iter().map(|c| io::IoSlice::new(c)).collect();
275    write_all_vectored(out, &iov)
276}
277
278/// Fuse encoded base64 data with newlines in a single pass.
279/// Uses ptr::copy_nonoverlapping with 8-line unrolling for max throughput.
280/// Returns number of bytes written.
281#[inline]
282fn fuse_wrap(encoded: &[u8], wrap_col: usize, out_buf: &mut [u8]) -> usize {
283    let line_out = wrap_col + 1; // wrap_col data bytes + 1 newline
284    let mut rp = 0;
285    let mut wp = 0;
286
287    // Unrolled: process 8 lines per iteration for better ILP
288    while rp + 8 * wrap_col <= encoded.len() {
289        unsafe {
290            let src = encoded.as_ptr().add(rp);
291            let dst = out_buf.as_mut_ptr().add(wp);
292
293            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
294            *dst.add(wrap_col) = b'\n';
295
296            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
297            *dst.add(line_out + wrap_col) = b'\n';
298
299            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
300            *dst.add(2 * line_out + wrap_col) = b'\n';
301
302            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
303            *dst.add(3 * line_out + wrap_col) = b'\n';
304
305            std::ptr::copy_nonoverlapping(src.add(4 * wrap_col), dst.add(4 * line_out), wrap_col);
306            *dst.add(4 * line_out + wrap_col) = b'\n';
307
308            std::ptr::copy_nonoverlapping(src.add(5 * wrap_col), dst.add(5 * line_out), wrap_col);
309            *dst.add(5 * line_out + wrap_col) = b'\n';
310
311            std::ptr::copy_nonoverlapping(src.add(6 * wrap_col), dst.add(6 * line_out), wrap_col);
312            *dst.add(6 * line_out + wrap_col) = b'\n';
313
314            std::ptr::copy_nonoverlapping(src.add(7 * wrap_col), dst.add(7 * line_out), wrap_col);
315            *dst.add(7 * line_out + wrap_col) = b'\n';
316        }
317        rp += 8 * wrap_col;
318        wp += 8 * line_out;
319    }
320
321    // Handle remaining 4 lines at a time
322    while rp + 4 * wrap_col <= encoded.len() {
323        unsafe {
324            let src = encoded.as_ptr().add(rp);
325            let dst = out_buf.as_mut_ptr().add(wp);
326
327            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
328            *dst.add(wrap_col) = b'\n';
329
330            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
331            *dst.add(line_out + wrap_col) = b'\n';
332
333            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
334            *dst.add(2 * line_out + wrap_col) = b'\n';
335
336            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
337            *dst.add(3 * line_out + wrap_col) = b'\n';
338        }
339        rp += 4 * wrap_col;
340        wp += 4 * line_out;
341    }
342
343    // Remaining full lines
344    while rp + wrap_col <= encoded.len() {
345        unsafe {
346            std::ptr::copy_nonoverlapping(
347                encoded.as_ptr().add(rp),
348                out_buf.as_mut_ptr().add(wp),
349                wrap_col,
350            );
351            *out_buf.as_mut_ptr().add(wp + wrap_col) = b'\n';
352        }
353        rp += wrap_col;
354        wp += line_out;
355    }
356
357    // Partial last line
358    if rp < encoded.len() {
359        let remaining = encoded.len() - rp;
360        unsafe {
361            std::ptr::copy_nonoverlapping(
362                encoded.as_ptr().add(rp),
363                out_buf.as_mut_ptr().add(wp),
364                remaining,
365            );
366        }
367        wp += remaining;
368        out_buf[wp] = b'\n';
369        wp += 1;
370    }
371
372    wp
373}
374
375/// Fallback for very small wrap columns (< 4 chars).
376fn encode_wrapped_small(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
377    let enc_max = BASE64_ENGINE.encoded_length(data.len());
378    let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
379    #[allow(clippy::uninit_vec)]
380    unsafe {
381        buf.set_len(enc_max);
382    }
383    let encoded = BASE64_ENGINE.encode(data, buf[..enc_max].as_out());
384
385    let wc = wrap_col.max(1);
386    for line in encoded.chunks(wc) {
387        out.write_all(line)?;
388        out.write_all(b"\n")?;
389    }
390    Ok(())
391}
392
393/// Decode base64 data and write to output (borrows data, allocates clean buffer).
394/// When `ignore_garbage` is true, strip all non-base64 characters.
395/// When false, only strip whitespace (standard behavior).
396pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
397    if data.is_empty() {
398        return Ok(());
399    }
400
401    if ignore_garbage {
402        let mut cleaned = strip_non_base64(data);
403        return decode_clean_slice(&mut cleaned, out);
404    }
405
406    // Fast path: single-pass strip + decode
407    decode_stripping_whitespace(data, out)
408}
409
410/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
411pub fn decode_owned(
412    data: &mut Vec<u8>,
413    ignore_garbage: bool,
414    out: &mut impl Write,
415) -> io::Result<()> {
416    if data.is_empty() {
417        return Ok(());
418    }
419
420    if ignore_garbage {
421        data.retain(|&b| is_base64_char(b));
422    } else {
423        strip_whitespace_inplace(data);
424    }
425
426    decode_clean_slice(data, out)
427}
428
429/// Strip all whitespace from a Vec in-place using the lookup table.
430/// Single-pass compaction: uses NOT_WHITESPACE table to classify all whitespace
431/// types simultaneously, avoiding the previous multi-scan approach.
432fn strip_whitespace_inplace(data: &mut Vec<u8>) {
433    // Quick check: any whitespace at all?
434    let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
435    if !has_ws {
436        return;
437    }
438
439    // Single-pass in-place compaction using the lookup table.
440    let ptr = data.as_ptr();
441    let mut_ptr = data.as_mut_ptr();
442    let len = data.len();
443    let mut wp = 0usize;
444
445    for i in 0..len {
446        let b = unsafe { *ptr.add(i) };
447        if NOT_WHITESPACE[b as usize] {
448            unsafe { *mut_ptr.add(wp) = b };
449            wp += 1;
450        }
451    }
452
453    data.truncate(wp);
454}
455
456/// 256-byte lookup table: true for non-whitespace bytes.
457/// Used for single-pass whitespace stripping in decode.
458static NOT_WHITESPACE: [bool; 256] = {
459    let mut table = [true; 256];
460    table[b' ' as usize] = false;
461    table[b'\t' as usize] = false;
462    table[b'\n' as usize] = false;
463    table[b'\r' as usize] = false;
464    table[0x0b] = false; // vertical tab
465    table[0x0c] = false; // form feed
466    table
467};
468
469/// Decode by stripping whitespace and decoding in a single fused pass.
470/// For data with no whitespace, decodes directly without any copy.
471/// Uses memchr2 SIMD gap-copy for \n/\r (the dominant whitespace in base64),
472/// then a conditional fallback pass for rare whitespace types (tab, space, VT, FF).
473/// Tracks rare whitespace presence during the gap-copy to skip the second scan
474/// entirely in the common case (pure \n/\r whitespace only).
475fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
476    // Quick check: any whitespace at all?  Use the lookup table for a single scan.
477    let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
478    if !has_ws {
479        // No whitespace — decode directly from borrowed data
480        return decode_borrowed_clean(out, data);
481    }
482
483    // SIMD gap-copy: use memchr2 to find \n and \r positions, then copy the
484    // gaps between them. For typical base64 (76-char lines), newlines are ~1/77
485    // of the data, so we process ~76 bytes per memchr hit instead of 1 per scalar.
486    let mut clean: Vec<u8> = Vec::with_capacity(data.len());
487    let dst = clean.as_mut_ptr();
488    let mut wp = 0usize;
489    let mut gap_start = 0usize;
490    // Track whether any rare whitespace (tab, space, VT, FF) exists in gap regions.
491    // This avoids the second full-scan pass when only \n/\r are present.
492    let mut has_rare_ws = false;
493
494    for pos in memchr::memchr2_iter(b'\n', b'\r', data) {
495        let gap_len = pos - gap_start;
496        if gap_len > 0 {
497            // Check gap region for rare whitespace during copy.
498            // This adds ~1 branch per gap but eliminates the second full scan.
499            if !has_rare_ws {
500                has_rare_ws = data[gap_start..pos]
501                    .iter()
502                    .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
503            }
504            unsafe {
505                std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), gap_len);
506            }
507            wp += gap_len;
508        }
509        gap_start = pos + 1;
510    }
511    // Copy the final gap after the last \n/\r
512    let tail_len = data.len() - gap_start;
513    if tail_len > 0 {
514        if !has_rare_ws {
515            has_rare_ws = data[gap_start..]
516                .iter()
517                .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
518        }
519        unsafe {
520            std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), tail_len);
521        }
522        wp += tail_len;
523    }
524    unsafe {
525        clean.set_len(wp);
526    }
527
528    // Second pass for rare whitespace (tab, space, VT, FF) — only runs when needed.
529    // In typical base64 streams (76-char lines with \n), this is skipped entirely.
530    if has_rare_ws {
531        let ptr = clean.as_mut_ptr();
532        let len = clean.len();
533        let mut rp = 0;
534        let mut cwp = 0;
535        while rp < len {
536            let b = unsafe { *ptr.add(rp) };
537            if NOT_WHITESPACE[b as usize] {
538                unsafe { *ptr.add(cwp) = b };
539                cwp += 1;
540            }
541            rp += 1;
542        }
543        clean.truncate(cwp);
544    }
545
546    decode_clean_slice(&mut clean, out)
547}
548
549/// Decode a clean (no whitespace) buffer in-place with SIMD.
550fn decode_clean_slice(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
551    if data.is_empty() {
552        return Ok(());
553    }
554    match BASE64_ENGINE.decode_inplace(data) {
555        Ok(decoded) => out.write_all(decoded),
556        Err(_) => decode_error(),
557    }
558}
559
560/// Cold error path — keeps hot decode path tight by moving error construction out of line.
561#[cold]
562#[inline(never)]
563fn decode_error() -> io::Result<()> {
564    Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input"))
565}
566
567/// Decode clean base64 data (no whitespace) from a borrowed slice.
568fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
569    if data.is_empty() {
570        return Ok(());
571    }
572    // Parallel decode for large data: split at 4-byte boundaries,
573    // decode each chunk independently (base64 is context-free per 4-char group).
574    if data.len() >= PARALLEL_DECODE_THRESHOLD {
575        return decode_borrowed_clean_parallel(out, data);
576    }
577    match BASE64_ENGINE.decode_to_vec(data) {
578        Ok(decoded) => {
579            out.write_all(&decoded)?;
580            Ok(())
581        }
582        Err(_) => decode_error(),
583    }
584}
585
586/// Parallel decode: split at 4-byte boundaries, decode chunks in parallel via rayon.
587/// Pre-allocates a single contiguous output buffer with exact decoded offsets computed
588/// upfront, so each thread decodes directly to its final position. No compaction needed.
589fn decode_borrowed_clean_parallel(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
590    let num_threads = rayon::current_num_threads().max(1);
591    let raw_chunk = data.len() / num_threads;
592    // Align to 4 bytes (each 4 base64 chars = 3 decoded bytes, context-free)
593    let chunk_size = ((raw_chunk + 3) / 4) * 4;
594
595    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(4)).collect();
596
597    // Compute exact decoded sizes per chunk upfront to eliminate the compaction pass.
598    // For all chunks except the last, decoded size is exactly chunk.len() * 3 / 4.
599    // For the last chunk, account for '=' padding bytes.
600    let mut offsets: Vec<usize> = Vec::with_capacity(chunks.len() + 1);
601    offsets.push(0);
602    let mut total_decoded = 0usize;
603    for (i, chunk) in chunks.iter().enumerate() {
604        let decoded_size = if i == chunks.len() - 1 {
605            // Last chunk: count '=' padding to get exact decoded size
606            let pad = chunk.iter().rev().take(2).filter(|&&b| b == b'=').count();
607            chunk.len() * 3 / 4 - pad
608        } else {
609            // Non-last chunks: 4-byte aligned, no padding, exact 3/4 ratio
610            chunk.len() * 3 / 4
611        };
612        total_decoded += decoded_size;
613        offsets.push(total_decoded);
614    }
615
616    // Pre-allocate contiguous output buffer with exact total size
617    let mut output_buf: Vec<u8> = Vec::with_capacity(total_decoded);
618    #[allow(clippy::uninit_vec)]
619    unsafe {
620        output_buf.set_len(total_decoded);
621    }
622
623    // Parallel decode: each thread decodes directly into its exact final position.
624    // No compaction pass needed since offsets are computed from exact decoded sizes.
625    // SAFETY: each thread writes to a non-overlapping region of the output buffer.
626    // Use usize representation of the pointer for Send+Sync compatibility with rayon.
627    let out_addr = output_buf.as_mut_ptr() as usize;
628    let decode_result: Result<Vec<()>, io::Error> = chunks
629        .par_iter()
630        .enumerate()
631        .map(|(i, chunk)| {
632            let offset = offsets[i];
633            let expected_size = offsets[i + 1] - offset;
634            // SAFETY: each thread writes to non-overlapping region [offset..offset+expected_size]
635            let out_slice = unsafe {
636                std::slice::from_raw_parts_mut((out_addr as *mut u8).add(offset), expected_size)
637            };
638            let decoded = BASE64_ENGINE
639                .decode(chunk, out_slice.as_out())
640                .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid input"))?;
641            debug_assert_eq!(decoded.len(), expected_size);
642            Ok(())
643        })
644        .collect();
645
646    decode_result?;
647
648    out.write_all(&output_buf[..total_decoded])
649}
650
651/// Strip non-base64 characters (for -i / --ignore-garbage).
652fn strip_non_base64(data: &[u8]) -> Vec<u8> {
653    data.iter()
654        .copied()
655        .filter(|&b| is_base64_char(b))
656        .collect()
657}
658
659/// Check if a byte is a valid base64 alphabet character or padding.
660#[inline]
661fn is_base64_char(b: u8) -> bool {
662    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
663}
664
665/// Stream-encode from a reader to a writer. Used for stdin processing.
666/// Dispatches to specialized paths for wrap_col=0 (no wrap) and wrap_col>0 (wrapping).
667pub fn encode_stream(
668    reader: &mut impl Read,
669    wrap_col: usize,
670    writer: &mut impl Write,
671) -> io::Result<()> {
672    if wrap_col == 0 {
673        return encode_stream_nowrap(reader, writer);
674    }
675    encode_stream_wrapped(reader, wrap_col, writer)
676}
677
678/// Streaming encode with NO line wrapping — optimized fast path.
679/// Read size is 12MB (divisible by 3): encoded output = 12MB * 4/3 = 16MB.
680/// 12MB reads mean 10MB input is consumed in a single read() call,
681/// and the 16MB encoded output writes in 1-2 write() calls.
682fn encode_stream_nowrap(reader: &mut impl Read, writer: &mut impl Write) -> io::Result<()> {
683    // 12MB aligned to 3 bytes: encoded output = 12MB * 4/3 = 16MB.
684    // For 10MB input: 1 read (10MB) instead of 2 reads.
685    const NOWRAP_READ: usize = 12 * 1024 * 1024; // exactly divisible by 3
686
687    let mut buf = vec![0u8; NOWRAP_READ];
688    let encode_buf_size = BASE64_ENGINE.encoded_length(NOWRAP_READ);
689    let mut encode_buf = vec![0u8; encode_buf_size];
690
691    loop {
692        let n = read_full(reader, &mut buf)?;
693        if n == 0 {
694            break;
695        }
696        let enc_len = BASE64_ENGINE.encoded_length(n);
697        let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
698        writer.write_all(encoded)?;
699    }
700    Ok(())
701}
702
703/// Streaming encode WITH line wrapping.
704/// For the common case (wrap_col divides evenly into 3-byte input groups),
705/// uses fuse_wrap to build a contiguous output buffer with newlines interleaved,
706/// then writes it in a single write() call. This eliminates the overhead of
707/// many writev() syscalls (one per ~512 lines via IoSlice).
708///
709/// For non-aligned wrap columns, falls back to the IoSlice/writev approach.
710fn encode_stream_wrapped(
711    reader: &mut impl Read,
712    wrap_col: usize,
713    writer: &mut impl Write,
714) -> io::Result<()> {
715    let bytes_per_line = wrap_col * 3 / 4;
716    // For the common case (76-col wrapping, bytes_per_line=57 which is divisible by 3),
717    // align the read buffer to bytes_per_line boundaries so each chunk produces
718    // complete lines with no column carry-over between chunks.
719    if bytes_per_line > 0 && bytes_per_line.is_multiple_of(3) {
720        return encode_stream_wrapped_fused(reader, wrap_col, bytes_per_line, writer);
721    }
722
723    // Fallback: non-aligned wrap columns use IoSlice/writev with column tracking
724    const STREAM_READ: usize = 12 * 1024 * 1024;
725    let mut buf = vec![0u8; STREAM_READ];
726    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_READ);
727    let mut encode_buf = vec![0u8; encode_buf_size];
728
729    let mut col = 0usize;
730
731    loop {
732        let n = read_full(reader, &mut buf)?;
733        if n == 0 {
734            break;
735        }
736        let enc_len = BASE64_ENGINE.encoded_length(n);
737        let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
738
739        write_wrapped_iov_streaming(encoded, wrap_col, &mut col, writer)?;
740    }
741
742    if col > 0 {
743        writer.write_all(b"\n")?;
744    }
745
746    Ok(())
747}
748
749/// Fused encode+wrap streaming: align reads to bytes_per_line boundaries,
750/// encode chunk, fuse_wrap into contiguous buffer with newlines, single write().
751/// For 76-col wrapping (bytes_per_line=57): 12MB / 57 = ~210K complete lines per chunk.
752/// Output = 210K * 77 bytes = ~16MB, one write() syscall per chunk.
753fn encode_stream_wrapped_fused(
754    reader: &mut impl Read,
755    wrap_col: usize,
756    bytes_per_line: usize,
757    writer: &mut impl Write,
758) -> io::Result<()> {
759    // Align read size to bytes_per_line for complete output lines per chunk.
760    // ~210K lines * 57 bytes = ~12MB input, ~16MB output.
761    let lines_per_chunk = (12 * 1024 * 1024) / bytes_per_line;
762    let read_size = lines_per_chunk * bytes_per_line;
763
764    let mut buf = vec![0u8; read_size];
765    let enc_max = BASE64_ENGINE.encoded_length(read_size);
766    let fused_max = enc_max + (enc_max / wrap_col + 2); // encoded + newlines
767    let total_buf_size = fused_max + enc_max;
768    // Single buffer: [0..fused_max) = fuse_wrap output, [fused_max..total_buf_size) = encode region
769    let mut work_buf: Vec<u8> = vec![0u8; total_buf_size];
770
771    let mut trailing_partial = false;
772
773    loop {
774        let n = read_full(reader, &mut buf)?;
775        if n == 0 {
776            break;
777        }
778
779        let enc_len = BASE64_ENGINE.encoded_length(n);
780        // Encode into the second region
781        let encode_start = fused_max;
782        let _ = BASE64_ENGINE.encode(
783            &buf[..n],
784            work_buf[encode_start..encode_start + enc_len].as_out(),
785        );
786
787        // Fuse wrap: copy encoded data with newlines interleaved into first region
788        let (fused_region, encode_region) = work_buf.split_at_mut(fused_max);
789        let encoded = &encode_region[..enc_len];
790        let wp = fuse_wrap(encoded, wrap_col, fused_region);
791
792        writer.write_all(&fused_region[..wp])?;
793        trailing_partial = !enc_len.is_multiple_of(wrap_col);
794    }
795
796    // fuse_wrap already adds a trailing newline for partial last lines,
797    // so we don't need to add one here. But if we never wrote any data
798    // at all or the last encoded chunk was exactly aligned, we're fine.
799    let _ = trailing_partial;
800
801    Ok(())
802}
803
804/// Stream-decode from a reader to a writer. Used for stdin processing.
805/// Fused single-pass: read chunk -> strip whitespace -> decode immediately.
806/// Uses 16MB read buffer for maximum pipe throughput — read_full retries to
807/// fill the entire buffer from the pipe, and 16MB means the entire 10MB
808/// benchmark input is read in a single syscall batch, minimizing overhead.
809/// memchr2-based SIMD whitespace stripping handles the common case efficiently.
810pub fn decode_stream(
811    reader: &mut impl Read,
812    ignore_garbage: bool,
813    writer: &mut impl Write,
814) -> io::Result<()> {
815    const READ_CHUNK: usize = 16 * 1024 * 1024;
816    let mut buf = vec![0u8; READ_CHUNK];
817    // Pre-allocate clean buffer once and reuse across iterations.
818    // Use Vec with set_len for zero-overhead reset instead of clear() + extend().
819    let mut clean: Vec<u8> = Vec::with_capacity(READ_CHUNK + 4);
820    let mut carry = [0u8; 4];
821    let mut carry_len = 0usize;
822
823    loop {
824        let n = read_full(reader, &mut buf)?;
825        if n == 0 {
826            break;
827        }
828
829        // Copy carry bytes to start of clean buffer (0-3 bytes from previous chunk)
830        unsafe {
831            std::ptr::copy_nonoverlapping(carry.as_ptr(), clean.as_mut_ptr(), carry_len);
832        }
833
834        let chunk = &buf[..n];
835        if ignore_garbage {
836            // Scalar filter for ignore_garbage mode (rare path)
837            let dst = unsafe { clean.as_mut_ptr().add(carry_len) };
838            let mut wp = 0usize;
839            for &b in chunk {
840                if is_base64_char(b) {
841                    unsafe { *dst.add(wp) = b };
842                    wp += 1;
843                }
844            }
845            unsafe { clean.set_len(carry_len + wp) };
846        } else {
847            // SIMD gap-copy: use memchr2 to find \n and \r positions, then copy
848            // the gaps between them. For typical base64 (76-char lines), newlines
849            // are ~1/77 of the data, so we process ~76 bytes per memchr hit
850            // instead of 1 byte per scalar iteration.
851            // Track rare whitespace during gap-copy to skip the second scan.
852            let dst = unsafe { clean.as_mut_ptr().add(carry_len) };
853            let mut wp = 0usize;
854            let mut gap_start = 0usize;
855            let mut has_rare_ws = false;
856
857            for pos in memchr::memchr2_iter(b'\n', b'\r', chunk) {
858                let gap_len = pos - gap_start;
859                if gap_len > 0 {
860                    if !has_rare_ws {
861                        has_rare_ws = chunk[gap_start..pos]
862                            .iter()
863                            .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
864                    }
865                    unsafe {
866                        std::ptr::copy_nonoverlapping(
867                            chunk.as_ptr().add(gap_start),
868                            dst.add(wp),
869                            gap_len,
870                        );
871                    }
872                    wp += gap_len;
873                }
874                gap_start = pos + 1;
875            }
876            let tail_len = n - gap_start;
877            if tail_len > 0 {
878                if !has_rare_ws {
879                    has_rare_ws = chunk[gap_start..n]
880                        .iter()
881                        .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
882                }
883                unsafe {
884                    std::ptr::copy_nonoverlapping(
885                        chunk.as_ptr().add(gap_start),
886                        dst.add(wp),
887                        tail_len,
888                    );
889                }
890                wp += tail_len;
891            }
892            let total_clean = carry_len + wp;
893            unsafe { clean.set_len(total_clean) };
894
895            // Second pass for rare whitespace (tab, space, VT, FF) — only when detected.
896            // In typical base64 streams (76-char lines with \n), this is skipped entirely.
897            if has_rare_ws {
898                let ptr = clean.as_mut_ptr();
899                let mut rp = carry_len;
900                let mut cwp = carry_len;
901                while rp < total_clean {
902                    let b = unsafe { *ptr.add(rp) };
903                    if NOT_WHITESPACE[b as usize] {
904                        unsafe { *ptr.add(cwp) = b };
905                        cwp += 1;
906                    }
907                    rp += 1;
908                }
909                clean.truncate(cwp);
910            }
911        }
912
913        carry_len = 0;
914        let is_last = n < READ_CHUNK;
915
916        if is_last {
917            // Last chunk: decode everything (including padding)
918            decode_clean_slice(&mut clean, writer)?;
919        } else {
920            // Save incomplete base64 quadruplet for next iteration
921            let clean_len = clean.len();
922            let decode_len = (clean_len / 4) * 4;
923            let leftover = clean_len - decode_len;
924            if leftover > 0 {
925                unsafe {
926                    std::ptr::copy_nonoverlapping(
927                        clean.as_ptr().add(decode_len),
928                        carry.as_mut_ptr(),
929                        leftover,
930                    );
931                }
932                carry_len = leftover;
933            }
934            if decode_len > 0 {
935                clean.truncate(decode_len);
936                decode_clean_slice(&mut clean, writer)?;
937            }
938        }
939    }
940
941    // Handle any remaining carry-over bytes
942    if carry_len > 0 {
943        let mut carry_buf = carry[..carry_len].to_vec();
944        decode_clean_slice(&mut carry_buf, writer)?;
945    }
946
947    Ok(())
948}
949
950/// Write all IoSlice entries using write_vectored (writev syscall).
951/// Falls back to write_all per slice on partial writes.
952fn write_all_vectored(out: &mut impl Write, slices: &[io::IoSlice]) -> io::Result<()> {
953    if slices.is_empty() {
954        return Ok(());
955    }
956    let total: usize = slices.iter().map(|s| s.len()).sum();
957
958    // Try write_vectored first — often writes everything in one syscall
959    let written = match out.write_vectored(slices) {
960        Ok(n) if n >= total => return Ok(()),
961        Ok(n) => n,
962        Err(e) => return Err(e),
963    };
964
965    // Partial write fallback
966    let mut skip = written;
967    for slice in slices {
968        let slen = slice.len();
969        if skip >= slen {
970            skip -= slen;
971            continue;
972        }
973        if skip > 0 {
974            out.write_all(&slice[skip..])?;
975            skip = 0;
976        } else {
977            out.write_all(slice)?;
978        }
979    }
980    Ok(())
981}
982
983/// Read as many bytes as possible into buf, retrying on partial reads.
984/// Fast path: regular file reads usually return the full buffer on the first call,
985/// avoiding the loop overhead entirely.
986#[inline]
987fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
988    // Fast path: first read() usually fills the entire buffer for regular files
989    let n = reader.read(buf)?;
990    if n == buf.len() || n == 0 {
991        return Ok(n);
992    }
993    // Slow path: partial read — retry to fill buffer (pipes, slow devices)
994    let mut total = n;
995    while total < buf.len() {
996        match reader.read(&mut buf[total..]) {
997            Ok(0) => break,
998            Ok(n) => total += n,
999            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
1000            Err(e) => return Err(e),
1001        }
1002    }
1003    Ok(total)
1004}