async_compression/tokio/write/generic/
decoder.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::{
8    codec::Decode,
9    tokio::write::{AsyncBufWrite, BufWriter},
10    util::PartialBuffer,
11};
12use futures_core::ready;
13use pin_project_lite::pin_project;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Debug)]
17enum State {
18    Decoding,
19    Finishing,
20    Done,
21}
22
23pin_project! {
24    #[derive(Debug)]
25    pub struct Decoder<W, D> {
26        #[pin]
27        writer: BufWriter<W>,
28        decoder: D,
29        state: State,
30    }
31}
32
33impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
34    pub fn new(writer: W, decoder: D) -> Self {
35        Self {
36            writer: BufWriter::new(writer),
37            decoder,
38            state: State::Decoding,
39        }
40    }
41}
42
43impl<W, D> Decoder<W, D> {
44    pub fn get_ref(&self) -> &W {
45        self.writer.get_ref()
46    }
47
48    pub fn get_mut(&mut self) -> &mut W {
49        self.writer.get_mut()
50    }
51
52    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
53        self.project().writer.get_pin_mut()
54    }
55
56    pub fn into_inner(self) -> W {
57        self.writer.into_inner()
58    }
59}
60
61impl<W: AsyncWrite, D: Decode> Decoder<W, D> {
62    fn do_poll_write(
63        self: Pin<&mut Self>,
64        cx: &mut Context<'_>,
65        input: &mut PartialBuffer<&[u8]>,
66    ) -> Poll<io::Result<()>> {
67        let mut this = self.project();
68
69        loop {
70            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
71            let mut output = PartialBuffer::new(output);
72
73            *this.state = match this.state {
74                State::Decoding => {
75                    if this.decoder.decode(input, &mut output)? {
76                        State::Finishing
77                    } else {
78                        State::Decoding
79                    }
80                }
81
82                State::Finishing => {
83                    if this.decoder.finish(&mut output)? {
84                        State::Done
85                    } else {
86                        State::Finishing
87                    }
88                }
89
90                State::Done => {
91                    return Poll::Ready(Err(io::Error::other("Write after end of stream")))
92                }
93            };
94
95            let produced = output.written().len();
96            this.writer.as_mut().produce(produced);
97
98            if let State::Done = this.state {
99                return Poll::Ready(Ok(()));
100            }
101
102            if input.unwritten().is_empty() {
103                return Poll::Ready(Ok(()));
104            }
105        }
106    }
107
108    fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
109        let mut this = self.project();
110
111        loop {
112            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
113            let mut output = PartialBuffer::new(output);
114
115            let (state, done) = match this.state {
116                State::Decoding => {
117                    let done = this.decoder.flush(&mut output)?;
118                    (State::Decoding, done)
119                }
120
121                State::Finishing => {
122                    if this.decoder.finish(&mut output)? {
123                        (State::Done, false)
124                    } else {
125                        (State::Finishing, false)
126                    }
127                }
128
129                State::Done => (State::Done, true),
130            };
131
132            *this.state = state;
133
134            let produced = output.written().len();
135            this.writer.as_mut().produce(produced);
136
137            if done {
138                return Poll::Ready(Ok(()));
139            }
140        }
141    }
142}
143
144impl<W: AsyncWrite, D: Decode> AsyncWrite for Decoder<W, D> {
145    fn poll_write(
146        self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148        buf: &[u8],
149    ) -> Poll<io::Result<usize>> {
150        if buf.is_empty() {
151            return Poll::Ready(Ok(0));
152        }
153
154        let mut input = PartialBuffer::new(buf);
155
156        match self.do_poll_write(cx, &mut input)? {
157            Poll::Pending if input.written().is_empty() => Poll::Pending,
158            _ => Poll::Ready(Ok(input.written().len())),
159        }
160    }
161
162    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
163        ready!(self.as_mut().do_poll_flush(cx))?;
164        ready!(self.project().writer.as_mut().poll_flush(cx))?;
165        Poll::Ready(Ok(()))
166    }
167
168    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
169        if let State::Decoding = self.as_mut().project().state {
170            *self.as_mut().project().state = State::Finishing;
171        }
172
173        ready!(self.as_mut().do_poll_flush(cx))?;
174
175        if let State::Done = self.as_mut().project().state {
176            ready!(self.as_mut().project().writer.as_mut().poll_shutdown(cx))?;
177            Poll::Ready(Ok(()))
178        } else {
179            Poll::Ready(Err(io::Error::other(
180                "Attempt to shutdown before finishing input",
181            )))
182        }
183    }
184}
185
186impl<W: AsyncRead, D> AsyncRead for Decoder<W, D> {
187    fn poll_read(
188        self: Pin<&mut Self>,
189        cx: &mut Context<'_>,
190        buf: &mut ReadBuf<'_>,
191    ) -> Poll<io::Result<()>> {
192        self.get_pin_mut().poll_read(cx, buf)
193    }
194}
195
196impl<W: AsyncBufRead, D> AsyncBufRead for Decoder<W, D> {
197    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
198        self.get_pin_mut().poll_fill_buf(cx)
199    }
200
201    fn consume(self: Pin<&mut Self>, amt: usize) {
202        self.get_pin_mut().consume(amt)
203    }
204}