watermelon_net/connection/
streaming.rs

1use std::{
2    future::{self, Future},
3    io,
4    pin::{Pin, pin},
5    task::{Context, Poll},
6};
7
8use bytes::Buf;
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
10use watermelon_proto::proto::{
11    ClientOp, ServerOp, StreamDecoder, StreamEncoder, error::DecoderError,
12};
13
14#[derive(Debug)]
15pub struct StreamingConnection<S> {
16    socket: S,
17    encoder: StreamEncoder,
18    decoder: StreamDecoder,
19    may_flush: bool,
20}
21
22impl<S> StreamingConnection<S>
23where
24    S: AsyncRead + AsyncWrite + Unpin,
25{
26    #[must_use]
27    pub fn new(socket: S) -> Self {
28        Self {
29            socket,
30            encoder: StreamEncoder::new(),
31            decoder: StreamDecoder::new(),
32            may_flush: false,
33        }
34    }
35
36    pub fn poll_read_next(
37        &mut self,
38        cx: &mut Context<'_>,
39    ) -> Poll<Result<ServerOp, StreamingReadError>> {
40        loop {
41            match self.decoder.decode() {
42                Ok(Some(server_op)) => return Poll::Ready(Ok(server_op)),
43                Ok(None) => {}
44                Err(err) => return Poll::Ready(Err(StreamingReadError::Decoder(err))),
45            }
46
47            let read_buf_fut = pin!(self.socket.read_buf(self.decoder.read_buf()));
48            match read_buf_fut.poll(cx) {
49                Poll::Pending => return Poll::Pending,
50                Poll::Ready(Ok(1..)) => {}
51                Poll::Ready(Ok(0)) => {
52                    return Poll::Ready(Err(StreamingReadError::Io(
53                        io::ErrorKind::UnexpectedEof.into(),
54                    )));
55                }
56                Poll::Ready(Err(err)) => return Poll::Ready(Err(StreamingReadError::Io(err))),
57            }
58        }
59    }
60
61    /// Reads the next [`ServerOp`].
62    ///
63    /// # Errors
64    ///
65    /// It returns an error if the content cannot be decoded or if an I/O error occurs.
66    pub async fn read_next(&mut self) -> Result<ServerOp, StreamingReadError> {
67        future::poll_fn(|cx| self.poll_read_next(cx)).await
68    }
69
70    pub fn may_write(&self) -> bool {
71        self.encoder.has_remaining()
72    }
73
74    pub fn may_flush(&self) -> bool {
75        self.may_flush
76    }
77
78    pub fn may_enqueue_more_ops(&self) -> bool {
79        self.encoder.remaining() < 8_290_304
80    }
81
82    pub fn enqueue_write_op(&mut self, item: &ClientOp) {
83        self.encoder.enqueue_write_op(item);
84    }
85
86    pub fn poll_write_next(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
87        let remaining = self.encoder.remaining();
88        if remaining == 0 {
89            return Poll::Ready(Ok(0));
90        }
91
92        let chunk = self.encoder.chunk();
93        let write_outcome = if chunk.len() < remaining && self.socket.is_write_vectored() {
94            let mut bufs = [io::IoSlice::new(&[]); 64];
95            let n = self.encoder.chunks_vectored(&mut bufs);
96            debug_assert!(
97                n >= 2,
98                "perf: chunks_vectored yielded less than 2 chunks despite the apparently fragmented internal encoder representation"
99            );
100
101            Pin::new(&mut self.socket).poll_write_vectored(cx, &bufs[..n])
102        } else {
103            debug_assert!(
104                !chunk.is_empty(),
105                "perf: chunk shouldn't be empty given that `remaining > 0`"
106            );
107            Pin::new(&mut self.socket).poll_write(cx, chunk)
108        };
109
110        match write_outcome {
111            Poll::Pending => {
112                self.may_flush = false;
113                Poll::Pending
114            }
115            Poll::Ready(Ok(n)) => {
116                self.encoder.advance(n);
117                self.may_flush = true;
118                Poll::Ready(Ok(n))
119            }
120            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
121        }
122    }
123
124    /// Writes the next chunk of data to the socket.
125    ///
126    /// It returns the number of bytes that have been written.
127    ///
128    /// # Errors
129    ///
130    /// An I/O error is returned if it is not possible to write to the socket.
131    pub async fn write_next(&mut self) -> io::Result<usize> {
132        future::poll_fn(|cx| self.poll_write_next(cx)).await
133    }
134
135    pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136        match Pin::new(&mut self.socket).poll_flush(cx) {
137            Poll::Pending => Poll::Pending,
138            Poll::Ready(Ok(())) => {
139                self.may_flush = false;
140                Poll::Ready(Ok(()))
141            }
142            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
143        }
144    }
145
146    /// Flush any buffered writes to the connection
147    ///
148    /// # Errors
149    ///
150    /// Returns an error if flushing fails
151    pub async fn flush(&mut self) -> io::Result<()> {
152        future::poll_fn(|cx| self.poll_flush(cx)).await
153    }
154
155    /// Shutdown the connection
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if shutting down the connection fails.
160    /// Implementations usually ignore this error.
161    pub async fn shutdown(&mut self) -> io::Result<()> {
162        future::poll_fn(|cx| Pin::new(&mut self.socket).poll_shutdown(cx)).await
163    }
164
165    pub fn socket(&self) -> &S {
166        &self.socket
167    }
168
169    pub fn socket_mut(&mut self) -> &mut S {
170        &mut self.socket
171    }
172
173    pub fn replace_socket<F, S2>(self, replacer: F) -> StreamingConnection<S2>
174    where
175        F: FnOnce(S) -> S2,
176    {
177        StreamingConnection {
178            socket: replacer(self.socket),
179            encoder: self.encoder,
180            decoder: self.decoder,
181            may_flush: self.may_flush,
182        }
183    }
184
185    pub fn into_inner(self) -> S {
186        self.socket
187    }
188}
189
190#[derive(Debug, thiserror::Error)]
191pub enum StreamingReadError {
192    #[error("decoder")]
193    Decoder(#[source] DecoderError),
194    #[error("io")]
195    Io(#[source] io::Error),
196}
197
198#[cfg(test)]
199mod tests {
200    use std::{
201        pin::Pin,
202        task::{Context, Poll},
203    };
204
205    use claims::assert_matches;
206    use futures_util::task;
207    use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
208    use watermelon_proto::proto::{ClientOp, ServerOp};
209
210    use super::StreamingConnection;
211
212    #[test]
213    fn ping_pong() {
214        let waker = task::noop_waker();
215        let mut cx = Context::from_waker(&waker);
216
217        let (socket, mut conn) = io::duplex(1024);
218
219        let mut client = StreamingConnection::new(socket);
220
221        // Initial state is ok
222        assert!(client.poll_read_next(&mut cx).is_pending());
223        assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(0)));
224
225        let mut buf = [0; 1024];
226        let mut read_buf = ReadBuf::new(&mut buf);
227        assert!(
228            Pin::new(&mut conn)
229                .poll_read(&mut cx, &mut read_buf)
230                .is_pending()
231        );
232
233        // Write PING and verify it was received
234        client.enqueue_write_op(&ClientOp::Ping);
235        assert_matches!(client.poll_write_next(&mut cx), Poll::Ready(Ok(6)));
236        assert_matches!(
237            Pin::new(&mut conn).poll_read(&mut cx, &mut read_buf),
238            Poll::Ready(Ok(()))
239        );
240        assert_eq!(read_buf.filled(), b"PING\r\n");
241
242        // Receive PONG
243        assert_matches!(
244            Pin::new(&mut conn).poll_write(&mut cx, b"PONG\r\n"),
245            Poll::Ready(Ok(6))
246        );
247        assert_matches!(
248            client.poll_read_next(&mut cx),
249            Poll::Ready(Ok(ServerOp::Pong))
250        );
251        assert!(client.poll_read_next(&mut cx).is_pending());
252    }
253}