coreutils_rs/base64/
core.rs1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4
5const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
6
7const STREAM_ENCODE_CHUNK: usize = 4 * 1024 * 1024 - (4 * 1024 * 1024 % 3);
9
10const NOWRAP_CHUNK: usize = 4 * 1024 * 1024 - (4 * 1024 * 1024 % 3);
12
13pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
16 if data.is_empty() {
17 return Ok(());
18 }
19
20 if wrap_col == 0 {
21 return encode_no_wrap(data, out);
22 }
23
24 encode_wrapped(data, wrap_col, out)
25}
26
27fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
29 let actual_chunk = NOWRAP_CHUNK.min(data.len());
31 let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
32 let mut buf = vec![0u8; enc_max];
33
34 for chunk in data.chunks(NOWRAP_CHUNK) {
35 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
36 let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
37 out.write_all(encoded)?;
38 }
39 Ok(())
40}
41
42fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
44 let bytes_per_line = wrap_col * 3 / 4;
45
46 let lines_per_chunk = (2 * 1024 * 1024) / bytes_per_line.max(1);
48 let chunk_input = lines_per_chunk * bytes_per_line.max(1);
49 let effective_chunk = chunk_input.max(1).min(data.len());
50 let chunk_encoded_max = BASE64_ENGINE.encoded_length(effective_chunk);
51 let mut encode_buf = vec![0u8; chunk_encoded_max];
52 let effective_lines = effective_chunk / bytes_per_line.max(1) + 1;
53 let wrapped_max = (effective_lines + 1) * (wrap_col + 1);
54 let mut wrap_buf = vec![0u8; wrapped_max];
55
56 for chunk in data.chunks(chunk_input.max(1)) {
57 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
58 let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
59 let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
60 out.write_all(&wrap_buf[..wp])?;
61 }
62
63 Ok(())
64}
65
66#[inline]
69fn wrap_encoded(encoded: &[u8], wrap_col: usize, wrap_buf: &mut [u8]) -> usize {
70 let line_out = wrap_col + 1;
71 let mut rp = 0;
72 let mut wp = 0;
73
74 while rp + 4 * wrap_col <= encoded.len() {
76 unsafe {
77 let src = encoded.as_ptr().add(rp);
78 let dst = wrap_buf.as_mut_ptr().add(wp);
79
80 std::ptr::copy_nonoverlapping(src, dst, wrap_col);
81 *dst.add(wrap_col) = b'\n';
82
83 std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
84 *dst.add(line_out + wrap_col) = b'\n';
85
86 std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
87 *dst.add(2 * line_out + wrap_col) = b'\n';
88
89 std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
90 *dst.add(3 * line_out + wrap_col) = b'\n';
91 }
92 rp += 4 * wrap_col;
93 wp += 4 * line_out;
94 }
95
96 while rp + wrap_col <= encoded.len() {
98 wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
99 wp += wrap_col;
100 wrap_buf[wp] = b'\n';
101 wp += 1;
102 rp += wrap_col;
103 }
104
105 if rp < encoded.len() {
107 let remaining = encoded.len() - rp;
108 wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
109 wp += remaining;
110 wrap_buf[wp] = b'\n';
111 wp += 1;
112 }
113
114 wp
115}
116
117pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
121 if data.is_empty() {
122 return Ok(());
123 }
124
125 if ignore_garbage {
126 let mut cleaned = strip_non_base64(data);
127 return decode_owned_clean(&mut cleaned, out);
128 }
129
130 decode_stripping_whitespace(data, out)
132}
133
134pub fn decode_owned(
136 data: &mut Vec<u8>,
137 ignore_garbage: bool,
138 out: &mut impl Write,
139) -> io::Result<()> {
140 if data.is_empty() {
141 return Ok(());
142 }
143
144 if ignore_garbage {
145 data.retain(|&b| is_base64_char(b));
146 } else {
147 strip_whitespace_inplace(data);
148 }
149
150 decode_owned_clean(data, out)
151}
152
153fn strip_whitespace_inplace(data: &mut Vec<u8>) {
155 if memchr::memchr(b'\n', data).is_none() {
157 if data.iter().any(|&b| is_whitespace(b)) {
158 data.retain(|&b| !is_whitespace(b));
159 }
160 return;
161 }
162
163 let ptr = data.as_ptr();
165 let mut_ptr = data.as_mut_ptr();
166 let len = data.len();
167 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
168
169 let mut wp = 0usize;
170 let mut rp = 0usize;
171
172 for pos in memchr::memchr_iter(b'\n', slice) {
173 if pos > rp {
174 let seg = pos - rp;
175 unsafe {
176 std::ptr::copy(ptr.add(rp), mut_ptr.add(wp), seg);
177 }
178 wp += seg;
179 }
180 rp = pos + 1;
181 }
182
183 if rp < len {
184 let seg = len - rp;
185 unsafe {
186 std::ptr::copy(ptr.add(rp), mut_ptr.add(wp), seg);
187 }
188 wp += seg;
189 }
190
191 data.truncate(wp);
192
193 if data.iter().any(|&b| is_whitespace(b)) {
195 data.retain(|&b| !is_whitespace(b));
196 }
197}
198
199fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
202 if memchr::memchr2(b'\n', b'\r', data).is_none()
204 && !data.iter().any(|&b| b == b' ' || b == b'\t')
205 {
206 return decode_borrowed_clean(out, data);
208 }
209
210 let mut clean = Vec::with_capacity(data.len());
212 let mut last = 0;
213 for pos in memchr::memchr_iter(b'\n', data) {
214 if pos > last {
215 clean.extend_from_slice(&data[last..pos]);
216 }
217 last = pos + 1;
218 }
219 if last < data.len() {
220 clean.extend_from_slice(&data[last..]);
221 }
222
223 if clean.iter().any(|&b| is_whitespace(b)) {
225 clean.retain(|&b| !is_whitespace(b));
226 }
227
228 decode_owned_clean(&mut clean, out)
229}
230
231fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
233 if data.is_empty() {
234 return Ok(());
235 }
236 match BASE64_ENGINE.decode_inplace(data) {
237 Ok(decoded) => out.write_all(decoded),
238 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
239 }
240}
241
242fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
244 if data.is_empty() {
245 return Ok(());
246 }
247 match BASE64_ENGINE.decode_to_vec(data) {
248 Ok(decoded) => {
249 out.write_all(&decoded)?;
250 Ok(())
251 }
252 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
253 }
254}
255
256fn strip_non_base64(data: &[u8]) -> Vec<u8> {
258 data.iter()
259 .copied()
260 .filter(|&b| is_base64_char(b))
261 .collect()
262}
263
264#[inline]
266fn is_base64_char(b: u8) -> bool {
267 b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
268}
269
270#[inline]
272fn is_whitespace(b: u8) -> bool {
273 matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
274}
275
276pub fn encode_stream(
278 reader: &mut impl Read,
279 wrap_col: usize,
280 writer: &mut impl Write,
281) -> io::Result<()> {
282 let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
283
284 let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
285 let mut encode_buf = vec![0u8; encode_buf_size];
286
287 if wrap_col == 0 {
288 loop {
290 let n = read_full(reader, &mut buf)?;
291 if n == 0 {
292 break;
293 }
294 let enc_len = BASE64_ENGINE.encoded_length(n);
295 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
296 writer.write_all(encoded)?;
297 }
298 } else {
299 let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
301 let mut wrap_buf = vec![0u8; max_wrapped];
302 let mut col = 0usize;
303
304 loop {
305 let n = read_full(reader, &mut buf)?;
306 if n == 0 {
307 break;
308 }
309 let enc_len = BASE64_ENGINE.encoded_length(n);
310 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
311
312 let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
314 writer.write_all(&wrap_buf[..wp])?;
315 }
316
317 if col > 0 {
318 writer.write_all(b"\n")?;
319 }
320 }
321
322 Ok(())
323}
324
325#[inline]
328fn build_wrapped_output(
329 data: &[u8],
330 wrap_col: usize,
331 col: &mut usize,
332 wrap_buf: &mut [u8],
333) -> usize {
334 let mut rp = 0;
335 let mut wp = 0;
336
337 while rp < data.len() {
338 let space = wrap_col - *col;
339 let avail = data.len() - rp;
340
341 if avail <= space {
342 wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
343 wp += avail;
344 *col += avail;
345 if *col == wrap_col {
346 wrap_buf[wp] = b'\n';
347 wp += 1;
348 *col = 0;
349 }
350 break;
351 } else {
352 wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
353 wp += space;
354 wrap_buf[wp] = b'\n';
355 wp += 1;
356 rp += space;
357 *col = 0;
358 }
359 }
360
361 wp
362}
363
364pub fn decode_stream(
366 reader: &mut impl Read,
367 ignore_garbage: bool,
368 writer: &mut impl Write,
369) -> io::Result<()> {
370 const READ_CHUNK: usize = 4 * 1024 * 1024;
371 let mut buf = vec![0u8; READ_CHUNK];
372 let mut clean = Vec::with_capacity(READ_CHUNK);
373 let mut carry: Vec<u8> = Vec::with_capacity(4);
374
375 loop {
376 let n = read_full(reader, &mut buf)?;
377 if n == 0 {
378 break;
379 }
380
381 clean.clear();
383 clean.extend_from_slice(&carry);
384 carry.clear();
385
386 let chunk = &buf[..n];
387 if ignore_garbage {
388 clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
389 } else {
390 let mut last = 0;
392 for pos in memchr::memchr_iter(b'\n', chunk) {
393 if pos > last {
394 clean.extend_from_slice(&chunk[last..pos]);
395 }
396 last = pos + 1;
397 }
398 if last < n {
399 clean.extend_from_slice(&chunk[last..]);
400 }
401 if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
403 clean.retain(|&b| !is_whitespace(b));
404 }
405 }
406
407 let is_last = n < READ_CHUNK;
408
409 if is_last {
410 decode_owned_clean(&mut clean, writer)?;
412 } else {
413 let decode_len = (clean.len() / 4) * 4;
415 if decode_len < clean.len() {
416 carry.extend_from_slice(&clean[decode_len..]);
417 }
418 if decode_len > 0 {
419 clean.truncate(decode_len);
420 decode_owned_clean(&mut clean, writer)?;
421 }
422 }
423 }
424
425 if !carry.is_empty() {
427 decode_owned_clean(&mut carry, writer)?;
428 }
429
430 Ok(())
431}
432
433fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
435 let mut total = 0;
436 while total < buf.len() {
437 match reader.read(&mut buf[total..]) {
438 Ok(0) => break,
439 Ok(n) => total += n,
440 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
441 Err(e) => return Err(e),
442 }
443 }
444 Ok(total)
445}