Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4
5const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
6
7/// Streaming encode chunk: 4MB aligned to 3 bytes.
8const STREAM_ENCODE_CHUNK: usize = 4 * 1024 * 1024 - (4 * 1024 * 1024 % 3);
9
10/// Chunk size for no-wrap encoding: 4MB aligned to 3 bytes.
11const NOWRAP_CHUNK: usize = 4 * 1024 * 1024 - (4 * 1024 * 1024 % 3);
12
13/// Encode data and write to output with line wrapping.
14/// Uses SIMD encoding with reusable buffers for maximum throughput.
15pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
16    if data.is_empty() {
17        return Ok(());
18    }
19
20    if wrap_col == 0 {
21        return encode_no_wrap(data, out);
22    }
23
24    encode_wrapped(data, wrap_col, out)
25}
26
27/// Encode without wrapping — sequential SIMD encoding in chunks.
28fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
29    // Size buffer for actual data, not max chunk size.
30    let actual_chunk = NOWRAP_CHUNK.min(data.len());
31    let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
32    let mut buf = vec![0u8; enc_max];
33
34    for chunk in data.chunks(NOWRAP_CHUNK) {
35        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
36        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
37        out.write_all(encoded)?;
38    }
39    Ok(())
40}
41
42/// Encode with line wrapping — sequential SIMD encoding in chunks.
43fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
44    let bytes_per_line = wrap_col * 3 / 4;
45
46    // Sequential path: 2MB chunks fit in L2 cache.
47    let lines_per_chunk = (2 * 1024 * 1024) / bytes_per_line.max(1);
48    let chunk_input = lines_per_chunk * bytes_per_line.max(1);
49    let effective_chunk = chunk_input.max(1).min(data.len());
50    let chunk_encoded_max = BASE64_ENGINE.encoded_length(effective_chunk);
51    let mut encode_buf = vec![0u8; chunk_encoded_max];
52    let effective_lines = effective_chunk / bytes_per_line.max(1) + 1;
53    let wrapped_max = (effective_lines + 1) * (wrap_col + 1);
54    let mut wrap_buf = vec![0u8; wrapped_max];
55
56    for chunk in data.chunks(chunk_input.max(1)) {
57        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
58        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
59        let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
60        out.write_all(&wrap_buf[..wp])?;
61    }
62
63    Ok(())
64}
65
66/// Wrap encoded base64 data with newlines at `wrap_col` columns.
67/// Returns number of bytes written to `wrap_buf`.
68#[inline]
69fn wrap_encoded(encoded: &[u8], wrap_col: usize, wrap_buf: &mut [u8]) -> usize {
70    let line_out = wrap_col + 1;
71    let mut rp = 0;
72    let mut wp = 0;
73
74    // Unrolled: process 4 lines per iteration
75    while rp + 4 * wrap_col <= encoded.len() {
76        unsafe {
77            let src = encoded.as_ptr().add(rp);
78            let dst = wrap_buf.as_mut_ptr().add(wp);
79
80            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
81            *dst.add(wrap_col) = b'\n';
82
83            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
84            *dst.add(line_out + wrap_col) = b'\n';
85
86            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
87            *dst.add(2 * line_out + wrap_col) = b'\n';
88
89            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
90            *dst.add(3 * line_out + wrap_col) = b'\n';
91        }
92        rp += 4 * wrap_col;
93        wp += 4 * line_out;
94    }
95
96    // Remaining full lines
97    while rp + wrap_col <= encoded.len() {
98        wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
99        wp += wrap_col;
100        wrap_buf[wp] = b'\n';
101        wp += 1;
102        rp += wrap_col;
103    }
104
105    // Partial last line
106    if rp < encoded.len() {
107        let remaining = encoded.len() - rp;
108        wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
109        wp += remaining;
110        wrap_buf[wp] = b'\n';
111        wp += 1;
112    }
113
114    wp
115}
116
117/// Decode base64 data and write to output (borrows data, allocates clean buffer).
118/// When `ignore_garbage` is true, strip all non-base64 characters.
119/// When false, only strip whitespace (standard behavior).
120pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
121    if data.is_empty() {
122        return Ok(());
123    }
124
125    if ignore_garbage {
126        let mut cleaned = strip_non_base64(data);
127        return decode_owned_clean(&mut cleaned, out);
128    }
129
130    // Fast path: strip newlines with memchr (SIMD), then SIMD decode
131    decode_stripping_whitespace(data, out)
132}
133
134/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
135pub fn decode_owned(
136    data: &mut Vec<u8>,
137    ignore_garbage: bool,
138    out: &mut impl Write,
139) -> io::Result<()> {
140    if data.is_empty() {
141        return Ok(());
142    }
143
144    if ignore_garbage {
145        data.retain(|&b| is_base64_char(b));
146    } else {
147        strip_whitespace_inplace(data);
148    }
149
150    decode_owned_clean(data, out)
151}
152
153/// Strip all whitespace from a Vec in-place using SIMD memchr for newlines.
154fn strip_whitespace_inplace(data: &mut Vec<u8>) {
155    // Quick check for newlines using SIMD
156    if memchr::memchr(b'\n', data).is_none() {
157        if data.iter().any(|&b| is_whitespace(b)) {
158            data.retain(|&b| !is_whitespace(b));
159        }
160        return;
161    }
162
163    // In-place compaction using raw pointers to avoid borrow conflict.
164    let ptr = data.as_ptr();
165    let mut_ptr = data.as_mut_ptr();
166    let len = data.len();
167    let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
168
169    let mut wp = 0usize;
170    let mut rp = 0usize;
171
172    for pos in memchr::memchr_iter(b'\n', slice) {
173        if pos > rp {
174            let seg = pos - rp;
175            unsafe {
176                std::ptr::copy(ptr.add(rp), mut_ptr.add(wp), seg);
177            }
178            wp += seg;
179        }
180        rp = pos + 1;
181    }
182
183    if rp < len {
184        let seg = len - rp;
185        unsafe {
186            std::ptr::copy(ptr.add(rp), mut_ptr.add(wp), seg);
187        }
188        wp += seg;
189    }
190
191    data.truncate(wp);
192
193    // Handle rare non-newline whitespace (CR, tab, etc.)
194    if data.iter().any(|&b| is_whitespace(b)) {
195        data.retain(|&b| !is_whitespace(b));
196    }
197}
198
199/// Decode by stripping all whitespace from the entire input at once,
200/// then performing a single SIMD decode pass.
201fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
202    // Quick check: any whitespace at all?
203    if memchr::memchr2(b'\n', b'\r', data).is_none()
204        && !data.iter().any(|&b| b == b' ' || b == b'\t')
205    {
206        // No whitespace — decode directly from borrowed data
207        return decode_borrowed_clean(out, data);
208    }
209
210    // Strip newlines from entire input in a single pass using SIMD memchr.
211    let mut clean = Vec::with_capacity(data.len());
212    let mut last = 0;
213    for pos in memchr::memchr_iter(b'\n', data) {
214        if pos > last {
215            clean.extend_from_slice(&data[last..pos]);
216        }
217        last = pos + 1;
218    }
219    if last < data.len() {
220        clean.extend_from_slice(&data[last..]);
221    }
222
223    // Handle rare non-newline whitespace (CR, tab, etc.)
224    if clean.iter().any(|&b| is_whitespace(b)) {
225        clean.retain(|&b| !is_whitespace(b));
226    }
227
228    decode_owned_clean(&mut clean, out)
229}
230
231/// Decode a clean (no whitespace) owned buffer in-place with SIMD.
232fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
233    if data.is_empty() {
234        return Ok(());
235    }
236    match BASE64_ENGINE.decode_inplace(data) {
237        Ok(decoded) => out.write_all(decoded),
238        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
239    }
240}
241
242/// Decode clean base64 data (no whitespace) from a borrowed slice.
243fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
244    if data.is_empty() {
245        return Ok(());
246    }
247    match BASE64_ENGINE.decode_to_vec(data) {
248        Ok(decoded) => {
249            out.write_all(&decoded)?;
250            Ok(())
251        }
252        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
253    }
254}
255
256/// Strip non-base64 characters (for -i / --ignore-garbage).
257fn strip_non_base64(data: &[u8]) -> Vec<u8> {
258    data.iter()
259        .copied()
260        .filter(|&b| is_base64_char(b))
261        .collect()
262}
263
264/// Check if a byte is a valid base64 alphabet character or padding.
265#[inline]
266fn is_base64_char(b: u8) -> bool {
267    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
268}
269
270/// Check if a byte is ASCII whitespace.
271#[inline]
272fn is_whitespace(b: u8) -> bool {
273    matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
274}
275
276/// Stream-encode from a reader to a writer. Used for stdin processing.
277pub fn encode_stream(
278    reader: &mut impl Read,
279    wrap_col: usize,
280    writer: &mut impl Write,
281) -> io::Result<()> {
282    let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
283
284    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
285    let mut encode_buf = vec![0u8; encode_buf_size];
286
287    if wrap_col == 0 {
288        // No wrapping: encode each 4MB chunk and write directly.
289        loop {
290            let n = read_full(reader, &mut buf)?;
291            if n == 0 {
292                break;
293            }
294            let enc_len = BASE64_ENGINE.encoded_length(n);
295            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
296            writer.write_all(encoded)?;
297        }
298    } else {
299        // Wrapping: batch wrapped output into a pre-allocated buffer.
300        let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
301        let mut wrap_buf = vec![0u8; max_wrapped];
302        let mut col = 0usize;
303
304        loop {
305            let n = read_full(reader, &mut buf)?;
306            if n == 0 {
307                break;
308            }
309            let enc_len = BASE64_ENGINE.encoded_length(n);
310            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
311
312            // Build wrapped output in wrap_buf, then single write.
313            let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
314            writer.write_all(&wrap_buf[..wp])?;
315        }
316
317        if col > 0 {
318            writer.write_all(b"\n")?;
319        }
320    }
321
322    Ok(())
323}
324
325/// Build wrapped output into a pre-allocated buffer.
326/// Returns the number of bytes written to wrap_buf.
327#[inline]
328fn build_wrapped_output(
329    data: &[u8],
330    wrap_col: usize,
331    col: &mut usize,
332    wrap_buf: &mut [u8],
333) -> usize {
334    let mut rp = 0;
335    let mut wp = 0;
336
337    while rp < data.len() {
338        let space = wrap_col - *col;
339        let avail = data.len() - rp;
340
341        if avail <= space {
342            wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
343            wp += avail;
344            *col += avail;
345            if *col == wrap_col {
346                wrap_buf[wp] = b'\n';
347                wp += 1;
348                *col = 0;
349            }
350            break;
351        } else {
352            wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
353            wp += space;
354            wrap_buf[wp] = b'\n';
355            wp += 1;
356            rp += space;
357            *col = 0;
358        }
359    }
360
361    wp
362}
363
364/// Stream-decode from a reader to a writer. Used for stdin processing.
365pub fn decode_stream(
366    reader: &mut impl Read,
367    ignore_garbage: bool,
368    writer: &mut impl Write,
369) -> io::Result<()> {
370    const READ_CHUNK: usize = 4 * 1024 * 1024;
371    let mut buf = vec![0u8; READ_CHUNK];
372    let mut clean = Vec::with_capacity(READ_CHUNK);
373    let mut carry: Vec<u8> = Vec::with_capacity(4);
374
375    loop {
376        let n = read_full(reader, &mut buf)?;
377        if n == 0 {
378            break;
379        }
380
381        // Build clean buffer: carry-over + stripped chunk
382        clean.clear();
383        clean.extend_from_slice(&carry);
384        carry.clear();
385
386        let chunk = &buf[..n];
387        if ignore_garbage {
388            clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
389        } else {
390            // Strip newlines using SIMD memchr
391            let mut last = 0;
392            for pos in memchr::memchr_iter(b'\n', chunk) {
393                if pos > last {
394                    clean.extend_from_slice(&chunk[last..pos]);
395                }
396                last = pos + 1;
397            }
398            if last < n {
399                clean.extend_from_slice(&chunk[last..]);
400            }
401            // Handle rare non-newline whitespace
402            if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
403                clean.retain(|&b| !is_whitespace(b));
404            }
405        }
406
407        let is_last = n < READ_CHUNK;
408
409        if is_last {
410            // Last chunk: decode everything (including padding)
411            decode_owned_clean(&mut clean, writer)?;
412        } else {
413            // Save incomplete base64 quadruplet for next iteration
414            let decode_len = (clean.len() / 4) * 4;
415            if decode_len < clean.len() {
416                carry.extend_from_slice(&clean[decode_len..]);
417            }
418            if decode_len > 0 {
419                clean.truncate(decode_len);
420                decode_owned_clean(&mut clean, writer)?;
421            }
422        }
423    }
424
425    // Handle any remaining carry-over bytes
426    if !carry.is_empty() {
427        decode_owned_clean(&mut carry, writer)?;
428    }
429
430    Ok(())
431}
432
433/// Read as many bytes as possible into buf, retrying on partial reads.
434fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
435    let mut total = 0;
436    while total < buf.len() {
437        match reader.read(&mut buf[total..]) {
438            Ok(0) => break,
439            Ok(n) => total += n,
440            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
441            Err(e) => return Err(e),
442        }
443    }
444    Ok(total)
445}