Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8/// Streaming encode chunk: 24MB aligned to 3 bytes.
9const STREAM_ENCODE_CHUNK: usize = 24 * 1024 * 1024 - (24 * 1024 * 1024 % 3);
10
11/// Chunk size for no-wrap encoding: 32MB aligned to 3 bytes.
12/// Larger chunks = fewer write() syscalls for big files.
13const NOWRAP_CHUNK: usize = 32 * 1024 * 1024 - (32 * 1024 * 1024 % 3);
14
15/// Minimum data size for parallel encoding (1MB).
16/// Lowered from 4MB so 10MB benchmark workloads get multi-core processing.
17const PARALLEL_ENCODE_THRESHOLD: usize = 1024 * 1024;
18
19/// Minimum data size for parallel decoding (1MB of base64 data).
20/// Lowered from 4MB for better parallelism on typical workloads.
21const PARALLEL_DECODE_THRESHOLD: usize = 1024 * 1024;
22
23/// Encode data and write to output with line wrapping.
24/// Uses SIMD encoding with fused encode+wrap for maximum throughput.
25pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
26    if data.is_empty() {
27        return Ok(());
28    }
29
30    if wrap_col == 0 {
31        return encode_no_wrap(data, out);
32    }
33
34    encode_wrapped(data, wrap_col, out)
35}
36
37/// Encode without wrapping — parallel SIMD encoding for large data, sequential for small.
38fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
39    if data.len() >= PARALLEL_ENCODE_THRESHOLD {
40        return encode_no_wrap_parallel(data, out);
41    }
42
43    let actual_chunk = NOWRAP_CHUNK.min(data.len());
44    let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
45    // SAFETY: encode() writes exactly enc_len bytes before we read them.
46    let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
47    #[allow(clippy::uninit_vec)]
48    unsafe {
49        buf.set_len(enc_max);
50    }
51
52    for chunk in data.chunks(NOWRAP_CHUNK) {
53        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
54        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
55        out.write_all(encoded)?;
56    }
57    Ok(())
58}
59
60/// Parallel no-wrap encoding: split at 3-byte boundaries, encode chunks in parallel.
61/// Each chunk except possibly the last is 3-byte aligned, so no padding in intermediate chunks.
62/// Uses write_vectored (writev) to send all encoded chunks in a single syscall.
63fn encode_no_wrap_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
64    let num_threads = rayon::current_num_threads().max(1);
65    let raw_chunk = data.len() / num_threads;
66    // Align to 3 bytes so each chunk encodes without padding (except the last)
67    let chunk_size = ((raw_chunk + 2) / 3) * 3;
68
69    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(3)).collect();
70    let encoded_chunks: Vec<Vec<u8>> = chunks
71        .par_iter()
72        .map(|chunk| {
73            let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
74            let mut buf: Vec<u8> = Vec::with_capacity(enc_len);
75            #[allow(clippy::uninit_vec)]
76            unsafe {
77                buf.set_len(enc_len);
78            }
79            let _ = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
80            buf
81        })
82        .collect();
83
84    // Use write_vectored to send all chunks in a single syscall
85    let iov: Vec<io::IoSlice> = encoded_chunks.iter().map(|c| io::IoSlice::new(c)).collect();
86    write_all_vectored(out, &iov)
87}
88
89/// Encode with line wrapping — fused encode+wrap in a single output buffer.
90/// Encodes aligned input chunks, then interleaves newlines directly into
91/// a single output buffer, eliminating the separate wrap pass.
92fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
93    // Calculate bytes_per_line: input bytes that produce exactly wrap_col encoded chars.
94    // For default wrap_col=76: 76*3/4 = 57 bytes per line.
95    let bytes_per_line = wrap_col * 3 / 4;
96    if bytes_per_line == 0 {
97        // Degenerate case: wrap_col < 4, fall back to byte-at-a-time
98        return encode_wrapped_small(data, wrap_col, out);
99    }
100
101    // Parallel encoding for large data when bytes_per_line is a multiple of 3.
102    // This guarantees each chunk encodes to complete base64 without padding.
103    if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line.is_multiple_of(3) {
104        return encode_wrapped_parallel(data, wrap_col, bytes_per_line, out);
105    }
106
107    // Align input chunk to bytes_per_line for complete output lines.
108    // Use 32MB chunks — large enough to process most files in a single pass,
109    // reducing write() syscalls.
110    let lines_per_chunk = (32 * 1024 * 1024) / bytes_per_line;
111    let max_input_chunk = (lines_per_chunk * bytes_per_line).max(bytes_per_line);
112    let input_chunk = max_input_chunk.min(data.len());
113
114    let enc_max = BASE64_ENGINE.encoded_length(input_chunk);
115    let mut encode_buf: Vec<u8> = Vec::with_capacity(enc_max);
116    #[allow(clippy::uninit_vec)]
117    unsafe {
118        encode_buf.set_len(enc_max);
119    }
120
121    // Fused output buffer: holds encoded data with newlines interleaved
122    let max_lines = enc_max / wrap_col + 2;
123    let fused_max = enc_max + max_lines;
124    let mut fused_buf: Vec<u8> = Vec::with_capacity(fused_max);
125    #[allow(clippy::uninit_vec)]
126    unsafe {
127        fused_buf.set_len(fused_max);
128    }
129
130    for chunk in data.chunks(max_input_chunk.max(1)) {
131        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
132        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
133
134        // Fuse: copy encoded data into fused_buf with newlines interleaved
135        let wp = fuse_wrap(encoded, wrap_col, &mut fused_buf);
136        out.write_all(&fused_buf[..wp])?;
137    }
138
139    Ok(())
140}
141
142/// Parallel wrapped encoding: split at bytes_per_line boundaries, encode + wrap in parallel.
143/// Requires bytes_per_line.is_multiple_of(3) so each chunk encodes without intermediate padding.
144/// Uses write_vectored (writev) to send all encoded+wrapped chunks in a single syscall.
145fn encode_wrapped_parallel(
146    data: &[u8],
147    wrap_col: usize,
148    bytes_per_line: usize,
149    out: &mut impl Write,
150) -> io::Result<()> {
151    let num_threads = rayon::current_num_threads().max(1);
152    // Split at bytes_per_line boundaries for complete output lines per chunk
153    let lines_per_chunk = (data.len() / bytes_per_line / num_threads).max(1);
154    let chunk_size = lines_per_chunk * bytes_per_line;
155
156    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(bytes_per_line)).collect();
157    let encoded_chunks: Vec<Vec<u8>> = chunks
158        .par_iter()
159        .map(|chunk| {
160            let enc_max = BASE64_ENGINE.encoded_length(chunk.len());
161            let max_lines = enc_max / wrap_col + 2;
162            // Single allocation with two non-overlapping regions:
163            //   [0..fused_size) = fuse_wrap output region
164            //   [fused_size..fused_size+enc_max) = encode region
165            let fused_size = enc_max + max_lines;
166            let total_size = fused_size + enc_max;
167            let mut buf: Vec<u8> = Vec::with_capacity(total_size);
168            #[allow(clippy::uninit_vec)]
169            unsafe {
170                buf.set_len(total_size);
171            }
172            // Encode into the second region [fused_size..fused_size+enc_max]
173            let _ = BASE64_ENGINE.encode(chunk, buf[fused_size..fused_size + enc_max].as_out());
174            // Use split_at_mut to get non-overlapping mutable/immutable refs
175            let (fused_region, encode_region) = buf.split_at_mut(fused_size);
176            let encoded = &encode_region[..enc_max];
177            let wp = fuse_wrap(encoded, wrap_col, fused_region);
178            buf.truncate(wp);
179            buf
180        })
181        .collect();
182
183    // Use write_vectored to send all chunks in a single syscall
184    let iov: Vec<io::IoSlice> = encoded_chunks.iter().map(|c| io::IoSlice::new(c)).collect();
185    write_all_vectored(out, &iov)
186}
187
188/// Fuse encoded base64 data with newlines in a single pass.
189/// Uses ptr::copy_nonoverlapping with 8-line unrolling for max throughput.
190/// Returns number of bytes written.
191#[inline]
192fn fuse_wrap(encoded: &[u8], wrap_col: usize, out_buf: &mut [u8]) -> usize {
193    let line_out = wrap_col + 1; // wrap_col data bytes + 1 newline
194    let mut rp = 0;
195    let mut wp = 0;
196
197    // Unrolled: process 8 lines per iteration for better ILP
198    while rp + 8 * wrap_col <= encoded.len() {
199        unsafe {
200            let src = encoded.as_ptr().add(rp);
201            let dst = out_buf.as_mut_ptr().add(wp);
202
203            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
204            *dst.add(wrap_col) = b'\n';
205
206            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
207            *dst.add(line_out + wrap_col) = b'\n';
208
209            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
210            *dst.add(2 * line_out + wrap_col) = b'\n';
211
212            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
213            *dst.add(3 * line_out + wrap_col) = b'\n';
214
215            std::ptr::copy_nonoverlapping(src.add(4 * wrap_col), dst.add(4 * line_out), wrap_col);
216            *dst.add(4 * line_out + wrap_col) = b'\n';
217
218            std::ptr::copy_nonoverlapping(src.add(5 * wrap_col), dst.add(5 * line_out), wrap_col);
219            *dst.add(5 * line_out + wrap_col) = b'\n';
220
221            std::ptr::copy_nonoverlapping(src.add(6 * wrap_col), dst.add(6 * line_out), wrap_col);
222            *dst.add(6 * line_out + wrap_col) = b'\n';
223
224            std::ptr::copy_nonoverlapping(src.add(7 * wrap_col), dst.add(7 * line_out), wrap_col);
225            *dst.add(7 * line_out + wrap_col) = b'\n';
226        }
227        rp += 8 * wrap_col;
228        wp += 8 * line_out;
229    }
230
231    // Handle remaining 4 lines at a time
232    while rp + 4 * wrap_col <= encoded.len() {
233        unsafe {
234            let src = encoded.as_ptr().add(rp);
235            let dst = out_buf.as_mut_ptr().add(wp);
236
237            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
238            *dst.add(wrap_col) = b'\n';
239
240            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
241            *dst.add(line_out + wrap_col) = b'\n';
242
243            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
244            *dst.add(2 * line_out + wrap_col) = b'\n';
245
246            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
247            *dst.add(3 * line_out + wrap_col) = b'\n';
248        }
249        rp += 4 * wrap_col;
250        wp += 4 * line_out;
251    }
252
253    // Remaining full lines
254    while rp + wrap_col <= encoded.len() {
255        unsafe {
256            std::ptr::copy_nonoverlapping(
257                encoded.as_ptr().add(rp),
258                out_buf.as_mut_ptr().add(wp),
259                wrap_col,
260            );
261            *out_buf.as_mut_ptr().add(wp + wrap_col) = b'\n';
262        }
263        rp += wrap_col;
264        wp += line_out;
265    }
266
267    // Partial last line
268    if rp < encoded.len() {
269        let remaining = encoded.len() - rp;
270        unsafe {
271            std::ptr::copy_nonoverlapping(
272                encoded.as_ptr().add(rp),
273                out_buf.as_mut_ptr().add(wp),
274                remaining,
275            );
276        }
277        wp += remaining;
278        out_buf[wp] = b'\n';
279        wp += 1;
280    }
281
282    wp
283}
284
285/// Fallback for very small wrap columns (< 4 chars).
286fn encode_wrapped_small(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
287    let enc_max = BASE64_ENGINE.encoded_length(data.len());
288    let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
289    #[allow(clippy::uninit_vec)]
290    unsafe {
291        buf.set_len(enc_max);
292    }
293    let encoded = BASE64_ENGINE.encode(data, buf[..enc_max].as_out());
294
295    let wc = wrap_col.max(1);
296    for line in encoded.chunks(wc) {
297        out.write_all(line)?;
298        out.write_all(b"\n")?;
299    }
300    Ok(())
301}
302
303/// Decode base64 data and write to output (borrows data, allocates clean buffer).
304/// When `ignore_garbage` is true, strip all non-base64 characters.
305/// When false, only strip whitespace (standard behavior).
306pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
307    if data.is_empty() {
308        return Ok(());
309    }
310
311    if ignore_garbage {
312        let mut cleaned = strip_non_base64(data);
313        return decode_clean_slice(&mut cleaned, out);
314    }
315
316    // Fast path: single-pass strip + decode
317    decode_stripping_whitespace(data, out)
318}
319
320/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
321pub fn decode_owned(
322    data: &mut Vec<u8>,
323    ignore_garbage: bool,
324    out: &mut impl Write,
325) -> io::Result<()> {
326    if data.is_empty() {
327        return Ok(());
328    }
329
330    if ignore_garbage {
331        data.retain(|&b| is_base64_char(b));
332    } else {
333        strip_whitespace_inplace(data);
334    }
335
336    decode_clean_slice(data, out)
337}
338
339/// Strip all whitespace from a Vec in-place using the lookup table.
340/// Single-pass compaction: uses NOT_WHITESPACE table to classify all whitespace
341/// types simultaneously, avoiding the previous multi-scan approach.
342fn strip_whitespace_inplace(data: &mut Vec<u8>) {
343    // Quick check: any whitespace at all?
344    let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
345    if !has_ws {
346        return;
347    }
348
349    // Single-pass in-place compaction using the lookup table.
350    let ptr = data.as_ptr();
351    let mut_ptr = data.as_mut_ptr();
352    let len = data.len();
353    let mut wp = 0usize;
354
355    for i in 0..len {
356        let b = unsafe { *ptr.add(i) };
357        if NOT_WHITESPACE[b as usize] {
358            unsafe { *mut_ptr.add(wp) = b };
359            wp += 1;
360        }
361    }
362
363    data.truncate(wp);
364}
365
366/// 256-byte lookup table: true for non-whitespace bytes.
367/// Used for single-pass whitespace stripping in decode.
368static NOT_WHITESPACE: [bool; 256] = {
369    let mut table = [true; 256];
370    table[b' ' as usize] = false;
371    table[b'\t' as usize] = false;
372    table[b'\n' as usize] = false;
373    table[b'\r' as usize] = false;
374    table[0x0b] = false; // vertical tab
375    table[0x0c] = false; // form feed
376    table
377};
378
379/// Decode by stripping whitespace and decoding in a single fused pass.
380/// For data with no whitespace, decodes directly without any copy.
381/// Uses memchr2 SIMD gap-copy for \n/\r (the dominant whitespace in base64),
382/// then a fallback pass for rare whitespace types (tab, space, VT, FF).
383fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
384    // Quick check: any whitespace at all?  Use the lookup table for a single scan.
385    let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
386    if !has_ws {
387        // No whitespace — decode directly from borrowed data
388        return decode_borrowed_clean(out, data);
389    }
390
391    // SIMD gap-copy: use memchr2 to find \n and \r positions, then copy the
392    // gaps between them. For typical base64 (76-char lines), newlines are ~1/77
393    // of the data, so we process ~76 bytes per memchr hit instead of 1 per scalar.
394    let mut clean: Vec<u8> = Vec::with_capacity(data.len());
395    let dst = clean.as_mut_ptr();
396    let mut wp = 0usize;
397    let mut gap_start = 0usize;
398
399    for pos in memchr::memchr2_iter(b'\n', b'\r', data) {
400        let gap_len = pos - gap_start;
401        if gap_len > 0 {
402            unsafe {
403                std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), gap_len);
404            }
405            wp += gap_len;
406        }
407        gap_start = pos + 1;
408    }
409    // Copy the final gap after the last \n/\r
410    let tail_len = data.len() - gap_start;
411    if tail_len > 0 {
412        unsafe {
413            std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), tail_len);
414        }
415        wp += tail_len;
416    }
417    unsafe {
418        clean.set_len(wp);
419    }
420
421    // Second pass for rare whitespace (tab, space, VT, FF) using lookup table.
422    // In typical base64 streams this does nothing, but correctness requires it.
423    let has_rare_ws = clean.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
424    if has_rare_ws {
425        let ptr = clean.as_mut_ptr();
426        let len = clean.len();
427        let mut rp = 0;
428        let mut cwp = 0;
429        while rp < len {
430            let b = unsafe { *ptr.add(rp) };
431            if NOT_WHITESPACE[b as usize] {
432                unsafe { *ptr.add(cwp) = b };
433                cwp += 1;
434            }
435            rp += 1;
436        }
437        clean.truncate(cwp);
438    }
439
440    decode_clean_slice(&mut clean, out)
441}
442
443/// Decode a clean (no whitespace) buffer in-place with SIMD.
444fn decode_clean_slice(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
445    if data.is_empty() {
446        return Ok(());
447    }
448    match BASE64_ENGINE.decode_inplace(data) {
449        Ok(decoded) => out.write_all(decoded),
450        Err(_) => decode_error(),
451    }
452}
453
454/// Cold error path — keeps hot decode path tight by moving error construction out of line.
455#[cold]
456#[inline(never)]
457fn decode_error() -> io::Result<()> {
458    Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input"))
459}
460
461/// Decode clean base64 data (no whitespace) from a borrowed slice.
462fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
463    if data.is_empty() {
464        return Ok(());
465    }
466    // Parallel decode for large data: split at 4-byte boundaries,
467    // decode each chunk independently (base64 is context-free per 4-char group).
468    if data.len() >= PARALLEL_DECODE_THRESHOLD {
469        return decode_borrowed_clean_parallel(out, data);
470    }
471    match BASE64_ENGINE.decode_to_vec(data) {
472        Ok(decoded) => {
473            out.write_all(&decoded)?;
474            Ok(())
475        }
476        Err(_) => decode_error(),
477    }
478}
479
480/// Parallel decode: split at 4-byte boundaries, decode chunks in parallel via rayon.
481/// Pre-allocates a single contiguous output buffer with exact decoded offsets computed
482/// upfront, so each thread decodes directly to its final position. No compaction needed.
483fn decode_borrowed_clean_parallel(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
484    let num_threads = rayon::current_num_threads().max(1);
485    let raw_chunk = data.len() / num_threads;
486    // Align to 4 bytes (each 4 base64 chars = 3 decoded bytes, context-free)
487    let chunk_size = ((raw_chunk + 3) / 4) * 4;
488
489    let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(4)).collect();
490
491    // Compute exact decoded sizes per chunk upfront to eliminate the compaction pass.
492    // For all chunks except the last, decoded size is exactly chunk.len() * 3 / 4.
493    // For the last chunk, account for '=' padding bytes.
494    let mut offsets: Vec<usize> = Vec::with_capacity(chunks.len() + 1);
495    offsets.push(0);
496    let mut total_decoded = 0usize;
497    for (i, chunk) in chunks.iter().enumerate() {
498        let decoded_size = if i == chunks.len() - 1 {
499            // Last chunk: count '=' padding to get exact decoded size
500            let pad = chunk.iter().rev().take(2).filter(|&&b| b == b'=').count();
501            chunk.len() * 3 / 4 - pad
502        } else {
503            // Non-last chunks: 4-byte aligned, no padding, exact 3/4 ratio
504            chunk.len() * 3 / 4
505        };
506        total_decoded += decoded_size;
507        offsets.push(total_decoded);
508    }
509
510    // Pre-allocate contiguous output buffer with exact total size
511    let mut output_buf: Vec<u8> = Vec::with_capacity(total_decoded);
512    #[allow(clippy::uninit_vec)]
513    unsafe {
514        output_buf.set_len(total_decoded);
515    }
516
517    // Parallel decode: each thread decodes directly into its exact final position.
518    // No compaction pass needed since offsets are computed from exact decoded sizes.
519    // SAFETY: each thread writes to a non-overlapping region of the output buffer.
520    // Use usize representation of the pointer for Send+Sync compatibility with rayon.
521    let out_addr = output_buf.as_mut_ptr() as usize;
522    let decode_result: Result<Vec<()>, io::Error> = chunks
523        .par_iter()
524        .enumerate()
525        .map(|(i, chunk)| {
526            let offset = offsets[i];
527            let expected_size = offsets[i + 1] - offset;
528            // SAFETY: each thread writes to non-overlapping region [offset..offset+expected_size]
529            let out_slice = unsafe {
530                std::slice::from_raw_parts_mut((out_addr as *mut u8).add(offset), expected_size)
531            };
532            let decoded = BASE64_ENGINE
533                .decode(chunk, out_slice.as_out())
534                .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid input"))?;
535            debug_assert_eq!(decoded.len(), expected_size);
536            Ok(())
537        })
538        .collect();
539
540    decode_result?;
541
542    out.write_all(&output_buf[..total_decoded])
543}
544
545/// Strip non-base64 characters (for -i / --ignore-garbage).
546fn strip_non_base64(data: &[u8]) -> Vec<u8> {
547    data.iter()
548        .copied()
549        .filter(|&b| is_base64_char(b))
550        .collect()
551}
552
553/// Check if a byte is a valid base64 alphabet character or padding.
554#[inline]
555fn is_base64_char(b: u8) -> bool {
556    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
557}
558
559/// Stream-encode from a reader to a writer. Used for stdin processing.
560pub fn encode_stream(
561    reader: &mut impl Read,
562    wrap_col: usize,
563    writer: &mut impl Write,
564) -> io::Result<()> {
565    let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
566
567    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
568    let mut encode_buf = vec![0u8; encode_buf_size];
569
570    if wrap_col == 0 {
571        // No wrapping: encode each chunk and write directly.
572        loop {
573            let n = read_full(reader, &mut buf)?;
574            if n == 0 {
575                break;
576            }
577            let enc_len = BASE64_ENGINE.encoded_length(n);
578            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
579            writer.write_all(encoded)?;
580        }
581    } else {
582        // Wrapping: fused encode+wrap into a single output buffer.
583        let max_fused = encode_buf_size + (encode_buf_size / wrap_col + 2);
584        let mut fused_buf = vec![0u8; max_fused];
585        let mut col = 0usize;
586
587        loop {
588            let n = read_full(reader, &mut buf)?;
589            if n == 0 {
590                break;
591            }
592            let enc_len = BASE64_ENGINE.encoded_length(n);
593            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
594
595            // Build fused output in a single buffer, then one write.
596            let wp = build_fused_output(encoded, wrap_col, &mut col, &mut fused_buf);
597            writer.write_all(&fused_buf[..wp])?;
598        }
599
600        if col > 0 {
601            writer.write_all(b"\n")?;
602        }
603    }
604
605    Ok(())
606}
607
608/// Build fused encode+wrap output into a pre-allocated buffer.
609/// Returns the number of bytes written.
610#[inline]
611fn build_fused_output(data: &[u8], wrap_col: usize, col: &mut usize, out_buf: &mut [u8]) -> usize {
612    let mut rp = 0;
613    let mut wp = 0;
614
615    while rp < data.len() {
616        let space = wrap_col - *col;
617        let avail = data.len() - rp;
618
619        if avail <= space {
620            out_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
621            wp += avail;
622            *col += avail;
623            if *col == wrap_col {
624                out_buf[wp] = b'\n';
625                wp += 1;
626                *col = 0;
627            }
628            break;
629        } else {
630            out_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
631            wp += space;
632            out_buf[wp] = b'\n';
633            wp += 1;
634            rp += space;
635            *col = 0;
636        }
637    }
638
639    wp
640}
641
642/// Stream-decode from a reader to a writer. Used for stdin processing.
643/// Fused single-pass: read chunk -> strip whitespace -> decode immediately.
644/// Uses 16MB read buffer to reduce syscalls and memchr2-based SIMD whitespace
645/// stripping for the common case (only \n and \r whitespace in base64 streams).
646pub fn decode_stream(
647    reader: &mut impl Read,
648    ignore_garbage: bool,
649    writer: &mut impl Write,
650) -> io::Result<()> {
651    const READ_CHUNK: usize = 16 * 1024 * 1024;
652    let mut buf = vec![0u8; READ_CHUNK];
653    let mut clean = Vec::with_capacity(READ_CHUNK);
654    let mut carry: Vec<u8> = Vec::with_capacity(4);
655
656    loop {
657        let n = read_full(reader, &mut buf)?;
658        if n == 0 {
659            break;
660        }
661
662        // Fused: build clean buffer from carry-over + stripped chunk in one pass
663        clean.clear();
664        clean.extend_from_slice(&carry);
665        carry.clear();
666
667        let chunk = &buf[..n];
668        if ignore_garbage {
669            clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
670        } else {
671            // SIMD gap-copy: use memchr2 to find \n and \r positions, then copy
672            // the gaps between them. For typical base64 (76-char lines), newlines
673            // are ~1/77 of the data, so we process ~76 bytes per memchr hit
674            // instead of 1 byte per scalar iteration.
675            clean.reserve(n);
676            let base_len = clean.len();
677            let dst = unsafe { clean.as_mut_ptr().add(base_len) };
678            let mut wp = 0usize;
679            let mut gap_start = 0usize;
680
681            for pos in memchr::memchr2_iter(b'\n', b'\r', chunk) {
682                // Copy the gap before this \n or \r
683                let gap_len = pos - gap_start;
684                if gap_len > 0 {
685                    unsafe {
686                        std::ptr::copy_nonoverlapping(
687                            chunk.as_ptr().add(gap_start),
688                            dst.add(wp),
689                            gap_len,
690                        );
691                    }
692                    wp += gap_len;
693                }
694                gap_start = pos + 1;
695            }
696            // Copy the final gap after the last \n/\r
697            let tail_len = n - gap_start;
698            if tail_len > 0 {
699                unsafe {
700                    std::ptr::copy_nonoverlapping(
701                        chunk.as_ptr().add(gap_start),
702                        dst.add(wp),
703                        tail_len,
704                    );
705                }
706                wp += tail_len;
707            }
708            unsafe { clean.set_len(base_len + wp) };
709
710            // Second pass for rare whitespace (tab, space, VT, FF) using lookup table.
711            // In typical base64 streams this does nothing, but we need correctness.
712            let has_rare_ws = clean[base_len..]
713                .iter()
714                .any(|&b| !NOT_WHITESPACE[b as usize]);
715            if has_rare_ws {
716                // Compact in-place within the clean region we just wrote
717                let start = base_len;
718                let end = clean.len();
719                let ptr = clean.as_mut_ptr();
720                let mut rp = start;
721                let mut cwp = start;
722                while rp < end {
723                    let b = unsafe { *ptr.add(rp) };
724                    if NOT_WHITESPACE[b as usize] {
725                        unsafe { *ptr.add(cwp) = b };
726                        cwp += 1;
727                    }
728                    rp += 1;
729                }
730                clean.truncate(cwp);
731            }
732        }
733
734        let is_last = n < READ_CHUNK;
735
736        if is_last {
737            // Last chunk: decode everything (including padding)
738            decode_clean_slice(&mut clean, writer)?;
739        } else {
740            // Save incomplete base64 quadruplet for next iteration
741            let decode_len = (clean.len() / 4) * 4;
742            if decode_len < clean.len() {
743                carry.extend_from_slice(&clean[decode_len..]);
744            }
745            if decode_len > 0 {
746                clean.truncate(decode_len);
747                decode_clean_slice(&mut clean, writer)?;
748            }
749        }
750    }
751
752    // Handle any remaining carry-over bytes
753    if !carry.is_empty() {
754        decode_clean_slice(&mut carry, writer)?;
755    }
756
757    Ok(())
758}
759
760/// Write all IoSlice entries using write_vectored (writev syscall).
761/// Falls back to write_all per slice on partial writes.
762fn write_all_vectored(out: &mut impl Write, slices: &[io::IoSlice]) -> io::Result<()> {
763    if slices.is_empty() {
764        return Ok(());
765    }
766    let total: usize = slices.iter().map(|s| s.len()).sum();
767
768    // Try write_vectored first — often writes everything in one syscall
769    let written = match out.write_vectored(slices) {
770        Ok(n) if n >= total => return Ok(()),
771        Ok(n) => n,
772        Err(e) => return Err(e),
773    };
774
775    // Partial write fallback
776    let mut skip = written;
777    for slice in slices {
778        let slen = slice.len();
779        if skip >= slen {
780            skip -= slen;
781            continue;
782        }
783        if skip > 0 {
784            out.write_all(&slice[skip..])?;
785            skip = 0;
786        } else {
787            out.write_all(slice)?;
788        }
789    }
790    Ok(())
791}
792
793/// Read as many bytes as possible into buf, retrying on partial reads.
794/// Fast path: regular file reads usually return the full buffer on the first call,
795/// avoiding the loop overhead entirely.
796#[inline]
797fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
798    // Fast path: first read() usually fills the entire buffer for regular files
799    let n = reader.read(buf)?;
800    if n == buf.len() || n == 0 {
801        return Ok(n);
802    }
803    // Slow path: partial read — retry to fill buffer (pipes, slow devices)
804    let mut total = n;
805    while total < buf.len() {
806        match reader.read(&mut buf[total..]) {
807            Ok(0) => break,
808            Ok(n) => total += n,
809            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
810            Err(e) => return Err(e),
811        }
812    }
813    Ok(total)
814}