Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8/// Chunk size for no-wrap encoding: 32MB aligned to 3 bytes.
9/// Larger chunks = fewer write() syscalls for big files.
10const NOWRAP_CHUNK: usize = 32 * 1024 * 1024 - (32 * 1024 * 1024 % 3);
11
12/// Minimum data size for parallel encoding (1MB).
13/// Lowered from 4MB so 10MB benchmark workloads get multi-core processing.
14const PARALLEL_ENCODE_THRESHOLD: usize = 1024 * 1024;
15
16/// Minimum data size for parallel decoding (1MB of base64 data).
17/// Lowered from 4MB for better parallelism on typical workloads.
18const PARALLEL_DECODE_THRESHOLD: usize = 1024 * 1024;
19
20/// Encode data and write to output with line wrapping.
21/// Uses SIMD encoding with fused encode+wrap for maximum throughput.
22pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
23    if data.is_empty() {
24        return Ok(());
25    }
26
27    if wrap_col == 0 {
28        return encode_no_wrap(data, out);
29    }
30
31    encode_wrapped(data, wrap_col, out)
32}
33
34/// Encode without wrapping — parallel SIMD encoding for large data, sequential for small.
35fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
36    if data.len() >= PARALLEL_ENCODE_THRESHOLD {
37        return encode_no_wrap_parallel(data, out);
38    }
39
40    let actual_chunk = NOWRAP_CHUNK.min(data.len());
41    let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
42    // SAFETY: encode() writes exactly enc_len bytes before we read them.
43    let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
44    #[allow(clippy::uninit_vec)]
45    unsafe {
46        buf.set_len(enc_max);
47    }
48
49    for chunk in data.chunks(NOWRAP_CHUNK) {
50        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
51        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
52        out.write_all(encoded)?;
53    }
54    Ok(())
55}
56
57/// Parallel no-wrap encoding: split at 3-byte boundaries, encode chunks in parallel.
58/// Each chunk except possibly the last is 3-byte aligned, so no padding in intermediate chunks.
59/// Uses write_vectored (writev) to send all encoded chunks in a single syscall.
60fn encode_no_wrap_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
61    let num_threads = rayon::current_num_threads().max(1);
62    let raw_chunk = data.len() / num_threads;
63    // Align to 3 bytes so each chunk encodes without padding (except the last)
64    let chunk_size = ((raw_chunk + 2) / 3) * 3;
65
66    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(3)).collect();
67    let encoded_chunks: Vec<Vec<u8>> = chunks
68        .par_iter()
69        .map(|chunk| {
70            let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
71            let mut buf: Vec<u8> = Vec::with_capacity(enc_len);
72            #[allow(clippy::uninit_vec)]
73            unsafe {
74                buf.set_len(enc_len);
75            }
76            let _ = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
77            buf
78        })
79        .collect();
80
81    // Use write_vectored to send all chunks in a single syscall
82    let iov: Vec<io::IoSlice> = encoded_chunks.iter().map(|c| io::IoSlice::new(c)).collect();
83    write_all_vectored(out, &iov)
84}
85
86/// Encode with line wrapping — uses writev to interleave encoded segments
87/// with newlines without copying data. For each wrap_col-sized segment of
88/// encoded output, we create an IoSlice pointing directly at the encode buffer,
89/// interleaved with IoSlice entries pointing at a static newline byte.
90fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
91    // Calculate bytes_per_line: input bytes that produce exactly wrap_col encoded chars.
92    // For default wrap_col=76: 76*3/4 = 57 bytes per line.
93    let bytes_per_line = wrap_col * 3 / 4;
94    if bytes_per_line == 0 {
95        // Degenerate case: wrap_col < 4, fall back to byte-at-a-time
96        return encode_wrapped_small(data, wrap_col, out);
97    }
98
99    // Parallel encoding for large data when bytes_per_line is a multiple of 3.
100    // This guarantees each chunk encodes to complete base64 without padding.
101    if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line.is_multiple_of(3) {
102        return encode_wrapped_parallel(data, wrap_col, bytes_per_line, out);
103    }
104
105    // Align input chunk to bytes_per_line for complete output lines.
106    let lines_per_chunk = (32 * 1024 * 1024) / bytes_per_line;
107    let max_input_chunk = (lines_per_chunk * bytes_per_line).max(bytes_per_line);
108    let input_chunk = max_input_chunk.min(data.len());
109
110    let enc_max = BASE64_ENGINE.encoded_length(input_chunk);
111    let mut encode_buf: Vec<u8> = Vec::with_capacity(enc_max);
112    #[allow(clippy::uninit_vec)]
113    unsafe {
114        encode_buf.set_len(enc_max);
115    }
116
117    for chunk in data.chunks(max_input_chunk.max(1)) {
118        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
119        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
120
121        // Use writev: build IoSlice entries pointing at wrap_col-sized segments
122        // of the encoded buffer interleaved with newline IoSlices.
123        // This eliminates the fused_buf copy entirely.
124        write_wrapped_iov(encoded, wrap_col, out)?;
125    }
126
127    Ok(())
128}
129
130/// Static newline byte for IoSlice references in writev calls.
131static NEWLINE: [u8; 1] = [b'\n'];
132
133/// Write encoded base64 data with line wrapping using write_vectored (writev).
134/// Builds IoSlice entries pointing at wrap_col-sized segments of the encoded buffer,
135/// interleaved with newline IoSlices, then writes in batches of MAX_WRITEV_IOV.
136/// This is zero-copy: no fused output buffer needed.
137#[inline]
138fn write_wrapped_iov(encoded: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
139    // Max IoSlice entries per writev batch. Linux UIO_MAXIOV is 1024.
140    // Each line needs 2 entries (data + newline), so 512 lines per batch.
141    const MAX_IOV: usize = 1024;
142
143    let num_full_lines = encoded.len() / wrap_col;
144    let remainder = encoded.len() % wrap_col;
145    let total_iov = num_full_lines * 2 + if remainder > 0 { 2 } else { 0 };
146
147    // Small output: build all IoSlices and write in one call
148    if total_iov <= MAX_IOV {
149        let mut iov: Vec<io::IoSlice> = Vec::with_capacity(total_iov);
150        let mut pos = 0;
151        for _ in 0..num_full_lines {
152            iov.push(io::IoSlice::new(&encoded[pos..pos + wrap_col]));
153            iov.push(io::IoSlice::new(&NEWLINE));
154            pos += wrap_col;
155        }
156        if remainder > 0 {
157            iov.push(io::IoSlice::new(&encoded[pos..pos + remainder]));
158            iov.push(io::IoSlice::new(&NEWLINE));
159        }
160        return write_all_vectored(out, &iov);
161    }
162
163    // Large output: write in batches
164    let mut iov: Vec<io::IoSlice> = Vec::with_capacity(MAX_IOV);
165    let mut pos = 0;
166    for _ in 0..num_full_lines {
167        iov.push(io::IoSlice::new(&encoded[pos..pos + wrap_col]));
168        iov.push(io::IoSlice::new(&NEWLINE));
169        pos += wrap_col;
170        if iov.len() >= MAX_IOV {
171            write_all_vectored(out, &iov)?;
172            iov.clear();
173        }
174    }
175    if remainder > 0 {
176        iov.push(io::IoSlice::new(&encoded[pos..pos + remainder]));
177        iov.push(io::IoSlice::new(&NEWLINE));
178    }
179    if !iov.is_empty() {
180        write_all_vectored(out, &iov)?;
181    }
182    Ok(())
183}
184
185/// Write encoded base64 data with line wrapping using writev, tracking column state
186/// across calls. Used by encode_stream for piped input where chunks don't align
187/// to line boundaries.
188#[inline]
189fn write_wrapped_iov_streaming(
190    encoded: &[u8],
191    wrap_col: usize,
192    col: &mut usize,
193    out: &mut impl Write,
194) -> io::Result<()> {
195    const MAX_IOV: usize = 1024;
196    let mut iov: Vec<io::IoSlice> = Vec::with_capacity(MAX_IOV);
197    let mut rp = 0;
198
199    while rp < encoded.len() {
200        let space = wrap_col - *col;
201        let avail = encoded.len() - rp;
202
203        if avail <= space {
204            // Remaining data fits in current line
205            iov.push(io::IoSlice::new(&encoded[rp..rp + avail]));
206            *col += avail;
207            if *col == wrap_col {
208                iov.push(io::IoSlice::new(&NEWLINE));
209                *col = 0;
210            }
211            break;
212        } else {
213            // Fill current line and add newline
214            iov.push(io::IoSlice::new(&encoded[rp..rp + space]));
215            iov.push(io::IoSlice::new(&NEWLINE));
216            rp += space;
217            *col = 0;
218        }
219
220        if iov.len() >= MAX_IOV - 1 {
221            write_all_vectored(out, &iov)?;
222            iov.clear();
223        }
224    }
225
226    if !iov.is_empty() {
227        write_all_vectored(out, &iov)?;
228    }
229    Ok(())
230}
231
232/// Parallel wrapped encoding: single output buffer, direct-to-position encode+wrap.
233/// Requires bytes_per_line % 3 == 0 so each chunk encodes without intermediate padding.
234///
235/// Pre-calculates exact output size and each thread's write offset, then encodes
236/// 57-byte input groups directly to their final position in a shared output buffer.
237/// Each thread writes wrap_col encoded bytes + newline per line, so output for line N
238/// starts at N * (wrap_col + 1). This eliminates per-chunk heap allocations and
239/// the fuse_wrap copy pass entirely.
240fn encode_wrapped_parallel(
241    data: &[u8],
242    wrap_col: usize,
243    bytes_per_line: usize,
244    out: &mut impl Write,
245) -> io::Result<()> {
246    let line_out = wrap_col + 1; // wrap_col data + 1 newline per line
247    let total_full_lines = data.len() / bytes_per_line;
248    let remainder_input = data.len() % bytes_per_line;
249
250    // Calculate exact output size
251    let remainder_encoded = if remainder_input > 0 {
252        BASE64_ENGINE.encoded_length(remainder_input) + 1 // +1 for trailing newline
253    } else {
254        0
255    };
256    let total_output = total_full_lines * line_out + remainder_encoded;
257
258    // Pre-allocate single contiguous output buffer
259    let mut outbuf: Vec<u8> = Vec::with_capacity(total_output);
260    #[allow(clippy::uninit_vec)]
261    unsafe {
262        outbuf.set_len(total_output);
263    }
264
265    // Split work at line boundaries for parallel processing
266    let num_threads = rayon::current_num_threads().max(1);
267    let lines_per_chunk = (total_full_lines / num_threads).max(1);
268    let input_chunk = lines_per_chunk * bytes_per_line;
269
270    // Compute per-chunk metadata: (input_offset, output_offset, num_input_bytes)
271    let mut tasks: Vec<(usize, usize, usize)> = Vec::new();
272    let mut in_off = 0usize;
273    let mut out_off = 0usize;
274    while in_off < data.len() {
275        let chunk_input = input_chunk.min(data.len() - in_off);
276        // Align to bytes_per_line except for the very last chunk
277        let aligned_input = if in_off + chunk_input < data.len() {
278            (chunk_input / bytes_per_line) * bytes_per_line
279        } else {
280            chunk_input
281        };
282        if aligned_input == 0 {
283            break;
284        }
285        let full_lines = aligned_input / bytes_per_line;
286        let rem = aligned_input % bytes_per_line;
287        let chunk_output = full_lines * line_out
288            + if rem > 0 {
289                BASE64_ENGINE.encoded_length(rem) + 1
290            } else {
291                0
292            };
293        tasks.push((in_off, out_off, aligned_input));
294        in_off += aligned_input;
295        out_off += chunk_output;
296    }
297
298    // Parallel encode: each thread batch-encodes all its input at once, then
299    // scatters the contiguous encoded output to line-separated positions.
300    // This does 1 SIMD encode call per thread (vs N calls per line), trading
301    // one thread-local encode buffer for dramatically fewer function calls.
302    // SAFETY: tasks have non-overlapping output regions.
303    let out_addr = outbuf.as_mut_ptr() as usize;
304
305    tasks.par_iter().for_each(|&(in_off, out_off, chunk_len)| {
306        let input = &data[in_off..in_off + chunk_len];
307        let full_lines = chunk_len / bytes_per_line;
308        let rem = chunk_len % bytes_per_line;
309        let full_input = full_lines * bytes_per_line;
310
311        let out_ptr = out_addr as *mut u8;
312
313        // Batch encode all full lines at once into a thread-local buffer
314        if full_lines > 0 {
315            let enc_total = BASE64_ENGINE.encoded_length(full_input);
316            let mut enc_buf: Vec<u8> = Vec::with_capacity(enc_total);
317            #[allow(clippy::uninit_vec)]
318            unsafe {
319                enc_buf.set_len(enc_total);
320            }
321            let _ = BASE64_ENGINE.encode(&input[..full_input], enc_buf[..enc_total].as_out());
322
323            // Scatter: copy wrap_col bytes per line + insert newlines
324            // Uses fuse_wrap-style unrolled copy for throughput.
325            let src = enc_buf.as_ptr();
326            let dst = unsafe { out_ptr.add(out_off) };
327            let mut rp = 0;
328            let mut wp = 0;
329
330            // 8-line unrolled scatter loop
331            while rp + 8 * wrap_col <= enc_total {
332                unsafe {
333                    std::ptr::copy_nonoverlapping(src.add(rp), dst.add(wp), wrap_col);
334                    *dst.add(wp + wrap_col) = b'\n';
335                    std::ptr::copy_nonoverlapping(
336                        src.add(rp + wrap_col),
337                        dst.add(wp + line_out),
338                        wrap_col,
339                    );
340                    *dst.add(wp + line_out + wrap_col) = b'\n';
341                    std::ptr::copy_nonoverlapping(
342                        src.add(rp + 2 * wrap_col),
343                        dst.add(wp + 2 * line_out),
344                        wrap_col,
345                    );
346                    *dst.add(wp + 2 * line_out + wrap_col) = b'\n';
347                    std::ptr::copy_nonoverlapping(
348                        src.add(rp + 3 * wrap_col),
349                        dst.add(wp + 3 * line_out),
350                        wrap_col,
351                    );
352                    *dst.add(wp + 3 * line_out + wrap_col) = b'\n';
353                    std::ptr::copy_nonoverlapping(
354                        src.add(rp + 4 * wrap_col),
355                        dst.add(wp + 4 * line_out),
356                        wrap_col,
357                    );
358                    *dst.add(wp + 4 * line_out + wrap_col) = b'\n';
359                    std::ptr::copy_nonoverlapping(
360                        src.add(rp + 5 * wrap_col),
361                        dst.add(wp + 5 * line_out),
362                        wrap_col,
363                    );
364                    *dst.add(wp + 5 * line_out + wrap_col) = b'\n';
365                    std::ptr::copy_nonoverlapping(
366                        src.add(rp + 6 * wrap_col),
367                        dst.add(wp + 6 * line_out),
368                        wrap_col,
369                    );
370                    *dst.add(wp + 6 * line_out + wrap_col) = b'\n';
371                    std::ptr::copy_nonoverlapping(
372                        src.add(rp + 7 * wrap_col),
373                        dst.add(wp + 7 * line_out),
374                        wrap_col,
375                    );
376                    *dst.add(wp + 7 * line_out + wrap_col) = b'\n';
377                }
378                rp += 8 * wrap_col;
379                wp += 8 * line_out;
380            }
381
382            // Remaining lines one at a time
383            while rp + wrap_col <= enc_total {
384                unsafe {
385                    std::ptr::copy_nonoverlapping(src.add(rp), dst.add(wp), wrap_col);
386                    *dst.add(wp + wrap_col) = b'\n';
387                }
388                rp += wrap_col;
389                wp += line_out;
390            }
391        }
392
393        // Handle remainder (last partial line of this chunk)
394        if rem > 0 {
395            let line_input = &input[full_input..];
396            let enc_len = BASE64_ENGINE.encoded_length(rem);
397            let woff = out_off + full_lines * line_out;
398            // Encode directly into final output position
399            let out_slice =
400                unsafe { std::slice::from_raw_parts_mut(out_ptr.add(woff), enc_len + 1) };
401            let _ = BASE64_ENGINE.encode(line_input, out_slice[..enc_len].as_out());
402            out_slice[enc_len] = b'\n';
403        }
404    });
405
406    out.write_all(&outbuf[..total_output])
407}
408
409/// Fuse encoded base64 data with newlines in a single pass.
410/// Uses ptr::copy_nonoverlapping with 8-line unrolling for max throughput.
411/// Returns number of bytes written.
412#[inline]
413fn fuse_wrap(encoded: &[u8], wrap_col: usize, out_buf: &mut [u8]) -> usize {
414    let line_out = wrap_col + 1; // wrap_col data bytes + 1 newline
415    let mut rp = 0;
416    let mut wp = 0;
417
418    // Unrolled: process 8 lines per iteration for better ILP
419    while rp + 8 * wrap_col <= encoded.len() {
420        unsafe {
421            let src = encoded.as_ptr().add(rp);
422            let dst = out_buf.as_mut_ptr().add(wp);
423
424            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
425            *dst.add(wrap_col) = b'\n';
426
427            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
428            *dst.add(line_out + wrap_col) = b'\n';
429
430            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
431            *dst.add(2 * line_out + wrap_col) = b'\n';
432
433            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
434            *dst.add(3 * line_out + wrap_col) = b'\n';
435
436            std::ptr::copy_nonoverlapping(src.add(4 * wrap_col), dst.add(4 * line_out), wrap_col);
437            *dst.add(4 * line_out + wrap_col) = b'\n';
438
439            std::ptr::copy_nonoverlapping(src.add(5 * wrap_col), dst.add(5 * line_out), wrap_col);
440            *dst.add(5 * line_out + wrap_col) = b'\n';
441
442            std::ptr::copy_nonoverlapping(src.add(6 * wrap_col), dst.add(6 * line_out), wrap_col);
443            *dst.add(6 * line_out + wrap_col) = b'\n';
444
445            std::ptr::copy_nonoverlapping(src.add(7 * wrap_col), dst.add(7 * line_out), wrap_col);
446            *dst.add(7 * line_out + wrap_col) = b'\n';
447        }
448        rp += 8 * wrap_col;
449        wp += 8 * line_out;
450    }
451
452    // Handle remaining 4 lines at a time
453    while rp + 4 * wrap_col <= encoded.len() {
454        unsafe {
455            let src = encoded.as_ptr().add(rp);
456            let dst = out_buf.as_mut_ptr().add(wp);
457
458            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
459            *dst.add(wrap_col) = b'\n';
460
461            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
462            *dst.add(line_out + wrap_col) = b'\n';
463
464            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
465            *dst.add(2 * line_out + wrap_col) = b'\n';
466
467            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
468            *dst.add(3 * line_out + wrap_col) = b'\n';
469        }
470        rp += 4 * wrap_col;
471        wp += 4 * line_out;
472    }
473
474    // Remaining full lines
475    while rp + wrap_col <= encoded.len() {
476        unsafe {
477            std::ptr::copy_nonoverlapping(
478                encoded.as_ptr().add(rp),
479                out_buf.as_mut_ptr().add(wp),
480                wrap_col,
481            );
482            *out_buf.as_mut_ptr().add(wp + wrap_col) = b'\n';
483        }
484        rp += wrap_col;
485        wp += line_out;
486    }
487
488    // Partial last line
489    if rp < encoded.len() {
490        let remaining = encoded.len() - rp;
491        unsafe {
492            std::ptr::copy_nonoverlapping(
493                encoded.as_ptr().add(rp),
494                out_buf.as_mut_ptr().add(wp),
495                remaining,
496            );
497        }
498        wp += remaining;
499        out_buf[wp] = b'\n';
500        wp += 1;
501    }
502
503    wp
504}
505
506/// Fallback for very small wrap columns (< 4 chars).
507fn encode_wrapped_small(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
508    let enc_max = BASE64_ENGINE.encoded_length(data.len());
509    let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
510    #[allow(clippy::uninit_vec)]
511    unsafe {
512        buf.set_len(enc_max);
513    }
514    let encoded = BASE64_ENGINE.encode(data, buf[..enc_max].as_out());
515
516    let wc = wrap_col.max(1);
517    for line in encoded.chunks(wc) {
518        out.write_all(line)?;
519        out.write_all(b"\n")?;
520    }
521    Ok(())
522}
523
524/// Decode base64 data and write to output (borrows data, allocates clean buffer).
525/// When `ignore_garbage` is true, strip all non-base64 characters.
526/// When false, only strip whitespace (standard behavior).
527pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
528    if data.is_empty() {
529        return Ok(());
530    }
531
532    if ignore_garbage {
533        let mut cleaned = strip_non_base64(data);
534        return decode_clean_slice(&mut cleaned, out);
535    }
536
537    // Fast path: single-pass strip + decode
538    decode_stripping_whitespace(data, out)
539}
540
541/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
542pub fn decode_owned(
543    data: &mut Vec<u8>,
544    ignore_garbage: bool,
545    out: &mut impl Write,
546) -> io::Result<()> {
547    if data.is_empty() {
548        return Ok(());
549    }
550
551    if ignore_garbage {
552        data.retain(|&b| is_base64_char(b));
553    } else {
554        strip_whitespace_inplace(data);
555    }
556
557    decode_clean_slice(data, out)
558}
559
560/// Strip all whitespace from a Vec in-place using SIMD memchr2 gap-copy.
561/// For typical base64 (76-char lines with \n), newlines are ~1/77 of the data,
562/// so SIMD memchr2 skips ~76 bytes per hit instead of checking every byte.
563/// Falls back to scalar compaction only for rare whitespace (tab, space, VT, FF).
564fn strip_whitespace_inplace(data: &mut Vec<u8>) {
565    // Quick check: any whitespace at all?
566    let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
567    if !has_ws {
568        return;
569    }
570
571    // SIMD gap-copy: find \n and \r positions with memchr2, then memmove the
572    // gaps between them to compact the data in-place. For typical base64 streams,
573    // newlines are the only whitespace, so this handles >99% of cases.
574    let ptr = data.as_mut_ptr();
575    let len = data.len();
576    let mut wp = 0usize;
577    let mut gap_start = 0usize;
578    let mut has_rare_ws = false;
579
580    for pos in memchr::memchr2_iter(b'\n', b'\r', data.as_slice()) {
581        let gap_len = pos - gap_start;
582        if gap_len > 0 {
583            if !has_rare_ws {
584                // Check for rare whitespace during copy (amortized ~1 branch per 77 bytes)
585                has_rare_ws = data[gap_start..pos]
586                    .iter()
587                    .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
588            }
589            if wp != gap_start {
590                unsafe {
591                    std::ptr::copy(ptr.add(gap_start), ptr.add(wp), gap_len);
592                }
593            }
594            wp += gap_len;
595        }
596        gap_start = pos + 1;
597    }
598    // Copy the final gap
599    let tail_len = len - gap_start;
600    if tail_len > 0 {
601        if !has_rare_ws {
602            has_rare_ws = data[gap_start..]
603                .iter()
604                .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
605        }
606        if wp != gap_start {
607            unsafe {
608                std::ptr::copy(ptr.add(gap_start), ptr.add(wp), tail_len);
609            }
610        }
611        wp += tail_len;
612    }
613
614    data.truncate(wp);
615
616    // Second pass for rare whitespace (tab, space, VT, FF) — only if detected.
617    // In typical base64 streams (76-char lines with \n), this is skipped entirely.
618    if has_rare_ws {
619        let ptr = data.as_mut_ptr();
620        let len = data.len();
621        let mut rp = 0;
622        let mut cwp = 0;
623        while rp < len {
624            let b = unsafe { *ptr.add(rp) };
625            if NOT_WHITESPACE[b as usize] {
626                unsafe { *ptr.add(cwp) = b };
627                cwp += 1;
628            }
629            rp += 1;
630        }
631        data.truncate(cwp);
632    }
633}
634
635/// 256-byte lookup table: true for non-whitespace bytes.
636/// Used for single-pass whitespace stripping in decode.
637static NOT_WHITESPACE: [bool; 256] = {
638    let mut table = [true; 256];
639    table[b' ' as usize] = false;
640    table[b'\t' as usize] = false;
641    table[b'\n' as usize] = false;
642    table[b'\r' as usize] = false;
643    table[0x0b] = false; // vertical tab
644    table[0x0c] = false; // form feed
645    table
646};
647
648/// Decode by stripping whitespace and decoding in a single fused pass.
649/// For data with no whitespace, decodes directly without any copy.
650/// Uses memchr2 SIMD gap-copy for \n/\r (the dominant whitespace in base64),
651/// then a conditional fallback pass for rare whitespace types (tab, space, VT, FF).
652/// Tracks rare whitespace presence during the gap-copy to skip the second scan
653/// entirely in the common case (pure \n/\r whitespace only).
654fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
655    // Quick check: any whitespace at all?  Use the lookup table for a single scan.
656    let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
657    if !has_ws {
658        // No whitespace — decode directly from borrowed data
659        return decode_borrowed_clean(out, data);
660    }
661
662    // SIMD gap-copy: use memchr2 to find \n and \r positions, then copy the
663    // gaps between them. For typical base64 (76-char lines), newlines are ~1/77
664    // of the data, so we process ~76 bytes per memchr hit instead of 1 per scalar.
665    let mut clean: Vec<u8> = Vec::with_capacity(data.len());
666    let dst = clean.as_mut_ptr();
667    let mut wp = 0usize;
668    let mut gap_start = 0usize;
669    // Track whether any rare whitespace (tab, space, VT, FF) exists in gap regions.
670    // This avoids the second full-scan pass when only \n/\r are present.
671    let mut has_rare_ws = false;
672
673    for pos in memchr::memchr2_iter(b'\n', b'\r', data) {
674        let gap_len = pos - gap_start;
675        if gap_len > 0 {
676            // Check gap region for rare whitespace during copy.
677            // This adds ~1 branch per gap but eliminates the second full scan.
678            if !has_rare_ws {
679                has_rare_ws = data[gap_start..pos]
680                    .iter()
681                    .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
682            }
683            unsafe {
684                std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), gap_len);
685            }
686            wp += gap_len;
687        }
688        gap_start = pos + 1;
689    }
690    // Copy the final gap after the last \n/\r
691    let tail_len = data.len() - gap_start;
692    if tail_len > 0 {
693        if !has_rare_ws {
694            has_rare_ws = data[gap_start..]
695                .iter()
696                .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
697        }
698        unsafe {
699            std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), tail_len);
700        }
701        wp += tail_len;
702    }
703    unsafe {
704        clean.set_len(wp);
705    }
706
707    // Second pass for rare whitespace (tab, space, VT, FF) — only runs when needed.
708    // In typical base64 streams (76-char lines with \n), this is skipped entirely.
709    if has_rare_ws {
710        let ptr = clean.as_mut_ptr();
711        let len = clean.len();
712        let mut rp = 0;
713        let mut cwp = 0;
714        while rp < len {
715            let b = unsafe { *ptr.add(rp) };
716            if NOT_WHITESPACE[b as usize] {
717                unsafe { *ptr.add(cwp) = b };
718                cwp += 1;
719            }
720            rp += 1;
721        }
722        clean.truncate(cwp);
723    }
724
725    decode_clean_slice(&mut clean, out)
726}
727
728/// Decode a clean (no whitespace) buffer in-place with SIMD.
729fn decode_clean_slice(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
730    if data.is_empty() {
731        return Ok(());
732    }
733    match BASE64_ENGINE.decode_inplace(data) {
734        Ok(decoded) => out.write_all(decoded),
735        Err(_) => decode_error(),
736    }
737}
738
739/// Cold error path — keeps hot decode path tight by moving error construction out of line.
740#[cold]
741#[inline(never)]
742fn decode_error() -> io::Result<()> {
743    Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input"))
744}
745
746/// Decode clean base64 data (no whitespace) from a borrowed slice.
747fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
748    if data.is_empty() {
749        return Ok(());
750    }
751    // Parallel decode for large data: split at 4-byte boundaries,
752    // decode each chunk independently (base64 is context-free per 4-char group).
753    if data.len() >= PARALLEL_DECODE_THRESHOLD {
754        return decode_borrowed_clean_parallel(out, data);
755    }
756    match BASE64_ENGINE.decode_to_vec(data) {
757        Ok(decoded) => {
758            out.write_all(&decoded)?;
759            Ok(())
760        }
761        Err(_) => decode_error(),
762    }
763}
764
765/// Parallel decode: split at 4-byte boundaries, decode chunks in parallel via rayon.
766/// Pre-allocates a single contiguous output buffer with exact decoded offsets computed
767/// upfront, so each thread decodes directly to its final position. No compaction needed.
768fn decode_borrowed_clean_parallel(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
769    let num_threads = rayon::current_num_threads().max(1);
770    let raw_chunk = data.len() / num_threads;
771    // Align to 4 bytes (each 4 base64 chars = 3 decoded bytes, context-free)
772    let chunk_size = ((raw_chunk + 3) / 4) * 4;
773
774    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(4)).collect();
775
776    // Compute exact decoded sizes per chunk upfront to eliminate the compaction pass.
777    // For all chunks except the last, decoded size is exactly chunk.len() * 3 / 4.
778    // For the last chunk, account for '=' padding bytes.
779    let mut offsets: Vec<usize> = Vec::with_capacity(chunks.len() + 1);
780    offsets.push(0);
781    let mut total_decoded = 0usize;
782    for (i, chunk) in chunks.iter().enumerate() {
783        let decoded_size = if i == chunks.len() - 1 {
784            // Last chunk: count '=' padding to get exact decoded size
785            let pad = chunk.iter().rev().take(2).filter(|&&b| b == b'=').count();
786            chunk.len() * 3 / 4 - pad
787        } else {
788            // Non-last chunks: 4-byte aligned, no padding, exact 3/4 ratio
789            chunk.len() * 3 / 4
790        };
791        total_decoded += decoded_size;
792        offsets.push(total_decoded);
793    }
794
795    // Pre-allocate contiguous output buffer with exact total size
796    let mut output_buf: Vec<u8> = Vec::with_capacity(total_decoded);
797    #[allow(clippy::uninit_vec)]
798    unsafe {
799        output_buf.set_len(total_decoded);
800    }
801
802    // Parallel decode: each thread decodes directly into its exact final position.
803    // No compaction pass needed since offsets are computed from exact decoded sizes.
804    // SAFETY: each thread writes to a non-overlapping region of the output buffer.
805    // Use usize representation of the pointer for Send+Sync compatibility with rayon.
806    let out_addr = output_buf.as_mut_ptr() as usize;
807    let decode_result: Result<Vec<()>, io::Error> = chunks
808        .par_iter()
809        .enumerate()
810        .map(|(i, chunk)| {
811            let offset = offsets[i];
812            let expected_size = offsets[i + 1] - offset;
813            // SAFETY: each thread writes to non-overlapping region [offset..offset+expected_size]
814            let out_slice = unsafe {
815                std::slice::from_raw_parts_mut((out_addr as *mut u8).add(offset), expected_size)
816            };
817            let decoded = BASE64_ENGINE
818                .decode(chunk, out_slice.as_out())
819                .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid input"))?;
820            debug_assert_eq!(decoded.len(), expected_size);
821            Ok(())
822        })
823        .collect();
824
825    decode_result?;
826
827    out.write_all(&output_buf[..total_decoded])
828}
829
830/// Strip non-base64 characters (for -i / --ignore-garbage).
831fn strip_non_base64(data: &[u8]) -> Vec<u8> {
832    data.iter()
833        .copied()
834        .filter(|&b| is_base64_char(b))
835        .collect()
836}
837
838/// Check if a byte is a valid base64 alphabet character or padding.
839#[inline]
840fn is_base64_char(b: u8) -> bool {
841    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
842}
843
844/// Stream-encode from a reader to a writer. Used for stdin processing.
845/// Dispatches to specialized paths for wrap_col=0 (no wrap) and wrap_col>0 (wrapping).
846pub fn encode_stream(
847    reader: &mut impl Read,
848    wrap_col: usize,
849    writer: &mut impl Write,
850) -> io::Result<()> {
851    if wrap_col == 0 {
852        return encode_stream_nowrap(reader, writer);
853    }
854    encode_stream_wrapped(reader, wrap_col, writer)
855}
856
857/// Streaming encode with NO line wrapping — optimized fast path.
858/// Read size is 12MB (divisible by 3): encoded output = 12MB * 4/3 = 16MB.
859/// 12MB reads mean 10MB input is consumed in a single read() call,
860/// and the 16MB encoded output writes in 1-2 write() calls.
861fn encode_stream_nowrap(reader: &mut impl Read, writer: &mut impl Write) -> io::Result<()> {
862    // 12MB aligned to 3 bytes: encoded output = 12MB * 4/3 = 16MB.
863    // For 10MB input: 1 read (10MB) instead of 2 reads.
864    const NOWRAP_READ: usize = 12 * 1024 * 1024; // exactly divisible by 3
865
866    let mut buf = vec![0u8; NOWRAP_READ];
867    let encode_buf_size = BASE64_ENGINE.encoded_length(NOWRAP_READ);
868    let mut encode_buf = vec![0u8; encode_buf_size];
869
870    loop {
871        let n = read_full(reader, &mut buf)?;
872        if n == 0 {
873            break;
874        }
875        let enc_len = BASE64_ENGINE.encoded_length(n);
876        let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
877        writer.write_all(encoded)?;
878    }
879    Ok(())
880}
881
882/// Streaming encode WITH line wrapping.
883/// For the common case (wrap_col divides evenly into 3-byte input groups),
884/// uses fuse_wrap to build a contiguous output buffer with newlines interleaved,
885/// then writes it in a single write() call. This eliminates the overhead of
886/// many writev() syscalls (one per ~512 lines via IoSlice).
887///
888/// For non-aligned wrap columns, falls back to the IoSlice/writev approach.
889fn encode_stream_wrapped(
890    reader: &mut impl Read,
891    wrap_col: usize,
892    writer: &mut impl Write,
893) -> io::Result<()> {
894    let bytes_per_line = wrap_col * 3 / 4;
895    // For the common case (76-col wrapping, bytes_per_line=57 which is divisible by 3),
896    // align the read buffer to bytes_per_line boundaries so each chunk produces
897    // complete lines with no column carry-over between chunks.
898    if bytes_per_line > 0 && bytes_per_line.is_multiple_of(3) {
899        return encode_stream_wrapped_fused(reader, wrap_col, bytes_per_line, writer);
900    }
901
902    // Fallback: non-aligned wrap columns use IoSlice/writev with column tracking
903    const STREAM_READ: usize = 12 * 1024 * 1024;
904    let mut buf = vec![0u8; STREAM_READ];
905    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_READ);
906    let mut encode_buf = vec![0u8; encode_buf_size];
907
908    let mut col = 0usize;
909
910    loop {
911        let n = read_full(reader, &mut buf)?;
912        if n == 0 {
913            break;
914        }
915        let enc_len = BASE64_ENGINE.encoded_length(n);
916        let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
917
918        write_wrapped_iov_streaming(encoded, wrap_col, &mut col, writer)?;
919    }
920
921    if col > 0 {
922        writer.write_all(b"\n")?;
923    }
924
925    Ok(())
926}
927
928/// Fused encode+wrap streaming: align reads to bytes_per_line boundaries,
929/// encode chunk, fuse_wrap into contiguous buffer with newlines, single write().
930/// For 76-col wrapping (bytes_per_line=57): 12MB / 57 = ~210K complete lines per chunk.
931/// Output = 210K * 77 bytes = ~16MB, one write() syscall per chunk.
932fn encode_stream_wrapped_fused(
933    reader: &mut impl Read,
934    wrap_col: usize,
935    bytes_per_line: usize,
936    writer: &mut impl Write,
937) -> io::Result<()> {
938    // Align read size to bytes_per_line for complete output lines per chunk.
939    // ~210K lines * 57 bytes = ~12MB input, ~16MB output.
940    let lines_per_chunk = (12 * 1024 * 1024) / bytes_per_line;
941    let read_size = lines_per_chunk * bytes_per_line;
942
943    let mut buf = vec![0u8; read_size];
944    let enc_max = BASE64_ENGINE.encoded_length(read_size);
945    let fused_max = enc_max + (enc_max / wrap_col + 2); // encoded + newlines
946    let total_buf_size = fused_max + enc_max;
947    // Single buffer: [0..fused_max) = fuse_wrap output, [fused_max..total_buf_size) = encode region
948    let mut work_buf: Vec<u8> = vec![0u8; total_buf_size];
949
950    let mut trailing_partial = false;
951
952    loop {
953        let n = read_full(reader, &mut buf)?;
954        if n == 0 {
955            break;
956        }
957
958        let enc_len = BASE64_ENGINE.encoded_length(n);
959        // Encode into the second region
960        let encode_start = fused_max;
961        let _ = BASE64_ENGINE.encode(
962            &buf[..n],
963            work_buf[encode_start..encode_start + enc_len].as_out(),
964        );
965
966        // Fuse wrap: copy encoded data with newlines interleaved into first region
967        let (fused_region, encode_region) = work_buf.split_at_mut(fused_max);
968        let encoded = &encode_region[..enc_len];
969        let wp = fuse_wrap(encoded, wrap_col, fused_region);
970
971        writer.write_all(&fused_region[..wp])?;
972        trailing_partial = !enc_len.is_multiple_of(wrap_col);
973    }
974
975    // fuse_wrap already adds a trailing newline for partial last lines,
976    // so we don't need to add one here. But if we never wrote any data
977    // at all or the last encoded chunk was exactly aligned, we're fine.
978    let _ = trailing_partial;
979
980    Ok(())
981}
982
983/// Stream-decode from a reader to a writer. Used for stdin processing.
984/// Fused single-pass: read chunk -> strip whitespace -> decode immediately.
985/// Uses 16MB read buffer for maximum pipe throughput — read_full retries to
986/// fill the entire buffer from the pipe, and 16MB means the entire 10MB
987/// benchmark input is read in a single syscall batch, minimizing overhead.
988/// memchr2-based SIMD whitespace stripping handles the common case efficiently.
989pub fn decode_stream(
990    reader: &mut impl Read,
991    ignore_garbage: bool,
992    writer: &mut impl Write,
993) -> io::Result<()> {
994    const READ_CHUNK: usize = 16 * 1024 * 1024;
995    let mut buf = vec![0u8; READ_CHUNK];
996    // Pre-allocate clean buffer once and reuse across iterations.
997    // Use Vec with set_len for zero-overhead reset instead of clear() + extend().
998    let mut clean: Vec<u8> = Vec::with_capacity(READ_CHUNK + 4);
999    let mut carry = [0u8; 4];
1000    let mut carry_len = 0usize;
1001
1002    loop {
1003        let n = read_full(reader, &mut buf)?;
1004        if n == 0 {
1005            break;
1006        }
1007
1008        // Copy carry bytes to start of clean buffer (0-3 bytes from previous chunk)
1009        unsafe {
1010            std::ptr::copy_nonoverlapping(carry.as_ptr(), clean.as_mut_ptr(), carry_len);
1011        }
1012
1013        let chunk = &buf[..n];
1014        if ignore_garbage {
1015            // Scalar filter for ignore_garbage mode (rare path)
1016            let dst = unsafe { clean.as_mut_ptr().add(carry_len) };
1017            let mut wp = 0usize;
1018            for &b in chunk {
1019                if is_base64_char(b) {
1020                    unsafe { *dst.add(wp) = b };
1021                    wp += 1;
1022                }
1023            }
1024            unsafe { clean.set_len(carry_len + wp) };
1025        } else {
1026            // SIMD gap-copy: use memchr2 to find \n and \r positions, then copy
1027            // the gaps between them. For typical base64 (76-char lines), newlines
1028            // are ~1/77 of the data, so we process ~76 bytes per memchr hit
1029            // instead of 1 byte per scalar iteration.
1030            // Track rare whitespace during gap-copy to skip the second scan.
1031            let dst = unsafe { clean.as_mut_ptr().add(carry_len) };
1032            let mut wp = 0usize;
1033            let mut gap_start = 0usize;
1034            let mut has_rare_ws = false;
1035
1036            for pos in memchr::memchr2_iter(b'\n', b'\r', chunk) {
1037                let gap_len = pos - gap_start;
1038                if gap_len > 0 {
1039                    if !has_rare_ws {
1040                        has_rare_ws = chunk[gap_start..pos]
1041                            .iter()
1042                            .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
1043                    }
1044                    unsafe {
1045                        std::ptr::copy_nonoverlapping(
1046                            chunk.as_ptr().add(gap_start),
1047                            dst.add(wp),
1048                            gap_len,
1049                        );
1050                    }
1051                    wp += gap_len;
1052                }
1053                gap_start = pos + 1;
1054            }
1055            let tail_len = n - gap_start;
1056            if tail_len > 0 {
1057                if !has_rare_ws {
1058                    has_rare_ws = chunk[gap_start..n]
1059                        .iter()
1060                        .any(|&b| b == b' ' || b == b'\t' || b == 0x0b || b == 0x0c);
1061                }
1062                unsafe {
1063                    std::ptr::copy_nonoverlapping(
1064                        chunk.as_ptr().add(gap_start),
1065                        dst.add(wp),
1066                        tail_len,
1067                    );
1068                }
1069                wp += tail_len;
1070            }
1071            let total_clean = carry_len + wp;
1072            unsafe { clean.set_len(total_clean) };
1073
1074            // Second pass for rare whitespace (tab, space, VT, FF) — only when detected.
1075            // In typical base64 streams (76-char lines with \n), this is skipped entirely.
1076            if has_rare_ws {
1077                let ptr = clean.as_mut_ptr();
1078                let mut rp = carry_len;
1079                let mut cwp = carry_len;
1080                while rp < total_clean {
1081                    let b = unsafe { *ptr.add(rp) };
1082                    if NOT_WHITESPACE[b as usize] {
1083                        unsafe { *ptr.add(cwp) = b };
1084                        cwp += 1;
1085                    }
1086                    rp += 1;
1087                }
1088                clean.truncate(cwp);
1089            }
1090        }
1091
1092        carry_len = 0;
1093        let is_last = n < READ_CHUNK;
1094
1095        if is_last {
1096            // Last chunk: decode everything (including padding)
1097            decode_clean_slice(&mut clean, writer)?;
1098        } else {
1099            // Save incomplete base64 quadruplet for next iteration
1100            let clean_len = clean.len();
1101            let decode_len = (clean_len / 4) * 4;
1102            let leftover = clean_len - decode_len;
1103            if leftover > 0 {
1104                unsafe {
1105                    std::ptr::copy_nonoverlapping(
1106                        clean.as_ptr().add(decode_len),
1107                        carry.as_mut_ptr(),
1108                        leftover,
1109                    );
1110                }
1111                carry_len = leftover;
1112            }
1113            if decode_len > 0 {
1114                clean.truncate(decode_len);
1115                decode_clean_slice(&mut clean, writer)?;
1116            }
1117        }
1118    }
1119
1120    // Handle any remaining carry-over bytes
1121    if carry_len > 0 {
1122        let mut carry_buf = carry[..carry_len].to_vec();
1123        decode_clean_slice(&mut carry_buf, writer)?;
1124    }
1125
1126    Ok(())
1127}
1128
1129/// Write all IoSlice entries using write_vectored (writev syscall).
1130/// Falls back to write_all per slice on partial writes.
1131fn write_all_vectored(out: &mut impl Write, slices: &[io::IoSlice]) -> io::Result<()> {
1132    if slices.is_empty() {
1133        return Ok(());
1134    }
1135    let total: usize = slices.iter().map(|s| s.len()).sum();
1136
1137    // Try write_vectored first — often writes everything in one syscall
1138    let written = match out.write_vectored(slices) {
1139        Ok(n) if n >= total => return Ok(()),
1140        Ok(n) => n,
1141        Err(e) => return Err(e),
1142    };
1143
1144    // Partial write fallback
1145    let mut skip = written;
1146    for slice in slices {
1147        let slen = slice.len();
1148        if skip >= slen {
1149            skip -= slen;
1150            continue;
1151        }
1152        if skip > 0 {
1153            out.write_all(&slice[skip..])?;
1154            skip = 0;
1155        } else {
1156            out.write_all(slice)?;
1157        }
1158    }
1159    Ok(())
1160}
1161
1162/// Read as many bytes as possible into buf, retrying on partial reads.
1163/// Fast path: regular file reads usually return the full buffer on the first call,
1164/// avoiding the loop overhead entirely.
1165#[inline]
1166fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
1167    // Fast path: first read() usually fills the entire buffer for regular files
1168    let n = reader.read(buf)?;
1169    if n == buf.len() || n == 0 {
1170        return Ok(n);
1171    }
1172    // Slow path: partial read — retry to fill buffer (pipes, slow devices)
1173    let mut total = n;
1174    while total < buf.len() {
1175        match reader.read(&mut buf[total..]) {
1176            Ok(0) => break,
1177            Ok(n) => total += n,
1178            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
1179            Err(e) => return Err(e),
1180        }
1181    }
1182    Ok(total)
1183}