Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8/// Streaming encode chunk: 12MB aligned to 3 bytes for maximum throughput.
9const STREAM_ENCODE_CHUNK: usize = 12 * 1024 * 1024 - (12 * 1024 * 1024 % 3);
10
11/// Chunk size for no-wrap encoding: 8MB aligned to 3 bytes.
12const NOWRAP_CHUNK: usize = 8 * 1024 * 1024 - (8 * 1024 * 1024 % 3);
13
14/// Minimum input size for parallel encoding.
15const PARALLEL_ENCODE_THRESHOLD: usize = 1024 * 1024;
16
17/// Encode data and write to output with line wrapping.
18/// Uses SIMD encoding with reusable buffers for maximum throughput.
19pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
20    if data.is_empty() {
21        return Ok(());
22    }
23
24    if wrap_col == 0 {
25        return encode_no_wrap(data, out);
26    }
27
28    encode_wrapped(data, wrap_col, out)
29}
30
31/// Encode without wrapping using parallel SIMD encoding for large inputs.
32fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
33    if data.len() >= PARALLEL_ENCODE_THRESHOLD {
34        // Split into per-thread chunks aligned to 3-byte boundaries
35        let num_threads = rayon::current_num_threads().max(1);
36        let raw_chunk = (data.len() + num_threads - 1) / num_threads;
37        // Align to 3 bytes for clean base64 boundaries (no padding mid-stream)
38        let chunk_size = ((raw_chunk + 2) / 3) * 3;
39
40        let encoded_chunks: Vec<Vec<u8>> = data
41            .par_chunks(chunk_size)
42            .map(|chunk| {
43                let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
44                let mut buf = vec![0u8; enc_len];
45                let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
46                let len = encoded.len();
47                buf.truncate(len);
48                buf
49            })
50            .collect();
51
52        for chunk in &encoded_chunks {
53            out.write_all(chunk)?;
54        }
55        return Ok(());
56    }
57
58    let enc_max = BASE64_ENGINE.encoded_length(NOWRAP_CHUNK);
59    let mut buf = vec![0u8; enc_max];
60
61    for chunk in data.chunks(NOWRAP_CHUNK) {
62        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
63        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
64        out.write_all(encoded)?;
65    }
66    Ok(())
67}
68
69/// Encode with line wrapping. For large inputs, uses parallel encoding.
70fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
71    let bytes_per_line = wrap_col * 3 / 4;
72
73    if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line > 0 {
74        // Parallel: split input into chunks aligned to bytes_per_line (= 3-byte aligned)
75        // so each chunk produces complete lines (no cross-chunk line splitting).
76        let num_threads = rayon::current_num_threads().max(1);
77        let lines_per_thread = ((data.len() / bytes_per_line) + num_threads - 1) / num_threads;
78        let chunk_input = (lines_per_thread * bytes_per_line).max(bytes_per_line);
79
80        let wrapped_chunks: Vec<Vec<u8>> = data
81            .par_chunks(chunk_input)
82            .map(|chunk| {
83                let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
84                let mut encode_buf = vec![0u8; enc_len];
85                let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
86
87                // Wrap the encoded output
88                let line_out = wrap_col + 1;
89                let max_lines = (encoded.len() + wrap_col - 1) / wrap_col + 1;
90                let mut wrap_buf = vec![0u8; max_lines * line_out];
91                let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
92                wrap_buf.truncate(wp);
93                wrap_buf
94            })
95            .collect();
96
97        for chunk in &wrapped_chunks {
98            out.write_all(chunk)?;
99        }
100        return Ok(());
101    }
102
103    // Sequential path
104    let lines_per_chunk = (8 * 1024 * 1024) / bytes_per_line.max(1);
105    let chunk_input = lines_per_chunk * bytes_per_line.max(1);
106    let chunk_encoded_max = BASE64_ENGINE.encoded_length(chunk_input.max(1));
107    let mut encode_buf = vec![0u8; chunk_encoded_max];
108    let wrapped_max = (lines_per_chunk + 1) * (wrap_col + 1);
109    let mut wrap_buf = vec![0u8; wrapped_max];
110
111    for chunk in data.chunks(chunk_input.max(1)) {
112        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
113        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
114        let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
115        out.write_all(&wrap_buf[..wp])?;
116    }
117
118    Ok(())
119}
120
121/// Wrap encoded base64 data with newlines at `wrap_col` columns.
122/// Returns number of bytes written to `wrap_buf`.
123#[inline]
124fn wrap_encoded(encoded: &[u8], wrap_col: usize, wrap_buf: &mut [u8]) -> usize {
125    let line_out = wrap_col + 1;
126    let mut rp = 0;
127    let mut wp = 0;
128
129    // Unrolled: process 4 lines per iteration
130    while rp + 4 * wrap_col <= encoded.len() {
131        unsafe {
132            let src = encoded.as_ptr().add(rp);
133            let dst = wrap_buf.as_mut_ptr().add(wp);
134
135            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
136            *dst.add(wrap_col) = b'\n';
137
138            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
139            *dst.add(line_out + wrap_col) = b'\n';
140
141            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
142            *dst.add(2 * line_out + wrap_col) = b'\n';
143
144            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
145            *dst.add(3 * line_out + wrap_col) = b'\n';
146        }
147        rp += 4 * wrap_col;
148        wp += 4 * line_out;
149    }
150
151    // Remaining full lines
152    while rp + wrap_col <= encoded.len() {
153        wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
154        wp += wrap_col;
155        wrap_buf[wp] = b'\n';
156        wp += 1;
157        rp += wrap_col;
158    }
159
160    // Partial last line
161    if rp < encoded.len() {
162        let remaining = encoded.len() - rp;
163        wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
164        wp += remaining;
165        wrap_buf[wp] = b'\n';
166        wp += 1;
167    }
168
169    wp
170}
171
172/// Decode base64 data and write to output (borrows data, allocates clean buffer).
173/// When `ignore_garbage` is true, strip all non-base64 characters.
174/// When false, only strip whitespace (standard behavior).
175pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
176    if data.is_empty() {
177        return Ok(());
178    }
179
180    if ignore_garbage {
181        let mut cleaned = strip_non_base64(data);
182        return decode_owned_clean(&mut cleaned, out);
183    }
184
185    // Fast path: strip newlines with memchr (SIMD), then SIMD decode
186    decode_stripping_whitespace(data, out)
187}
188
189/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
190/// Avoids a full buffer copy by stripping whitespace in the existing allocation,
191/// then decoding in-place. Ideal when the caller already has an owned Vec.
192pub fn decode_owned(
193    data: &mut Vec<u8>,
194    ignore_garbage: bool,
195    out: &mut impl Write,
196) -> io::Result<()> {
197    if data.is_empty() {
198        return Ok(());
199    }
200
201    if ignore_garbage {
202        data.retain(|&b| is_base64_char(b));
203    } else {
204        strip_whitespace_inplace(data);
205    }
206
207    decode_owned_clean(data, out)
208}
209
210/// Strip all whitespace from a Vec in-place using SIMD memchr for newlines
211/// and a fallback scan for rare non-newline whitespace.
212fn strip_whitespace_inplace(data: &mut Vec<u8>) {
213    // First, collect newline positions using SIMD memchr.
214    let positions: Vec<usize> = memchr::memchr_iter(b'\n', data.as_slice()).collect();
215
216    if positions.is_empty() {
217        // No newlines; check for other whitespace only.
218        if data.iter().any(|&b| is_whitespace(b)) {
219            data.retain(|&b| !is_whitespace(b));
220        }
221        return;
222    }
223
224    // Compact data in-place, removing newlines using copy_within.
225    let mut wp = 0;
226    let mut rp = 0;
227
228    for &pos in &positions {
229        if pos > rp {
230            let len = pos - rp;
231            data.copy_within(rp..pos, wp);
232            wp += len;
233        }
234        rp = pos + 1;
235    }
236
237    let data_len = data.len();
238    if rp < data_len {
239        let len = data_len - rp;
240        data.copy_within(rp..data_len, wp);
241        wp += len;
242    }
243
244    data.truncate(wp);
245
246    // Handle rare non-newline whitespace (CR, tab, etc.)
247    if data.iter().any(|&b| is_whitespace(b)) {
248        data.retain(|&b| !is_whitespace(b));
249    }
250}
251
252/// Decode by stripping all whitespace from the entire input at once,
253/// then performing a single SIMD decode pass. Used when data is borrowed.
254/// For large inputs, decodes in parallel chunks for maximum throughput.
255fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
256    // Quick check: any whitespace at all?
257    // Use SIMD memchr2 to check for both \n and \r simultaneously
258    if memchr::memchr2(b'\n', b'\r', data).is_none()
259        && !data.iter().any(|&b| b == b' ' || b == b'\t')
260    {
261        // No whitespace — decode directly from borrowed data
262        if data.len() >= PARALLEL_ENCODE_THRESHOLD {
263            return decode_parallel(data, out);
264        }
265        return decode_borrowed_clean(out, data);
266    }
267
268    // Strip newlines from entire input in a single pass using SIMD memchr.
269    let mut clean = Vec::with_capacity(data.len());
270    let mut last = 0;
271    for pos in memchr::memchr_iter(b'\n', data) {
272        if pos > last {
273            clean.extend_from_slice(&data[last..pos]);
274        }
275        last = pos + 1;
276    }
277    if last < data.len() {
278        clean.extend_from_slice(&data[last..]);
279    }
280
281    // Handle rare non-newline whitespace (CR, tab, etc.)
282    if clean.iter().any(|&b| is_whitespace(b)) {
283        clean.retain(|&b| !is_whitespace(b));
284    }
285
286    // Parallel decode for large inputs
287    if clean.len() >= PARALLEL_ENCODE_THRESHOLD {
288        return decode_parallel(&clean, out);
289    }
290
291    decode_owned_clean(&mut clean, out)
292}
293
294/// Decode clean base64 data in parallel chunks.
295fn decode_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
296    let num_threads = rayon::current_num_threads().max(1);
297    // Each chunk must be aligned to 4 bytes (base64 quadruplet boundary)
298    let raw_chunk = (data.len() + num_threads - 1) / num_threads;
299    let chunk_size = ((raw_chunk + 3) / 4) * 4;
300
301    // Check if last chunk has padding — only the very last chunk can have '='
302    // Split so that all but the last chunk are padless and 4-aligned
303    let decoded_chunks: Vec<Result<Vec<u8>, _>> = data
304        .par_chunks(chunk_size)
305        .map(|chunk| match BASE64_ENGINE.decode_to_vec(chunk) {
306            Ok(decoded) => Ok(decoded),
307            Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
308        })
309        .collect();
310
311    for chunk_result in decoded_chunks {
312        let chunk = chunk_result?;
313        out.write_all(&chunk)?;
314    }
315
316    Ok(())
317}
318
319/// Decode a clean (no whitespace) owned buffer in-place with SIMD.
320fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
321    if data.is_empty() {
322        return Ok(());
323    }
324    match BASE64_ENGINE.decode_inplace(data) {
325        Ok(decoded) => out.write_all(decoded),
326        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
327    }
328}
329
330/// Decode clean base64 data (no whitespace) from a borrowed slice.
331fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
332    if data.is_empty() {
333        return Ok(());
334    }
335    match BASE64_ENGINE.decode_to_vec(data) {
336        Ok(decoded) => {
337            out.write_all(&decoded)?;
338            Ok(())
339        }
340        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
341    }
342}
343
344/// Strip non-base64 characters (for -i / --ignore-garbage).
345fn strip_non_base64(data: &[u8]) -> Vec<u8> {
346    data.iter()
347        .copied()
348        .filter(|&b| is_base64_char(b))
349        .collect()
350}
351
352/// Check if a byte is a valid base64 alphabet character or padding.
353#[inline]
354fn is_base64_char(b: u8) -> bool {
355    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
356}
357
358/// Check if a byte is ASCII whitespace.
359#[inline]
360fn is_whitespace(b: u8) -> bool {
361    matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
362}
363
364/// Stream-encode from a reader to a writer. Used for stdin processing.
365/// Uses 4MB read chunks and batches wrapped output for minimum syscalls.
366/// The caller is expected to provide a suitably buffered or raw fd writer.
367pub fn encode_stream(
368    reader: &mut impl Read,
369    wrap_col: usize,
370    writer: &mut impl Write,
371) -> io::Result<()> {
372    let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
373
374    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
375    let mut encode_buf = vec![0u8; encode_buf_size];
376
377    if wrap_col == 0 {
378        // No wrapping: encode each 4MB chunk and write directly.
379        loop {
380            let n = read_full(reader, &mut buf)?;
381            if n == 0 {
382                break;
383            }
384            let enc_len = BASE64_ENGINE.encoded_length(n);
385            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
386            writer.write_all(encoded)?;
387        }
388    } else {
389        // Wrapping: batch wrapped output into a pre-allocated buffer.
390        // For 4MB input at 76-col wrap, wrapped output is ~5.6MB.
391        let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
392        let mut wrap_buf = vec![0u8; max_wrapped];
393        let mut col = 0usize;
394
395        loop {
396            let n = read_full(reader, &mut buf)?;
397            if n == 0 {
398                break;
399            }
400            let enc_len = BASE64_ENGINE.encoded_length(n);
401            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
402
403            // Build wrapped output in wrap_buf, then single write.
404            let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
405            writer.write_all(&wrap_buf[..wp])?;
406        }
407
408        if col > 0 {
409            writer.write_all(b"\n")?;
410        }
411    }
412
413    Ok(())
414}
415
416/// Build wrapped output into a pre-allocated buffer.
417/// Returns the number of bytes written to wrap_buf.
418/// Updates `col` to track the current column position across calls.
419#[inline]
420fn build_wrapped_output(
421    data: &[u8],
422    wrap_col: usize,
423    col: &mut usize,
424    wrap_buf: &mut [u8],
425) -> usize {
426    let mut rp = 0;
427    let mut wp = 0;
428
429    while rp < data.len() {
430        let space = wrap_col - *col;
431        let avail = data.len() - rp;
432
433        if avail <= space {
434            wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
435            wp += avail;
436            *col += avail;
437            if *col == wrap_col {
438                wrap_buf[wp] = b'\n';
439                wp += 1;
440                *col = 0;
441            }
442            break;
443        } else {
444            wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
445            wp += space;
446            wrap_buf[wp] = b'\n';
447            wp += 1;
448            rp += space;
449            *col = 0;
450        }
451    }
452
453    wp
454}
455
456/// Stream-decode from a reader to a writer. Used for stdin processing.
457/// Reads 4MB chunks, strips whitespace, decodes, and writes incrementally.
458/// Handles base64 quadruplet boundaries across chunk reads.
459pub fn decode_stream(
460    reader: &mut impl Read,
461    ignore_garbage: bool,
462    writer: &mut impl Write,
463) -> io::Result<()> {
464    const READ_CHUNK: usize = 4 * 1024 * 1024;
465    let mut buf = vec![0u8; READ_CHUNK];
466    let mut clean = Vec::with_capacity(READ_CHUNK);
467    let mut carry: Vec<u8> = Vec::with_capacity(4);
468
469    loop {
470        let n = read_full(reader, &mut buf)?;
471        if n == 0 {
472            break;
473        }
474
475        // Build clean buffer: carry-over + stripped chunk
476        clean.clear();
477        clean.extend_from_slice(&carry);
478        carry.clear();
479
480        let chunk = &buf[..n];
481        if ignore_garbage {
482            clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
483        } else {
484            // Strip newlines using SIMD memchr
485            let mut last = 0;
486            for pos in memchr::memchr_iter(b'\n', chunk) {
487                if pos > last {
488                    clean.extend_from_slice(&chunk[last..pos]);
489                }
490                last = pos + 1;
491            }
492            if last < n {
493                clean.extend_from_slice(&chunk[last..]);
494            }
495            // Handle rare non-newline whitespace
496            if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
497                clean.retain(|&b| !is_whitespace(b));
498            }
499        }
500
501        let is_last = n < READ_CHUNK;
502
503        if is_last {
504            // Last chunk: decode everything (including padding)
505            decode_owned_clean(&mut clean, writer)?;
506        } else {
507            // Save incomplete base64 quadruplet for next iteration
508            let decode_len = (clean.len() / 4) * 4;
509            if decode_len < clean.len() {
510                carry.extend_from_slice(&clean[decode_len..]);
511            }
512            if decode_len > 0 {
513                clean.truncate(decode_len);
514                decode_owned_clean(&mut clean, writer)?;
515            }
516        }
517    }
518
519    // Handle any remaining carry-over bytes
520    if !carry.is_empty() {
521        decode_owned_clean(&mut carry, writer)?;
522    }
523
524    Ok(())
525}
526
527/// Read as many bytes as possible into buf, retrying on partial reads.
528fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
529    let mut total = 0;
530    while total < buf.len() {
531        match reader.read(&mut buf[total..]) {
532            Ok(0) => break,
533            Ok(n) => total += n,
534            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
535            Err(e) => return Err(e),
536        }
537    }
538    Ok(total)
539}