base64_stream/
from_base64_writer.rs

1use std::{
2    intrinsics::copy_nonoverlapping,
3    io::{self, ErrorKind, Write},
4};
5
6use base64::{
7    engine::{general_purpose::STANDARD, GeneralPurpose},
8    Engine,
9};
10use generic_array::{
11    typenum::{IsGreaterOrEqual, True, U4, U4096},
12    ArrayLength, GenericArray,
13};
14
15/// Write base64 data and decode them to plain data.
16#[derive(Educe)]
17#[educe(Debug)]
18pub struct FromBase64Writer<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True> = U4096>
19{
20    #[educe(Debug(ignore))]
21    inner:      W,
22    buf:        [u8; 4],
23    buf_length: usize,
24    temp:       GenericArray<u8, N>,
25    #[educe(Debug(ignore))]
26    engine:     &'static GeneralPurpose,
27}
28
29impl<W: Write> FromBase64Writer<W> {
30    #[inline]
31    pub fn new(writer: W) -> FromBase64Writer<W> {
32        Self::new2(writer)
33    }
34}
35
36impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> FromBase64Writer<W, N> {
37    #[inline]
38    pub fn new2(writer: W) -> FromBase64Writer<W, N> {
39        FromBase64Writer {
40            inner:      writer,
41            buf:        [0; 4],
42            buf_length: 0,
43            temp:       GenericArray::default(),
44            engine:     &STANDARD,
45        }
46    }
47}
48
49impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> FromBase64Writer<W, N> {
50    fn drain_block(&mut self) -> Result<(), io::Error> {
51        debug_assert!(self.buf_length > 0);
52
53        let decode_length = self
54            .engine
55            .decode_slice(&self.buf[..self.buf_length], &mut self.temp)
56            .map_err(|err| io::Error::new(ErrorKind::Other, err))?;
57
58        self.inner.write_all(&self.temp[..decode_length])?;
59
60        self.buf_length = 0;
61
62        Ok(())
63    }
64}
65
66impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> Write
67    for FromBase64Writer<W, N>
68{
69    fn write(&mut self, mut buf: &[u8]) -> Result<usize, io::Error> {
70        let original_buf_length = buf.len();
71
72        if self.buf_length == 0 {
73            while buf.len() >= 4 {
74                let max_available_buf_length = (buf.len() & !0b11).min((N::USIZE / 3) << 2); // (N::USIZE / 3) * 4
75
76                let decode_length = self
77                    .engine
78                    .decode_slice(&buf[..max_available_buf_length], &mut self.temp)
79                    .map_err(|err| io::Error::new(ErrorKind::Other, err))?;
80
81                buf = &buf[max_available_buf_length..];
82
83                self.inner.write_all(&self.temp[..decode_length])?;
84            }
85
86            let buf_length = buf.len();
87
88            if buf_length > 0 {
89                unsafe {
90                    copy_nonoverlapping(buf.as_ptr(), self.buf.as_mut_ptr(), buf_length);
91                }
92
93                self.buf_length = buf_length;
94            }
95        } else {
96            debug_assert!(self.buf_length < 4);
97
98            let r = 4 - self.buf_length;
99
100            let buf_length = buf.len();
101
102            let drain_length = r.min(buf_length);
103
104            unsafe {
105                copy_nonoverlapping(
106                    buf.as_ptr(),
107                    self.buf.as_mut_ptr().add(self.buf_length),
108                    drain_length,
109                );
110            }
111
112            buf = &buf[drain_length..];
113
114            self.buf_length += drain_length;
115
116            if self.buf_length == 4 {
117                self.drain_block()?;
118
119                if buf_length > r {
120                    self.write_all(buf)?;
121                }
122            }
123        }
124
125        Ok(original_buf_length)
126    }
127
128    #[inline]
129    fn flush(&mut self) -> Result<(), io::Error> {
130        if self.buf_length > 0 {
131            self.drain_block()?;
132        }
133
134        Ok(())
135    }
136}
137
138impl<W: Write> From<W> for FromBase64Writer<W> {
139    #[inline]
140    fn from(reader: W) -> Self {
141        FromBase64Writer::new(reader)
142    }
143}