Skip to main content

base64_stream/
from_base64_writer.rs

1use std::{
2    fmt,
3    io::{self, Write},
4};
5
6use base64::{
7    Engine,
8    engine::{GeneralPurpose, general_purpose::STANDARD},
9};
10
11/// Write base64 data and decode them to plain data.
12pub struct FromBase64Writer<W: Write, const N: usize = 4096> {
13    inner:      W,
14    buf:        [u8; 4],
15    buf_length: usize,
16    temp:       [u8; N],
17    engine:     &'static GeneralPurpose,
18}
19
20impl<W: Write, const N: usize> fmt::Debug for FromBase64Writer<W, N> {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        f.debug_struct("FromBase64Writer")
23            .field("buf", &&self.buf[..self.buf_length])
24            .field("buf_length", &self.buf_length)
25            .finish_non_exhaustive()
26    }
27}
28
29impl<W: Write> FromBase64Writer<W> {
30    #[inline]
31    pub fn new(writer: W) -> FromBase64Writer<W> {
32        Self::new2(writer)
33    }
34}
35
36impl<W: Write, const N: usize> FromBase64Writer<W, N> {
37    #[inline]
38    pub fn new2(writer: W) -> FromBase64Writer<W, N> {
39        const { assert!(N >= 4, "buffer size N must be at least 4") };
40        FromBase64Writer {
41            inner:      writer,
42            buf:        [0; 4],
43            buf_length: 0,
44            temp:       [0u8; N],
45            engine:     &STANDARD,
46        }
47    }
48}
49
50impl<W: Write, const N: usize> FromBase64Writer<W, N> {
51    fn drain_block(&mut self) -> Result<(), io::Error> {
52        debug_assert!(self.buf_length > 0);
53
54        let decode_length = self
55            .engine
56            .decode_slice(&self.buf[..self.buf_length], &mut self.temp)
57            .map_err(io::Error::other)?;
58
59        self.inner.write_all(&self.temp[..decode_length])?;
60
61        self.buf_length = 0;
62
63        Ok(())
64    }
65
66    /// Returns the inner writer, consuming this wrapper.
67    ///
68    /// Call [`flush`](std::io::Write::flush) before this method to ensure all
69    /// buffered data is written.
70    #[inline]
71    pub fn into_inner(self) -> W {
72        self.inner
73    }
74}
75
76impl<W: Write, const N: usize> Write for FromBase64Writer<W, N> {
77    fn write(&mut self, mut buf: &[u8]) -> Result<usize, io::Error> {
78        let original_buf_length = buf.len();
79
80        if self.buf_length == 0 {
81            while buf.len() >= 4 {
82                let max_available_buf_length = (buf.len() & !0b11).min((N / 3) << 2); // (N / 3) * 4
83
84                let decode_length = self
85                    .engine
86                    .decode_slice(&buf[..max_available_buf_length], &mut self.temp)
87                    .map_err(io::Error::other)?;
88
89                buf = &buf[max_available_buf_length..];
90
91                self.inner.write_all(&self.temp[..decode_length])?;
92            }
93
94            let buf_length = buf.len();
95
96            if buf_length > 0 {
97                self.buf[..buf_length].copy_from_slice(&buf[..buf_length]);
98
99                self.buf_length = buf_length;
100            }
101        } else {
102            debug_assert!(self.buf_length < 4);
103
104            let r = 4 - self.buf_length;
105
106            let buf_length = buf.len();
107
108            let drain_length = r.min(buf_length);
109
110            self.buf[self.buf_length..self.buf_length + drain_length]
111                .copy_from_slice(&buf[..drain_length]);
112
113            buf = &buf[drain_length..];
114
115            self.buf_length += drain_length;
116
117            if self.buf_length == 4 {
118                self.drain_block()?;
119
120                if buf_length > r {
121                    self.write_all(buf)?;
122                }
123            }
124        }
125
126        Ok(original_buf_length)
127    }
128
129    #[inline]
130    fn flush(&mut self) -> Result<(), io::Error> {
131        if self.buf_length > 0 {
132            self.drain_block()?;
133        }
134
135        Ok(())
136    }
137}
138
139impl<W: Write> From<W> for FromBase64Writer<W> {
140    #[inline]
141    fn from(reader: W) -> Self {
142        FromBase64Writer::new(reader)
143    }
144}