Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4
5const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
6
7/// Streaming encode chunk: 8MB aligned to 3 bytes for maximum throughput.
8const STREAM_ENCODE_CHUNK: usize = 8 * 1024 * 1024 - (8 * 1024 * 1024 % 3);
9
10/// Chunk size for no-wrap encoding: 8MB aligned to 3 bytes.
11const NOWRAP_CHUNK: usize = 8 * 1024 * 1024 - (8 * 1024 * 1024 % 3);
12
13/// Encode data and write to output with line wrapping.
14/// Uses SIMD encoding with reusable buffers for maximum throughput.
15pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
16    if data.is_empty() {
17        return Ok(());
18    }
19
20    if wrap_col == 0 {
21        return encode_no_wrap(data, out);
22    }
23
24    encode_wrapped(data, wrap_col, out)
25}
26
27/// Encode without wrapping: process in 4MB chunks for bounded memory usage.
28/// Each chunk is SIMD-encoded and written immediately. Reuses a single
29/// encode buffer across chunks to avoid repeated allocation.
30fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
31    let enc_max = BASE64_ENGINE.encoded_length(NOWRAP_CHUNK);
32    let mut buf = vec![0u8; enc_max];
33
34    for chunk in data.chunks(NOWRAP_CHUNK) {
35        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
36        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
37        out.write_all(encoded)?;
38    }
39    Ok(())
40}
41
42/// Encode with line wrapping using large cache-friendly chunks.
43/// Each chunk is SIMD-encoded, then wrapped with newlines in a pre-allocated
44/// buffer using direct slice copies, and written with a single write_all.
45fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
46    let bytes_per_line = wrap_col * 3 / 4;
47
48    // Process ~8MB of input per chunk for maximum throughput.
49    // Aligned to bytes_per_line for clean line boundaries.
50    let lines_per_chunk = (8 * 1024 * 1024) / bytes_per_line;
51    let chunk_input = lines_per_chunk * bytes_per_line;
52    let chunk_encoded_max = BASE64_ENGINE.encoded_length(chunk_input);
53
54    // Pre-allocate reusable buffers (no per-chunk allocation).
55    let mut encode_buf = vec![0u8; chunk_encoded_max];
56    // Wrapped output: each line is wrap_col + 1 bytes (content + newline).
57    let wrapped_max = (lines_per_chunk + 1) * (wrap_col + 1);
58    let mut wrap_buf = vec![0u8; wrapped_max];
59
60    let line_out = wrap_col + 1; // bytes per output line (content + newline)
61
62    for chunk in data.chunks(chunk_input) {
63        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
64        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
65
66        // Build wrapped output with unrolled copies for common case.
67        let mut rp = 0;
68        let mut wp = 0;
69
70        // Unrolled: process 4 lines per iteration for better throughput
71        while rp + 4 * wrap_col <= encoded.len() {
72            unsafe {
73                let src = encoded.as_ptr().add(rp);
74                let dst = wrap_buf.as_mut_ptr().add(wp);
75
76                std::ptr::copy_nonoverlapping(src, dst, wrap_col);
77                *dst.add(wrap_col) = b'\n';
78
79                std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
80                *dst.add(line_out + wrap_col) = b'\n';
81
82                std::ptr::copy_nonoverlapping(
83                    src.add(2 * wrap_col),
84                    dst.add(2 * line_out),
85                    wrap_col,
86                );
87                *dst.add(2 * line_out + wrap_col) = b'\n';
88
89                std::ptr::copy_nonoverlapping(
90                    src.add(3 * wrap_col),
91                    dst.add(3 * line_out),
92                    wrap_col,
93                );
94                *dst.add(3 * line_out + wrap_col) = b'\n';
95            }
96            rp += 4 * wrap_col;
97            wp += 4 * line_out;
98        }
99
100        // Remaining full lines
101        while rp + wrap_col <= encoded.len() {
102            wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
103            wp += wrap_col;
104            wrap_buf[wp] = b'\n';
105            wp += 1;
106            rp += wrap_col;
107        }
108
109        // Partial last line
110        if rp < encoded.len() {
111            let remaining = encoded.len() - rp;
112            wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
113            wp += remaining;
114            wrap_buf[wp] = b'\n';
115            wp += 1;
116        }
117
118        out.write_all(&wrap_buf[..wp])?;
119    }
120
121    Ok(())
122}
123
124/// Decode base64 data and write to output (borrows data, allocates clean buffer).
125/// When `ignore_garbage` is true, strip all non-base64 characters.
126/// When false, only strip whitespace (standard behavior).
127pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
128    if data.is_empty() {
129        return Ok(());
130    }
131
132    if ignore_garbage {
133        let mut cleaned = strip_non_base64(data);
134        return decode_owned_clean(&mut cleaned, out);
135    }
136
137    // Fast path: strip newlines with memchr (SIMD), then SIMD decode
138    decode_stripping_whitespace(data, out)
139}
140
141/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
142/// Avoids a full buffer copy by stripping whitespace in the existing allocation,
143/// then decoding in-place. Ideal when the caller already has an owned Vec.
144pub fn decode_owned(
145    data: &mut Vec<u8>,
146    ignore_garbage: bool,
147    out: &mut impl Write,
148) -> io::Result<()> {
149    if data.is_empty() {
150        return Ok(());
151    }
152
153    if ignore_garbage {
154        data.retain(|&b| is_base64_char(b));
155    } else {
156        strip_whitespace_inplace(data);
157    }
158
159    decode_owned_clean(data, out)
160}
161
162/// Strip all whitespace from a Vec in-place using SIMD memchr for newlines
163/// and a fallback scan for rare non-newline whitespace.
164fn strip_whitespace_inplace(data: &mut Vec<u8>) {
165    // First, collect newline positions using SIMD memchr.
166    let positions: Vec<usize> = memchr::memchr_iter(b'\n', data.as_slice()).collect();
167
168    if positions.is_empty() {
169        // No newlines; check for other whitespace only.
170        if data.iter().any(|&b| is_whitespace(b)) {
171            data.retain(|&b| !is_whitespace(b));
172        }
173        return;
174    }
175
176    // Compact data in-place, removing newlines using copy_within.
177    let mut wp = 0;
178    let mut rp = 0;
179
180    for &pos in &positions {
181        if pos > rp {
182            let len = pos - rp;
183            data.copy_within(rp..pos, wp);
184            wp += len;
185        }
186        rp = pos + 1;
187    }
188
189    let data_len = data.len();
190    if rp < data_len {
191        let len = data_len - rp;
192        data.copy_within(rp..data_len, wp);
193        wp += len;
194    }
195
196    data.truncate(wp);
197
198    // Handle rare non-newline whitespace (CR, tab, etc.)
199    if data.iter().any(|&b| is_whitespace(b)) {
200        data.retain(|&b| !is_whitespace(b));
201    }
202}
203
204/// Decode by stripping all whitespace from the entire input at once,
205/// then performing a single SIMD decode pass. Used when data is borrowed.
206fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
207    // Quick check: any whitespace at all?
208    if memchr::memchr(b'\n', data).is_none() && !data.iter().any(|&b| is_whitespace(b)) {
209        return decode_borrowed_clean(out, data);
210    }
211
212    // Strip newlines from entire input in a single pass using SIMD memchr.
213    let mut clean = Vec::with_capacity(data.len());
214    let mut last = 0;
215    for pos in memchr::memchr_iter(b'\n', data) {
216        if pos > last {
217            clean.extend_from_slice(&data[last..pos]);
218        }
219        last = pos + 1;
220    }
221    if last < data.len() {
222        clean.extend_from_slice(&data[last..]);
223    }
224
225    // Handle rare non-newline whitespace (CR, tab, etc.)
226    if clean.iter().any(|&b| is_whitespace(b)) {
227        clean.retain(|&b| !is_whitespace(b));
228    }
229
230    decode_owned_clean(&mut clean, out)
231}
232
233/// Decode a clean (no whitespace) owned buffer in-place with SIMD.
234fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
235    if data.is_empty() {
236        return Ok(());
237    }
238    match BASE64_ENGINE.decode_inplace(data) {
239        Ok(decoded) => out.write_all(decoded),
240        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
241    }
242}
243
244/// Decode clean base64 data (no whitespace) from a borrowed slice.
245fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
246    if data.is_empty() {
247        return Ok(());
248    }
249    match BASE64_ENGINE.decode_to_vec(data) {
250        Ok(decoded) => {
251            out.write_all(&decoded)?;
252            Ok(())
253        }
254        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
255    }
256}
257
258/// Strip non-base64 characters (for -i / --ignore-garbage).
259fn strip_non_base64(data: &[u8]) -> Vec<u8> {
260    data.iter()
261        .copied()
262        .filter(|&b| is_base64_char(b))
263        .collect()
264}
265
266/// Check if a byte is a valid base64 alphabet character or padding.
267#[inline]
268fn is_base64_char(b: u8) -> bool {
269    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
270}
271
272/// Check if a byte is ASCII whitespace.
273#[inline]
274fn is_whitespace(b: u8) -> bool {
275    matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
276}
277
278/// Stream-encode from a reader to a writer. Used for stdin processing.
279/// Uses 4MB read chunks and batches wrapped output for minimum syscalls.
280/// The caller is expected to provide a suitably buffered or raw fd writer.
281pub fn encode_stream(
282    reader: &mut impl Read,
283    wrap_col: usize,
284    writer: &mut impl Write,
285) -> io::Result<()> {
286    let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
287
288    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
289    let mut encode_buf = vec![0u8; encode_buf_size];
290
291    if wrap_col == 0 {
292        // No wrapping: encode each 4MB chunk and write directly.
293        loop {
294            let n = read_full(reader, &mut buf)?;
295            if n == 0 {
296                break;
297            }
298            let enc_len = BASE64_ENGINE.encoded_length(n);
299            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
300            writer.write_all(encoded)?;
301        }
302    } else {
303        // Wrapping: batch wrapped output into a pre-allocated buffer.
304        // For 4MB input at 76-col wrap, wrapped output is ~5.6MB.
305        let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
306        let mut wrap_buf = vec![0u8; max_wrapped];
307        let mut col = 0usize;
308
309        loop {
310            let n = read_full(reader, &mut buf)?;
311            if n == 0 {
312                break;
313            }
314            let enc_len = BASE64_ENGINE.encoded_length(n);
315            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
316
317            // Build wrapped output in wrap_buf, then single write.
318            let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
319            writer.write_all(&wrap_buf[..wp])?;
320        }
321
322        if col > 0 {
323            writer.write_all(b"\n")?;
324        }
325    }
326
327    Ok(())
328}
329
330/// Build wrapped output into a pre-allocated buffer.
331/// Returns the number of bytes written to wrap_buf.
332/// Updates `col` to track the current column position across calls.
333#[inline]
334fn build_wrapped_output(
335    data: &[u8],
336    wrap_col: usize,
337    col: &mut usize,
338    wrap_buf: &mut [u8],
339) -> usize {
340    let mut rp = 0;
341    let mut wp = 0;
342
343    while rp < data.len() {
344        let space = wrap_col - *col;
345        let avail = data.len() - rp;
346
347        if avail <= space {
348            wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
349            wp += avail;
350            *col += avail;
351            if *col == wrap_col {
352                wrap_buf[wp] = b'\n';
353                wp += 1;
354                *col = 0;
355            }
356            break;
357        } else {
358            wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
359            wp += space;
360            wrap_buf[wp] = b'\n';
361            wp += 1;
362            rp += space;
363            *col = 0;
364        }
365    }
366
367    wp
368}
369
370/// Stream-decode from a reader to a writer. Used for stdin processing.
371/// Reads all input, strips whitespace, decodes in one SIMD pass, writes once.
372/// The caller is expected to provide a suitably buffered or raw fd writer.
373pub fn decode_stream(
374    reader: &mut impl Read,
375    ignore_garbage: bool,
376    writer: &mut impl Write,
377) -> io::Result<()> {
378    let mut data = Vec::new();
379    reader.read_to_end(&mut data)?;
380
381    if ignore_garbage {
382        data.retain(|&b| is_base64_char(b));
383    } else {
384        strip_whitespace_inplace(&mut data);
385    }
386
387    decode_owned_clean(&mut data, writer)
388}
389
390/// Read as many bytes as possible into buf, retrying on partial reads.
391fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
392    let mut total = 0;
393    while total < buf.len() {
394        match reader.read(&mut buf[total..]) {
395            Ok(0) => break,
396            Ok(n) => total += n,
397            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
398            Err(e) => return Err(e),
399        }
400    }
401    Ok(total)
402}