base64_stream/
from_base64_reader.rs

1use std::{
2    io::{self, ErrorKind, Read},
3    ptr::{copy, copy_nonoverlapping},
4};
5
6use base64::{
7    engine::{general_purpose::STANDARD, GeneralPurpose},
8    DecodeSliceError, Engine,
9};
10use generic_array::{
11    typenum::{IsGreaterOrEqual, True, U4, U4096},
12    ArrayLength, GenericArray,
13};
14
15/// Read base64 data and decode them to plain data.
16#[derive(Educe)]
17#[educe(Debug)]
18pub struct FromBase64Reader<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True> = U4096> {
19    #[educe(Debug(ignore))]
20    inner:       R,
21    buf:         GenericArray<u8, N>,
22    buf_length:  usize,
23    buf_offset:  usize,
24    temp:        [u8; 2],
25    temp_length: usize,
26    #[educe(Debug(ignore))]
27    engine:      &'static GeneralPurpose,
28}
29
30impl<R: Read> FromBase64Reader<R> {
31    #[inline]
32    pub fn new(reader: R) -> FromBase64Reader<R> {
33        Self::new2(reader)
34    }
35}
36
37impl<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> FromBase64Reader<R, N> {
38    #[inline]
39    pub fn new2(reader: R) -> FromBase64Reader<R, N> {
40        FromBase64Reader {
41            inner:       reader,
42            buf:         GenericArray::default(),
43            buf_length:  0,
44            buf_offset:  0,
45            temp:        [0; 2],
46            temp_length: 0,
47            engine:      &STANDARD,
48        }
49    }
50}
51
52impl<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> FromBase64Reader<R, N> {
53    fn buf_left_shift(&mut self, distance: usize) {
54        debug_assert!(self.buf_length >= distance);
55
56        self.buf_offset += distance;
57
58        if self.buf_offset >= N::USIZE - 4 {
59            unsafe {
60                copy(
61                    self.buf.as_ptr().add(self.buf_offset),
62                    self.buf.as_mut_ptr(),
63                    self.buf_length,
64                );
65            }
66
67            self.buf_offset = 0;
68        }
69
70        self.buf_length -= distance;
71    }
72
73    #[inline]
74    fn drain_temp<'a>(&mut self, buf: &'a mut [u8]) -> &'a mut [u8] {
75        debug_assert!(self.temp_length > 0);
76        debug_assert!(!buf.is_empty());
77
78        let drain_length = buf.len().min(self.temp_length);
79
80        unsafe {
81            copy_nonoverlapping(self.temp.as_ptr(), buf.as_mut_ptr(), drain_length);
82        }
83
84        self.temp_length -= drain_length;
85
86        unsafe {
87            copy(
88                self.temp.as_ptr().add(self.temp_length),
89                self.temp.as_mut_ptr(),
90                self.temp_length,
91            );
92        }
93
94        &mut buf[drain_length..]
95    }
96
97    #[inline]
98    fn drain_block<'a>(&mut self, mut buf: &'a mut [u8]) -> Result<&'a mut [u8], DecodeSliceError> {
99        debug_assert!(self.buf_length > 0);
100        debug_assert!(self.temp_length == 0);
101        debug_assert!(!buf.is_empty());
102
103        let drain_length = self.buf_length.min(4);
104
105        let mut b = [0; 3];
106
107        let decode_length = self
108            .engine
109            .decode_slice(&self.buf[self.buf_offset..(self.buf_offset + drain_length)], &mut b)?;
110
111        self.buf_left_shift(drain_length);
112
113        let buf_length = buf.len();
114
115        if buf_length >= decode_length {
116            unsafe {
117                copy_nonoverlapping(b.as_ptr(), buf.as_mut_ptr(), decode_length);
118            }
119
120            buf = &mut buf[decode_length..];
121        } else {
122            unsafe {
123                copy_nonoverlapping(b.as_ptr(), buf.as_mut_ptr(), buf_length);
124            }
125
126            buf = &mut buf[buf_length..];
127
128            self.temp_length = decode_length - buf_length;
129
130            unsafe {
131                copy_nonoverlapping(
132                    b.as_ptr().add(buf_length),
133                    self.temp.as_mut_ptr(),
134                    self.temp_length,
135                );
136            }
137        }
138
139        Ok(buf)
140    }
141
142    fn drain<'a>(&mut self, mut buf: &'a mut [u8]) -> Result<&'a mut [u8], DecodeSliceError> {
143        if buf.is_empty() {
144            return Ok(buf);
145        }
146
147        if self.temp_length > 0 {
148            buf = self.drain_temp(buf);
149        }
150
151        debug_assert!(self.buf_length >= 4);
152
153        let buf_length = buf.len();
154
155        if buf_length >= 3 {
156            debug_assert!(self.temp_length == 0);
157
158            let actual_max_read_size = (buf_length / 3) << 2; // (buf_length / 3) * 4
159            let max_available_self_buf_length = self.buf_length & !0b11;
160
161            let drain_length = max_available_self_buf_length.min(actual_max_read_size);
162
163            let decode_length = self
164                .engine
165                .decode_slice(&self.buf[self.buf_offset..(self.buf_offset + drain_length)], buf)?;
166
167            buf = &mut buf[decode_length..];
168
169            self.buf_left_shift(drain_length);
170        }
171
172        if !buf.is_empty() && self.buf_length >= 4 {
173            self.drain_block(buf)
174        } else {
175            Ok(buf)
176        }
177    }
178
179    #[inline]
180    fn drain_end<'a>(&mut self, mut buf: &'a mut [u8]) -> Result<&'a mut [u8], DecodeSliceError> {
181        if buf.is_empty() {
182            return Ok(buf);
183        }
184
185        if self.temp_length > 0 {
186            buf = self.drain_temp(buf);
187        }
188
189        if !buf.is_empty() && self.buf_length > 0 {
190            self.drain_block(buf)
191        } else {
192            Ok(buf)
193        }
194    }
195}
196
197impl<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> Read
198    for FromBase64Reader<R, N>
199{
200    fn read(&mut self, mut buf: &mut [u8]) -> Result<usize, io::Error> {
201        let original_buf_length = buf.len();
202
203        while self.buf_length < 4 {
204            match self.inner.read(&mut self.buf[(self.buf_offset + self.buf_length)..]) {
205                Ok(0) => {
206                    buf =
207                        self.drain_end(buf).map_err(|err| io::Error::new(ErrorKind::Other, err))?;
208
209                    return Ok(original_buf_length - buf.len());
210                },
211                Ok(c) => self.buf_length += c,
212                Err(ref e) if e.kind() == ErrorKind::Interrupted => {},
213                Err(e) => return Err(e),
214            }
215        }
216
217        buf = self.drain(buf).map_err(|err| io::Error::new(ErrorKind::Other, err))?;
218
219        Ok(original_buf_length - buf.len())
220    }
221}
222
223impl<R: Read> From<R> for FromBase64Reader<R> {
224    #[inline]
225    fn from(reader: R) -> Self {
226        FromBase64Reader::new(reader)
227    }
228}