1use crate::{MemoryListener, MemorySocket};
2use bytes::{buf::BufExt, Buf};
3use futures::{
4 io::{AsyncRead, AsyncWrite},
5 ready,
6 stream::{FusedStream, Stream},
7};
8use std::{
9 io::{ErrorKind, Result},
10 pin::Pin,
11 task::{Context, Poll},
12};
13
14impl MemoryListener {
15 pub fn incoming_stream(&mut self) -> IncomingStream<'_> {
41 IncomingStream { inner: self }
42 }
43
44 fn poll_accept(&mut self, context: &mut Context) -> Poll<Result<MemorySocket>> {
45 match Pin::new(&mut self.incoming).poll_next(context) {
46 Poll::Ready(Some(socket)) => Poll::Ready(Ok(socket)),
47 Poll::Ready(None) => unreachable!(),
49 Poll::Pending => Poll::Pending,
50 }
51 }
52}
53
54pub struct IncomingStream<'a> {
62 inner: &'a mut MemoryListener,
63}
64
65impl<'a> Stream for IncomingStream<'a> {
66 type Item = Result<MemorySocket>;
67
68 fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Option<Self::Item>> {
69 let socket = ready!(self.inner.poll_accept(context)?);
70 Poll::Ready(Some(Ok(socket)))
71 }
72}
73
74impl AsyncRead for MemorySocket {
75 fn poll_read(
76 mut self: Pin<&mut Self>,
77 mut context: &mut Context,
78 buf: &mut [u8],
79 ) -> Poll<Result<usize>> {
80 if self.incoming.is_terminated() {
81 if self.seen_eof {
82 return Poll::Ready(Err(ErrorKind::UnexpectedEof.into()));
83 } else {
84 self.seen_eof = true;
85 return Poll::Ready(Ok(0));
86 }
87 }
88
89 let mut bytes_read = 0;
90
91 loop {
92 if bytes_read == buf.len() {
94 return Poll::Ready(Ok(bytes_read));
95 }
96
97 match self.current_buffer {
98 Some(ref mut current_buffer) if current_buffer.has_remaining() => {
100 let bytes_to_read =
101 ::std::cmp::min(buf.len() - bytes_read, current_buffer.remaining());
102 debug_assert!(bytes_to_read > 0);
103
104 current_buffer
105 .take(bytes_to_read)
106 .copy_to_slice(&mut buf[bytes_read..(bytes_read + bytes_to_read)]);
107 bytes_read += bytes_to_read;
108 }
109
110 _ => {
112 if bytes_read > 0 {
114 return Poll::Ready(Ok(bytes_read));
115 }
116
117 self.current_buffer = {
118 match Pin::new(&mut self.incoming).poll_next(&mut context) {
119 Poll::Pending => return Poll::Pending,
120 Poll::Ready(Some(buf)) => Some(buf),
121 Poll::Ready(None) => return Poll::Ready(Ok(bytes_read)),
122 }
123 };
124 }
125 }
126 }
127 }
128}
129
130impl AsyncWrite for MemorySocket {
131 fn poll_write(
132 mut self: Pin<&mut Self>,
133 _context: &mut Context,
134 buf: &[u8],
135 ) -> Poll<Result<usize>> {
136 self.write_buffer.extend_from_slice(buf);
137 Poll::Ready(Ok(buf.len()))
138 }
139
140 fn poll_flush(mut self: Pin<&mut Self>, _context: &mut Context) -> Poll<Result<()>> {
141 use flume::TrySendError;
142
143 if !self.write_buffer.is_empty() {
144 let buffer = self.write_buffer.split().freeze();
145 match self.outgoing.try_send(buffer) {
146 Ok(()) => Poll::Ready(Ok(())),
147 Err(TrySendError::Disconnected(_)) => {
148 Poll::Ready(Err(ErrorKind::BrokenPipe.into()))
149 }
150 Err(TrySendError::Full(_)) => unreachable!(),
151 }
152 } else {
153 Poll::Ready(Ok(()))
154 }
155 }
156
157 fn poll_close(self: Pin<&mut Self>, _context: &mut Context) -> Poll<Result<()>> {
158 Poll::Ready(Ok(()))
159 }
160}