1use std::io;
2use std::io::Error;
3use std::pin::{pin, Pin};
4use std::sync::Arc;
5use std::task::{ready, Context, Poll};
6use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
7use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use crate::thin_addr::SocketAddr;
11use crate::constructor::ConstructExt;
12use crate::poll_mutex::PollMutex;
13use crate::read::{ReaderInner, SharedReader};
14use crate::write::WriteInner;
15
16pub mod thin_addr;
17mod poll_mutex;
18mod packet_buffer;
19mod constructor;
20mod write;
21mod read;
22mod protocol;
23mod integers;
24
25type Writer = OwnedWriteHalf;
26type Reader = BufReader<OwnedReadHalf>;
27
28pub struct MuxConnection {
40 write: Box<WriteInner>,
41 read: ReaderInner
42}
43
44impl MuxConnection {
45 fn new(write: Box<WriteInner>, read: ReaderInner) -> Self {
46 Self {
47 write,
48 read
49 }
50 }
51
52 pub fn addr(&self) -> SocketAddr {
53 self.write.addr()
54 }
55}
56
57impl AsyncWrite for MuxConnection {
58 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
59 Pin::new(&mut Pin::into_inner(self).write).poll_write(cx, buf)
60 }
61
62 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
63 Pin::new(&mut Pin::into_inner(self).write).poll_flush(cx)
64 }
65
66 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
67 Pin::new(&mut Pin::into_inner(self).write).poll_shutdown(cx)
68 }
69}
70
71impl AsyncRead for MuxConnection {
72 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
73 Pin::new(&mut Pin::into_inner(self).read).poll_read(cx, buf)
74 }
75}
76
77
78#[derive(Clone)]
92pub struct MuxPipe {
93 write: Arc<Mutex<Writer>>,
94 read: Arc<SharedReader>,
95}
96
97impl MuxPipe {
98 pub fn new(stream: TcpStream) -> Self {
100 MuxListener::with_listener_capacity(stream, 0).into_pipe()
101 }
102
103 fn make_writer(&self, addr: SocketAddr) -> Box<WriteInner> {
104 WriteInner::box_new((addr, PollMutex::new(Arc::clone(&self.write))))
105 }
106
107
108 pub async fn add_connection(&self, addr: SocketAddr) -> io::Result<MuxConnection> {
152 let reader = self.read.add_connection(addr)?;
153 let mut writer = self.make_writer(addr);
154 writer.handshake().await?;
155 Ok(MuxConnection::new(writer, reader))
156 }
157}
158
159
160
161
162pub struct MuxListener {
169 pipe: MuxPipe,
170 receiver: flume::Receiver<(SocketAddr, ReaderInner)>
171}
172
173impl MuxListener {
174 pub fn new(stream: TcpStream) -> Self {
175 Self::with_listener_capacity(stream, 1)
176 }
177
178 fn with_listener_capacity(stream: TcpStream, capacity: usize) -> Self {
179 let (read, write) = stream.into_split();
180 let reader = BufReader::new(read);
181 let (sender, receiver) = flume::bounded(capacity);
182 let read = SharedReader::new(reader, sender);
183 let write = Arc::new(Mutex::new(write));
184
185 Self {
186 pipe: MuxPipe { write, read },
187 receiver
188 }
189 }
190
191 pub async fn add_connection(&self, addr: SocketAddr) -> io::Result<MuxConnection> {
192 self.pipe.add_connection(addr).await
193 }
194
195 pub async fn accept(&self) -> io::Result<MuxConnection> {
196 let mut fut = pin!(self.receiver.recv_async());
197 let (addr, reader) = std::future::poll_fn(move |cx| {
198 if let Poll::Ready(res) = fut.as_mut().poll(cx) {
199 return Poll::Ready(Ok::<_, Error>(res.expect("receiver should never close")))
200 }
201
202 match ready!(self.pipe.read.poll(cx))? {}
203 }).await?;
204 let writer = self.pipe.make_writer(addr);
205 Ok(MuxConnection::new(writer, reader))
206 }
207
208 pub fn pipe(&self) -> &MuxPipe {
209 &self.pipe
210 }
211
212 pub fn into_pipe(self) -> MuxPipe {
213 self.pipe
214 }
215}
216
217#[cfg(all(test, not(miri)))]
218mod tests {
219 use super::*;
220 use tokio::net::TcpListener;
221 use tokio::io::{AsyncReadExt, AsyncWriteExt};
222
223 fn dummy_addr() -> SocketAddr {
224 "127.0.0.1:12345".parse().unwrap()
226 }
227
228 async fn mux_pipe() -> (MuxListener, MuxPipe) {
229 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
231 let addr = listener.local_addr().unwrap();
232
233 let server = async {
235 let (stream, _) = listener.accept().await.unwrap();
236 MuxListener::new(stream)
237 };
238
239 let client = async {
240 MuxPipe::new(TcpStream::connect(addr).await.unwrap())
241 };
242
243 tokio::join!(server, client)
244 }
245
246 #[tokio::test]
247 async fn test_mux_listener_accept_connection() {
248 let (mux_listener, conn) = mux_pipe().await;
249
250 let client_task = async {
252 let mut mux_conn = conn.add_connection(dummy_addr()).await.unwrap();
253
254 mux_conn.write_all(b"hello world").await.unwrap();
255 mux_conn.flush().await.unwrap();
256 mux_conn.shutdown().await.unwrap();
257 };
258
259 let server_task = async {
261 let mut accepted = mux_listener.accept().await.unwrap();
262 let mut buf = vec![];
263 let n = accepted.read_to_end(&mut buf).await.unwrap();
264 let received = &buf[..n];
265 assert_eq!(received, b"hello world");
266 };
267
268 tokio::join!(client_task, server_task);
269 }
270
271 #[tokio::test]
272 async fn test_mux_pipe_add_connection_multiple_times() {
273 let (mux_pipe_server, mux_pipe_client) = mux_pipe().await;
274
275 let addr1 = dummy_addr();
277 let addr2 = "127.0.0.1:12346".parse::<SocketAddr>().unwrap();
278
279 let client_task = async {
280 let handle = async |addr, bytes| {
281 let mut conn = mux_pipe_client.add_connection(addr).await?;
282 conn.write_all(bytes).await?;
283 conn.flush().await?;
284 conn.shutdown().await
285 };
286
287 tokio::try_join!(handle(addr1, b"first connection"), handle(addr2, b"second connection"))
288 };
289
290 let server_task = async {
291 let (mut conn1, mut conn2) = {
292 let conn1 = mux_pipe_server.accept().await?;
293 let conn2 = mux_pipe_server.accept().await?;
294
295 match (conn1.addr(), conn2.addr()) {
296 (con1, con2) if con1 == addr1 && con2 == addr2 => {
297 (conn1, conn2)
298 }
299 (con1, con2) if con1 == addr2 && con2 == addr1 => {
300 (conn2, conn1)
301 }
302 _ => unreachable!()
303 }
304 };
305
306 let mut buf1 = vec![];
307 let n1 = conn1.read_to_end(&mut buf1).await?;
308 assert_eq!(&buf1[..n1], b"first connection");
309
310 let mut buf2 = vec![];
311 let n2 = conn2.read_to_end(&mut buf2).await?;
312 assert_eq!(&buf2[..n2], b"second connection");
313 Ok(())
314 };
315
316 tokio::try_join!(client_task, server_task).unwrap();
317 }
318}