memory_socket/
async.rs

1use crate::{MemoryListener, MemorySocket};
2use bytes::{buf::BufExt, Buf};
3use futures::{
4    io::{AsyncRead, AsyncWrite},
5    ready,
6    stream::{FusedStream, Stream},
7};
8use std::{
9    io::{ErrorKind, Result},
10    pin::Pin,
11    task::{Context, Poll},
12};
13
14impl MemoryListener {
15    /// Returns a stream over the connections being received on this
16    /// listener.
17    ///
18    /// The returned stream will never return `None`.
19    ///
20    /// # Examples
21    ///
22    /// ```no_run
23    /// use futures::prelude::*;
24    /// use memory_socket::MemoryListener;
25    ///
26    /// # async fn work () -> ::std::io::Result<()> {
27    /// let mut listener = MemoryListener::bind(80).unwrap();
28    /// let mut incoming = listener.incoming_stream();
29    ///
30    /// while let Some(stream) = incoming.next().await {
31    ///     match stream {
32    ///         Ok(stream) => {
33    ///             println!("new client!");
34    ///         }
35    ///         Err(e) => { /* connection failed */ }
36    ///     }
37    /// }
38    /// # Ok(())}
39    /// ```
40    pub fn incoming_stream(&mut self) -> IncomingStream<'_> {
41        IncomingStream { inner: self }
42    }
43
44    fn poll_accept(&mut self, context: &mut Context) -> Poll<Result<MemorySocket>> {
45        match Pin::new(&mut self.incoming).poll_next(context) {
46            Poll::Ready(Some(socket)) => Poll::Ready(Ok(socket)),
47            // The stream will never terminate
48            Poll::Ready(None) => unreachable!(),
49            Poll::Pending => Poll::Pending,
50        }
51    }
52}
53
54/// A Stream that infinitely accepts connections on a [`MemoryListener`].
55///
56/// This `struct` is created by the [`incoming_stream`] method on [`MemoryListener`].
57/// See its documentation for more info.
58///
59/// [`incoming_stream`]: struct.MemoryListener.html#method.incoming_stream
60/// [`MemoryListener`]: struct.MemoryListener.html
61pub struct IncomingStream<'a> {
62    inner: &'a mut MemoryListener,
63}
64
65impl<'a> Stream for IncomingStream<'a> {
66    type Item = Result<MemorySocket>;
67
68    fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Option<Self::Item>> {
69        let socket = ready!(self.inner.poll_accept(context)?);
70        Poll::Ready(Some(Ok(socket)))
71    }
72}
73
74impl AsyncRead for MemorySocket {
75    fn poll_read(
76        mut self: Pin<&mut Self>,
77        mut context: &mut Context,
78        buf: &mut [u8],
79    ) -> Poll<Result<usize>> {
80        if self.incoming.is_terminated() {
81            if self.seen_eof {
82                return Poll::Ready(Err(ErrorKind::UnexpectedEof.into()));
83            } else {
84                self.seen_eof = true;
85                return Poll::Ready(Ok(0));
86            }
87        }
88
89        let mut bytes_read = 0;
90
91        loop {
92            // If we've already filled up the buffer then we can return
93            if bytes_read == buf.len() {
94                return Poll::Ready(Ok(bytes_read));
95            }
96
97            match self.current_buffer {
98                // We still have data to copy to `buf`
99                Some(ref mut current_buffer) if current_buffer.has_remaining() => {
100                    let bytes_to_read =
101                        ::std::cmp::min(buf.len() - bytes_read, current_buffer.remaining());
102                    debug_assert!(bytes_to_read > 0);
103
104                    current_buffer
105                        .take(bytes_to_read)
106                        .copy_to_slice(&mut buf[bytes_read..(bytes_read + bytes_to_read)]);
107                    bytes_read += bytes_to_read;
108                }
109
110                // Either we've exhausted our current buffer or we don't have one
111                _ => {
112                    // If we've read anything up to this point return the bytes read
113                    if bytes_read > 0 {
114                        return Poll::Ready(Ok(bytes_read));
115                    }
116
117                    self.current_buffer = {
118                        match Pin::new(&mut self.incoming).poll_next(&mut context) {
119                            Poll::Pending => return Poll::Pending,
120                            Poll::Ready(Some(buf)) => Some(buf),
121                            Poll::Ready(None) => return Poll::Ready(Ok(bytes_read)),
122                        }
123                    };
124                }
125            }
126        }
127    }
128}
129
130impl AsyncWrite for MemorySocket {
131    fn poll_write(
132        mut self: Pin<&mut Self>,
133        _context: &mut Context,
134        buf: &[u8],
135    ) -> Poll<Result<usize>> {
136        self.write_buffer.extend_from_slice(buf);
137        Poll::Ready(Ok(buf.len()))
138    }
139
140    fn poll_flush(mut self: Pin<&mut Self>, _context: &mut Context) -> Poll<Result<()>> {
141        use flume::TrySendError;
142
143        if !self.write_buffer.is_empty() {
144            let buffer = self.write_buffer.split().freeze();
145            match self.outgoing.try_send(buffer) {
146                Ok(()) => Poll::Ready(Ok(())),
147                Err(TrySendError::Disconnected(_)) => {
148                    Poll::Ready(Err(ErrorKind::BrokenPipe.into()))
149                }
150                Err(TrySendError::Full(_)) => unreachable!(),
151            }
152        } else {
153            Poll::Ready(Ok(()))
154        }
155    }
156
157    fn poll_close(self: Pin<&mut Self>, _context: &mut Context) -> Poll<Result<()>> {
158        Poll::Ready(Ok(()))
159    }
160}