1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8const STREAM_ENCODE_CHUNK: usize = 12 * 1024 * 1024 - (12 * 1024 * 1024 % 3);
10
11const NOWRAP_CHUNK: usize = 8 * 1024 * 1024 - (8 * 1024 * 1024 % 3);
13
14const PARALLEL_ENCODE_THRESHOLD: usize = 1024 * 1024;
16
17pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
20 if data.is_empty() {
21 return Ok(());
22 }
23
24 if wrap_col == 0 {
25 return encode_no_wrap(data, out);
26 }
27
28 encode_wrapped(data, wrap_col, out)
29}
30
31fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
33 if data.len() >= PARALLEL_ENCODE_THRESHOLD {
34 let num_threads = rayon::current_num_threads().max(1);
36 let raw_chunk = (data.len() + num_threads - 1) / num_threads;
37 let chunk_size = ((raw_chunk + 2) / 3) * 3;
39
40 let encoded_chunks: Vec<Vec<u8>> = data
41 .par_chunks(chunk_size)
42 .map(|chunk| {
43 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
44 let mut buf = vec![0u8; enc_len];
45 let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
46 let len = encoded.len();
47 buf.truncate(len);
48 buf
49 })
50 .collect();
51
52 for chunk in &encoded_chunks {
53 out.write_all(chunk)?;
54 }
55 return Ok(());
56 }
57
58 let enc_max = BASE64_ENGINE.encoded_length(NOWRAP_CHUNK);
59 let mut buf = vec![0u8; enc_max];
60
61 for chunk in data.chunks(NOWRAP_CHUNK) {
62 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
63 let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
64 out.write_all(encoded)?;
65 }
66 Ok(())
67}
68
69fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
71 let bytes_per_line = wrap_col * 3 / 4;
72
73 if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line > 0 {
74 let num_threads = rayon::current_num_threads().max(1);
77 let lines_per_thread = ((data.len() / bytes_per_line) + num_threads - 1) / num_threads;
78 let chunk_input = (lines_per_thread * bytes_per_line).max(bytes_per_line);
79
80 let wrapped_chunks: Vec<Vec<u8>> = data
81 .par_chunks(chunk_input)
82 .map(|chunk| {
83 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
84 let mut encode_buf = vec![0u8; enc_len];
85 let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
86
87 let line_out = wrap_col + 1;
89 let max_lines = (encoded.len() + wrap_col - 1) / wrap_col + 1;
90 let mut wrap_buf = vec![0u8; max_lines * line_out];
91 let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
92 wrap_buf.truncate(wp);
93 wrap_buf
94 })
95 .collect();
96
97 for chunk in &wrapped_chunks {
98 out.write_all(chunk)?;
99 }
100 return Ok(());
101 }
102
103 let lines_per_chunk = (8 * 1024 * 1024) / bytes_per_line.max(1);
105 let chunk_input = lines_per_chunk * bytes_per_line.max(1);
106 let chunk_encoded_max = BASE64_ENGINE.encoded_length(chunk_input.max(1));
107 let mut encode_buf = vec![0u8; chunk_encoded_max];
108 let wrapped_max = (lines_per_chunk + 1) * (wrap_col + 1);
109 let mut wrap_buf = vec![0u8; wrapped_max];
110
111 for chunk in data.chunks(chunk_input.max(1)) {
112 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
113 let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
114 let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
115 out.write_all(&wrap_buf[..wp])?;
116 }
117
118 Ok(())
119}
120
121#[inline]
124fn wrap_encoded(encoded: &[u8], wrap_col: usize, wrap_buf: &mut [u8]) -> usize {
125 let line_out = wrap_col + 1;
126 let mut rp = 0;
127 let mut wp = 0;
128
129 while rp + 4 * wrap_col <= encoded.len() {
131 unsafe {
132 let src = encoded.as_ptr().add(rp);
133 let dst = wrap_buf.as_mut_ptr().add(wp);
134
135 std::ptr::copy_nonoverlapping(src, dst, wrap_col);
136 *dst.add(wrap_col) = b'\n';
137
138 std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
139 *dst.add(line_out + wrap_col) = b'\n';
140
141 std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
142 *dst.add(2 * line_out + wrap_col) = b'\n';
143
144 std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
145 *dst.add(3 * line_out + wrap_col) = b'\n';
146 }
147 rp += 4 * wrap_col;
148 wp += 4 * line_out;
149 }
150
151 while rp + wrap_col <= encoded.len() {
153 wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
154 wp += wrap_col;
155 wrap_buf[wp] = b'\n';
156 wp += 1;
157 rp += wrap_col;
158 }
159
160 if rp < encoded.len() {
162 let remaining = encoded.len() - rp;
163 wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
164 wp += remaining;
165 wrap_buf[wp] = b'\n';
166 wp += 1;
167 }
168
169 wp
170}
171
172pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
176 if data.is_empty() {
177 return Ok(());
178 }
179
180 if ignore_garbage {
181 let mut cleaned = strip_non_base64(data);
182 return decode_owned_clean(&mut cleaned, out);
183 }
184
185 decode_stripping_whitespace(data, out)
187}
188
189pub fn decode_owned(
193 data: &mut Vec<u8>,
194 ignore_garbage: bool,
195 out: &mut impl Write,
196) -> io::Result<()> {
197 if data.is_empty() {
198 return Ok(());
199 }
200
201 if ignore_garbage {
202 data.retain(|&b| is_base64_char(b));
203 } else {
204 strip_whitespace_inplace(data);
205 }
206
207 decode_owned_clean(data, out)
208}
209
210fn strip_whitespace_inplace(data: &mut Vec<u8>) {
213 let positions: Vec<usize> = memchr::memchr_iter(b'\n', data.as_slice()).collect();
215
216 if positions.is_empty() {
217 if data.iter().any(|&b| is_whitespace(b)) {
219 data.retain(|&b| !is_whitespace(b));
220 }
221 return;
222 }
223
224 let mut wp = 0;
226 let mut rp = 0;
227
228 for &pos in &positions {
229 if pos > rp {
230 let len = pos - rp;
231 data.copy_within(rp..pos, wp);
232 wp += len;
233 }
234 rp = pos + 1;
235 }
236
237 let data_len = data.len();
238 if rp < data_len {
239 let len = data_len - rp;
240 data.copy_within(rp..data_len, wp);
241 wp += len;
242 }
243
244 data.truncate(wp);
245
246 if data.iter().any(|&b| is_whitespace(b)) {
248 data.retain(|&b| !is_whitespace(b));
249 }
250}
251
252fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
256 if memchr::memchr2(b'\n', b'\r', data).is_none()
259 && !data.iter().any(|&b| b == b' ' || b == b'\t')
260 {
261 if data.len() >= PARALLEL_ENCODE_THRESHOLD {
263 return decode_parallel(data, out);
264 }
265 return decode_borrowed_clean(out, data);
266 }
267
268 let mut clean = Vec::with_capacity(data.len());
270 let mut last = 0;
271 for pos in memchr::memchr_iter(b'\n', data) {
272 if pos > last {
273 clean.extend_from_slice(&data[last..pos]);
274 }
275 last = pos + 1;
276 }
277 if last < data.len() {
278 clean.extend_from_slice(&data[last..]);
279 }
280
281 if clean.iter().any(|&b| is_whitespace(b)) {
283 clean.retain(|&b| !is_whitespace(b));
284 }
285
286 if clean.len() >= PARALLEL_ENCODE_THRESHOLD {
288 return decode_parallel(&clean, out);
289 }
290
291 decode_owned_clean(&mut clean, out)
292}
293
294fn decode_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
296 let num_threads = rayon::current_num_threads().max(1);
297 let raw_chunk = (data.len() + num_threads - 1) / num_threads;
299 let chunk_size = ((raw_chunk + 3) / 4) * 4;
300
301 let decoded_chunks: Vec<Result<Vec<u8>, _>> = data
304 .par_chunks(chunk_size)
305 .map(|chunk| match BASE64_ENGINE.decode_to_vec(chunk) {
306 Ok(decoded) => Ok(decoded),
307 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
308 })
309 .collect();
310
311 for chunk_result in decoded_chunks {
312 let chunk = chunk_result?;
313 out.write_all(&chunk)?;
314 }
315
316 Ok(())
317}
318
319fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
321 if data.is_empty() {
322 return Ok(());
323 }
324 match BASE64_ENGINE.decode_inplace(data) {
325 Ok(decoded) => out.write_all(decoded),
326 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
327 }
328}
329
330fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
332 if data.is_empty() {
333 return Ok(());
334 }
335 match BASE64_ENGINE.decode_to_vec(data) {
336 Ok(decoded) => {
337 out.write_all(&decoded)?;
338 Ok(())
339 }
340 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
341 }
342}
343
344fn strip_non_base64(data: &[u8]) -> Vec<u8> {
346 data.iter()
347 .copied()
348 .filter(|&b| is_base64_char(b))
349 .collect()
350}
351
352#[inline]
354fn is_base64_char(b: u8) -> bool {
355 b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
356}
357
358#[inline]
360fn is_whitespace(b: u8) -> bool {
361 matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
362}
363
364pub fn encode_stream(
368 reader: &mut impl Read,
369 wrap_col: usize,
370 writer: &mut impl Write,
371) -> io::Result<()> {
372 let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
373
374 let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
375 let mut encode_buf = vec![0u8; encode_buf_size];
376
377 if wrap_col == 0 {
378 loop {
380 let n = read_full(reader, &mut buf)?;
381 if n == 0 {
382 break;
383 }
384 let enc_len = BASE64_ENGINE.encoded_length(n);
385 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
386 writer.write_all(encoded)?;
387 }
388 } else {
389 let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
392 let mut wrap_buf = vec![0u8; max_wrapped];
393 let mut col = 0usize;
394
395 loop {
396 let n = read_full(reader, &mut buf)?;
397 if n == 0 {
398 break;
399 }
400 let enc_len = BASE64_ENGINE.encoded_length(n);
401 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
402
403 let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
405 writer.write_all(&wrap_buf[..wp])?;
406 }
407
408 if col > 0 {
409 writer.write_all(b"\n")?;
410 }
411 }
412
413 Ok(())
414}
415
416#[inline]
420fn build_wrapped_output(
421 data: &[u8],
422 wrap_col: usize,
423 col: &mut usize,
424 wrap_buf: &mut [u8],
425) -> usize {
426 let mut rp = 0;
427 let mut wp = 0;
428
429 while rp < data.len() {
430 let space = wrap_col - *col;
431 let avail = data.len() - rp;
432
433 if avail <= space {
434 wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
435 wp += avail;
436 *col += avail;
437 if *col == wrap_col {
438 wrap_buf[wp] = b'\n';
439 wp += 1;
440 *col = 0;
441 }
442 break;
443 } else {
444 wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
445 wp += space;
446 wrap_buf[wp] = b'\n';
447 wp += 1;
448 rp += space;
449 *col = 0;
450 }
451 }
452
453 wp
454}
455
456pub fn decode_stream(
460 reader: &mut impl Read,
461 ignore_garbage: bool,
462 writer: &mut impl Write,
463) -> io::Result<()> {
464 const READ_CHUNK: usize = 4 * 1024 * 1024;
465 let mut buf = vec![0u8; READ_CHUNK];
466 let mut clean = Vec::with_capacity(READ_CHUNK);
467 let mut carry: Vec<u8> = Vec::with_capacity(4);
468
469 loop {
470 let n = read_full(reader, &mut buf)?;
471 if n == 0 {
472 break;
473 }
474
475 clean.clear();
477 clean.extend_from_slice(&carry);
478 carry.clear();
479
480 let chunk = &buf[..n];
481 if ignore_garbage {
482 clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
483 } else {
484 let mut last = 0;
486 for pos in memchr::memchr_iter(b'\n', chunk) {
487 if pos > last {
488 clean.extend_from_slice(&chunk[last..pos]);
489 }
490 last = pos + 1;
491 }
492 if last < n {
493 clean.extend_from_slice(&chunk[last..]);
494 }
495 if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
497 clean.retain(|&b| !is_whitespace(b));
498 }
499 }
500
501 let is_last = n < READ_CHUNK;
502
503 if is_last {
504 decode_owned_clean(&mut clean, writer)?;
506 } else {
507 let decode_len = (clean.len() / 4) * 4;
509 if decode_len < clean.len() {
510 carry.extend_from_slice(&clean[decode_len..]);
511 }
512 if decode_len > 0 {
513 clean.truncate(decode_len);
514 decode_owned_clean(&mut clean, writer)?;
515 }
516 }
517 }
518
519 if !carry.is_empty() {
521 decode_owned_clean(&mut carry, writer)?;
522 }
523
524 Ok(())
525}
526
527fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
529 let mut total = 0;
530 while total < buf.len() {
531 match reader.read(&mut buf[total..]) {
532 Ok(0) => break,
533 Ok(n) => total += n,
534 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
535 Err(e) => return Err(e),
536 }
537 }
538 Ok(total)
539}