Skip to main content

coreutils_rs/base64/
core.rs

1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4
5const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
6
7/// Streaming encode chunk: 24MB aligned to 3 bytes.
8const STREAM_ENCODE_CHUNK: usize = 24 * 1024 * 1024 - (24 * 1024 * 1024 % 3);
9
10/// Chunk size for no-wrap encoding: 24MB aligned to 3 bytes.
11const NOWRAP_CHUNK: usize = 24 * 1024 * 1024 - (24 * 1024 * 1024 % 3);
12
13/// Encode data and write to output with line wrapping.
14/// Uses SIMD encoding with fused encode+wrap for maximum throughput.
15pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
16    if data.is_empty() {
17        return Ok(());
18    }
19
20    if wrap_col == 0 {
21        return encode_no_wrap(data, out);
22    }
23
24    encode_wrapped(data, wrap_col, out)
25}
26
27/// Encode without wrapping — sequential SIMD encoding in chunks.
28fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
29    let actual_chunk = NOWRAP_CHUNK.min(data.len());
30    let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
31    let mut buf = vec![0u8; enc_max];
32
33    for chunk in data.chunks(NOWRAP_CHUNK) {
34        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
35        let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
36        out.write_all(encoded)?;
37    }
38    Ok(())
39}
40
41/// Encode with line wrapping — fused encode+wrap in a single output buffer.
42/// Encodes aligned input chunks, then interleaves newlines directly into
43/// a single output buffer, eliminating the separate wrap pass.
44fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
45    // Calculate bytes_per_line: input bytes that produce exactly wrap_col encoded chars.
46    // For default wrap_col=76: 76*3/4 = 57 bytes per line.
47    let bytes_per_line = wrap_col * 3 / 4;
48    if bytes_per_line == 0 {
49        // Degenerate case: wrap_col < 4, fall back to byte-at-a-time
50        return encode_wrapped_small(data, wrap_col, out);
51    }
52
53    // Process in chunks aligned to bytes_per_line for complete output lines.
54    // 16MB input -> ~21MB encoded -> ~21.3MB with newlines
55    // Large chunks ensure typical files (≤10MB) process in a single pass.
56    let lines_per_chunk = (16 * 1024 * 1024) / bytes_per_line;
57    let max_input_chunk = (lines_per_chunk * bytes_per_line).max(bytes_per_line);
58    let input_chunk = max_input_chunk.min(data.len());
59
60    let enc_max = BASE64_ENGINE.encoded_length(input_chunk);
61    let mut encode_buf = vec![0u8; enc_max];
62
63    // Fused output buffer: holds encoded data with newlines interleaved
64    let max_lines = enc_max / wrap_col + 2;
65    let fused_max = enc_max + max_lines;
66    let mut fused_buf = vec![0u8; fused_max];
67
68    for chunk in data.chunks(max_input_chunk.max(1)) {
69        let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
70        let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
71
72        // Fuse: copy encoded data into fused_buf with newlines interleaved
73        let wp = fuse_wrap(encoded, wrap_col, &mut fused_buf);
74        out.write_all(&fused_buf[..wp])?;
75    }
76
77    Ok(())
78}
79
80/// Fuse encoded base64 data with newlines in a single pass.
81/// Uses ptr::copy_nonoverlapping with 4-line unrolling for max throughput.
82/// Returns number of bytes written.
83#[inline]
84fn fuse_wrap(encoded: &[u8], wrap_col: usize, out_buf: &mut [u8]) -> usize {
85    let line_out = wrap_col + 1; // wrap_col data bytes + 1 newline
86    let mut rp = 0;
87    let mut wp = 0;
88
89    // Unrolled: process 4 lines per iteration
90    while rp + 4 * wrap_col <= encoded.len() {
91        unsafe {
92            let src = encoded.as_ptr().add(rp);
93            let dst = out_buf.as_mut_ptr().add(wp);
94
95            std::ptr::copy_nonoverlapping(src, dst, wrap_col);
96            *dst.add(wrap_col) = b'\n';
97
98            std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
99            *dst.add(line_out + wrap_col) = b'\n';
100
101            std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
102            *dst.add(2 * line_out + wrap_col) = b'\n';
103
104            std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
105            *dst.add(3 * line_out + wrap_col) = b'\n';
106        }
107        rp += 4 * wrap_col;
108        wp += 4 * line_out;
109    }
110
111    // Remaining full lines
112    while rp + wrap_col <= encoded.len() {
113        unsafe {
114            std::ptr::copy_nonoverlapping(
115                encoded.as_ptr().add(rp),
116                out_buf.as_mut_ptr().add(wp),
117                wrap_col,
118            );
119        }
120        wp += wrap_col;
121        out_buf[wp] = b'\n';
122        wp += 1;
123        rp += wrap_col;
124    }
125
126    // Partial last line
127    if rp < encoded.len() {
128        let remaining = encoded.len() - rp;
129        out_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
130        wp += remaining;
131        out_buf[wp] = b'\n';
132        wp += 1;
133    }
134
135    wp
136}
137
138/// Fallback for very small wrap columns (< 4 chars).
139fn encode_wrapped_small(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
140    let enc_max = BASE64_ENGINE.encoded_length(data.len());
141    let mut buf = vec![0u8; enc_max];
142    let encoded = BASE64_ENGINE.encode(data, buf[..enc_max].as_out());
143
144    let wc = wrap_col.max(1);
145    for line in encoded.chunks(wc) {
146        out.write_all(line)?;
147        out.write_all(b"\n")?;
148    }
149    Ok(())
150}
151
152/// Decode base64 data and write to output (borrows data, allocates clean buffer).
153/// When `ignore_garbage` is true, strip all non-base64 characters.
154/// When false, only strip whitespace (standard behavior).
155pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
156    if data.is_empty() {
157        return Ok(());
158    }
159
160    if ignore_garbage {
161        let mut cleaned = strip_non_base64(data);
162        return decode_clean_slice(&mut cleaned, out);
163    }
164
165    // Fast path: single-pass strip + decode
166    decode_stripping_whitespace(data, out)
167}
168
169/// Decode base64 from an owned Vec (in-place whitespace strip + decode).
170pub fn decode_owned(
171    data: &mut Vec<u8>,
172    ignore_garbage: bool,
173    out: &mut impl Write,
174) -> io::Result<()> {
175    if data.is_empty() {
176        return Ok(());
177    }
178
179    if ignore_garbage {
180        data.retain(|&b| is_base64_char(b));
181    } else {
182        strip_whitespace_inplace(data);
183    }
184
185    decode_clean_slice(data, out)
186}
187
188/// Strip all whitespace from a Vec in-place using SIMD memchr.
189/// Single-pass compaction: scan for newlines (most common whitespace in base64),
190/// compact non-newline segments, then handle rare other whitespace.
191fn strip_whitespace_inplace(data: &mut Vec<u8>) {
192    // Quick check: no newlines at all?
193    if memchr::memchr(b'\n', data).is_none() {
194        if data.iter().any(|&b| is_whitespace(b)) {
195            data.retain(|&b| !is_whitespace(b));
196        }
197        return;
198    }
199
200    // In-place compaction using raw pointers
201    let ptr = data.as_ptr();
202    let mut_ptr = data.as_mut_ptr();
203    let len = data.len();
204    let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
205
206    let mut wp = 0usize;
207    let mut rp = 0usize;
208
209    for pos in memchr::memchr_iter(b'\n', slice) {
210        if pos > rp {
211            let seg = pos - rp;
212            unsafe {
213                std::ptr::copy(ptr.add(rp), mut_ptr.add(wp), seg);
214            }
215            wp += seg;
216        }
217        rp = pos + 1;
218    }
219
220    if rp < len {
221        let seg = len - rp;
222        unsafe {
223            std::ptr::copy(ptr.add(rp), mut_ptr.add(wp), seg);
224        }
225        wp += seg;
226    }
227
228    data.truncate(wp);
229
230    // Handle rare non-newline whitespace (CR, tab, etc.)
231    if data.iter().any(|&b| is_whitespace(b)) {
232        data.retain(|&b| !is_whitespace(b));
233    }
234}
235
236/// Decode by stripping whitespace and decoding in a single fused pass.
237/// For data with no whitespace, decodes directly without any copy.
238fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
239    // Quick check: any whitespace at all?
240    if memchr::memchr2(b'\n', b'\r', data).is_none()
241        && !data.iter().any(|&b| b == b' ' || b == b'\t')
242    {
243        // No whitespace — decode directly from borrowed data
244        return decode_borrowed_clean(out, data);
245    }
246
247    // Fused strip+collect: use SIMD memchr to find newlines,
248    // copy non-newline segments directly into clean buffer.
249    let mut clean = Vec::with_capacity(data.len());
250    let mut last = 0;
251    for pos in memchr::memchr_iter(b'\n', data) {
252        if pos > last {
253            clean.extend_from_slice(&data[last..pos]);
254        }
255        last = pos + 1;
256    }
257    if last < data.len() {
258        clean.extend_from_slice(&data[last..]);
259    }
260
261    // Handle rare non-newline whitespace (CR, tab, etc.)
262    if clean.iter().any(|&b| is_whitespace(b)) {
263        clean.retain(|&b| !is_whitespace(b));
264    }
265
266    decode_clean_slice(&mut clean, out)
267}
268
269/// Decode a clean (no whitespace) buffer in-place with SIMD.
270fn decode_clean_slice(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
271    if data.is_empty() {
272        return Ok(());
273    }
274    match BASE64_ENGINE.decode_inplace(data) {
275        Ok(decoded) => out.write_all(decoded),
276        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
277    }
278}
279
280/// Decode clean base64 data (no whitespace) from a borrowed slice.
281fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
282    if data.is_empty() {
283        return Ok(());
284    }
285    match BASE64_ENGINE.decode_to_vec(data) {
286        Ok(decoded) => {
287            out.write_all(&decoded)?;
288            Ok(())
289        }
290        Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
291    }
292}
293
294/// Strip non-base64 characters (for -i / --ignore-garbage).
295fn strip_non_base64(data: &[u8]) -> Vec<u8> {
296    data.iter()
297        .copied()
298        .filter(|&b| is_base64_char(b))
299        .collect()
300}
301
302/// Check if a byte is a valid base64 alphabet character or padding.
303#[inline]
304fn is_base64_char(b: u8) -> bool {
305    b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
306}
307
308/// Check if a byte is ASCII whitespace.
309#[inline]
310fn is_whitespace(b: u8) -> bool {
311    matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
312}
313
314/// Stream-encode from a reader to a writer. Used for stdin processing.
315pub fn encode_stream(
316    reader: &mut impl Read,
317    wrap_col: usize,
318    writer: &mut impl Write,
319) -> io::Result<()> {
320    let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
321
322    let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
323    let mut encode_buf = vec![0u8; encode_buf_size];
324
325    if wrap_col == 0 {
326        // No wrapping: encode each chunk and write directly.
327        loop {
328            let n = read_full(reader, &mut buf)?;
329            if n == 0 {
330                break;
331            }
332            let enc_len = BASE64_ENGINE.encoded_length(n);
333            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
334            writer.write_all(encoded)?;
335        }
336    } else {
337        // Wrapping: fused encode+wrap into a single output buffer.
338        let max_fused = encode_buf_size + (encode_buf_size / wrap_col + 2);
339        let mut fused_buf = vec![0u8; max_fused];
340        let mut col = 0usize;
341
342        loop {
343            let n = read_full(reader, &mut buf)?;
344            if n == 0 {
345                break;
346            }
347            let enc_len = BASE64_ENGINE.encoded_length(n);
348            let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
349
350            // Build fused output in a single buffer, then one write.
351            let wp = build_fused_output(encoded, wrap_col, &mut col, &mut fused_buf);
352            writer.write_all(&fused_buf[..wp])?;
353        }
354
355        if col > 0 {
356            writer.write_all(b"\n")?;
357        }
358    }
359
360    Ok(())
361}
362
363/// Build fused encode+wrap output into a pre-allocated buffer.
364/// Returns the number of bytes written.
365#[inline]
366fn build_fused_output(data: &[u8], wrap_col: usize, col: &mut usize, out_buf: &mut [u8]) -> usize {
367    let mut rp = 0;
368    let mut wp = 0;
369
370    while rp < data.len() {
371        let space = wrap_col - *col;
372        let avail = data.len() - rp;
373
374        if avail <= space {
375            out_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
376            wp += avail;
377            *col += avail;
378            if *col == wrap_col {
379                out_buf[wp] = b'\n';
380                wp += 1;
381                *col = 0;
382            }
383            break;
384        } else {
385            out_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
386            wp += space;
387            out_buf[wp] = b'\n';
388            wp += 1;
389            rp += space;
390            *col = 0;
391        }
392    }
393
394    wp
395}
396
397/// Stream-decode from a reader to a writer. Used for stdin processing.
398/// Fused single-pass: read chunk → strip whitespace in-place → decode immediately.
399pub fn decode_stream(
400    reader: &mut impl Read,
401    ignore_garbage: bool,
402    writer: &mut impl Write,
403) -> io::Result<()> {
404    const READ_CHUNK: usize = 4 * 1024 * 1024;
405    let mut buf = vec![0u8; READ_CHUNK];
406    let mut clean = Vec::with_capacity(READ_CHUNK);
407    let mut carry: Vec<u8> = Vec::with_capacity(4);
408
409    loop {
410        let n = read_full(reader, &mut buf)?;
411        if n == 0 {
412            break;
413        }
414
415        // Fused: build clean buffer from carry-over + stripped chunk in one pass
416        clean.clear();
417        clean.extend_from_slice(&carry);
418        carry.clear();
419
420        let chunk = &buf[..n];
421        if ignore_garbage {
422            clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
423        } else {
424            // Strip newlines using SIMD memchr — single pass
425            let mut last = 0;
426            for pos in memchr::memchr_iter(b'\n', chunk) {
427                if pos > last {
428                    clean.extend_from_slice(&chunk[last..pos]);
429                }
430                last = pos + 1;
431            }
432            if last < n {
433                clean.extend_from_slice(&chunk[last..]);
434            }
435            // Handle rare non-newline whitespace
436            if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
437                clean.retain(|&b| !is_whitespace(b));
438            }
439        }
440
441        let is_last = n < READ_CHUNK;
442
443        if is_last {
444            // Last chunk: decode everything (including padding)
445            decode_clean_slice(&mut clean, writer)?;
446        } else {
447            // Save incomplete base64 quadruplet for next iteration
448            let decode_len = (clean.len() / 4) * 4;
449            if decode_len < clean.len() {
450                carry.extend_from_slice(&clean[decode_len..]);
451            }
452            if decode_len > 0 {
453                clean.truncate(decode_len);
454                decode_clean_slice(&mut clean, writer)?;
455            }
456        }
457    }
458
459    // Handle any remaining carry-over bytes
460    if !carry.is_empty() {
461        decode_clean_slice(&mut carry, writer)?;
462    }
463
464    Ok(())
465}
466
467/// Read as many bytes as possible into buf, retrying on partial reads.
468/// Fast path: regular file reads usually return the full buffer on the first call,
469/// avoiding the loop overhead entirely.
470#[inline]
471fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
472    // Fast path: first read() usually fills the entire buffer for regular files
473    let n = reader.read(buf)?;
474    if n == buf.len() || n == 0 {
475        return Ok(n);
476    }
477    // Slow path: partial read — retry to fill buffer (pipes, slow devices)
478    let mut total = n;
479    while total < buf.len() {
480        match reader.read(&mut buf[total..]) {
481            Ok(0) => break,
482            Ok(n) => total += n,
483            Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
484            Err(e) => return Err(e),
485        }
486    }
487    Ok(total)
488}