1use std::io::{self, Read, Write};
2
3use base64_simd::AsOut;
4use rayon::prelude::*;
5
6const BASE64_ENGINE: &base64_simd::Base64 = &base64_simd::STANDARD;
7
8const STREAM_ENCODE_CHUNK: usize = 4 * 1024 * 1024 - (4 * 1024 * 1024 % 3);
12
13const NOWRAP_CHUNK: usize = 2 * 1024 * 1024 - (2 * 1024 * 1024 % 3);
17
18const PARALLEL_ENCODE_THRESHOLD: usize = 32 * 1024 * 1024;
22
23pub fn encode_to_writer(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
26 if data.is_empty() {
27 return Ok(());
28 }
29
30 if wrap_col == 0 {
31 return encode_no_wrap(data, out);
32 }
33
34 encode_wrapped(data, wrap_col, out)
35}
36
37fn encode_no_wrap(data: &[u8], out: &mut impl Write) -> io::Result<()> {
39 if data.len() >= PARALLEL_ENCODE_THRESHOLD {
40 let num_threads = rayon::current_num_threads().max(1);
42 let raw_chunk = (data.len() + num_threads - 1) / num_threads;
43 let chunk_size = ((raw_chunk + 2) / 3) * 3;
45
46 let encoded_chunks: Vec<Vec<u8>> = data
47 .par_chunks(chunk_size)
48 .map(|chunk| {
49 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
50 let mut buf = vec![0u8; enc_len];
51 let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
52 let len = encoded.len();
53 buf.truncate(len);
54 buf
55 })
56 .collect();
57
58 for chunk in &encoded_chunks {
59 out.write_all(chunk)?;
60 }
61 return Ok(());
62 }
63
64 let actual_chunk = NOWRAP_CHUNK.min(data.len());
67 let enc_max = BASE64_ENGINE.encoded_length(actual_chunk);
68 let mut buf = vec![0u8; enc_max];
69
70 for chunk in data.chunks(NOWRAP_CHUNK) {
71 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
72 let encoded = BASE64_ENGINE.encode(chunk, buf[..enc_len].as_out());
73 out.write_all(encoded)?;
74 }
75 Ok(())
76}
77
78fn encode_wrapped(data: &[u8], wrap_col: usize, out: &mut impl Write) -> io::Result<()> {
80 let bytes_per_line = wrap_col * 3 / 4;
81
82 if data.len() >= PARALLEL_ENCODE_THRESHOLD && bytes_per_line > 0 {
83 let num_threads = rayon::current_num_threads().max(1);
86 let lines_per_thread = ((data.len() / bytes_per_line) + num_threads - 1) / num_threads;
87 let chunk_input = (lines_per_thread * bytes_per_line).max(bytes_per_line);
88
89 let wrapped_chunks: Vec<Vec<u8>> = data
90 .par_chunks(chunk_input)
91 .map(|chunk| {
92 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
93 let mut encode_buf = vec![0u8; enc_len];
94 let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
95
96 let line_out = wrap_col + 1;
98 let max_lines = (encoded.len() + wrap_col - 1) / wrap_col + 1;
99 let mut wrap_buf = vec![0u8; max_lines * line_out];
100 let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
101 wrap_buf.truncate(wp);
102 wrap_buf
103 })
104 .collect();
105
106 for chunk in &wrapped_chunks {
107 out.write_all(chunk)?;
108 }
109 return Ok(());
110 }
111
112 let lines_per_chunk = (2 * 1024 * 1024) / bytes_per_line.max(1);
116 let chunk_input = lines_per_chunk * bytes_per_line.max(1);
117 let effective_chunk = chunk_input.max(1).min(data.len());
118 let chunk_encoded_max = BASE64_ENGINE.encoded_length(effective_chunk);
119 let mut encode_buf = vec![0u8; chunk_encoded_max];
120 let effective_lines = effective_chunk / bytes_per_line.max(1) + 1;
121 let wrapped_max = (effective_lines + 1) * (wrap_col + 1);
122 let mut wrap_buf = vec![0u8; wrapped_max];
123
124 for chunk in data.chunks(chunk_input.max(1)) {
125 let enc_len = BASE64_ENGINE.encoded_length(chunk.len());
126 let encoded = BASE64_ENGINE.encode(chunk, encode_buf[..enc_len].as_out());
127 let wp = wrap_encoded(encoded, wrap_col, &mut wrap_buf);
128 out.write_all(&wrap_buf[..wp])?;
129 }
130
131 Ok(())
132}
133
134#[inline]
137fn wrap_encoded(encoded: &[u8], wrap_col: usize, wrap_buf: &mut [u8]) -> usize {
138 let line_out = wrap_col + 1;
139 let mut rp = 0;
140 let mut wp = 0;
141
142 while rp + 4 * wrap_col <= encoded.len() {
144 unsafe {
145 let src = encoded.as_ptr().add(rp);
146 let dst = wrap_buf.as_mut_ptr().add(wp);
147
148 std::ptr::copy_nonoverlapping(src, dst, wrap_col);
149 *dst.add(wrap_col) = b'\n';
150
151 std::ptr::copy_nonoverlapping(src.add(wrap_col), dst.add(line_out), wrap_col);
152 *dst.add(line_out + wrap_col) = b'\n';
153
154 std::ptr::copy_nonoverlapping(src.add(2 * wrap_col), dst.add(2 * line_out), wrap_col);
155 *dst.add(2 * line_out + wrap_col) = b'\n';
156
157 std::ptr::copy_nonoverlapping(src.add(3 * wrap_col), dst.add(3 * line_out), wrap_col);
158 *dst.add(3 * line_out + wrap_col) = b'\n';
159 }
160 rp += 4 * wrap_col;
161 wp += 4 * line_out;
162 }
163
164 while rp + wrap_col <= encoded.len() {
166 wrap_buf[wp..wp + wrap_col].copy_from_slice(&encoded[rp..rp + wrap_col]);
167 wp += wrap_col;
168 wrap_buf[wp] = b'\n';
169 wp += 1;
170 rp += wrap_col;
171 }
172
173 if rp < encoded.len() {
175 let remaining = encoded.len() - rp;
176 wrap_buf[wp..wp + remaining].copy_from_slice(&encoded[rp..rp + remaining]);
177 wp += remaining;
178 wrap_buf[wp] = b'\n';
179 wp += 1;
180 }
181
182 wp
183}
184
185pub fn decode_to_writer(data: &[u8], ignore_garbage: bool, out: &mut impl Write) -> io::Result<()> {
189 if data.is_empty() {
190 return Ok(());
191 }
192
193 if ignore_garbage {
194 let mut cleaned = strip_non_base64(data);
195 return decode_owned_clean(&mut cleaned, out);
196 }
197
198 decode_stripping_whitespace(data, out)
200}
201
202pub fn decode_owned(
206 data: &mut Vec<u8>,
207 ignore_garbage: bool,
208 out: &mut impl Write,
209) -> io::Result<()> {
210 if data.is_empty() {
211 return Ok(());
212 }
213
214 if ignore_garbage {
215 data.retain(|&b| is_base64_char(b));
216 } else {
217 strip_whitespace_inplace(data);
218 }
219
220 decode_owned_clean(data, out)
221}
222
223fn strip_whitespace_inplace(data: &mut Vec<u8>) {
226 let positions: Vec<usize> = memchr::memchr_iter(b'\n', data.as_slice()).collect();
228
229 if positions.is_empty() {
230 if data.iter().any(|&b| is_whitespace(b)) {
232 data.retain(|&b| !is_whitespace(b));
233 }
234 return;
235 }
236
237 let mut wp = 0;
239 let mut rp = 0;
240
241 for &pos in &positions {
242 if pos > rp {
243 let len = pos - rp;
244 data.copy_within(rp..pos, wp);
245 wp += len;
246 }
247 rp = pos + 1;
248 }
249
250 let data_len = data.len();
251 if rp < data_len {
252 let len = data_len - rp;
253 data.copy_within(rp..data_len, wp);
254 wp += len;
255 }
256
257 data.truncate(wp);
258
259 if data.iter().any(|&b| is_whitespace(b)) {
261 data.retain(|&b| !is_whitespace(b));
262 }
263}
264
265fn decode_stripping_whitespace(data: &[u8], out: &mut impl Write) -> io::Result<()> {
269 if memchr::memchr2(b'\n', b'\r', data).is_none()
272 && !data.iter().any(|&b| b == b' ' || b == b'\t')
273 {
274 if data.len() >= PARALLEL_ENCODE_THRESHOLD {
276 return decode_parallel(data, out);
277 }
278 return decode_borrowed_clean(out, data);
279 }
280
281 let mut clean = Vec::with_capacity(data.len());
283 let mut last = 0;
284 for pos in memchr::memchr_iter(b'\n', data) {
285 if pos > last {
286 clean.extend_from_slice(&data[last..pos]);
287 }
288 last = pos + 1;
289 }
290 if last < data.len() {
291 clean.extend_from_slice(&data[last..]);
292 }
293
294 if clean.iter().any(|&b| is_whitespace(b)) {
296 clean.retain(|&b| !is_whitespace(b));
297 }
298
299 if clean.len() >= PARALLEL_ENCODE_THRESHOLD {
301 return decode_parallel(&clean, out);
302 }
303
304 decode_owned_clean(&mut clean, out)
305}
306
307fn decode_parallel(data: &[u8], out: &mut impl Write) -> io::Result<()> {
309 let num_threads = rayon::current_num_threads().max(1);
310 let raw_chunk = (data.len() + num_threads - 1) / num_threads;
312 let chunk_size = ((raw_chunk + 3) / 4) * 4;
313
314 let decoded_chunks: Vec<Result<Vec<u8>, _>> = data
317 .par_chunks(chunk_size)
318 .map(|chunk| match BASE64_ENGINE.decode_to_vec(chunk) {
319 Ok(decoded) => Ok(decoded),
320 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
321 })
322 .collect();
323
324 for chunk_result in decoded_chunks {
325 let chunk = chunk_result?;
326 out.write_all(&chunk)?;
327 }
328
329 Ok(())
330}
331
332fn decode_owned_clean(data: &mut [u8], out: &mut impl Write) -> io::Result<()> {
334 if data.is_empty() {
335 return Ok(());
336 }
337 match BASE64_ENGINE.decode_inplace(data) {
338 Ok(decoded) => out.write_all(decoded),
339 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
340 }
341}
342
343fn decode_borrowed_clean(out: &mut impl Write, data: &[u8]) -> io::Result<()> {
345 if data.is_empty() {
346 return Ok(());
347 }
348 match BASE64_ENGINE.decode_to_vec(data) {
349 Ok(decoded) => {
350 out.write_all(&decoded)?;
351 Ok(())
352 }
353 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, "invalid input")),
354 }
355}
356
357fn strip_non_base64(data: &[u8]) -> Vec<u8> {
359 data.iter()
360 .copied()
361 .filter(|&b| is_base64_char(b))
362 .collect()
363}
364
365#[inline]
367fn is_base64_char(b: u8) -> bool {
368 b.is_ascii_alphanumeric() || b == b'+' || b == b'/' || b == b'='
369}
370
371#[inline]
373fn is_whitespace(b: u8) -> bool {
374 matches!(b, b' ' | b'\t' | b'\n' | b'\r' | 0x0b | 0x0c)
375}
376
377pub fn encode_stream(
381 reader: &mut impl Read,
382 wrap_col: usize,
383 writer: &mut impl Write,
384) -> io::Result<()> {
385 let mut buf = vec![0u8; STREAM_ENCODE_CHUNK];
386
387 let encode_buf_size = BASE64_ENGINE.encoded_length(STREAM_ENCODE_CHUNK);
388 let mut encode_buf = vec![0u8; encode_buf_size];
389
390 if wrap_col == 0 {
391 loop {
393 let n = read_full(reader, &mut buf)?;
394 if n == 0 {
395 break;
396 }
397 let enc_len = BASE64_ENGINE.encoded_length(n);
398 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
399 writer.write_all(encoded)?;
400 }
401 } else {
402 let max_wrapped = encode_buf_size + (encode_buf_size / wrap_col + 2);
405 let mut wrap_buf = vec![0u8; max_wrapped];
406 let mut col = 0usize;
407
408 loop {
409 let n = read_full(reader, &mut buf)?;
410 if n == 0 {
411 break;
412 }
413 let enc_len = BASE64_ENGINE.encoded_length(n);
414 let encoded = BASE64_ENGINE.encode(&buf[..n], encode_buf[..enc_len].as_out());
415
416 let wp = build_wrapped_output(encoded, wrap_col, &mut col, &mut wrap_buf);
418 writer.write_all(&wrap_buf[..wp])?;
419 }
420
421 if col > 0 {
422 writer.write_all(b"\n")?;
423 }
424 }
425
426 Ok(())
427}
428
429#[inline]
433fn build_wrapped_output(
434 data: &[u8],
435 wrap_col: usize,
436 col: &mut usize,
437 wrap_buf: &mut [u8],
438) -> usize {
439 let mut rp = 0;
440 let mut wp = 0;
441
442 while rp < data.len() {
443 let space = wrap_col - *col;
444 let avail = data.len() - rp;
445
446 if avail <= space {
447 wrap_buf[wp..wp + avail].copy_from_slice(&data[rp..rp + avail]);
448 wp += avail;
449 *col += avail;
450 if *col == wrap_col {
451 wrap_buf[wp] = b'\n';
452 wp += 1;
453 *col = 0;
454 }
455 break;
456 } else {
457 wrap_buf[wp..wp + space].copy_from_slice(&data[rp..rp + space]);
458 wp += space;
459 wrap_buf[wp] = b'\n';
460 wp += 1;
461 rp += space;
462 *col = 0;
463 }
464 }
465
466 wp
467}
468
469pub fn decode_stream(
473 reader: &mut impl Read,
474 ignore_garbage: bool,
475 writer: &mut impl Write,
476) -> io::Result<()> {
477 const READ_CHUNK: usize = 4 * 1024 * 1024;
478 let mut buf = vec![0u8; READ_CHUNK];
479 let mut clean = Vec::with_capacity(READ_CHUNK);
480 let mut carry: Vec<u8> = Vec::with_capacity(4);
481
482 loop {
483 let n = read_full(reader, &mut buf)?;
484 if n == 0 {
485 break;
486 }
487
488 clean.clear();
490 clean.extend_from_slice(&carry);
491 carry.clear();
492
493 let chunk = &buf[..n];
494 if ignore_garbage {
495 clean.extend(chunk.iter().copied().filter(|&b| is_base64_char(b)));
496 } else {
497 let mut last = 0;
499 for pos in memchr::memchr_iter(b'\n', chunk) {
500 if pos > last {
501 clean.extend_from_slice(&chunk[last..pos]);
502 }
503 last = pos + 1;
504 }
505 if last < n {
506 clean.extend_from_slice(&chunk[last..]);
507 }
508 if clean.iter().any(|&b| is_whitespace(b) && b != b'\n') {
510 clean.retain(|&b| !is_whitespace(b));
511 }
512 }
513
514 let is_last = n < READ_CHUNK;
515
516 if is_last {
517 decode_owned_clean(&mut clean, writer)?;
519 } else {
520 let decode_len = (clean.len() / 4) * 4;
522 if decode_len < clean.len() {
523 carry.extend_from_slice(&clean[decode_len..]);
524 }
525 if decode_len > 0 {
526 clean.truncate(decode_len);
527 decode_owned_clean(&mut clean, writer)?;
528 }
529 }
530 }
531
532 if !carry.is_empty() {
534 decode_owned_clean(&mut carry, writer)?;
535 }
536
537 Ok(())
538}
539
540fn read_full(reader: &mut impl Read, buf: &mut [u8]) -> io::Result<usize> {
542 let mut total = 0;
543 while total < buf.len() {
544 match reader.read(&mut buf[total..]) {
545 Ok(0) => break,
546 Ok(n) => total += n,
547 Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
548 Err(e) => return Err(e),
549 }
550 }
551 Ok(total)
552}