async_compression/tokio/write/generic/
decoder.rs1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::{
8 codec::Decode,
9 tokio::write::{AsyncBufWrite, BufWriter},
10 util::PartialBuffer,
11};
12use futures_core::ready;
13use pin_project_lite::pin_project;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Debug)]
17enum State {
18 Decoding,
19 Finishing,
20 Done,
21}
22
23pin_project! {
24 #[derive(Debug)]
25 pub struct Decoder<W, D> {
26 #[pin]
27 writer: BufWriter<W>,
28 decoder: D,
29 state: State,
30 }
31}
32
33impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
34 pub fn new(writer: W, decoder: D) -> Self {
35 Self {
36 writer: BufWriter::new(writer),
37 decoder,
38 state: State::Decoding,
39 }
40 }
41}
42
43impl<W, D> Decoder<W, D> {
44 pub fn get_ref(&self) -> &W {
45 self.writer.get_ref()
46 }
47
48 pub fn get_mut(&mut self) -> &mut W {
49 self.writer.get_mut()
50 }
51
52 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
53 self.project().writer.get_pin_mut()
54 }
55
56 pub fn into_inner(self) -> W {
57 self.writer.into_inner()
58 }
59}
60
61impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
62 fn do_poll_write(
63 self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 input: &mut PartialBuffer<&[u8]>,
66 ) -> Poll<io::Result<()>> {
67 let mut this = self.project();
68
69 loop {
70 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
71 let mut output = PartialBuffer::new(output);
72
73 *this.state = match this.state {
74 State::Decoding => {
75 if this.decoder.decode(input, &mut output)? {
76 State::Finishing
77 } else {
78 State::Decoding
79 }
80 }
81
82 State::Finishing => {
83 if this.decoder.finish(&mut output)? {
84 State::Done
85 } else {
86 State::Finishing
87 }
88 }
89
90 State::Done => {
91 return Poll::Ready(Err(io::Error::other("Write after end of stream")))
92 }
93 };
94
95 let produced = output.written().len();
96 this.writer.as_mut().produce(produced);
97
98 if let State::Done = this.state {
99 return Poll::Ready(Ok(()));
100 }
101
102 if input.unwritten().is_empty() {
103 return Poll::Ready(Ok(()));
104 }
105 }
106 }
107
108 fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
109 let mut this = self.project();
110
111 loop {
112 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
113 let mut output = PartialBuffer::new(output);
114
115 let (state, done) = match this.state {
116 State::Decoding => {
117 let done = this.decoder.flush(&mut output)?;
118 (State::Decoding, done)
119 }
120
121 State::Finishing => {
122 if this.decoder.finish(&mut output)? {
123 (State::Done, false)
124 } else {
125 (State::Finishing, false)
126 }
127 }
128
129 State::Done => (State::Done, true),
130 };
131
132 *this.state = state;
133
134 let produced = output.written().len();
135 this.writer.as_mut().produce(produced);
136
137 if done {
138 return Poll::Ready(Ok(()));
139 }
140 }
141 }
142}
143
144impl<W: AsyncWrite, D: Decode> AsyncWrite for Decoder<W, D> {
145 fn poll_write(
146 self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 buf: &[u8],
149 ) -> Poll<io::Result<usize>> {
150 if buf.is_empty() {
151 return Poll::Ready(Ok(0));
152 }
153
154 let mut input = PartialBuffer::new(buf);
155
156 match self.do_poll_write(cx, &mut input)? {
157 Poll::Pending if input.written().is_empty() => Poll::Pending,
158 _ => Poll::Ready(Ok(input.written().len())),
159 }
160 }
161
162 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
163 ready!(self.as_mut().do_poll_flush(cx))?;
164 ready!(self.project().writer.as_mut().poll_flush(cx))?;
165 Poll::Ready(Ok(()))
166 }
167
168 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
169 if let State::Decoding = self.as_mut().project().state {
170 *self.as_mut().project().state = State::Finishing;
171 }
172
173 ready!(self.as_mut().do_poll_flush(cx))?;
174
175 if let State::Done = self.as_mut().project().state {
176 ready!(self.as_mut().project().writer.as_mut().poll_shutdown(cx))?;
177 Poll::Ready(Ok(()))
178 } else {
179 Poll::Ready(Err(io::Error::other(
180 "Attempt to shutdown before finishing input",
181 )))
182 }
183 }
184}
185
186impl<W: AsyncRead, D> AsyncRead for Decoder<W, D> {
187 fn poll_read(
188 self: Pin<&mut Self>,
189 cx: &mut Context<'_>,
190 buf: &mut ReadBuf<'_>,
191 ) -> Poll<io::Result<()>> {
192 self.get_pin_mut().poll_read(cx, buf)
193 }
194}
195
196impl<W: AsyncBufRead, D> AsyncBufRead for Decoder<W, D> {
197 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
198 self.get_pin_mut().poll_fill_buf(cx)
199 }
200
201 fn consume(self: Pin<&mut Self>, amt: usize) {
202 self.get_pin_mut().consume(amt)
203 }
204}