cmail_rpgp/
base64_decoder.rs

1//! # base64 decoder module
2
3use std::io::{self, BufRead, Read};
4
5use base64::engine::{general_purpose::GeneralPurpose, Engine};
6use buffer_redux::{BufReader, Buffer};
7
8const BUF_SIZE: usize = 1024;
9const BUF_CAPACITY: usize = BUF_SIZE / 4 * 3;
10const ENGINE: GeneralPurpose = base64::engine::general_purpose::STANDARD;
11
12/// Decodes Base64 from the supplied reader.
13#[derive(Debug)]
14pub struct Base64Decoder<R> {
15    /// The inner Read instance we are reading bytes from.
16    inner: BufReader<R>,
17    /// leftover decoded output
18    out: Buffer,
19    out_buffer: [u8; BUF_CAPACITY],
20    /// Memorize if we had an error, so we can return it on calls to read again.
21    err: Option<io::Error>,
22}
23
24impl<R: Read> Base64Decoder<R> {
25    /// Creates a new `Base64Decoder`.
26    pub fn new(input: R) -> Self {
27        Base64Decoder {
28            inner: BufReader::with_capacity(BUF_SIZE, input),
29            out: Buffer::with_capacity(BUF_CAPACITY),
30            out_buffer: [0u8; BUF_CAPACITY],
31            err: None,
32        }
33    }
34
35    pub fn into_inner_with_buffer(self) -> (R, Buffer) {
36        self.inner.into_inner_with_buffer()
37    }
38}
39
40impl<R: Read> Read for Base64Decoder<R> {
41    fn read(&mut self, into: &mut [u8]) -> io::Result<usize> {
42        // take care of leftovers
43        if !self.out.is_empty() {
44            let len = self.out.copy_to_slice(into);
45            return Ok(len);
46        }
47
48        // if we had an error before, return it
49        if let Some(ref err) = self.err {
50            return Err(copy_err(err));
51        }
52
53        // fill our buffer
54        if self.inner.buf_len() < 4 {
55            let b = &mut self.inner;
56
57            if let Err(err) = b.read_into_buf() {
58                self.err = Some(copy_err(&err));
59                return Err(err);
60            }
61        }
62
63        // short circuit empty read
64        if self.inner.buf_len() == 0 {
65            return Ok(0);
66        }
67
68        let nr = self.inner.buf_len() / 4 * 4;
69        let nw = self.inner.buf_len() / 4 * 3;
70
71        let (consumed, written) = if nw > into.len() {
72            let (consumed, nw) =
73                try_decode_engine_slice(&self.inner.buffer()[..nr], &mut self.out_buffer[..]);
74
75            let n = std::cmp::min(nw, into.len());
76            let t = &self.out_buffer[0..nw];
77            let (t1, t2) = t.split_at(n);
78
79            // copy what we have into `into`
80            into[0..n].copy_from_slice(t1);
81            // store the rest
82            self.out.copy_from_slice(t2);
83
84            (consumed, n)
85        } else {
86            try_decode_engine_slice(&self.inner.buffer()[..nr], into)
87        };
88
89        self.inner.consume(consumed);
90
91        Ok(written)
92    }
93}
94
95/// Tries to decode as much of the given slice as possible.
96/// Returns the amount written and consumed.
97fn try_decode_engine_slice<T: ?Sized + AsRef<[u8]>>(
98    input: &T,
99    output: &mut [u8],
100) -> (usize, usize) {
101    let input_bytes = input.as_ref();
102    let mut n = input_bytes.len();
103    while n > 0 {
104        match ENGINE.decode_slice(&input_bytes[..n], output) {
105            Ok(size) => {
106                return (n, size);
107            }
108            Err(_) => {
109                if n % 4 != 0 {
110                    n -= n % 4
111                } else {
112                    n -= 4
113                }
114            }
115        }
116    }
117
118    (0, 0)
119}
120
121// why, why why????
122fn copy_err(err: &io::Error) -> io::Error {
123    io::Error::new(err.kind(), err.to_string())
124}
125
126#[cfg(test)]
127mod tests {
128    #![allow(clippy::unwrap_used)]
129
130    use super::*;
131
132    use rand::{Rng, SeedableRng};
133    use rand_xorshift::XorShiftRng;
134
135    use crate::base64_reader::Base64Reader;
136
137    fn test_roundtrip(cap: usize, n: usize, insert_lines: bool) {
138        let rng = &mut XorShiftRng::from_seed([
139            0x3, 0x8, 0x3, 0xe, 0x3, 0x8, 0x3, 0xe, 0x3, 0x8, 0x3, 0xe, 0x3, 0x8, 0x3, 0xe,
140        ]);
141
142        for i in 0..n {
143            let data: Vec<u8> = (0..i).map(|_| rng.gen()).collect();
144            let mut encoded_data = ENGINE.encode(&data);
145
146            if insert_lines {
147                for j in 0..i {
148                    // insert line break with a 1/10 chance
149                    if rng.gen_ratio(1, 10) {
150                        if j >= encoded_data.len() {
151                            encoded_data.push('\n');
152                        } else {
153                            encoded_data.insert(j, '\n');
154                        }
155                    }
156                }
157                println!("testing: \n{}", encoded_data);
158                let mut r = Base64Decoder::new(Base64Reader::new(
159                    std::io::BufReader::with_capacity(cap, encoded_data.as_bytes()),
160                ));
161                let mut out = Vec::new();
162                r.read_to_end(&mut out).unwrap();
163                assert_eq!(data, out);
164            } else {
165                println!("testing: \n{}", encoded_data);
166                let mut r = Base64Decoder::new(std::io::BufReader::with_capacity(
167                    cap,
168                    encoded_data.as_bytes(),
169                ));
170                let mut out = Vec::new();
171                r.read_to_end(&mut out).unwrap();
172                assert_eq!(data, out);
173            }
174        }
175    }
176
177    #[test]
178    fn test_base64_decoder_roundtrip_standard_1000_no_newlines() {
179        test_roundtrip(1, 1000, false);
180        test_roundtrip(2, 1000, false);
181        test_roundtrip(8, 1000, false);
182        test_roundtrip(256, 1000, false);
183        test_roundtrip(1024, 1000, false);
184        test_roundtrip(8 * 1024, 1000, false);
185    }
186
187    #[test]
188    fn test_base64_decoder_roundtrip_standard_1000_newlines() {
189        test_roundtrip(1, 1000, true);
190        test_roundtrip(2, 1000, true);
191        test_roundtrip(8, 1000, true);
192        test_roundtrip(256, 1000, true);
193        test_roundtrip(1024, 1000, true);
194        test_roundtrip(8 * 1024, 1000, true);
195    }
196
197    #[test]
198    fn test_base64_decoder_with_base64_reader() {
199        let source = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
200
201        let data = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVyIGFkaXBpc2NpbmcgZWxpdCwgc2VkIGRvIGVpdXNtb2Qgd\n\
202                     GVtcG9yIGluY2lkaWR1bnQgdXQgbGFib3JlIGV0IGRvbG9yZSBtYWduYSBhbGlxdWEuIFV0IGVuaW0gYWQgbWluaW0\n\
203                     gdmVuaWFtLCBxdWlz\n\
204                     IG5vc3RydWQgZXhlcmNpdGF0aW9uIHVsbGFtY28gbGFib3JpcyBuaXNpIHV0IGFsaXF1aXAgZXggZW\n\
205                     EgY29tbW9kbyBjb25zZXF1YXQuIER1aXMgYXV0ZSBpcnVyZSBkb2\n\
206                     xvciBpbiByZXByZWhlbmRlcml0IGluIHZvbHVwdGF0ZSB2ZWxpdCBlc3NlIGNpbGx1bSBkb2xvcmUgZXUgZnVnaWF\n\
207                     0IG51bGxhIHBhcmlhdHVyLiBFeGNlcHRldXIgc2ludCBvY2NhZWNhdCBjdXBpZGF0YXQgbm9uIHByb2lkZW50LCBzdW50IGluIGN1bHBhIHF1aSBvZm\n\
208                     ZpY2lhIGRlc2VydW50IG1vbGxpdCBhbmltIGlkIGVzdCBsYWJvcnVtLg==";
209
210        let reader = Base64Reader::new(data.as_bytes());
211        let mut reader = Base64Decoder::new(reader);
212        let mut res = String::new();
213
214        reader.read_to_string(&mut res).unwrap();
215        assert_eq!(source, res);
216    }
217
218    #[test]
219    fn test_base64_decoder_with_end_base() {
220        let data = "TG9yZW0g\n=TG9y\n-----hello";
221
222        let br = Base64Reader::new(data.as_bytes());
223        let mut reader = Base64Decoder::new(br);
224        let mut res = vec![0u8; 32];
225
226        assert_eq!(reader.read(&mut res).unwrap(), 6);
227        assert_eq!(&res[0..6], b"Lorem ");
228        let (r, buffer) = reader.into_inner_with_buffer();
229        let mut r = r.into_inner();
230
231        assert_eq!(buffer.buf(), b"=TG9y");
232        let mut rest = Vec::new();
233        assert_eq!(r.read_to_end(&mut rest).unwrap(), 10);
234        assert_eq!(&rest, b"-----hello");
235    }
236
237    #[test]
238    fn test_base64_decoder_with_end_one_linebreak() {
239        let data = "TG9yZW0g\n=TG9y-----hello";
240
241        let br = Base64Reader::new(data.as_bytes());
242        let mut reader = Base64Decoder::new(br);
243        let mut res = vec![0u8; 32];
244
245        assert_eq!(reader.read(&mut res).unwrap(), 6);
246        assert_eq!(&res[0..6], b"Lorem ");
247        let (r, buffer) = reader.into_inner_with_buffer();
248        let mut r = r.into_inner();
249
250        assert_eq!(buffer.buf(), b"=TG9y");
251        let mut rest = Vec::new();
252        assert_eq!(r.read_to_end(&mut rest).unwrap(), 10);
253        assert_eq!(&rest, b"-----hello");
254    }
255
256    #[test]
257    fn test_base64_decoder_with_end_no_linebreak() {
258        let data = "TG9yZW0g=TG9y-----hello";
259
260        let br = Base64Reader::new(data.as_bytes());
261        let mut reader = Base64Decoder::new(br);
262        let mut res = vec![0u8; 32];
263
264        assert_eq!(reader.read(&mut res).unwrap(), 6);
265        assert_eq!(&res[0..6], b"Lorem ");
266        let (r, buffer) = reader.into_inner_with_buffer();
267        let mut r = r.into_inner();
268
269        assert_eq!(buffer.buf(), b"=TG9y");
270        let mut rest = Vec::new();
271        assert_eq!(r.read_to_end(&mut rest).unwrap(), 10);
272        assert_eq!(&rest, b"-----hello");
273    }
274}