Skip to main content

base64_stream/
from_base64_reader.rs

1use std::{
2    fmt,
3    io::{self, ErrorKind, Read},
4};
5
6use base64::{
7    DecodeSliceError, Engine,
8    engine::{GeneralPurpose, general_purpose::STANDARD},
9};
10
11/// Read base64 data and decode them to plain data.
12pub struct FromBase64Reader<R: Read, const N: usize = 4096> {
13    inner:       R,
14    buf:         [u8; N],
15    buf_length:  usize,
16    buf_offset:  usize,
17    temp:        [u8; 2],
18    temp_length: usize,
19    engine:      &'static GeneralPurpose,
20}
21
22impl<R: Read, const N: usize> fmt::Debug for FromBase64Reader<R, N> {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_struct("FromBase64Reader")
25            .field("buf_length", &self.buf_length)
26            .field("buf_offset", &self.buf_offset)
27            .field("temp", &&self.temp[..self.temp_length])
28            .field("temp_length", &self.temp_length)
29            .finish_non_exhaustive()
30    }
31}
32
33impl<R: Read> FromBase64Reader<R> {
34    #[inline]
35    pub fn new(reader: R) -> FromBase64Reader<R> {
36        Self::new2(reader)
37    }
38}
39
40impl<R: Read, const N: usize> FromBase64Reader<R, N> {
41    #[inline]
42    pub fn new2(reader: R) -> FromBase64Reader<R, N> {
43        const { assert!(N >= 4, "buffer size N must be at least 4") };
44        FromBase64Reader {
45            inner:       reader,
46            buf:         [0u8; N],
47            buf_length:  0,
48            buf_offset:  0,
49            temp:        [0; 2],
50            temp_length: 0,
51            engine:      &STANDARD,
52        }
53    }
54}
55
56impl<R: Read, const N: usize> FromBase64Reader<R, N> {
57    fn buf_left_shift(&mut self, distance: usize) {
58        debug_assert!(self.buf_length >= distance);
59
60        self.buf_offset += distance;
61        self.buf_length -= distance;
62
63        if self.buf_offset >= N - 4 {
64            self.buf.copy_within(self.buf_offset..self.buf_offset + self.buf_length, 0);
65
66            self.buf_offset = 0;
67        }
68    }
69
70    #[inline]
71    fn drain_temp<'a>(&mut self, buf: &'a mut [u8]) -> &'a mut [u8] {
72        debug_assert!(self.temp_length > 0);
73        debug_assert!(!buf.is_empty());
74
75        let drain_length = buf.len().min(self.temp_length);
76
77        buf[..drain_length].copy_from_slice(&self.temp[..drain_length]);
78
79        self.temp_length -= drain_length;
80        self.temp.copy_within(drain_length..drain_length + self.temp_length, 0);
81
82        &mut buf[drain_length..]
83    }
84
85    #[inline]
86    fn drain_block<'a>(&mut self, mut buf: &'a mut [u8]) -> Result<&'a mut [u8], DecodeSliceError> {
87        debug_assert!(self.buf_length > 0);
88        debug_assert!(self.temp_length == 0);
89        debug_assert!(!buf.is_empty());
90
91        let drain_length = self.buf_length.min(4);
92
93        let mut b = [0; 3];
94
95        let decode_length = self
96            .engine
97            .decode_slice(&self.buf[self.buf_offset..(self.buf_offset + drain_length)], &mut b)?;
98
99        self.buf_left_shift(drain_length);
100
101        let buf_length = buf.len();
102
103        if buf_length >= decode_length {
104            buf[..decode_length].copy_from_slice(&b[..decode_length]);
105
106            buf = &mut buf[decode_length..];
107        } else {
108            buf[..buf_length].copy_from_slice(&b[..buf_length]);
109
110            buf = &mut buf[buf_length..];
111
112            self.temp_length = decode_length - buf_length;
113            self.temp[..self.temp_length].copy_from_slice(&b[buf_length..decode_length]);
114        }
115
116        Ok(buf)
117    }
118
119    fn drain<'a>(&mut self, mut buf: &'a mut [u8]) -> Result<&'a mut [u8], DecodeSliceError> {
120        if buf.is_empty() {
121            return Ok(buf);
122        }
123
124        if self.temp_length > 0 {
125            buf = self.drain_temp(buf);
126        }
127
128        debug_assert!(self.buf_length >= 4);
129
130        let buf_length = buf.len();
131
132        if buf_length >= 3 {
133            debug_assert!(self.temp_length == 0);
134
135            let actual_max_read_size = (buf_length / 3) << 2; // (buf_length / 3) * 4
136            let max_available_self_buf_length = self.buf_length & !0b11;
137
138            let drain_length = max_available_self_buf_length.min(actual_max_read_size);
139
140            let decode_length = self
141                .engine
142                .decode_slice(&self.buf[self.buf_offset..(self.buf_offset + drain_length)], buf)?;
143
144            buf = &mut buf[decode_length..];
145
146            self.buf_left_shift(drain_length);
147        }
148
149        if !buf.is_empty() && self.buf_length >= 4 { self.drain_block(buf) } else { Ok(buf) }
150    }
151
152    #[inline]
153    fn drain_end<'a>(&mut self, mut buf: &'a mut [u8]) -> Result<&'a mut [u8], DecodeSliceError> {
154        if buf.is_empty() {
155            return Ok(buf);
156        }
157
158        if self.temp_length > 0 {
159            buf = self.drain_temp(buf);
160        }
161
162        if !buf.is_empty() && self.buf_length > 0 { self.drain_block(buf) } else { Ok(buf) }
163    }
164
165    /// Returns the inner reader, consuming this wrapper.
166    #[inline]
167    pub fn into_inner(self) -> R {
168        self.inner
169    }
170}
171
172impl<R: Read, const N: usize> Read for FromBase64Reader<R, N> {
173    fn read(&mut self, mut buf: &mut [u8]) -> Result<usize, io::Error> {
174        let original_buf_length = buf.len();
175
176        while self.buf_length < 4 {
177            match self.inner.read(&mut self.buf[(self.buf_offset + self.buf_length)..]) {
178                Ok(0) => {
179                    buf = self.drain_end(buf).map_err(io::Error::other)?;
180
181                    return Ok(original_buf_length - buf.len());
182                },
183                Ok(c) => self.buf_length += c,
184                Err(ref e) if e.kind() == ErrorKind::Interrupted => {},
185                Err(e) => return Err(e),
186            }
187        }
188
189        buf = self.drain(buf).map_err(io::Error::other)?;
190
191        Ok(original_buf_length - buf.len())
192    }
193}
194
195impl<R: Read> From<R> for FromBase64Reader<R> {
196    #[inline]
197    fn from(reader: R) -> Self {
198        FromBase64Reader::new(reader)
199    }
200}