Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8/// Streaming encode chunk: 4MB aligned to 3 bytes.
9/// Smaller than 12MB for less memory pressure while keeping syscall count low
10/// (25 reads for 100MB). SIMD encode is fast enough that I/O dominates.
11const STREAM_ENCODE_CHUNK: usize = 4 * 1024 * 1024 - (4 * 1024 * 1024 % 3);
12
13/// Chunk size for no-wrap encoding: 2MB aligned to 3 bytes.
14/// Smaller than before (was 8MB) for better L2 cache behavior and
15/// faster initial buffer allocation (fewer page faults).
16const NOWRAP_CHUNK: usize = 2 * 1024 * 1024 - (2 * 1024 * 1024 % 3);
17
18/// Minimum input size for parallel encoding/decoding.
19/// Set high enough to avoid rayon thread pool init cost (~0.5-1ms per process)
20/// which dominates for inputs under 32MB where SIMD encode is already very fast.
21const PARALLEL_ENCODE_THRESHOLD: usize = 32 * 1024 * 1024;
22
23/// Encode data and write to output with line wrapping.
24/// Uses SIMD encoding with reusable buffers for maximum throughput.
25pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
26    if data.is_empty() {
27        return Ok(());
28    }
29
30    if wrap_col == 0 {
31        return encode_no_wrap(data, out);
32    }
33
34    encode_wrapped(data, wrap_col, out)
35}
36
37/// Encode without wrapping using parallel SIMD encoding for large inputs.
38fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
39    if data.len() >= PARALLEL_ENCODE_THRESHOLD {
40        // Split into per-thread chunks aligned to 3-byte boundaries
41        let num_threads = rayon::current_num_threads().max(1);
42        let raw_chunk = (data.len() + num_threads - 1) / num_threads;
43        // Align to 3 bytes for clean base64 boundaries (no padding mid-stream)
44        let chunk_size = ((raw_chunk + 2) / 3) * 3;
45
46        let encoded_chunks: Vec<Vec<u8>> = data
47            .par_chunks(chunk_size)
48            .map(|chunk| {
49                let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
50                let mut buf = vec![0u8; enc_len];
51                let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
52                let len = encoded.len();
53                buf.truncate(len);
54                buf
55            })
56            .collect();
57
58        for chunk in &encoded_chunks {
59            out.write_all(chunk)?;
60        }
61        return Ok(());
62    }
63
64    // Size buffer for actual data, not max chunk size.
65    // For 100KB input, allocates ~134KB instead of ~2.7MB.
66    let actual_chunk = NOWRAP_CHUNK.min(data.len());
67    let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
68    let mut buf = vec![0u8; enc_max];
69
70    for chunk in data.chunks(NOWRAP_CHUNK) {
71        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
72        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
73        out.write_all(encoded)?;
74    }
75    Ok(())
76}
77
78/// Encode with line wrapping. For large inputs, uses parallel encoding.
79fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
80    let bytes_per_line = wrap_col * 3 / 4;
81
82    if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line > 0 {
83        // Parallel: split input into chunks aligned to bytes_per_line (= 3-byte aligned)
84        // so each chunk produces complete lines (no cross-chunk line splitting).
85        let num_threads = rayon::current_num_threads().max(1);
86        let lines_per_thread = ((data.len() / bytes_per_line) + num_threads - 1) / num_threads;
87        let chunk_input = (lines_per_thread * bytes_per_line).max(bytes_per_line);
88
89        let wrapped_chunks: Vec<Vec<u8>> = data
90            .par_chunks(chunk_input)
91            .map(|chunk| {
92                let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
93                let mut encode_buf = vec![0u8; enc_len];
94                let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
95
96                // Wrap the encoded output
97                let line_out = wrap_col + 1;
98                let max_lines = (encoded.len() + wrap_col - 1) / wrap_col + 1;
99                let mut wrap_buf = vec![0u8; max_lines * line_out];
100                let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
101                wrap_buf.truncate(wp);
102                wrap_buf
103            })
104            .collect();
105
106        for chunk in &wrapped_chunks {
107            out.write_all(chunk)?;
108        }
109        return Ok(());
110    }
111
112    // Sequential path: 2MB chunks fit in L2 cache and reduce initial allocation.
113    // Cap buffer sizes to actual data size to avoid over-allocation for small inputs.
114    // For 100KB input: allocates ~270KB instead of ~5.5MB.
115    let lines_per_chunk = (2 * 1024 * 1024) / bytes_per_line.max(1);
116    let chunk_input = lines_per_chunk * bytes_per_line.max(1);
117    let effective_chunk = chunk_input.max(1).min(data.len());
118    let chunk_encoded_max = BASE64_ENGINE.encoded_length(effective_chunk);
119    let mut encode_buf = vec![0u8; chunk_encoded_max];
120    let effective_lines = effective_chunk / bytes_per_line.max(1) + 1;
121    let wrapped_max = (effective_lines + 1) * (wrap_col + 1);
122    let mut wrap_buf = vec![0u8; wrapped_max];
123
124    for chunk in data.chunks(chunk_input.max(1)) {
125        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
126        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
127        let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
128        out.write_all(&wrap_buf[..wp])?;
129    }
130
131    Ok(())
132}
133
134/// Wrap encoded base64 data with newlines at `wrap_col` columns.
135/// Returns number of bytes written to `wrap_buf`.
136#[inline]
137fn wrap_encoded(encoded: &[u8], wrap_col: usize, wrap_buf: &mut [u8]) -> usize {
138    let line_out = wrap_col + 1;
139    let mut rp = 0;
140    let mut wp = 0;
141
142    // Unrolled: process 4 lines per iteration
143    while rp + 4 * wrap_col <= encoded.len() {
144        unsafe {
145            let src = encoded.as_ptr().add(rp);
146            let dst = wrap_buf.as_mut_ptr().add(wp);
147
148            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
149            *dst.add(wrap_col) = b'\n';
150
151            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
152            *dst.add(line_out + wrap_col) = b'\n';
153
154            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
155            *dst.add(2 * line_out + wrap_col) = b'\n';
156
157            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
158            *dst.add(3 * line_out + wrap_col) = b'\n';
159        }
160        rp += 4 * wrap_col;
161        wp += 4 * line_out;
162    }
163
164    // Remaining full lines
165    while rp + wrap_col <= encoded.len() {
166        wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
167        wp += wrap_col;
168        wrap_buf[wp] = b'\n';
169        wp += 1;
170        rp += wrap_col;
171    }
172
173    // Partial last line
174    if rp < encoded.len() {
175        let remaining = encoded.len() - rp;
176        wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
177        wp += remaining;
178        wrap_buf[wp] = b'\n';
179        wp += 1;
180    }
181
182    wp
183}
184
185/// Decode base64 data and write to output (borrows data, allocates clean buffer).
186/// When `ignore_garbage` is true, strip all non-base64 characters.
187/// When false, only strip whitespace (standard behavior).
188pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
189    if data.is_empty() {
190        return Ok(());
191    }
192
193    if ignore_garbage {
194        let mut cleaned = strip_non_base64(data);
195        return decode_owned_clean(&mut cleaned, out);
196    }
197
198    // Fast path: strip newlines with memchr (SIMD), then SIMD decode
199    decode_stripping_whitespace(data, out)
200}
201
202/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
203/// Avoids a full buffer copy by stripping whitespace in the existing allocation,
204/// then decoding in-place. Ideal when the caller already has an owned Vec.
205pub fn decode_owned(
206    data: &mut Vec<u8>,
207    ignore_garbage: bool,
208    out: &mut impl Write,
209) -> io::Result<()> {
210    if data.is_empty() {
211        return Ok(());
212    }
213
214    if ignore_garbage {
215        data.retain(|&b| is_base64_char(b));
216    } else {
217        strip_whitespace_inplace(data);
218    }
219
220    decode_owned_clean(data, out)
221}
222
223/// Strip all whitespace from a Vec in-place using SIMD memchr for newlines
224/// and a fallback scan for rare non-newline whitespace.
225fn strip_whitespace_inplace(data: &mut Vec<u8>) {
226    // First, collect newline positions using SIMD memchr.
227    let positions: Vec<usize> = memchr::memchr_iter(b'\n', data.as_slice()).collect();
228
229    if positions.is_empty() {
230        // No newlines; check for other whitespace only.
231        if data.iter().any(|&b| is_whitespace(b)) {
232            data.retain(|&b| !is_whitespace(b));
233        }
234        return;
235    }
236
237    // Compact data in-place, removing newlines using copy_within.
238    let mut wp = 0;
239    let mut rp = 0;
240
241    for &pos in &positions {
242        if pos > rp {
243            let len = pos - rp;
244            data.copy_within(rp..pos, wp);
245            wp += len;
246        }
247        rp = pos + 1;
248    }
249
250    let data_len = data.len();
251    if rp < data_len {
252        let len = data_len - rp;
253        data.copy_within(rp..data_len, wp);
254        wp += len;
255    }
256
257    data.truncate(wp);
258
259    // Handle rare non-newline whitespace (CR, tab, etc.)
260    if data.iter().any(|&b| is_whitespace(b)) {
261        data.retain(|&b| !is_whitespace(b));
262    }
263}
264
265/// Decode by stripping all whitespace from the entire input at once,
266/// then performing a single SIMD decode pass. Used when data is borrowed.
267/// For large inputs, decodes in parallel chunks for maximum throughput.
268fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
269    // Quick check: any whitespace at all?
270    // Use SIMD memchr2 to check for both \n and \r simultaneously
271    if memchr::memchr2(b'\n', b'\r', data).is_none()
272        && !data.iter().any(|&b| b == b' ' || b == b'\t')
273    {
274        // No whitespace — decode directly from borrowed data
275        if data.len() >= PARALLEL_ENCODE_THRESHOLD {
276            return decode_parallel(data, out);
277        }
278        return decode_borrowed_clean(out, data);
279    }
280
281    // Strip newlines from entire input in a single pass using SIMD memchr.
282    let mut clean = Vec::with_capacity(data.len());
283    let mut last = 0;
284    for pos in memchr::memchr_iter(b'\n', data) {
285        if pos > last {
286            clean.extend_from_slice(&data[last..pos]);
287        }
288        last = pos + 1;
289    }
290    if last < data.len() {
291        clean.extend_from_slice(&data[last..]);
292    }
293
294    // Handle rare non-newline whitespace (CR, tab, etc.)
295    if clean.iter().any(|&b| is_whitespace(b)) {
296        clean.retain(|&b| !is_whitespace(b));
297    }
298
299    // Parallel decode for large inputs
300    if clean.len() >= PARALLEL_ENCODE_THRESHOLD {
301        return decode_parallel(&clean, out);
302    }
303
304    decode_owned_clean(&mut clean, out)
305}
306
307/// Decode clean base64 data in parallel chunks.
308fn decode_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
309    let num_threads = rayon::current_num_threads().max(1);
310    // Each chunk must be aligned to 4 bytes (base64 quadruplet boundary)
311    let raw_chunk = (data.len() + num_threads - 1) / num_threads;
312    let chunk_size = ((raw_chunk + 3) / 4) * 4;
313
314    // Check if last chunk has padding — only the very last chunk can have '='
315    // Split so that all but the last chunk are padless and 4-aligned
316    let decoded_chunks: Vec<Result<Vec<u8>, _>> = data
317        .par_chunks(chunk_size)
318        .map(|chunk| match BASE64_ENGINE.decode_to_vec(chunk) {
319            Ok(decoded) => Ok(decoded),
320            Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
321        })
322        .collect();
323
324    for chunk_result in decoded_chunks {
325        let chunk = chunk_result?;
326        out.write_all(&chunk)?;
327    }
328
329    Ok(())
330}
331
332/// Decode a clean (no whitespace) owned buffer in-place with SIMD.
333fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
334    if data.is_empty() {
335        return Ok(());
336    }
337    match BASE64_ENGINE.decode_inplace(data) {
338        Ok(decoded) => out.write_all(decoded),
339        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
340    }
341}
342
343/// Decode clean base64 data (no whitespace) from a borrowed slice.
344fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
345    if data.is_empty() {
346        return Ok(());
347    }
348    match BASE64_ENGINE.decode_to_vec(data) {
349        Ok(decoded) => {
350            out.write_all(&decoded)?;
351            Ok(())
352        }
353        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
354    }
355}
356
357/// Strip non-base64 characters (for -i / --ignore-garbage).
358fn strip_non_base64(data: &[u8]) -> Vec<u8> {
359    data.iter()
360        .copied()
361        .filter(|&b| is_base64_char(b))
362        .collect()
363}
364
365/// Check if a byte is a valid base64 alphabet character or padding.
366#[inline]
367fn is_base64_char(b: u8) -> bool {
368    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
369}
370
371/// Check if a byte is ASCII whitespace.
372#[inline]
373fn is_whitespace(b: u8) -> bool {
374    matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
375}
376
377/// Stream-encode from a reader to a writer. Used for stdin processing.
378/// Uses 4MB read chunks and batches wrapped output for minimum syscalls.
379/// The caller is expected to provide a suitably buffered or raw fd writer.
380pub fn encode_stream(
381    reader: &mut impl Read,
382    wrap_col: usize,
383    writer: &mut impl Write,
384) -> io::Result<()> {
385    let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
386
387    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
388    let mut encode_buf = vec![0u8; encode_buf_size];
389
390    if wrap_col == 0 {
391        // No wrapping: encode each 4MB chunk and write directly.
392        loop {
393            let n = read_full(reader, &mut buf)?;
394            if n == 0 {
395                break;
396            }
397            let enc_len = BASE64_ENGINE.encoded_length(n);
398            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
399            writer.write_all(encoded)?;
400        }
401    } else {
402        // Wrapping: batch wrapped output into a pre-allocated buffer.
403        // For 4MB input at 76-col wrap, wrapped output is ~5.6MB.
404        let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
405        let mut wrap_buf = vec![0u8; max_wrapped];
406        let mut col = 0usize;
407
408        loop {
409            let n = read_full(reader, &mut buf)?;
410            if n == 0 {
411                break;
412            }
413            let enc_len = BASE64_ENGINE.encoded_length(n);
414            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
415
416            // Build wrapped output in wrap_buf, then single write.
417            let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
418            writer.write_all(&wrap_buf[..wp])?;
419        }
420
421        if col > 0 {
422            writer.write_all(b"\n")?;
423        }
424    }
425
426    Ok(())
427}
428
429/// Build wrapped output into a pre-allocated buffer.
430/// Returns the number of bytes written to wrap_buf.
431/// Updates `col` to track the current column position across calls.
432#[inline]
433fn build_wrapped_output(
434    data: &[u8],
435    wrap_col: usize,
436    col: &mut usize,
437    wrap_buf: &mut [u8],
438) -> usize {
439    let mut rp = 0;
440    let mut wp = 0;
441
442    while rp < data.len() {
443        let space = wrap_col - *col;
444        let avail = data.len() - rp;
445
446        if avail <= space {
447            wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
448            wp += avail;
449            *col += avail;
450            if *col == wrap_col {
451                wrap_buf[wp] = b'\n';
452                wp += 1;
453                *col = 0;
454            }
455            break;
456        } else {
457            wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
458            wp += space;
459            wrap_buf[wp] = b'\n';
460            wp += 1;
461            rp += space;
462            *col = 0;
463        }
464    }
465
466    wp
467}
468
469/// Stream-decode from a reader to a writer. Used for stdin processing.
470/// Reads 4MB chunks, strips whitespace, decodes, and writes incrementally.
471/// Handles base64 quadruplet boundaries across chunk reads.
472pub fn decode_stream(
473    reader: &mut impl Read,
474    ignore_garbage: bool,
475    writer: &mut impl Write,
476) -> io::Result<()> {
477    const READ_CHUNK: usize = 4 * 1024 * 1024;
478    let mut buf = vec![0u8; READ_CHUNK];
479    let mut clean = Vec::with_capacity(READ_CHUNK);
480    let mut carry: Vec<u8> = Vec::with_capacity(4);
481
482    loop {
483        let n = read_full(reader, &mut buf)?;
484        if n == 0 {
485            break;
486        }
487
488        // Build clean buffer: carry-over + stripped chunk
489        clean.clear();
490        clean.extend_from_slice(&carry);
491        carry.clear();
492
493        let chunk = &buf[..n];
494        if ignore_garbage {
495            clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
496        } else {
497            // Strip newlines using SIMD memchr
498            let mut last = 0;
499            for pos in memchr::memchr_iter(b'\n', chunk) {
500                if pos > last {
501                    clean.extend_from_slice(&chunk[last..pos]);
502                }
503                last = pos + 1;
504            }
505            if last < n {
506                clean.extend_from_slice(&chunk[last..]);
507            }
508            // Handle rare non-newline whitespace
509            if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
510                clean.retain(|&b| !is_whitespace(b));
511            }
512        }
513
514        let is_last = n < READ_CHUNK;
515
516        if is_last {
517            // Last chunk: decode everything (including padding)
518            decode_owned_clean(&mut clean, writer)?;
519        } else {
520            // Save incomplete base64 quadruplet for next iteration
521            let decode_len = (clean.len() / 4) * 4;
522            if decode_len < clean.len() {
523                carry.extend_from_slice(&clean[decode_len..]);
524            }
525            if decode_len > 0 {
526                clean.truncate(decode_len);
527                decode_owned_clean(&mut clean, writer)?;
528            }
529        }
530    }
531
532    // Handle any remaining carry-over bytes
533    if !carry.is_empty() {
534        decode_owned_clean(&mut carry, writer)?;
535    }
536
537    Ok(())
538}
539
540/// Read as many bytes as possible into buf, retrying on partial reads.
541fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
542    let mut total = 0;
543    while total < buf.len() {
544        match reader.read(&mut buf[total..]) {
545            Ok(0) => break,
546            Ok(n) => total += n,
547            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
548            Err(e) => return Err(e),
549        }
550    }
551    Ok(total)
552}