base64_stream/
from_base64_reader.rs1use std::{
2 io::{self, ErrorKind, Read},
3 ptr::{copy, copy_nonoverlapping},
4};
5
6use base64::{
7 engine::{general_purpose::STANDARD, GeneralPurpose},
8 DecodeSliceError, Engine,
9};
10use generic_array::{
11 typenum::{IsGreaterOrEqual, True, U4, U4096},
12 ArrayLength, GenericArray,
13};
14
15#[derive(Educe)]
17#[educe(Debug)]
18pub struct FromBase64Reader<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True> = U4096> {
19 #[educe(Debug(ignore))]
20 inner: R,
21 buf: GenericArray<u8, N>,
22 buf_length: usize,
23 buf_offset: usize,
24 temp: [u8; 2],
25 temp_length: usize,
26 #[educe(Debug(ignore))]
27 engine: &'static GeneralPurpose,
28}
29
30impl<R: Read> FromBase64Reader<R> {
31 #[inline]
32 pub fn new(reader: R) -> FromBase64Reader<R> {
33 Self::new2(reader)
34 }
35}
36
37impl<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> FromBase64Reader<R, N> {
38 #[inline]
39 pub fn new2(reader: R) -> FromBase64Reader<R, N> {
40 FromBase64Reader {
41 inner: reader,
42 buf: GenericArray::default(),
43 buf_length: 0,
44 buf_offset: 0,
45 temp: [0; 2],
46 temp_length: 0,
47 engine: &STANDARD,
48 }
49 }
50}
51
52impl<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> FromBase64Reader<R, N> {
53 fn buf_left_shift(&mut self, distance: usize) {
54 debug_assert!(self.buf_length >= distance);
55
56 self.buf_offset += distance;
57
58 if self.buf_offset >= N::USIZE - 4 {
59 unsafe {
60 copy(
61 self.buf.as_ptr().add(self.buf_offset),
62 self.buf.as_mut_ptr(),
63 self.buf_length,
64 );
65 }
66
67 self.buf_offset = 0;
68 }
69
70 self.buf_length -= distance;
71 }
72
73 #[inline]
74 fn drain_temp<'a>(&mut self, buf: &'a mut [u8]) -> &'a mut [u8] {
75 debug_assert!(self.temp_length > 0);
76 debug_assert!(!buf.is_empty());
77
78 let drain_length = buf.len().min(self.temp_length);
79
80 unsafe {
81 copy_nonoverlapping(self.temp.as_ptr(), buf.as_mut_ptr(), drain_length);
82 }
83
84 self.temp_length -= drain_length;
85
86 unsafe {
87 copy(
88 self.temp.as_ptr().add(self.temp_length),
89 self.temp.as_mut_ptr(),
90 self.temp_length,
91 );
92 }
93
94 &mut buf[drain_length..]
95 }
96
97 #[inline]
98 fn drain_block<'a>(&mut self, mut buf: &'a mut [u8]) -> Result<&'a mut [u8], DecodeSliceError> {
99 debug_assert!(self.buf_length > 0);
100 debug_assert!(self.temp_length == 0);
101 debug_assert!(!buf.is_empty());
102
103 let drain_length = self.buf_length.min(4);
104
105 let mut b = [0; 3];
106
107 let decode_length = self
108 .engine
109 .decode_slice(&self.buf[self.buf_offset..(self.buf_offset + drain_length)], &mut b)?;
110
111 self.buf_left_shift(drain_length);
112
113 let buf_length = buf.len();
114
115 if buf_length >= decode_length {
116 unsafe {
117 copy_nonoverlapping(b.as_ptr(), buf.as_mut_ptr(), decode_length);
118 }
119
120 buf = &mut buf[decode_length..];
121 } else {
122 unsafe {
123 copy_nonoverlapping(b.as_ptr(), buf.as_mut_ptr(), buf_length);
124 }
125
126 buf = &mut buf[buf_length..];
127
128 self.temp_length = decode_length - buf_length;
129
130 unsafe {
131 copy_nonoverlapping(
132 b.as_ptr().add(buf_length),
133 self.temp.as_mut_ptr(),
134 self.temp_length,
135 );
136 }
137 }
138
139 Ok(buf)
140 }
141
142 fn drain<'a>(&mut self, mut buf: &'a mut [u8]) -> Result<&'a mut [u8], DecodeSliceError> {
143 if buf.is_empty() {
144 return Ok(buf);
145 }
146
147 if self.temp_length > 0 {
148 buf = self.drain_temp(buf);
149 }
150
151 debug_assert!(self.buf_length >= 4);
152
153 let buf_length = buf.len();
154
155 if buf_length >= 3 {
156 debug_assert!(self.temp_length == 0);
157
158 let actual_max_read_size = (buf_length / 3) << 2; let max_available_self_buf_length = self.buf_length & !0b11;
160
161 let drain_length = max_available_self_buf_length.min(actual_max_read_size);
162
163 let decode_length = self
164 .engine
165 .decode_slice(&self.buf[self.buf_offset..(self.buf_offset + drain_length)], buf)?;
166
167 buf = &mut buf[decode_length..];
168
169 self.buf_left_shift(drain_length);
170 }
171
172 if !buf.is_empty() && self.buf_length >= 4 {
173 self.drain_block(buf)
174 } else {
175 Ok(buf)
176 }
177 }
178
179 #[inline]
180 fn drain_end<'a>(&mut self, mut buf: &'a mut [u8]) -> Result<&'a mut [u8], DecodeSliceError> {
181 if buf.is_empty() {
182 return Ok(buf);
183 }
184
185 if self.temp_length > 0 {
186 buf = self.drain_temp(buf);
187 }
188
189 if !buf.is_empty() && self.buf_length > 0 {
190 self.drain_block(buf)
191 } else {
192 Ok(buf)
193 }
194 }
195}
196
197impl<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> Read
198 for FromBase64Reader<R, N>
199{
200 fn read(&mut self, mut buf: &mut [u8]) -> Result<usize, io::Error> {
201 let original_buf_length = buf.len();
202
203 while self.buf_length < 4 {
204 match self.inner.read(&mut self.buf[(self.buf_offset + self.buf_length)..]) {
205 Ok(0) => {
206 buf =
207 self.drain_end(buf).map_err(|err| io::Error::new(ErrorKind::Other, err))?;
208
209 return Ok(original_buf_length - buf.len());
210 },
211 Ok(c) => self.buf_length += c,
212 Err(ref e) if e.kind() == ErrorKind::Interrupted => {},
213 Err(e) => return Err(e),
214 }
215 }
216
217 buf = self.drain(buf).map_err(|err| io::Error::new(ErrorKind::Other, err))?;
218
219 Ok(original_buf_length - buf.len())
220 }
221}
222
223impl<R: Read> From<R> for FromBase64Reader<R> {
224 #[inline]
225 fn from(reader: R) -> Self {
226 FromBase64Reader::new(reader)
227 }
228}