1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8const STREAM_ENCODE_CHUNK: usize = 12 * 1024 * 1024 - (12 * 1024 * 1024 % 3);
10
11const NOWRAP_CHUNK: usize = 2 * 1024 * 1024 - (2 * 1024 * 1024 % 3);
15
16const PARALLEL_ENCODE_THRESHOLD: usize = 32 * 1024 * 1024;
20
21pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
24 if data.is_empty() {
25 return Ok(());
26 }
27
28 if wrap_col == 0 {
29 return encode_no_wrap(data, out);
30 }
31
32 encode_wrapped(data, wrap_col, out)
33}
34
35fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
37 if data.len() >= PARALLEL_ENCODE_THRESHOLD {
38 let num_threads = rayon::current_num_threads().max(1);
40 let raw_chunk = (data.len() + num_threads - 1) / num_threads;
41 let chunk_size = ((raw_chunk + 2) / 3) * 3;
43
44 let encoded_chunks: Vec<Vec<u8>> = data
45 .par_chunks(chunk_size)
46 .map(|chunk| {
47 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
48 let mut buf = vec![0u8; enc_len];
49 let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
50 let len = encoded.len();
51 buf.truncate(len);
52 buf
53 })
54 .collect();
55
56 for chunk in &encoded_chunks {
57 out.write_all(chunk)?;
58 }
59 return Ok(());
60 }
61
62 let enc_max = BASE64_ENGINE.encoded_length(NOWRAP_CHUNK);
63 let mut buf = vec![0u8; enc_max];
64
65 for chunk in data.chunks(NOWRAP_CHUNK) {
66 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
67 let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
68 out.write_all(encoded)?;
69 }
70 Ok(())
71}
72
73fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
75 let bytes_per_line = wrap_col * 3 / 4;
76
77 if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line > 0 {
78 let num_threads = rayon::current_num_threads().max(1);
81 let lines_per_thread = ((data.len() / bytes_per_line) + num_threads - 1) / num_threads;
82 let chunk_input = (lines_per_thread * bytes_per_line).max(bytes_per_line);
83
84 let wrapped_chunks: Vec<Vec<u8>> = data
85 .par_chunks(chunk_input)
86 .map(|chunk| {
87 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
88 let mut encode_buf = vec![0u8; enc_len];
89 let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
90
91 let line_out = wrap_col + 1;
93 let max_lines = (encoded.len() + wrap_col - 1) / wrap_col + 1;
94 let mut wrap_buf = vec![0u8; max_lines * line_out];
95 let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
96 wrap_buf.truncate(wp);
97 wrap_buf
98 })
99 .collect();
100
101 for chunk in &wrapped_chunks {
102 out.write_all(chunk)?;
103 }
104 return Ok(());
105 }
106
107 let lines_per_chunk = (2 * 1024 * 1024) / bytes_per_line.max(1);
109 let chunk_input = lines_per_chunk * bytes_per_line.max(1);
110 let chunk_encoded_max = BASE64_ENGINE.encoded_length(chunk_input.max(1));
111 let mut encode_buf = vec![0u8; chunk_encoded_max];
112 let wrapped_max = (lines_per_chunk + 1) * (wrap_col + 1);
113 let mut wrap_buf = vec![0u8; wrapped_max];
114
115 for chunk in data.chunks(chunk_input.max(1)) {
116 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
117 let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
118 let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
119 out.write_all(&wrap_buf[..wp])?;
120 }
121
122 Ok(())
123}
124
125#[inline]
128fn wrap_encoded(encoded: &[u8], wrap_col: usize, wrap_buf: &mut [u8]) -> usize {
129 let line_out = wrap_col + 1;
130 let mut rp = 0;
131 let mut wp = 0;
132
133 while rp + 4 * wrap_col <= encoded.len() {
135 unsafe {
136 let src = encoded.as_ptr().add(rp);
137 let dst = wrap_buf.as_mut_ptr().add(wp);
138
139 std::ptr::copy_nonoverlapping(src, dst, wrap_col);
140 *dst.add(wrap_col) = b'\n';
141
142 std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
143 *dst.add(line_out + wrap_col) = b'\n';
144
145 std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
146 *dst.add(2 * line_out + wrap_col) = b'\n';
147
148 std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
149 *dst.add(3 * line_out + wrap_col) = b'\n';
150 }
151 rp += 4 * wrap_col;
152 wp += 4 * line_out;
153 }
154
155 while rp + wrap_col <= encoded.len() {
157 wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
158 wp += wrap_col;
159 wrap_buf[wp] = b'\n';
160 wp += 1;
161 rp += wrap_col;
162 }
163
164 if rp < encoded.len() {
166 let remaining = encoded.len() - rp;
167 wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
168 wp += remaining;
169 wrap_buf[wp] = b'\n';
170 wp += 1;
171 }
172
173 wp
174}
175
176pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
180 if data.is_empty() {
181 return Ok(());
182 }
183
184 if ignore_garbage {
185 let mut cleaned = strip_non_base64(data);
186 return decode_owned_clean(&mut cleaned, out);
187 }
188
189 decode_stripping_whitespace(data, out)
191}
192
193pub fn decode_owned(
197 data: &mut Vec<u8>,
198 ignore_garbage: bool,
199 out: &mut impl Write,
200) -> io::Result<()> {
201 if data.is_empty() {
202 return Ok(());
203 }
204
205 if ignore_garbage {
206 data.retain(|&b| is_base64_char(b));
207 } else {
208 strip_whitespace_inplace(data);
209 }
210
211 decode_owned_clean(data, out)
212}
213
214fn strip_whitespace_inplace(data: &mut Vec<u8>) {
217 let positions: Vec<usize> = memchr::memchr_iter(b'\n', data.as_slice()).collect();
219
220 if positions.is_empty() {
221 if data.iter().any(|&b| is_whitespace(b)) {
223 data.retain(|&b| !is_whitespace(b));
224 }
225 return;
226 }
227
228 let mut wp = 0;
230 let mut rp = 0;
231
232 for &pos in &positions {
233 if pos > rp {
234 let len = pos - rp;
235 data.copy_within(rp..pos, wp);
236 wp += len;
237 }
238 rp = pos + 1;
239 }
240
241 let data_len = data.len();
242 if rp < data_len {
243 let len = data_len - rp;
244 data.copy_within(rp..data_len, wp);
245 wp += len;
246 }
247
248 data.truncate(wp);
249
250 if data.iter().any(|&b| is_whitespace(b)) {
252 data.retain(|&b| !is_whitespace(b));
253 }
254}
255
256fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
260 if memchr::memchr2(b'\n', b'\r', data).is_none()
263 && !data.iter().any(|&b| b == b' ' || b == b'\t')
264 {
265 if data.len() >= PARALLEL_ENCODE_THRESHOLD {
267 return decode_parallel(data, out);
268 }
269 return decode_borrowed_clean(out, data);
270 }
271
272 let mut clean = Vec::with_capacity(data.len());
274 let mut last = 0;
275 for pos in memchr::memchr_iter(b'\n', data) {
276 if pos > last {
277 clean.extend_from_slice(&data[last..pos]);
278 }
279 last = pos + 1;
280 }
281 if last < data.len() {
282 clean.extend_from_slice(&data[last..]);
283 }
284
285 if clean.iter().any(|&b| is_whitespace(b)) {
287 clean.retain(|&b| !is_whitespace(b));
288 }
289
290 if clean.len() >= PARALLEL_ENCODE_THRESHOLD {
292 return decode_parallel(&clean, out);
293 }
294
295 decode_owned_clean(&mut clean, out)
296}
297
298fn decode_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
300 let num_threads = rayon::current_num_threads().max(1);
301 let raw_chunk = (data.len() + num_threads - 1) / num_threads;
303 let chunk_size = ((raw_chunk + 3) / 4) * 4;
304
305 let decoded_chunks: Vec<Result<Vec<u8>, _>> = data
308 .par_chunks(chunk_size)
309 .map(|chunk| match BASE64_ENGINE.decode_to_vec(chunk) {
310 Ok(decoded) => Ok(decoded),
311 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
312 })
313 .collect();
314
315 for chunk_result in decoded_chunks {
316 let chunk = chunk_result?;
317 out.write_all(&chunk)?;
318 }
319
320 Ok(())
321}
322
323fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
325 if data.is_empty() {
326 return Ok(());
327 }
328 match BASE64_ENGINE.decode_inplace(data) {
329 Ok(decoded) => out.write_all(decoded),
330 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
331 }
332}
333
334fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
336 if data.is_empty() {
337 return Ok(());
338 }
339 match BASE64_ENGINE.decode_to_vec(data) {
340 Ok(decoded) => {
341 out.write_all(&decoded)?;
342 Ok(())
343 }
344 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
345 }
346}
347
348fn strip_non_base64(data: &[u8]) -> Vec<u8> {
350 data.iter()
351 .copied()
352 .filter(|&b| is_base64_char(b))
353 .collect()
354}
355
356#[inline]
358fn is_base64_char(b: u8) -> bool {
359 b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
360}
361
362#[inline]
364fn is_whitespace(b: u8) -> bool {
365 matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
366}
367
368pub fn encode_stream(
372 reader: &mut impl Read,
373 wrap_col: usize,
374 writer: &mut impl Write,
375) -> io::Result<()> {
376 let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
377
378 let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
379 let mut encode_buf = vec![0u8; encode_buf_size];
380
381 if wrap_col == 0 {
382 loop {
384 let n = read_full(reader, &mut buf)?;
385 if n == 0 {
386 break;
387 }
388 let enc_len = BASE64_ENGINE.encoded_length(n);
389 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
390 writer.write_all(encoded)?;
391 }
392 } else {
393 let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
396 let mut wrap_buf = vec![0u8; max_wrapped];
397 let mut col = 0usize;
398
399 loop {
400 let n = read_full(reader, &mut buf)?;
401 if n == 0 {
402 break;
403 }
404 let enc_len = BASE64_ENGINE.encoded_length(n);
405 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
406
407 let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
409 writer.write_all(&wrap_buf[..wp])?;
410 }
411
412 if col > 0 {
413 writer.write_all(b"\n")?;
414 }
415 }
416
417 Ok(())
418}
419
420#[inline]
424fn build_wrapped_output(
425 data: &[u8],
426 wrap_col: usize,
427 col: &mut usize,
428 wrap_buf: &mut [u8],
429) -> usize {
430 let mut rp = 0;
431 let mut wp = 0;
432
433 while rp < data.len() {
434 let space = wrap_col - *col;
435 let avail = data.len() - rp;
436
437 if avail <= space {
438 wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
439 wp += avail;
440 *col += avail;
441 if *col == wrap_col {
442 wrap_buf[wp] = b'\n';
443 wp += 1;
444 *col = 0;
445 }
446 break;
447 } else {
448 wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
449 wp += space;
450 wrap_buf[wp] = b'\n';
451 wp += 1;
452 rp += space;
453 *col = 0;
454 }
455 }
456
457 wp
458}
459
460pub fn decode_stream(
464 reader: &mut impl Read,
465 ignore_garbage: bool,
466 writer: &mut impl Write,
467) -> io::Result<()> {
468 const READ_CHUNK: usize = 4 * 1024 * 1024;
469 let mut buf = vec![0u8; READ_CHUNK];
470 let mut clean = Vec::with_capacity(READ_CHUNK);
471 let mut carry: Vec<u8> = Vec::with_capacity(4);
472
473 loop {
474 let n = read_full(reader, &mut buf)?;
475 if n == 0 {
476 break;
477 }
478
479 clean.clear();
481 clean.extend_from_slice(&carry);
482 carry.clear();
483
484 let chunk = &buf[..n];
485 if ignore_garbage {
486 clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
487 } else {
488 let mut last = 0;
490 for pos in memchr::memchr_iter(b'\n', chunk) {
491 if pos > last {
492 clean.extend_from_slice(&chunk[last..pos]);
493 }
494 last = pos + 1;
495 }
496 if last < n {
497 clean.extend_from_slice(&chunk[last..]);
498 }
499 if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
501 clean.retain(|&b| !is_whitespace(b));
502 }
503 }
504
505 let is_last = n < READ_CHUNK;
506
507 if is_last {
508 decode_owned_clean(&mut clean, writer)?;
510 } else {
511 let decode_len = (clean.len() / 4) * 4;
513 if decode_len < clean.len() {
514 carry.extend_from_slice(&clean[decode_len..]);
515 }
516 if decode_len > 0 {
517 clean.truncate(decode_len);
518 decode_owned_clean(&mut clean, writer)?;
519 }
520 }
521 }
522
523 if !carry.is_empty() {
525 decode_owned_clean(&mut carry, writer)?;
526 }
527
528 Ok(())
529}
530
531fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
533 let mut total = 0;
534 while total < buf.len() {
535 match reader.read(&mut buf[total..]) {
536 Ok(0) => break,
537 Ok(n) => total += n,
538 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
539 Err(e) => return Err(e),
540 }
541 }
542 Ok(total)
543}