1use std::io;
9use std::net::SocketAddr;
10
11use bytes::{Bytes, BytesMut};
12use futures::stream::StreamExt;
13use tokio::io::{split, AsyncReadExt, AsyncWriteExt};
14use tokio::net::{TcpListener, TcpStream};
15
16use crate::source::Source;
17
18pub struct Tcp;
19
20pub struct OutgoingConnection {
21 pub reader: Source<io::Result<Bytes>>,
22 pub writer: tokio::sync::mpsc::UnboundedSender<Bytes>,
23 pub remote_addr: SocketAddr,
24}
25
26pub struct IncomingConnection {
27 pub reader: Source<io::Result<Bytes>>,
28 pub writer: tokio::sync::mpsc::UnboundedSender<Bytes>,
29 pub remote_addr: SocketAddr,
30 pub local_addr: SocketAddr,
31}
32
33impl Tcp {
34 pub async fn outgoing_connection(addr: SocketAddr) -> io::Result<OutgoingConnection> {
35 let stream = TcpStream::connect(addr).await?;
36 let remote_addr = stream.peer_addr()?;
37 let (rd, mut wr) = split(stream);
38 let (w_tx, mut w_rx) = tokio::sync::mpsc::unbounded_channel::<Bytes>();
39 tokio::spawn(async move {
40 while let Some(b) = w_rx.recv().await {
41 if wr.write_all(&b).await.is_err() {
42 break;
43 }
44 }
45 let _ = wr.shutdown().await;
46 });
47 let reader = read_stream(rd);
48 Ok(OutgoingConnection { reader, writer: w_tx, remote_addr })
49 }
50
51 pub async fn bind(addr: SocketAddr) -> io::Result<Source<io::Result<IncomingConnection>>> {
52 let listener = TcpListener::bind(addr).await?;
53 let local = listener.local_addr()?;
54 let s = futures::stream::unfold(AcceptState { listener, local }, |state| async move {
55 match state.listener.accept().await {
56 Ok((stream, remote)) => {
57 let (rd, mut wr) = split(stream);
58 let (w_tx, mut w_rx) = tokio::sync::mpsc::unbounded_channel::<Bytes>();
59 tokio::spawn(async move {
60 while let Some(b) = w_rx.recv().await {
61 if wr.write_all(&b).await.is_err() {
62 break;
63 }
64 }
65 let _ = wr.shutdown().await;
66 });
67 let reader = read_stream(rd);
68 let ic = IncomingConnection {
69 reader,
70 writer: w_tx,
71 remote_addr: remote,
72 local_addr: state.local,
73 };
74 Some((Ok(ic), state))
75 }
76 Err(e) => Some((Err(e), state)),
77 }
78 })
79 .boxed();
80 Ok(Source { inner: s })
81 }
82}
83
84struct AcceptState {
85 listener: TcpListener,
86 local: SocketAddr,
87}
88
89fn read_stream<R>(rd: R) -> Source<io::Result<Bytes>>
90where
91 R: tokio::io::AsyncRead + Unpin + Send + 'static,
92{
93 struct St<R> {
94 rd: R,
95 done: bool,
96 }
97 let s = futures::stream::unfold(St { rd, done: false }, |mut st| async move {
98 if st.done {
99 return None;
100 }
101 let mut buf = BytesMut::with_capacity(4096);
102 buf.resize(4096, 0);
103 match st.rd.read(&mut buf).await {
104 Ok(0) => None,
105 Ok(n) => {
106 buf.truncate(n);
107 Some((Ok(buf.freeze()), st))
108 }
109 Err(e) => {
110 st.done = true;
111 Some((Err(e), st))
112 }
113 }
114 })
115 .boxed();
116 Source { inner: s }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::sink::Sink;
123
124 #[tokio::test]
125 async fn tcp_roundtrip() {
126 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
127 let addr = listener.local_addr().unwrap();
128 drop(listener); let incoming = Tcp::bind(addr).await.unwrap();
131 let (tx_done, mut rx_done) = tokio::sync::mpsc::unbounded_channel::<Vec<Bytes>>();
132 tokio::spawn(async move {
133 let mut stream = incoming.into_boxed();
134 if let Some(Ok(conn)) = stream.next().await {
135 let collected = Sink::collect(conn.reader).await;
136 let mut ok = Vec::new();
137 for b in collected.into_iter().flatten() {
138 ok.push(b);
139 }
140 let _ = tx_done.send(ok);
141 }
142 });
143
144 let out = Tcp::outgoing_connection(addr).await.unwrap();
145 out.writer.send(Bytes::from_static(b"hello")).unwrap();
146 drop(out.writer);
147
148 let received = rx_done.recv().await.unwrap();
149 assert!(received.iter().any(|b| b.as_ref() == b"hello"));
150 }
151}