Skip to main content

base64_stream/
to_base64_writer.rs

1use std::{
2    fmt,
3    io::{self, Write},
4};
5
6use base64::{
7    Engine,
8    engine::{GeneralPurpose, general_purpose::STANDARD},
9};
10
11/// Write base64 data and encode them to plain data.
12pub struct ToBase64Writer<W: Write, const N: usize = 4096> {
13    inner:      W,
14    buf:        [u8; 3],
15    buf_length: usize,
16    temp:       [u8; N],
17    engine:     &'static GeneralPurpose,
18}
19
20impl<W: Write, const N: usize> fmt::Debug for ToBase64Writer<W, N> {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        f.debug_struct("ToBase64Writer")
23            .field("buf", &&self.buf[..self.buf_length])
24            .field("buf_length", &self.buf_length)
25            .finish_non_exhaustive()
26    }
27}
28
29impl<W: Write> ToBase64Writer<W> {
30    #[inline]
31    pub fn new(writer: W) -> ToBase64Writer<W> {
32        Self::new2(writer)
33    }
34}
35
36impl<W: Write, const N: usize> ToBase64Writer<W, N> {
37    #[inline]
38    pub fn new2(writer: W) -> ToBase64Writer<W, N> {
39        const { assert!(N >= 4, "buffer size N must be at least 4") };
40        ToBase64Writer {
41            inner:      writer,
42            buf:        [0; 3],
43            buf_length: 0,
44            temp:       [0u8; N],
45            engine:     &STANDARD,
46        }
47    }
48}
49
50impl<W: Write, const N: usize> ToBase64Writer<W, N> {
51    fn drain_block(&mut self) -> Result<(), io::Error> {
52        debug_assert!(self.buf_length > 0);
53
54        let encode_length =
55            self.engine.encode_slice(&self.buf[..self.buf_length], &mut self.temp).unwrap();
56
57        self.inner.write_all(&self.temp[..encode_length])?;
58
59        self.buf_length = 0;
60
61        Ok(())
62    }
63
64    /// Returns the inner writer, consuming this wrapper.
65    ///
66    /// Call [`flush`](std::io::Write::flush) before this method to ensure all
67    /// buffered data is written.
68    #[inline]
69    pub fn into_inner(self) -> W {
70        self.inner
71    }
72}
73
74impl<W: Write, const N: usize> Write for ToBase64Writer<W, N> {
75    fn write(&mut self, mut buf: &[u8]) -> Result<usize, io::Error> {
76        let original_buf_length = buf.len();
77
78        if self.buf_length == 0 {
79            while buf.len() >= 3 {
80                let max_available_buf_length = (buf.len() - (buf.len() % 3)).min((N >> 2) * 3); // (N / 4) * 3
81
82                let encode_length = self
83                    .engine
84                    .encode_slice(&buf[..max_available_buf_length], &mut self.temp)
85                    .unwrap();
86
87                buf = &buf[max_available_buf_length..];
88
89                self.inner.write_all(&self.temp[..encode_length])?;
90            }
91
92            let buf_length = buf.len();
93
94            if buf_length > 0 {
95                self.buf[..buf_length].copy_from_slice(&buf[..buf_length]);
96
97                self.buf_length = buf_length;
98            }
99        } else {
100            debug_assert!(self.buf_length < 3);
101
102            let r = 3 - self.buf_length;
103
104            let buf_length = buf.len();
105
106            let drain_length = r.min(buf_length);
107
108            self.buf[self.buf_length..self.buf_length + drain_length]
109                .copy_from_slice(&buf[..drain_length]);
110
111            buf = &buf[drain_length..];
112
113            self.buf_length += drain_length;
114
115            if self.buf_length == 3 {
116                self.drain_block()?;
117
118                if buf_length > r {
119                    self.write_all(buf)?;
120                }
121            }
122        }
123
124        Ok(original_buf_length)
125    }
126
127    fn flush(&mut self) -> Result<(), io::Error> {
128        if self.buf_length > 0 {
129            self.drain_block()?;
130        }
131
132        Ok(())
133    }
134}
135
136impl<W: Write> From<W> for ToBase64Writer<W> {
137    #[inline]
138    fn from(reader: W) -> Self {
139        ToBase64Writer::new(reader)
140    }
141}