base64_stream/
from_base64_writer.rs1use std::{
2 intrinsics::copy_nonoverlapping,
3 io::{self, ErrorKind, Write},
4};
5
6use base64::{
7 engine::{general_purpose::STANDARD, GeneralPurpose},
8 Engine,
9};
10use generic_array::{
11 typenum::{IsGreaterOrEqual, True, U4, U4096},
12 ArrayLength, GenericArray,
13};
14
15#[derive(Educe)]
17#[educe(Debug)]
18pub struct FromBase64Writer<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True> = U4096>
19{
20 #[educe(Debug(ignore))]
21 inner: W,
22 buf: [u8; 4],
23 buf_length: usize,
24 temp: GenericArray<u8, N>,
25 #[educe(Debug(ignore))]
26 engine: &'static GeneralPurpose,
27}
28
29impl<W: Write> FromBase64Writer<W> {
30 #[inline]
31 pub fn new(writer: W) -> FromBase64Writer<W> {
32 Self::new2(writer)
33 }
34}
35
36impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> FromBase64Writer<W, N> {
37 #[inline]
38 pub fn new2(writer: W) -> FromBase64Writer<W, N> {
39 FromBase64Writer {
40 inner: writer,
41 buf: [0; 4],
42 buf_length: 0,
43 temp: GenericArray::default(),
44 engine: &STANDARD,
45 }
46 }
47}
48
49impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> FromBase64Writer<W, N> {
50 fn drain_block(&mut self) -> Result<(), io::Error> {
51 debug_assert!(self.buf_length > 0);
52
53 let decode_length = self
54 .engine
55 .decode_slice(&self.buf[..self.buf_length], &mut self.temp)
56 .map_err(|err| io::Error::new(ErrorKind::Other, err))?;
57
58 self.inner.write_all(&self.temp[..decode_length])?;
59
60 self.buf_length = 0;
61
62 Ok(())
63 }
64}
65
66impl<W: Write, N: ArrayLength + IsGreaterOrEqual<U4, Output = True>> Write
67 for FromBase64Writer<W, N>
68{
69 fn write(&mut self, mut buf: &[u8]) -> Result<usize, io::Error> {
70 let original_buf_length = buf.len();
71
72 if self.buf_length == 0 {
73 while buf.len() >= 4 {
74 let max_available_buf_length = (buf.len() & !0b11).min((N::USIZE / 3) << 2); let decode_length = self
77 .engine
78 .decode_slice(&buf[..max_available_buf_length], &mut self.temp)
79 .map_err(|err| io::Error::new(ErrorKind::Other, err))?;
80
81 buf = &buf[max_available_buf_length..];
82
83 self.inner.write_all(&self.temp[..decode_length])?;
84 }
85
86 let buf_length = buf.len();
87
88 if buf_length > 0 {
89 unsafe {
90 copy_nonoverlapping(buf.as_ptr(), self.buf.as_mut_ptr(), buf_length);
91 }
92
93 self.buf_length = buf_length;
94 }
95 } else {
96 debug_assert!(self.buf_length < 4);
97
98 let r = 4 - self.buf_length;
99
100 let buf_length = buf.len();
101
102 let drain_length = r.min(buf_length);
103
104 unsafe {
105 copy_nonoverlapping(
106 buf.as_ptr(),
107 self.buf.as_mut_ptr().add(self.buf_length),
108 drain_length,
109 );
110 }
111
112 buf = &buf[drain_length..];
113
114 self.buf_length += drain_length;
115
116 if self.buf_length == 4 {
117 self.drain_block()?;
118
119 if buf_length > r {
120 self.write_all(buf)?;
121 }
122 }
123 }
124
125 Ok(original_buf_length)
126 }
127
128 #[inline]
129 fn flush(&mut self) -> Result<(), io::Error> {
130 if self.buf_length > 0 {
131 self.drain_block()?;
132 }
133
134 Ok(())
135 }
136}
137
138impl<W: Write> From<W> for FromBase64Writer<W> {
139 #[inline]
140 fn from(reader: W) -> Self {
141 FromBase64Writer::new(reader)
142 }
143}