base64_stream/
from_base64_writer.rs1use std::{
2 fmt,
3 io::{self, Write},
4};
5
6use base64::{
7 Engine,
8 engine::{GeneralPurpose, general_purpose::STANDARD},
9};
10
11pub struct FromBase64Writer<W: Write, const N: usize = 4096> {
13 inner: W,
14 buf: [u8; 4],
15 buf_length: usize,
16 temp: [u8; N],
17 engine: &'static GeneralPurpose,
18}
19
20impl<W: Write, const N: usize> fmt::Debug for FromBase64Writer<W, N> {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 f.debug_struct("FromBase64Writer")
23 .field("buf", &&self.buf[..self.buf_length])
24 .field("buf_length", &self.buf_length)
25 .finish_non_exhaustive()
26 }
27}
28
29impl<W: Write> FromBase64Writer<W> {
30 #[inline]
31 pub fn new(writer: W) -> FromBase64Writer<W> {
32 Self::new2(writer)
33 }
34}
35
36impl<W: Write, const N: usize> FromBase64Writer<W, N> {
37 #[inline]
38 pub fn new2(writer: W) -> FromBase64Writer<W, N> {
39 const { assert!(N >= 4, "buffer size N must be at least 4") };
40 FromBase64Writer {
41 inner: writer,
42 buf: [0; 4],
43 buf_length: 0,
44 temp: [0u8; N],
45 engine: &STANDARD,
46 }
47 }
48}
49
50impl<W: Write, const N: usize> FromBase64Writer<W, N> {
51 fn drain_block(&mut self) -> Result<(), io::Error> {
52 debug_assert!(self.buf_length > 0);
53
54 let decode_length = self
55 .engine
56 .decode_slice(&self.buf[..self.buf_length], &mut self.temp)
57 .map_err(io::Error::other)?;
58
59 self.inner.write_all(&self.temp[..decode_length])?;
60
61 self.buf_length = 0;
62
63 Ok(())
64 }
65
66 #[inline]
71 pub fn into_inner(self) -> W {
72 self.inner
73 }
74}
75
76impl<W: Write, const N: usize> Write for FromBase64Writer<W, N> {
77 fn write(&mut self, mut buf: &[u8]) -> Result<usize, io::Error> {
78 let original_buf_length = buf.len();
79
80 if self.buf_length == 0 {
81 while buf.len() >= 4 {
82 let max_available_buf_length = (buf.len() & !0b11).min((N / 3) << 2); let decode_length = self
85 .engine
86 .decode_slice(&buf[..max_available_buf_length], &mut self.temp)
87 .map_err(io::Error::other)?;
88
89 buf = &buf[max_available_buf_length..];
90
91 self.inner.write_all(&self.temp[..decode_length])?;
92 }
93
94 let buf_length = buf.len();
95
96 if buf_length > 0 {
97 self.buf[..buf_length].copy_from_slice(&buf[..buf_length]);
98
99 self.buf_length = buf_length;
100 }
101 } else {
102 debug_assert!(self.buf_length < 4);
103
104 let r = 4 - self.buf_length;
105
106 let buf_length = buf.len();
107
108 let drain_length = r.min(buf_length);
109
110 self.buf[self.buf_length..self.buf_length + drain_length]
111 .copy_from_slice(&buf[..drain_length]);
112
113 buf = &buf[drain_length..];
114
115 self.buf_length += drain_length;
116
117 if self.buf_length == 4 {
118 self.drain_block()?;
119
120 if buf_length > r {
121 self.write_all(buf)?;
122 }
123 }
124 }
125
126 Ok(original_buf_length)
127 }
128
129 #[inline]
130 fn flush(&mut self) -> Result<(), io::Error> {
131 if self.buf_length > 0 {
132 self.drain_block()?;
133 }
134
135 Ok(())
136 }
137}
138
139impl<W: Write> From<W> for FromBase64Writer<W> {
140 #[inline]
141 fn from(reader: W) -> Self {
142 FromBase64Writer::new(reader)
143 }
144}