1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8const STREAM_ENCODE_CHUNK: usize = 24 * 1024 * 1024 - (24 * 1024 * 1024 % 3);
10
11const NOWRAP_CHUNK: usize = 32 * 1024 * 1024 - (32 * 1024 * 1024 % 3);
14
15const PARALLEL_ENCODE_THRESHOLD: usize = 1024 * 1024;
18
19const PARALLEL_DECODE_THRESHOLD: usize = 1024 * 1024;
22
23pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
26 if data.is_empty() {
27 return Ok(());
28 }
29
30 if wrap_col == 0 {
31 return encode_no_wrap(data, out);
32 }
33
34 encode_wrapped(data, wrap_col, out)
35}
36
37fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
39 if data.len() >= PARALLEL_ENCODE_THRESHOLD {
40 return encode_no_wrap_parallel(data, out);
41 }
42
43 let actual_chunk = NOWRAP_CHUNK.min(data.len());
44 let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
45 let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
47 #[allow(clippy::uninit_vec)]
48 unsafe {
49 buf.set_len(enc_max);
50 }
51
52 for chunk in data.chunks(NOWRAP_CHUNK) {
53 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
54 let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
55 out.write_all(encoded)?;
56 }
57 Ok(())
58}
59
60fn encode_no_wrap_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
64 let num_threads = rayon::current_num_threads().max(1);
65 let raw_chunk = data.len() / num_threads;
66 let chunk_size = ((raw_chunk + 2) / 3) * 3;
68
69 let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(3)).collect();
70 let encoded_chunks: Vec<Vec<u8>> = chunks
71 .par_iter()
72 .map(|chunk| {
73 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
74 let mut buf: Vec<u8> = Vec::with_capacity(enc_len);
75 #[allow(clippy::uninit_vec)]
76 unsafe {
77 buf.set_len(enc_len);
78 }
79 let _ = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
80 buf
81 })
82 .collect();
83
84 let iov: Vec<io::IoSlice> = encoded_chunks.iter().map(|c| io::IoSlice::new(c)).collect();
86 write_all_vectored(out, &iov)
87}
88
89fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
93 let bytes_per_line = wrap_col * 3 / 4;
96 if bytes_per_line == 0 {
97 return encode_wrapped_small(data, wrap_col, out);
99 }
100
101 if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line.is_multiple_of(3) {
104 return encode_wrapped_parallel(data, wrap_col, bytes_per_line, out);
105 }
106
107 let lines_per_chunk = (32 * 1024 * 1024) / bytes_per_line;
111 let max_input_chunk = (lines_per_chunk * bytes_per_line).max(bytes_per_line);
112 let input_chunk = max_input_chunk.min(data.len());
113
114 let enc_max = BASE64_ENGINE.encoded_length(input_chunk);
115 let mut encode_buf: Vec<u8> = Vec::with_capacity(enc_max);
116 #[allow(clippy::uninit_vec)]
117 unsafe {
118 encode_buf.set_len(enc_max);
119 }
120
121 let max_lines = enc_max / wrap_col + 2;
123 let fused_max = enc_max + max_lines;
124 let mut fused_buf: Vec<u8> = Vec::with_capacity(fused_max);
125 #[allow(clippy::uninit_vec)]
126 unsafe {
127 fused_buf.set_len(fused_max);
128 }
129
130 for chunk in data.chunks(max_input_chunk.max(1)) {
131 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
132 let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
133
134 let wp = fuse_wrap(encoded, wrap_col, &mut fused_buf);
136 out.write_all(&fused_buf[..wp])?;
137 }
138
139 Ok(())
140}
141
142fn encode_wrapped_parallel(
146 data: &[u8],
147 wrap_col: usize,
148 bytes_per_line: usize,
149 out: &mut impl Write,
150) -> io::Result<()> {
151 let num_threads = rayon::current_num_threads().max(1);
152 let lines_per_chunk = (data.len() / bytes_per_line / num_threads).max(1);
154 let chunk_size = lines_per_chunk * bytes_per_line;
155
156 let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(bytes_per_line)).collect();
157 let encoded_chunks: Vec<Vec<u8>> = chunks
158 .par_iter()
159 .map(|chunk| {
160 let enc_max = BASE64_ENGINE.encoded_length(chunk.len());
161 let max_lines = enc_max / wrap_col + 2;
162 let fused_size = enc_max + max_lines;
166 let total_size = fused_size + enc_max;
167 let mut buf: Vec<u8> = Vec::with_capacity(total_size);
168 #[allow(clippy::uninit_vec)]
169 unsafe {
170 buf.set_len(total_size);
171 }
172 let _ = BASE64_ENGINE.encode(chunk, buf[fused_size..fused_size + enc_max].as_out());
174 let (fused_region, encode_region) = buf.split_at_mut(fused_size);
176 let encoded = &encode_region[..enc_max];
177 let wp = fuse_wrap(encoded, wrap_col, fused_region);
178 buf.truncate(wp);
179 buf
180 })
181 .collect();
182
183 let iov: Vec<io::IoSlice> = encoded_chunks.iter().map(|c| io::IoSlice::new(c)).collect();
185 write_all_vectored(out, &iov)
186}
187
188#[inline]
192fn fuse_wrap(encoded: &[u8], wrap_col: usize, out_buf: &mut [u8]) -> usize {
193 let line_out = wrap_col + 1; let mut rp = 0;
195 let mut wp = 0;
196
197 while rp + 8 * wrap_col <= encoded.len() {
199 unsafe {
200 let src = encoded.as_ptr().add(rp);
201 let dst = out_buf.as_mut_ptr().add(wp);
202
203 std::ptr::copy_nonoverlapping(src, dst, wrap_col);
204 *dst.add(wrap_col) = b'\n';
205
206 std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
207 *dst.add(line_out + wrap_col) = b'\n';
208
209 std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
210 *dst.add(2 * line_out + wrap_col) = b'\n';
211
212 std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
213 *dst.add(3 * line_out + wrap_col) = b'\n';
214
215 std::ptr::copy_nonoverlapping(src.add(4 * wrap_col), dst.add(4 * line_out), wrap_col);
216 *dst.add(4 * line_out + wrap_col) = b'\n';
217
218 std::ptr::copy_nonoverlapping(src.add(5 * wrap_col), dst.add(5 * line_out), wrap_col);
219 *dst.add(5 * line_out + wrap_col) = b'\n';
220
221 std::ptr::copy_nonoverlapping(src.add(6 * wrap_col), dst.add(6 * line_out), wrap_col);
222 *dst.add(6 * line_out + wrap_col) = b'\n';
223
224 std::ptr::copy_nonoverlapping(src.add(7 * wrap_col), dst.add(7 * line_out), wrap_col);
225 *dst.add(7 * line_out + wrap_col) = b'\n';
226 }
227 rp += 8 * wrap_col;
228 wp += 8 * line_out;
229 }
230
231 while rp + 4 * wrap_col <= encoded.len() {
233 unsafe {
234 let src = encoded.as_ptr().add(rp);
235 let dst = out_buf.as_mut_ptr().add(wp);
236
237 std::ptr::copy_nonoverlapping(src, dst, wrap_col);
238 *dst.add(wrap_col) = b'\n';
239
240 std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
241 *dst.add(line_out + wrap_col) = b'\n';
242
243 std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
244 *dst.add(2 * line_out + wrap_col) = b'\n';
245
246 std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
247 *dst.add(3 * line_out + wrap_col) = b'\n';
248 }
249 rp += 4 * wrap_col;
250 wp += 4 * line_out;
251 }
252
253 while rp + wrap_col <= encoded.len() {
255 unsafe {
256 std::ptr::copy_nonoverlapping(
257 encoded.as_ptr().add(rp),
258 out_buf.as_mut_ptr().add(wp),
259 wrap_col,
260 );
261 *out_buf.as_mut_ptr().add(wp + wrap_col) = b'\n';
262 }
263 rp += wrap_col;
264 wp += line_out;
265 }
266
267 if rp < encoded.len() {
269 let remaining = encoded.len() - rp;
270 unsafe {
271 std::ptr::copy_nonoverlapping(
272 encoded.as_ptr().add(rp),
273 out_buf.as_mut_ptr().add(wp),
274 remaining,
275 );
276 }
277 wp += remaining;
278 out_buf[wp] = b'\n';
279 wp += 1;
280 }
281
282 wp
283}
284
285fn encode_wrapped_small(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
287 let enc_max = BASE64_ENGINE.encoded_length(data.len());
288 let mut buf: Vec<u8> = Vec::with_capacity(enc_max);
289 #[allow(clippy::uninit_vec)]
290 unsafe {
291 buf.set_len(enc_max);
292 }
293 let encoded = BASE64_ENGINE.encode(data, buf[..enc_max].as_out());
294
295 let wc = wrap_col.max(1);
296 for line in encoded.chunks(wc) {
297 out.write_all(line)?;
298 out.write_all(b"\n")?;
299 }
300 Ok(())
301}
302
303pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
307 if data.is_empty() {
308 return Ok(());
309 }
310
311 if ignore_garbage {
312 let mut cleaned = strip_non_base64(data);
313 return decode_clean_slice(&mut cleaned, out);
314 }
315
316 decode_stripping_whitespace(data, out)
318}
319
320pub fn decode_owned(
322 data: &mut Vec<u8>,
323 ignore_garbage: bool,
324 out: &mut impl Write,
325) -> io::Result<()> {
326 if data.is_empty() {
327 return Ok(());
328 }
329
330 if ignore_garbage {
331 data.retain(|&b| is_base64_char(b));
332 } else {
333 strip_whitespace_inplace(data);
334 }
335
336 decode_clean_slice(data, out)
337}
338
339fn strip_whitespace_inplace(data: &mut Vec<u8>) {
343 let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
345 if !has_ws {
346 return;
347 }
348
349 let ptr = data.as_ptr();
351 let mut_ptr = data.as_mut_ptr();
352 let len = data.len();
353 let mut wp = 0usize;
354
355 for i in 0..len {
356 let b = unsafe { *ptr.add(i) };
357 if NOT_WHITESPACE[b as usize] {
358 unsafe { *mut_ptr.add(wp) = b };
359 wp += 1;
360 }
361 }
362
363 data.truncate(wp);
364}
365
366static NOT_WHITESPACE: [bool; 256] = {
369 let mut table = [true; 256];
370 table[b' ' as usize] = false;
371 table[b'\t' as usize] = false;
372 table[b'\n' as usize] = false;
373 table[b'\r' as usize] = false;
374 table[0x0b] = false; table[0x0c] = false; table
377};
378
379fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
384 let has_ws = data.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
386 if !has_ws {
387 return decode_borrowed_clean(out, data);
389 }
390
391 let mut clean: Vec<u8> = Vec::with_capacity(data.len());
395 let dst = clean.as_mut_ptr();
396 let mut wp = 0usize;
397 let mut gap_start = 0usize;
398
399 for pos in memchr::memchr2_iter(b'\n', b'\r', data) {
400 let gap_len = pos - gap_start;
401 if gap_len > 0 {
402 unsafe {
403 std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), gap_len);
404 }
405 wp += gap_len;
406 }
407 gap_start = pos + 1;
408 }
409 let tail_len = data.len() - gap_start;
411 if tail_len > 0 {
412 unsafe {
413 std::ptr::copy_nonoverlapping(data.as_ptr().add(gap_start), dst.add(wp), tail_len);
414 }
415 wp += tail_len;
416 }
417 unsafe {
418 clean.set_len(wp);
419 }
420
421 let has_rare_ws = clean.iter().any(|&b| !NOT_WHITESPACE[b as usize]);
424 if has_rare_ws {
425 let ptr = clean.as_mut_ptr();
426 let len = clean.len();
427 let mut rp = 0;
428 let mut cwp = 0;
429 while rp < len {
430 let b = unsafe { *ptr.add(rp) };
431 if NOT_WHITESPACE[b as usize] {
432 unsafe { *ptr.add(cwp) = b };
433 cwp += 1;
434 }
435 rp += 1;
436 }
437 clean.truncate(cwp);
438 }
439
440 decode_clean_slice(&mut clean, out)
441}
442
443fn decode_clean_slice(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
445 if data.is_empty() {
446 return Ok(());
447 }
448 match BASE64_ENGINE.decode_inplace(data) {
449 Ok(decoded) => out.write_all(decoded),
450 Err(_) => decode_error(),
451 }
452}
453
454#[cold]
456#[inline(never)]
457fn decode_error() -> io::Result<()> {
458 Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input"))
459}
460
461fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
463 if data.is_empty() {
464 return Ok(());
465 }
466 if data.len() >= PARALLEL_DECODE_THRESHOLD {
469 return decode_borrowed_clean_parallel(out, data);
470 }
471 match BASE64_ENGINE.decode_to_vec(data) {
472 Ok(decoded) => {
473 out.write_all(&decoded)?;
474 Ok(())
475 }
476 Err(_) => decode_error(),
477 }
478}
479
480fn decode_borrowed_clean_parallel(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
484 let num_threads = rayon::current_num_threads().max(1);
485 let raw_chunk = data.len() / num_threads;
486 let chunk_size = ((raw_chunk + 3) / 4) * 4;
488
489 let chunks: Vec<&[u8]> = data.chunks(chunk_size.max(4)).collect();
490
491 let mut offsets: Vec<usize> = Vec::with_capacity(chunks.len() + 1);
495 offsets.push(0);
496 let mut total_decoded = 0usize;
497 for (i, chunk) in chunks.iter().enumerate() {
498 let decoded_size = if i == chunks.len() - 1 {
499 let pad = chunk.iter().rev().take(2).filter(|&&b| b == b'=').count();
501 chunk.len() * 3 / 4 - pad
502 } else {
503 chunk.len() * 3 / 4
505 };
506 total_decoded += decoded_size;
507 offsets.push(total_decoded);
508 }
509
510 let mut output_buf: Vec<u8> = Vec::with_capacity(total_decoded);
512 #[allow(clippy::uninit_vec)]
513 unsafe {
514 output_buf.set_len(total_decoded);
515 }
516
517 let out_addr = output_buf.as_mut_ptr() as usize;
522 let decode_result: Result<Vec<()>, io::Error> = chunks
523 .par_iter()
524 .enumerate()
525 .map(|(i, chunk)| {
526 let offset = offsets[i];
527 let expected_size = offsets[i + 1] - offset;
528 let out_slice = unsafe {
530 std::slice::from_raw_parts_mut((out_addr as *mut u8).add(offset), expected_size)
531 };
532 let decoded = BASE64_ENGINE
533 .decode(chunk, out_slice.as_out())
534 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid input"))?;
535 debug_assert_eq!(decoded.len(), expected_size);
536 Ok(())
537 })
538 .collect();
539
540 decode_result?;
541
542 out.write_all(&output_buf[..total_decoded])
543}
544
545fn strip_non_base64(data: &[u8]) -> Vec<u8> {
547 data.iter()
548 .copied()
549 .filter(|&b| is_base64_char(b))
550 .collect()
551}
552
553#[inline]
555fn is_base64_char(b: u8) -> bool {
556 b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
557}
558
559pub fn encode_stream(
561 reader: &mut impl Read,
562 wrap_col: usize,
563 writer: &mut impl Write,
564) -> io::Result<()> {
565 let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
566
567 let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
568 let mut encode_buf = vec![0u8; encode_buf_size];
569
570 if wrap_col == 0 {
571 loop {
573 let n = read_full(reader, &mut buf)?;
574 if n == 0 {
575 break;
576 }
577 let enc_len = BASE64_ENGINE.encoded_length(n);
578 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
579 writer.write_all(encoded)?;
580 }
581 } else {
582 let max_fused = encode_buf_size + (encode_buf_size / wrap_col + 2);
584 let mut fused_buf = vec![0u8; max_fused];
585 let mut col = 0usize;
586
587 loop {
588 let n = read_full(reader, &mut buf)?;
589 if n == 0 {
590 break;
591 }
592 let enc_len = BASE64_ENGINE.encoded_length(n);
593 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
594
595 let wp = build_fused_output(encoded, wrap_col, &mut col, &mut fused_buf);
597 writer.write_all(&fused_buf[..wp])?;
598 }
599
600 if col > 0 {
601 writer.write_all(b"\n")?;
602 }
603 }
604
605 Ok(())
606}
607
608#[inline]
611fn build_fused_output(data: &[u8], wrap_col: usize, col: &mut usize, out_buf: &mut [u8]) -> usize {
612 let mut rp = 0;
613 let mut wp = 0;
614
615 while rp < data.len() {
616 let space = wrap_col - *col;
617 let avail = data.len() - rp;
618
619 if avail <= space {
620 out_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
621 wp += avail;
622 *col += avail;
623 if *col == wrap_col {
624 out_buf[wp] = b'\n';
625 wp += 1;
626 *col = 0;
627 }
628 break;
629 } else {
630 out_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
631 wp += space;
632 out_buf[wp] = b'\n';
633 wp += 1;
634 rp += space;
635 *col = 0;
636 }
637 }
638
639 wp
640}
641
642pub fn decode_stream(
647 reader: &mut impl Read,
648 ignore_garbage: bool,
649 writer: &mut impl Write,
650) -> io::Result<()> {
651 const READ_CHUNK: usize = 16 * 1024 * 1024;
652 let mut buf = vec![0u8; READ_CHUNK];
653 let mut clean = Vec::with_capacity(READ_CHUNK);
654 let mut carry: Vec<u8> = Vec::with_capacity(4);
655
656 loop {
657 let n = read_full(reader, &mut buf)?;
658 if n == 0 {
659 break;
660 }
661
662 clean.clear();
664 clean.extend_from_slice(&carry);
665 carry.clear();
666
667 let chunk = &buf[..n];
668 if ignore_garbage {
669 clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
670 } else {
671 clean.reserve(n);
676 let base_len = clean.len();
677 let dst = unsafe { clean.as_mut_ptr().add(base_len) };
678 let mut wp = 0usize;
679 let mut gap_start = 0usize;
680
681 for pos in memchr::memchr2_iter(b'\n', b'\r', chunk) {
682 let gap_len = pos - gap_start;
684 if gap_len > 0 {
685 unsafe {
686 std::ptr::copy_nonoverlapping(
687 chunk.as_ptr().add(gap_start),
688 dst.add(wp),
689 gap_len,
690 );
691 }
692 wp += gap_len;
693 }
694 gap_start = pos + 1;
695 }
696 let tail_len = n - gap_start;
698 if tail_len > 0 {
699 unsafe {
700 std::ptr::copy_nonoverlapping(
701 chunk.as_ptr().add(gap_start),
702 dst.add(wp),
703 tail_len,
704 );
705 }
706 wp += tail_len;
707 }
708 unsafe { clean.set_len(base_len + wp) };
709
710 let has_rare_ws = clean[base_len..]
713 .iter()
714 .any(|&b| !NOT_WHITESPACE[b as usize]);
715 if has_rare_ws {
716 let start = base_len;
718 let end = clean.len();
719 let ptr = clean.as_mut_ptr();
720 let mut rp = start;
721 let mut cwp = start;
722 while rp < end {
723 let b = unsafe { *ptr.add(rp) };
724 if NOT_WHITESPACE[b as usize] {
725 unsafe { *ptr.add(cwp) = b };
726 cwp += 1;
727 }
728 rp += 1;
729 }
730 clean.truncate(cwp);
731 }
732 }
733
734 let is_last = n < READ_CHUNK;
735
736 if is_last {
737 decode_clean_slice(&mut clean, writer)?;
739 } else {
740 let decode_len = (clean.len() / 4) * 4;
742 if decode_len < clean.len() {
743 carry.extend_from_slice(&clean[decode_len..]);
744 }
745 if decode_len > 0 {
746 clean.truncate(decode_len);
747 decode_clean_slice(&mut clean, writer)?;
748 }
749 }
750 }
751
752 if !carry.is_empty() {
754 decode_clean_slice(&mut carry, writer)?;
755 }
756
757 Ok(())
758}
759
760fn write_all_vectored(out: &mut impl Write, slices: &[io::IoSlice]) -> io::Result<()> {
763 if slices.is_empty() {
764 return Ok(());
765 }
766 let total: usize = slices.iter().map(|s| s.len()).sum();
767
768 let written = match out.write_vectored(slices) {
770 Ok(n) if n >= total => return Ok(()),
771 Ok(n) => n,
772 Err(e) => return Err(e),
773 };
774
775 let mut skip = written;
777 for slice in slices {
778 let slen = slice.len();
779 if skip >= slen {
780 skip -= slen;
781 continue;
782 }
783 if skip > 0 {
784 out.write_all(&slice[skip..])?;
785 skip = 0;
786 } else {
787 out.write_all(slice)?;
788 }
789 }
790 Ok(())
791}
792
793#[inline]
797fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
798 let n = reader.read(buf)?;
800 if n == buf.len() || n == 0 {
801 return Ok(n);
802 }
803 let mut total = n;
805 while total < buf.len() {
806 match reader.read(&mut buf[total..]) {
807 Ok(0) => break,
808 Ok(n) => total += n,
809 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
810 Err(e) => return Err(e),
811 }
812 }
813 Ok(total)
814}