base64_stream/
to_base64_writer.rs

1use std::{
2    intrinsics::copy_nonoverlapping,
3    io::{self, Write},
4};
5
6use base64::{
7    engine::{general_purpose::STANDARD, GeneralPurpose},
8    Engine,
9};
10use generic_array::{
11    typenum::{IsGreaterOrEqual, True, U4, U4096},
12    ArrayLength, GenericArray,
13};
14
15/// Write base64 data and encode them to plain data.
16#[derive(Educe)]
17#[educe(Debug)]
18pub struct ToBase64Writer<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True> = U4096> {
19    #[educe(Debug(ignore))]
20    inner:      W,
21    buf:        [u8; 3],
22    buf_length: usize,
23    temp:       GenericArray<u8, N>,
24    #[educe(Debug(ignore))]
25    engine:     &'static GeneralPurpose,
26}
27
28impl<W: Write> ToBase64Writer<W> {
29    #[inline]
30    pub fn new(writer: W) -> ToBase64Writer<W> {
31        Self::new2(writer)
32    }
33}
34
35impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> ToBase64Writer<W, N> {
36    #[inline]
37    pub fn new2(writer: W) -> ToBase64Writer<W, N> {
38        ToBase64Writer {
39            inner:      writer,
40            buf:        [0; 3],
41            buf_length: 0,
42            temp:       GenericArray::default(),
43            engine:     &STANDARD,
44        }
45    }
46}
47
48impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> ToBase64Writer<W, N> {
49    fn drain_block(&mut self) -> Result<(), io::Error> {
50        debug_assert!(self.buf_length > 0);
51
52        let encode_length =
53            self.engine.encode_slice(&self.buf[..self.buf_length], &mut self.temp).unwrap();
54
55        self.inner.write_all(&self.temp[..encode_length])?;
56
57        self.buf_length = 0;
58
59        Ok(())
60    }
61}
62
63impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> Write
64    for ToBase64Writer<W, N>
65{
66    fn write(&mut self, mut buf: &[u8]) -> Result<usize, io::Error> {
67        let original_buf_length = buf.len();
68
69        if self.buf_length == 0 {
70            while buf.len() >= 3 {
71                let max_available_buf_length =
72                    (buf.len() - (buf.len() % 3)).min((N::USIZE >> 2) * 3); // (N::USIZE / 4) * 3
73
74                let encode_length = self
75                    .engine
76                    .encode_slice(&buf[..max_available_buf_length], &mut self.temp)
77                    .unwrap();
78
79                buf = &buf[max_available_buf_length..];
80
81                self.inner.write_all(&self.temp[..encode_length])?;
82            }
83
84            let buf_length = buf.len();
85
86            if buf_length > 0 {
87                unsafe {
88                    copy_nonoverlapping(buf.as_ptr(), self.buf.as_mut_ptr(), buf_length);
89                }
90
91                self.buf_length = buf_length;
92            }
93        } else {
94            debug_assert!(self.buf_length < 3);
95
96            let r = 3 - self.buf_length;
97
98            let buf_length = buf.len();
99
100            let drain_length = r.min(buf_length);
101
102            unsafe {
103                copy_nonoverlapping(
104                    buf.as_ptr(),
105                    self.buf.as_mut_ptr().add(self.buf_length),
106                    drain_length,
107                );
108            }
109
110            buf = &buf[drain_length..];
111
112            self.buf_length += drain_length;
113
114            if self.buf_length == 3 {
115                self.drain_block()?;
116
117                if buf_length > r {
118                    self.write_all(buf)?;
119                }
120            }
121        }
122
123        Ok(original_buf_length)
124    }
125
126    fn flush(&mut self) -> Result<(), io::Error> {
127        if self.buf_length > 0 {
128            self.drain_block()?;
129        }
130
131        Ok(())
132    }
133}
134
135impl<W: Write> From<W> for ToBase64Writer<W> {
136    #[inline]
137    fn from(reader: W) -> Self {
138        ToBase64Writer::new(reader)
139    }
140}