Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8/// Streaming encode chunk: 12MB aligned to 3 bytes for maximum throughput.
9const STREAM_ENCODE_CHUNK: usize = 12 * 1024 * 1024 - (12 * 1024 * 1024 % 3);
10
11/// Chunk size for no-wrap encoding: 2MB aligned to 3 bytes.
12/// Smaller than before (was 8MB) for better L2 cache behavior and
13/// faster initial buffer allocation (fewer page faults).
14const NOWRAP_CHUNK: usize = 2 * 1024 * 1024 - (2 * 1024 * 1024 % 3);
15
16/// Minimum input size for parallel encoding/decoding.
17/// Set high enough to avoid rayon thread pool init cost (~0.5-1ms per process)
18/// which dominates for inputs under 32MB where SIMD encode is already very fast.
19const PARALLEL_ENCODE_THRESHOLD: usize = 32 * 1024 * 1024;
20
21/// Encode data and write to output with line wrapping.
22/// Uses SIMD encoding with reusable buffers for maximum throughput.
23pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
24    if data.is_empty() {
25        return Ok(());
26    }
27
28    if wrap_col == 0 {
29        return encode_no_wrap(data, out);
30    }
31
32    encode_wrapped(data, wrap_col, out)
33}
34
35/// Encode without wrapping using parallel SIMD encoding for large inputs.
36fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
37    if data.len() >= PARALLEL_ENCODE_THRESHOLD {
38        // Split into per-thread chunks aligned to 3-byte boundaries
39        let num_threads = rayon::current_num_threads().max(1);
40        let raw_chunk = (data.len() + num_threads - 1) / num_threads;
41        // Align to 3 bytes for clean base64 boundaries (no padding mid-stream)
42        let chunk_size = ((raw_chunk + 2) / 3) * 3;
43
44        let encoded_chunks: Vec<Vec<u8>> = data
45            .par_chunks(chunk_size)
46            .map(|chunk| {
47                let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
48                let mut buf = vec![0u8; enc_len];
49                let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
50                let len = encoded.len();
51                buf.truncate(len);
52                buf
53            })
54            .collect();
55
56        for chunk in &encoded_chunks {
57            out.write_all(chunk)?;
58        }
59        return Ok(());
60    }
61
62    let enc_max = BASE64_ENGINE.encoded_length(NOWRAP_CHUNK);
63    let mut buf = vec![0u8; enc_max];
64
65    for chunk in data.chunks(NOWRAP_CHUNK) {
66        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
67        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
68        out.write_all(encoded)?;
69    }
70    Ok(())
71}
72
73/// Encode with line wrapping. For large inputs, uses parallel encoding.
74fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
75    let bytes_per_line = wrap_col * 3 / 4;
76
77    if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line > 0 {
78        // Parallel: split input into chunks aligned to bytes_per_line (= 3-byte aligned)
79        // so each chunk produces complete lines (no cross-chunk line splitting).
80        let num_threads = rayon::current_num_threads().max(1);
81        let lines_per_thread = ((data.len() / bytes_per_line) + num_threads - 1) / num_threads;
82        let chunk_input = (lines_per_thread * bytes_per_line).max(bytes_per_line);
83
84        let wrapped_chunks: Vec<Vec<u8>> = data
85            .par_chunks(chunk_input)
86            .map(|chunk| {
87                let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
88                let mut encode_buf = vec![0u8; enc_len];
89                let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
90
91                // Wrap the encoded output
92                let line_out = wrap_col + 1;
93                let max_lines = (encoded.len() + wrap_col - 1) / wrap_col + 1;
94                let mut wrap_buf = vec![0u8; max_lines * line_out];
95                let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
96                wrap_buf.truncate(wp);
97                wrap_buf
98            })
99            .collect();
100
101        for chunk in &wrapped_chunks {
102            out.write_all(chunk)?;
103        }
104        return Ok(());
105    }
106
107    // Sequential path: 2MB chunks fit in L2 cache and reduce initial allocation.
108    let lines_per_chunk = (2 * 1024 * 1024) / bytes_per_line.max(1);
109    let chunk_input = lines_per_chunk * bytes_per_line.max(1);
110    let chunk_encoded_max = BASE64_ENGINE.encoded_length(chunk_input.max(1));
111    let mut encode_buf = vec![0u8; chunk_encoded_max];
112    let wrapped_max = (lines_per_chunk + 1) * (wrap_col + 1);
113    let mut wrap_buf = vec![0u8; wrapped_max];
114
115    for chunk in data.chunks(chunk_input.max(1)) {
116        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
117        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
118        let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
119        out.write_all(&wrap_buf[..wp])?;
120    }
121
122    Ok(())
123}
124
125/// Wrap encoded base64 data with newlines at `wrap_col` columns.
126/// Returns number of bytes written to `wrap_buf`.
127#[inline]
128fn wrap_encoded(encoded: &[u8], wrap_col: usize, wrap_buf: &mut [u8]) -> usize {
129    let line_out = wrap_col + 1;
130    let mut rp = 0;
131    let mut wp = 0;
132
133    // Unrolled: process 4 lines per iteration
134    while rp + 4 * wrap_col <= encoded.len() {
135        unsafe {
136            let src = encoded.as_ptr().add(rp);
137            let dst = wrap_buf.as_mut_ptr().add(wp);
138
139            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
140            *dst.add(wrap_col) = b'\n';
141
142            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
143            *dst.add(line_out + wrap_col) = b'\n';
144
145            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
146            *dst.add(2 * line_out + wrap_col) = b'\n';
147
148            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
149            *dst.add(3 * line_out + wrap_col) = b'\n';
150        }
151        rp += 4 * wrap_col;
152        wp += 4 * line_out;
153    }
154
155    // Remaining full lines
156    while rp + wrap_col <= encoded.len() {
157        wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
158        wp += wrap_col;
159        wrap_buf[wp] = b'\n';
160        wp += 1;
161        rp += wrap_col;
162    }
163
164    // Partial last line
165    if rp < encoded.len() {
166        let remaining = encoded.len() - rp;
167        wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
168        wp += remaining;
169        wrap_buf[wp] = b'\n';
170        wp += 1;
171    }
172
173    wp
174}
175
176/// Decode base64 data and write to output (borrows data, allocates clean buffer).
177/// When `ignore_garbage` is true, strip all non-base64 characters.
178/// When false, only strip whitespace (standard behavior).
179pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
180    if data.is_empty() {
181        return Ok(());
182    }
183
184    if ignore_garbage {
185        let mut cleaned = strip_non_base64(data);
186        return decode_owned_clean(&mut cleaned, out);
187    }
188
189    // Fast path: strip newlines with memchr (SIMD), then SIMD decode
190    decode_stripping_whitespace(data, out)
191}
192
193/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
194/// Avoids a full buffer copy by stripping whitespace in the existing allocation,
195/// then decoding in-place. Ideal when the caller already has an owned Vec.
196pub fn decode_owned(
197    data: &mut Vec<u8>,
198    ignore_garbage: bool,
199    out: &mut impl Write,
200) -> io::Result<()> {
201    if data.is_empty() {
202        return Ok(());
203    }
204
205    if ignore_garbage {
206        data.retain(|&b| is_base64_char(b));
207    } else {
208        strip_whitespace_inplace(data);
209    }
210
211    decode_owned_clean(data, out)
212}
213
214/// Strip all whitespace from a Vec in-place using SIMD memchr for newlines
215/// and a fallback scan for rare non-newline whitespace.
216fn strip_whitespace_inplace(data: &mut Vec<u8>) {
217    // First, collect newline positions using SIMD memchr.
218    let positions: Vec<usize> = memchr::memchr_iter(b'\n', data.as_slice()).collect();
219
220    if positions.is_empty() {
221        // No newlines; check for other whitespace only.
222        if data.iter().any(|&b| is_whitespace(b)) {
223            data.retain(|&b| !is_whitespace(b));
224        }
225        return;
226    }
227
228    // Compact data in-place, removing newlines using copy_within.
229    let mut wp = 0;
230    let mut rp = 0;
231
232    for &pos in &positions {
233        if pos > rp {
234            let len = pos - rp;
235            data.copy_within(rp..pos, wp);
236            wp += len;
237        }
238        rp = pos + 1;
239    }
240
241    let data_len = data.len();
242    if rp < data_len {
243        let len = data_len - rp;
244        data.copy_within(rp..data_len, wp);
245        wp += len;
246    }
247
248    data.truncate(wp);
249
250    // Handle rare non-newline whitespace (CR, tab, etc.)
251    if data.iter().any(|&b| is_whitespace(b)) {
252        data.retain(|&b| !is_whitespace(b));
253    }
254}
255
256/// Decode by stripping all whitespace from the entire input at once,
257/// then performing a single SIMD decode pass. Used when data is borrowed.
258/// For large inputs, decodes in parallel chunks for maximum throughput.
259fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
260    // Quick check: any whitespace at all?
261    // Use SIMD memchr2 to check for both \n and \r simultaneously
262    if memchr::memchr2(b'\n', b'\r', data).is_none()
263        && !data.iter().any(|&b| b == b' ' || b == b'\t')
264    {
265        // No whitespace — decode directly from borrowed data
266        if data.len() >= PARALLEL_ENCODE_THRESHOLD {
267            return decode_parallel(data, out);
268        }
269        return decode_borrowed_clean(out, data);
270    }
271
272    // Strip newlines from entire input in a single pass using SIMD memchr.
273    let mut clean = Vec::with_capacity(data.len());
274    let mut last = 0;
275    for pos in memchr::memchr_iter(b'\n', data) {
276        if pos > last {
277            clean.extend_from_slice(&data[last..pos]);
278        }
279        last = pos + 1;
280    }
281    if last < data.len() {
282        clean.extend_from_slice(&data[last..]);
283    }
284
285    // Handle rare non-newline whitespace (CR, tab, etc.)
286    if clean.iter().any(|&b| is_whitespace(b)) {
287        clean.retain(|&b| !is_whitespace(b));
288    }
289
290    // Parallel decode for large inputs
291    if clean.len() >= PARALLEL_ENCODE_THRESHOLD {
292        return decode_parallel(&clean, out);
293    }
294
295    decode_owned_clean(&mut clean, out)
296}
297
298/// Decode clean base64 data in parallel chunks.
299fn decode_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
300    let num_threads = rayon::current_num_threads().max(1);
301    // Each chunk must be aligned to 4 bytes (base64 quadruplet boundary)
302    let raw_chunk = (data.len() + num_threads - 1) / num_threads;
303    let chunk_size = ((raw_chunk + 3) / 4) * 4;
304
305    // Check if last chunk has padding — only the very last chunk can have '='
306    // Split so that all but the last chunk are padless and 4-aligned
307    let decoded_chunks: Vec<Result<Vec<u8>, _>> = data
308        .par_chunks(chunk_size)
309        .map(|chunk| match BASE64_ENGINE.decode_to_vec(chunk) {
310            Ok(decoded) => Ok(decoded),
311            Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
312        })
313        .collect();
314
315    for chunk_result in decoded_chunks {
316        let chunk = chunk_result?;
317        out.write_all(&chunk)?;
318    }
319
320    Ok(())
321}
322
323/// Decode a clean (no whitespace) owned buffer in-place with SIMD.
324fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
325    if data.is_empty() {
326        return Ok(());
327    }
328    match BASE64_ENGINE.decode_inplace(data) {
329        Ok(decoded) => out.write_all(decoded),
330        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
331    }
332}
333
334/// Decode clean base64 data (no whitespace) from a borrowed slice.
335fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
336    if data.is_empty() {
337        return Ok(());
338    }
339    match BASE64_ENGINE.decode_to_vec(data) {
340        Ok(decoded) => {
341            out.write_all(&decoded)?;
342            Ok(())
343        }
344        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
345    }
346}
347
348/// Strip non-base64 characters (for -i / --ignore-garbage).
349fn strip_non_base64(data: &[u8]) -> Vec<u8> {
350    data.iter()
351        .copied()
352        .filter(|&b| is_base64_char(b))
353        .collect()
354}
355
356/// Check if a byte is a valid base64 alphabet character or padding.
357#[inline]
358fn is_base64_char(b: u8) -> bool {
359    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
360}
361
362/// Check if a byte is ASCII whitespace.
363#[inline]
364fn is_whitespace(b: u8) -> bool {
365    matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
366}
367
368/// Stream-encode from a reader to a writer. Used for stdin processing.
369/// Uses 4MB read chunks and batches wrapped output for minimum syscalls.
370/// The caller is expected to provide a suitably buffered or raw fd writer.
371pub fn encode_stream(
372    reader: &mut impl Read,
373    wrap_col: usize,
374    writer: &mut impl Write,
375) -> io::Result<()> {
376    let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
377
378    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
379    let mut encode_buf = vec![0u8; encode_buf_size];
380
381    if wrap_col == 0 {
382        // No wrapping: encode each 4MB chunk and write directly.
383        loop {
384            let n = read_full(reader, &mut buf)?;
385            if n == 0 {
386                break;
387            }
388            let enc_len = BASE64_ENGINE.encoded_length(n);
389            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
390            writer.write_all(encoded)?;
391        }
392    } else {
393        // Wrapping: batch wrapped output into a pre-allocated buffer.
394        // For 4MB input at 76-col wrap, wrapped output is ~5.6MB.
395        let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
396        let mut wrap_buf = vec![0u8; max_wrapped];
397        let mut col = 0usize;
398
399        loop {
400            let n = read_full(reader, &mut buf)?;
401            if n == 0 {
402                break;
403            }
404            let enc_len = BASE64_ENGINE.encoded_length(n);
405            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
406
407            // Build wrapped output in wrap_buf, then single write.
408            let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
409            writer.write_all(&wrap_buf[..wp])?;
410        }
411
412        if col > 0 {
413            writer.write_all(b"\n")?;
414        }
415    }
416
417    Ok(())
418}
419
420/// Build wrapped output into a pre-allocated buffer.
421/// Returns the number of bytes written to wrap_buf.
422/// Updates `col` to track the current column position across calls.
423#[inline]
424fn build_wrapped_output(
425    data: &[u8],
426    wrap_col: usize,
427    col: &mut usize,
428    wrap_buf: &mut [u8],
429) -> usize {
430    let mut rp = 0;
431    let mut wp = 0;
432
433    while rp < data.len() {
434        let space = wrap_col - *col;
435        let avail = data.len() - rp;
436
437        if avail <= space {
438            wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
439            wp += avail;
440            *col += avail;
441            if *col == wrap_col {
442                wrap_buf[wp] = b'\n';
443                wp += 1;
444                *col = 0;
445            }
446            break;
447        } else {
448            wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
449            wp += space;
450            wrap_buf[wp] = b'\n';
451            wp += 1;
452            rp += space;
453            *col = 0;
454        }
455    }
456
457    wp
458}
459
460/// Stream-decode from a reader to a writer. Used for stdin processing.
461/// Reads 4MB chunks, strips whitespace, decodes, and writes incrementally.
462/// Handles base64 quadruplet boundaries across chunk reads.
463pub fn decode_stream(
464    reader: &mut impl Read,
465    ignore_garbage: bool,
466    writer: &mut impl Write,
467) -> io::Result<()> {
468    const READ_CHUNK: usize = 4 * 1024 * 1024;
469    let mut buf = vec![0u8; READ_CHUNK];
470    let mut clean = Vec::with_capacity(READ_CHUNK);
471    let mut carry: Vec<u8> = Vec::with_capacity(4);
472
473    loop {
474        let n = read_full(reader, &mut buf)?;
475        if n == 0 {
476            break;
477        }
478
479        // Build clean buffer: carry-over + stripped chunk
480        clean.clear();
481        clean.extend_from_slice(&carry);
482        carry.clear();
483
484        let chunk = &buf[..n];
485        if ignore_garbage {
486            clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
487        } else {
488            // Strip newlines using SIMD memchr
489            let mut last = 0;
490            for pos in memchr::memchr_iter(b'\n', chunk) {
491                if pos > last {
492                    clean.extend_from_slice(&chunk[last..pos]);
493                }
494                last = pos + 1;
495            }
496            if last < n {
497                clean.extend_from_slice(&chunk[last..]);
498            }
499            // Handle rare non-newline whitespace
500            if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
501                clean.retain(|&b| !is_whitespace(b));
502            }
503        }
504
505        let is_last = n < READ_CHUNK;
506
507        if is_last {
508            // Last chunk: decode everything (including padding)
509            decode_owned_clean(&mut clean, writer)?;
510        } else {
511            // Save incomplete base64 quadruplet for next iteration
512            let decode_len = (clean.len() / 4) * 4;
513            if decode_len < clean.len() {
514                carry.extend_from_slice(&clean[decode_len..]);
515            }
516            if decode_len > 0 {
517                clean.truncate(decode_len);
518                decode_owned_clean(&mut clean, writer)?;
519            }
520        }
521    }
522
523    // Handle any remaining carry-over bytes
524    if !carry.is_empty() {
525        decode_owned_clean(&mut carry, writer)?;
526    }
527
528    Ok(())
529}
530
531/// Read as many bytes as possible into buf, retrying on partial reads.
532fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
533    let mut total = 0;
534    while total < buf.len() {
535        match reader.read(&mut buf[total..]) {
536            Ok(0) => break,
537            Ok(n) => total += n,
538            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
539            Err(e) => return Err(e),
540        }
541    }
542    Ok(total)
543}