base64_stream/
to_base64_reader.rs

1use std::{
2    intrinsics::{copy, copy_nonoverlapping},
3    io::{self, ErrorKind, Read},
4};
5
6use base64::{
7    engine::{general_purpose::STANDARD, GeneralPurpose},
8    Engine,
9};
10use generic_array::{
11    typenum::{IsGreaterOrEqual, True, U4, U4096},
12    ArrayLength, GenericArray,
13};
14
15/// Read any data and encode them to base64 data.
16#[derive(Educe)]
17#[educe(Debug)]
18pub struct ToBase64Reader<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True> = U4096> {
19    #[educe(Debug(ignore))]
20    inner:       R,
21    buf:         GenericArray<u8, N>,
22    buf_length:  usize,
23    buf_offset:  usize,
24    temp:        [u8; 3],
25    temp_length: usize,
26    #[educe(Debug(ignore))]
27    engine:      &'static GeneralPurpose,
28}
29
30impl<R: Read> ToBase64Reader<R> {
31    #[inline]
32    pub fn new(reader: R) -> ToBase64Reader<R> {
33        Self::new2(reader)
34    }
35}
36
37impl<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> ToBase64Reader<R, N> {
38    #[inline]
39    pub fn new2(reader: R) -> ToBase64Reader<R, N> {
40        ToBase64Reader {
41            inner:       reader,
42            buf:         GenericArray::default(),
43            buf_length:  0,
44            buf_offset:  0,
45            temp:        [0; 3],
46            temp_length: 0,
47            engine:      &STANDARD,
48        }
49    }
50}
51
52impl<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> ToBase64Reader<R, N> {
53    fn buf_left_shift(&mut self, distance: usize) {
54        debug_assert!(self.buf_length >= distance);
55
56        self.buf_offset += distance;
57
58        if self.buf_offset >= N::USIZE - 4 {
59            unsafe {
60                copy(
61                    self.buf.as_ptr().add(self.buf_offset),
62                    self.buf.as_mut_ptr(),
63                    self.buf_length,
64                );
65            }
66
67            self.buf_offset = 0;
68        }
69
70        self.buf_length -= distance;
71    }
72
73    #[inline]
74    fn drain_temp<'a>(&mut self, buf: &'a mut [u8]) -> &'a mut [u8] {
75        debug_assert!(self.temp_length > 0);
76        debug_assert!(!buf.is_empty());
77
78        let drain_length = buf.len().min(self.temp_length);
79
80        unsafe {
81            copy_nonoverlapping(self.temp.as_ptr(), buf.as_mut_ptr(), drain_length);
82        }
83
84        self.temp_length -= drain_length;
85
86        unsafe {
87            copy(
88                self.temp.as_ptr().add(self.temp_length),
89                self.temp.as_mut_ptr(),
90                self.temp_length,
91            );
92        }
93
94        &mut buf[drain_length..]
95    }
96
97    #[inline]
98    fn drain_block<'a>(&mut self, mut buf: &'a mut [u8]) -> &'a mut [u8] {
99        debug_assert!(self.buf_length > 0);
100        debug_assert!(self.temp_length == 0);
101        debug_assert!(!buf.is_empty());
102
103        let drain_length = self.buf_length.min(3);
104
105        let mut b = [0; 4];
106
107        let encode_length = self
108            .engine
109            .encode_slice(&self.buf[self.buf_offset..(self.buf_offset + drain_length)], &mut b)
110            .unwrap();
111
112        self.buf_left_shift(drain_length);
113
114        let buf_length = buf.len();
115
116        if buf_length >= encode_length {
117            unsafe {
118                copy_nonoverlapping(b.as_ptr(), buf.as_mut_ptr(), encode_length);
119            }
120
121            buf = &mut buf[encode_length..];
122        } else {
123            unsafe {
124                copy_nonoverlapping(b.as_ptr(), buf.as_mut_ptr(), buf_length);
125            }
126
127            buf = &mut buf[buf_length..];
128
129            self.temp_length = encode_length - buf_length;
130
131            unsafe {
132                copy_nonoverlapping(
133                    b.as_ptr().add(buf_length),
134                    self.temp.as_mut_ptr(),
135                    self.temp_length,
136                );
137            }
138        }
139
140        buf
141    }
142
143    fn drain<'a>(&mut self, mut buf: &'a mut [u8]) -> &'a mut [u8] {
144        if buf.is_empty() {
145            return buf;
146        }
147
148        if self.temp_length > 0 {
149            buf = self.drain_temp(buf);
150        }
151
152        debug_assert!(self.buf_length >= 3);
153
154        let buf_length = buf.len();
155
156        if buf_length >= 4 {
157            debug_assert!(self.temp_length == 0);
158
159            let actual_max_read_size = (buf_length >> 2) * 3; // (buf_length / 4) * 3
160            let max_available_self_buf_length = self.buf_length - (self.buf_length % 3);
161
162            let drain_length = max_available_self_buf_length.min(actual_max_read_size);
163
164            let encode_length = self
165                .engine
166                .encode_slice(&self.buf[self.buf_offset..(self.buf_offset + drain_length)], buf)
167                .unwrap();
168
169            buf = &mut buf[encode_length..];
170
171            self.buf_left_shift(drain_length);
172        }
173
174        if !buf.is_empty() && self.buf_length >= 3 {
175            self.drain_block(buf)
176        } else {
177            buf
178        }
179    }
180
181    #[inline]
182    fn drain_end<'a>(&mut self, mut buf: &'a mut [u8]) -> &'a mut [u8] {
183        if buf.is_empty() {
184            return buf;
185        }
186
187        if self.temp_length > 0 {
188            buf = self.drain_temp(buf);
189        }
190
191        if !buf.is_empty() && self.buf_length > 0 {
192            self.drain_block(buf)
193        } else {
194            buf
195        }
196    }
197}
198
199impl<R: Read, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> Read for ToBase64Reader<R, N> {
200    fn read(&mut self, mut buf: &mut [u8]) -> Result<usize, io::Error> {
201        let original_buf_length = buf.len();
202
203        while self.buf_length < 3 {
204            match self.inner.read(&mut self.buf[(self.buf_offset + self.buf_length)..]) {
205                Ok(0) => {
206                    buf = self.drain_end(buf);
207
208                    return Ok(original_buf_length - buf.len());
209                },
210                Ok(c) => self.buf_length += c,
211                Err(ref e) if e.kind() == ErrorKind::Interrupted => {},
212                Err(e) => return Err(e),
213            }
214        }
215
216        buf = self.drain(buf);
217
218        Ok(original_buf_length - buf.len())
219    }
220}
221
222impl<R: Read> From<R> for ToBase64Reader<R> {
223    #[inline]
224    fn from(reader: R) -> Self {
225        ToBase64Reader::new(reader)
226    }
227}