Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8/// Streaming encode chunk: 12MB aligned to 3 bytes for maximum throughput.
9const STREAM_ENCODE_CHUNK: usize = 12 * 1024 * 1024 - (12 * 1024 * 1024 % 3);
10
11/// Chunk size for no-wrap encoding: 2MB aligned to 3 bytes.
12/// Smaller than before (was 8MB) for better L2 cache behavior and
13/// faster initial buffer allocation (fewer page faults).
14const NOWRAP_CHUNK: usize = 2 * 1024 * 1024 - (2 * 1024 * 1024 % 3);
15
16/// Minimum input size for parallel encoding/decoding.
17/// Set high enough to avoid rayon thread pool init cost (~0.5-1ms per process)
18/// which dominates for inputs under 32MB where SIMD encode is already very fast.
19const PARALLEL_ENCODE_THRESHOLD: usize = 32 * 1024 * 1024;
20
21/// Encode data and write to output with line wrapping.
22/// Uses SIMD encoding with reusable buffers for maximum throughput.
23pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
24    if data.is_empty() {
25        return Ok(());
26    }
27
28    if wrap_col == 0 {
29        return encode_no_wrap(data, out);
30    }
31
32    encode_wrapped(data, wrap_col, out)
33}
34
35/// Encode without wrapping using parallel SIMD encoding for large inputs.
36fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
37    if data.len() >= PARALLEL_ENCODE_THRESHOLD {
38        // Split into per-thread chunks aligned to 3-byte boundaries
39        let num_threads = rayon::current_num_threads().max(1);
40        let raw_chunk = (data.len() + num_threads - 1) / num_threads;
41        // Align to 3 bytes for clean base64 boundaries (no padding mid-stream)
42        let chunk_size = ((raw_chunk + 2) / 3) * 3;
43
44        let encoded_chunks: Vec<Vec<u8>> = data
45            .par_chunks(chunk_size)
46            .map(|chunk| {
47                let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
48                let mut buf = vec![0u8; enc_len];
49                let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
50                let len = encoded.len();
51                buf.truncate(len);
52                buf
53            })
54            .collect();
55
56        for chunk in &encoded_chunks {
57            out.write_all(chunk)?;
58        }
59        return Ok(());
60    }
61
62    // Size buffer for actual data, not max chunk size.
63    // For 100KB input, allocates ~134KB instead of ~2.7MB.
64    let actual_chunk = NOWRAP_CHUNK.min(data.len());
65    let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
66    let mut buf = vec![0u8; enc_max];
67
68    for chunk in data.chunks(NOWRAP_CHUNK) {
69        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
70        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
71        out.write_all(encoded)?;
72    }
73    Ok(())
74}
75
76/// Encode with line wrapping. For large inputs, uses parallel encoding.
77fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
78    let bytes_per_line = wrap_col * 3 / 4;
79
80    if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line > 0 {
81        // Parallel: split input into chunks aligned to bytes_per_line (= 3-byte aligned)
82        // so each chunk produces complete lines (no cross-chunk line splitting).
83        let num_threads = rayon::current_num_threads().max(1);
84        let lines_per_thread = ((data.len() / bytes_per_line) + num_threads - 1) / num_threads;
85        let chunk_input = (lines_per_thread * bytes_per_line).max(bytes_per_line);
86
87        let wrapped_chunks: Vec<Vec<u8>> = data
88            .par_chunks(chunk_input)
89            .map(|chunk| {
90                let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
91                let mut encode_buf = vec![0u8; enc_len];
92                let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
93
94                // Wrap the encoded output
95                let line_out = wrap_col + 1;
96                let max_lines = (encoded.len() + wrap_col - 1) / wrap_col + 1;
97                let mut wrap_buf = vec![0u8; max_lines * line_out];
98                let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
99                wrap_buf.truncate(wp);
100                wrap_buf
101            })
102            .collect();
103
104        for chunk in &wrapped_chunks {
105            out.write_all(chunk)?;
106        }
107        return Ok(());
108    }
109
110    // Sequential path: 2MB chunks fit in L2 cache and reduce initial allocation.
111    // Cap buffer sizes to actual data size to avoid over-allocation for small inputs.
112    // For 100KB input: allocates ~270KB instead of ~5.5MB.
113    let lines_per_chunk = (2 * 1024 * 1024) / bytes_per_line.max(1);
114    let chunk_input = lines_per_chunk * bytes_per_line.max(1);
115    let effective_chunk = chunk_input.max(1).min(data.len());
116    let chunk_encoded_max = BASE64_ENGINE.encoded_length(effective_chunk);
117    let mut encode_buf = vec![0u8; chunk_encoded_max];
118    let effective_lines = effective_chunk / bytes_per_line.max(1) + 1;
119    let wrapped_max = (effective_lines + 1) * (wrap_col + 1);
120    let mut wrap_buf = vec![0u8; wrapped_max];
121
122    for chunk in data.chunks(chunk_input.max(1)) {
123        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
124        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
125        let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
126        out.write_all(&wrap_buf[..wp])?;
127    }
128
129    Ok(())
130}
131
132/// Wrap encoded base64 data with newlines at `wrap_col` columns.
133/// Returns number of bytes written to `wrap_buf`.
134#[inline]
135fn wrap_encoded(encoded: &[u8], wrap_col: usize, wrap_buf: &mut [u8]) -> usize {
136    let line_out = wrap_col + 1;
137    let mut rp = 0;
138    let mut wp = 0;
139
140    // Unrolled: process 4 lines per iteration
141    while rp + 4 * wrap_col <= encoded.len() {
142        unsafe {
143            let src = encoded.as_ptr().add(rp);
144            let dst = wrap_buf.as_mut_ptr().add(wp);
145
146            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
147            *dst.add(wrap_col) = b'\n';
148
149            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
150            *dst.add(line_out + wrap_col) = b'\n';
151
152            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
153            *dst.add(2 * line_out + wrap_col) = b'\n';
154
155            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
156            *dst.add(3 * line_out + wrap_col) = b'\n';
157        }
158        rp += 4 * wrap_col;
159        wp += 4 * line_out;
160    }
161
162    // Remaining full lines
163    while rp + wrap_col <= encoded.len() {
164        wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
165        wp += wrap_col;
166        wrap_buf[wp] = b'\n';
167        wp += 1;
168        rp += wrap_col;
169    }
170
171    // Partial last line
172    if rp < encoded.len() {
173        let remaining = encoded.len() - rp;
174        wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
175        wp += remaining;
176        wrap_buf[wp] = b'\n';
177        wp += 1;
178    }
179
180    wp
181}
182
183/// Decode base64 data and write to output (borrows data, allocates clean buffer).
184/// When `ignore_garbage` is true, strip all non-base64 characters.
185/// When false, only strip whitespace (standard behavior).
186pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
187    if data.is_empty() {
188        return Ok(());
189    }
190
191    if ignore_garbage {
192        let mut cleaned = strip_non_base64(data);
193        return decode_owned_clean(&mut cleaned, out);
194    }
195
196    // Fast path: strip newlines with memchr (SIMD), then SIMD decode
197    decode_stripping_whitespace(data, out)
198}
199
200/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
201/// Avoids a full buffer copy by stripping whitespace in the existing allocation,
202/// then decoding in-place. Ideal when the caller already has an owned Vec.
203pub fn decode_owned(
204    data: &mut Vec<u8>,
205    ignore_garbage: bool,
206    out: &mut impl Write,
207) -> io::Result<()> {
208    if data.is_empty() {
209        return Ok(());
210    }
211
212    if ignore_garbage {
213        data.retain(|&b| is_base64_char(b));
214    } else {
215        strip_whitespace_inplace(data);
216    }
217
218    decode_owned_clean(data, out)
219}
220
221/// Strip all whitespace from a Vec in-place using SIMD memchr for newlines
222/// and a fallback scan for rare non-newline whitespace.
223fn strip_whitespace_inplace(data: &mut Vec<u8>) {
224    // First, collect newline positions using SIMD memchr.
225    let positions: Vec<usize> = memchr::memchr_iter(b'\n', data.as_slice()).collect();
226
227    if positions.is_empty() {
228        // No newlines; check for other whitespace only.
229        if data.iter().any(|&b| is_whitespace(b)) {
230            data.retain(|&b| !is_whitespace(b));
231        }
232        return;
233    }
234
235    // Compact data in-place, removing newlines using copy_within.
236    let mut wp = 0;
237    let mut rp = 0;
238
239    for &pos in &positions {
240        if pos > rp {
241            let len = pos - rp;
242            data.copy_within(rp..pos, wp);
243            wp += len;
244        }
245        rp = pos + 1;
246    }
247
248    let data_len = data.len();
249    if rp < data_len {
250        let len = data_len - rp;
251        data.copy_within(rp..data_len, wp);
252        wp += len;
253    }
254
255    data.truncate(wp);
256
257    // Handle rare non-newline whitespace (CR, tab, etc.)
258    if data.iter().any(|&b| is_whitespace(b)) {
259        data.retain(|&b| !is_whitespace(b));
260    }
261}
262
263/// Decode by stripping all whitespace from the entire input at once,
264/// then performing a single SIMD decode pass. Used when data is borrowed.
265/// For large inputs, decodes in parallel chunks for maximum throughput.
266fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
267    // Quick check: any whitespace at all?
268    // Use SIMD memchr2 to check for both \n and \r simultaneously
269    if memchr::memchr2(b'\n', b'\r', data).is_none()
270        && !data.iter().any(|&b| b == b' ' || b == b'\t')
271    {
272        // No whitespace — decode directly from borrowed data
273        if data.len() >= PARALLEL_ENCODE_THRESHOLD {
274            return decode_parallel(data, out);
275        }
276        return decode_borrowed_clean(out, data);
277    }
278
279    // Strip newlines from entire input in a single pass using SIMD memchr.
280    let mut clean = Vec::with_capacity(data.len());
281    let mut last = 0;
282    for pos in memchr::memchr_iter(b'\n', data) {
283        if pos > last {
284            clean.extend_from_slice(&data[last..pos]);
285        }
286        last = pos + 1;
287    }
288    if last < data.len() {
289        clean.extend_from_slice(&data[last..]);
290    }
291
292    // Handle rare non-newline whitespace (CR, tab, etc.)
293    if clean.iter().any(|&b| is_whitespace(b)) {
294        clean.retain(|&b| !is_whitespace(b));
295    }
296
297    // Parallel decode for large inputs
298    if clean.len() >= PARALLEL_ENCODE_THRESHOLD {
299        return decode_parallel(&clean, out);
300    }
301
302    decode_owned_clean(&mut clean, out)
303}
304
305/// Decode clean base64 data in parallel chunks.
306fn decode_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
307    let num_threads = rayon::current_num_threads().max(1);
308    // Each chunk must be aligned to 4 bytes (base64 quadruplet boundary)
309    let raw_chunk = (data.len() + num_threads - 1) / num_threads;
310    let chunk_size = ((raw_chunk + 3) / 4) * 4;
311
312    // Check if last chunk has padding — only the very last chunk can have '='
313    // Split so that all but the last chunk are padless and 4-aligned
314    let decoded_chunks: Vec<Result<Vec<u8>, _>> = data
315        .par_chunks(chunk_size)
316        .map(|chunk| match BASE64_ENGINE.decode_to_vec(chunk) {
317            Ok(decoded) => Ok(decoded),
318            Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
319        })
320        .collect();
321
322    for chunk_result in decoded_chunks {
323        let chunk = chunk_result?;
324        out.write_all(&chunk)?;
325    }
326
327    Ok(())
328}
329
330/// Decode a clean (no whitespace) owned buffer in-place with SIMD.
331fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
332    if data.is_empty() {
333        return Ok(());
334    }
335    match BASE64_ENGINE.decode_inplace(data) {
336        Ok(decoded) => out.write_all(decoded),
337        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
338    }
339}
340
341/// Decode clean base64 data (no whitespace) from a borrowed slice.
342fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
343    if data.is_empty() {
344        return Ok(());
345    }
346    match BASE64_ENGINE.decode_to_vec(data) {
347        Ok(decoded) => {
348            out.write_all(&decoded)?;
349            Ok(())
350        }
351        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
352    }
353}
354
355/// Strip non-base64 characters (for -i / --ignore-garbage).
356fn strip_non_base64(data: &[u8]) -> Vec<u8> {
357    data.iter()
358        .copied()
359        .filter(|&b| is_base64_char(b))
360        .collect()
361}
362
363/// Check if a byte is a valid base64 alphabet character or padding.
364#[inline]
365fn is_base64_char(b: u8) -> bool {
366    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
367}
368
369/// Check if a byte is ASCII whitespace.
370#[inline]
371fn is_whitespace(b: u8) -> bool {
372    matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
373}
374
375/// Stream-encode from a reader to a writer. Used for stdin processing.
376/// Uses 4MB read chunks and batches wrapped output for minimum syscalls.
377/// The caller is expected to provide a suitably buffered or raw fd writer.
378pub fn encode_stream(
379    reader: &mut impl Read,
380    wrap_col: usize,
381    writer: &mut impl Write,
382) -> io::Result<()> {
383    let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
384
385    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
386    let mut encode_buf = vec![0u8; encode_buf_size];
387
388    if wrap_col == 0 {
389        // No wrapping: encode each 4MB chunk and write directly.
390        loop {
391            let n = read_full(reader, &mut buf)?;
392            if n == 0 {
393                break;
394            }
395            let enc_len = BASE64_ENGINE.encoded_length(n);
396            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
397            writer.write_all(encoded)?;
398        }
399    } else {
400        // Wrapping: batch wrapped output into a pre-allocated buffer.
401        // For 4MB input at 76-col wrap, wrapped output is ~5.6MB.
402        let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
403        let mut wrap_buf = vec![0u8; max_wrapped];
404        let mut col = 0usize;
405
406        loop {
407            let n = read_full(reader, &mut buf)?;
408            if n == 0 {
409                break;
410            }
411            let enc_len = BASE64_ENGINE.encoded_length(n);
412            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
413
414            // Build wrapped output in wrap_buf, then single write.
415            let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
416            writer.write_all(&wrap_buf[..wp])?;
417        }
418
419        if col > 0 {
420            writer.write_all(b"\n")?;
421        }
422    }
423
424    Ok(())
425}
426
427/// Build wrapped output into a pre-allocated buffer.
428/// Returns the number of bytes written to wrap_buf.
429/// Updates `col` to track the current column position across calls.
430#[inline]
431fn build_wrapped_output(
432    data: &[u8],
433    wrap_col: usize,
434    col: &mut usize,
435    wrap_buf: &mut [u8],
436) -> usize {
437    let mut rp = 0;
438    let mut wp = 0;
439
440    while rp < data.len() {
441        let space = wrap_col - *col;
442        let avail = data.len() - rp;
443
444        if avail <= space {
445            wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
446            wp += avail;
447            *col += avail;
448            if *col == wrap_col {
449                wrap_buf[wp] = b'\n';
450                wp += 1;
451                *col = 0;
452            }
453            break;
454        } else {
455            wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
456            wp += space;
457            wrap_buf[wp] = b'\n';
458            wp += 1;
459            rp += space;
460            *col = 0;
461        }
462    }
463
464    wp
465}
466
467/// Stream-decode from a reader to a writer. Used for stdin processing.
468/// Reads 4MB chunks, strips whitespace, decodes, and writes incrementally.
469/// Handles base64 quadruplet boundaries across chunk reads.
470pub fn decode_stream(
471    reader: &mut impl Read,
472    ignore_garbage: bool,
473    writer: &mut impl Write,
474) -> io::Result<()> {
475    const READ_CHUNK: usize = 4 * 1024 * 1024;
476    let mut buf = vec![0u8; READ_CHUNK];
477    let mut clean = Vec::with_capacity(READ_CHUNK);
478    let mut carry: Vec<u8> = Vec::with_capacity(4);
479
480    loop {
481        let n = read_full(reader, &mut buf)?;
482        if n == 0 {
483            break;
484        }
485
486        // Build clean buffer: carry-over + stripped chunk
487        clean.clear();
488        clean.extend_from_slice(&carry);
489        carry.clear();
490
491        let chunk = &buf[..n];
492        if ignore_garbage {
493            clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
494        } else {
495            // Strip newlines using SIMD memchr
496            let mut last = 0;
497            for pos in memchr::memchr_iter(b'\n', chunk) {
498                if pos > last {
499                    clean.extend_from_slice(&chunk[last..pos]);
500                }
501                last = pos + 1;
502            }
503            if last < n {
504                clean.extend_from_slice(&chunk[last..]);
505            }
506            // Handle rare non-newline whitespace
507            if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
508                clean.retain(|&b| !is_whitespace(b));
509            }
510        }
511
512        let is_last = n < READ_CHUNK;
513
514        if is_last {
515            // Last chunk: decode everything (including padding)
516            decode_owned_clean(&mut clean, writer)?;
517        } else {
518            // Save incomplete base64 quadruplet for next iteration
519            let decode_len = (clean.len() / 4) * 4;
520            if decode_len < clean.len() {
521                carry.extend_from_slice(&clean[decode_len..]);
522            }
523            if decode_len > 0 {
524                clean.truncate(decode_len);
525                decode_owned_clean(&mut clean, writer)?;
526            }
527        }
528    }
529
530    // Handle any remaining carry-over bytes
531    if !carry.is_empty() {
532        decode_owned_clean(&mut carry, writer)?;
533    }
534
535    Ok(())
536}
537
538/// Read as many bytes as possible into buf, retrying on partial reads.
539fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
540    let mut total = 0;
541    while total < buf.len() {
542        match reader.read(&mut buf[total..]) {
543            Ok(0) => break,
544            Ok(n) => total += n,
545            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
546            Err(e) => return Err(e),
547        }
548    }
549    Ok(total)
550}