base64_stream/
to_base64_writer.rs1use std::{
2 intrinsics::copy_nonoverlapping,
3 io::{self, Write},
4};
5
6use base64::{
7 engine::{general_purpose::STANDARD, GeneralPurpose},
8 Engine,
9};
10use generic_array::{
11 typenum::{IsGreaterOrEqual, True, U4, U4096},
12 ArrayLength, GenericArray,
13};
14
15#[derive(Educe)]
17#[educe(Debug)]
18pub struct ToBase64Writer<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True> = U4096> {
19 #[educe(Debug(ignore))]
20 inner: W,
21 buf: [u8; 3],
22 buf_length: usize,
23 temp: GenericArray<u8, N>,
24 #[educe(Debug(ignore))]
25 engine: &'static GeneralPurpose,
26}
27
28impl<W: Write> ToBase64Writer<W> {
29 #[inline]
30 pub fn new(writer: W) -> ToBase64Writer<W> {
31 Self::new2(writer)
32 }
33}
34
35impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> ToBase64Writer<W, N> {
36 #[inline]
37 pub fn new2(writer: W) -> ToBase64Writer<W, N> {
38 ToBase64Writer {
39 inner: writer,
40 buf: [0; 3],
41 buf_length: 0,
42 temp: GenericArray::default(),
43 engine: &STANDARD,
44 }
45 }
46}
47
48impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> ToBase64Writer<W, N> {
49 fn drain_block(&mut self) -> Result<(), io::Error> {
50 debug_assert!(self.buf_length > 0);
51
52 let encode_length =
53 self.engine.encode_slice(&self.buf[..self.buf_length], &mut self.temp).unwrap();
54
55 self.inner.write_all(&self.temp[..encode_length])?;
56
57 self.buf_length = 0;
58
59 Ok(())
60 }
61}
62
63impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> Write
64 for ToBase64Writer<W, N>
65{
66 fn write(&mut self, mut buf: &[u8]) -> Result<usize, io::Error> {
67 let original_buf_length = buf.len();
68
69 if self.buf_length == 0 {
70 while buf.len() >= 3 {
71 let max_available_buf_length =
72 (buf.len() - (buf.len() % 3)).min((N::USIZE >> 2) * 3); let encode_length = self
75 .engine
76 .encode_slice(&buf[..max_available_buf_length], &mut self.temp)
77 .unwrap();
78
79 buf = &buf[max_available_buf_length..];
80
81 self.inner.write_all(&self.temp[..encode_length])?;
82 }
83
84 let buf_length = buf.len();
85
86 if buf_length > 0 {
87 unsafe {
88 copy_nonoverlapping(buf.as_ptr(), self.buf.as_mut_ptr(), buf_length);
89 }
90
91 self.buf_length = buf_length;
92 }
93 } else {
94 debug_assert!(self.buf_length < 3);
95
96 let r = 3 - self.buf_length;
97
98 let buf_length = buf.len();
99
100 let drain_length = r.min(buf_length);
101
102 unsafe {
103 copy_nonoverlapping(
104 buf.as_ptr(),
105 self.buf.as_mut_ptr().add(self.buf_length),
106 drain_length,
107 );
108 }
109
110 buf = &buf[drain_length..];
111
112 self.buf_length += drain_length;
113
114 if self.buf_length == 3 {
115 self.drain_block()?;
116
117 if buf_length > r {
118 self.write_all(buf)?;
119 }
120 }
121 }
122
123 Ok(original_buf_length)
124 }
125
126 fn flush(&mut self) -> Result<(), io::Error> {
127 if self.buf_length > 0 {
128 self.drain_block()?;
129 }
130
131 Ok(())
132 }
133}
134
135impl<W: Write> From<W> for ToBase64Writer<W> {
136 #[inline]
137 fn from(reader: W) -> Self {
138 ToBase64Writer::new(reader)
139 }
140}