Skip to main content

atomr_streams/
tcp.rs

1//! TCP stream helpers.
2//!
3//! * [`Tcp::outgoing_connection`] — connect to `addr` and expose the remote
4//!   side as a `(Sink<Bytes>, Source<io::Result<Bytes>>)` pair.
5//! * [`Tcp::bind`] — accept inbound connections as a stream of
6//!   [`IncomingConnection`]s.
7
8use std::io;
9use std::net::SocketAddr;
10
11use bytes::{Bytes, BytesMut};
12use futures::stream::StreamExt;
13use tokio::io::{split, AsyncReadExt, AsyncWriteExt};
14use tokio::net::{TcpListener, TcpStream};
15
16use crate::source::Source;
17
18pub struct Tcp;
19
20pub struct OutgoingConnection {
21    pub reader: Source<io::Result<Bytes>>,
22    pub writer: tokio::sync::mpsc::UnboundedSender<Bytes>,
23    pub remote_addr: SocketAddr,
24}
25
26pub struct IncomingConnection {
27    pub reader: Source<io::Result<Bytes>>,
28    pub writer: tokio::sync::mpsc::UnboundedSender<Bytes>,
29    pub remote_addr: SocketAddr,
30    pub local_addr: SocketAddr,
31}
32
33impl Tcp {
34    pub async fn outgoing_connection(addr: SocketAddr) -> io::Result<OutgoingConnection> {
35        let stream = TcpStream::connect(addr).await?;
36        let remote_addr = stream.peer_addr()?;
37        let (rd, mut wr) = split(stream);
38        let (w_tx, mut w_rx) = tokio::sync::mpsc::unbounded_channel::<Bytes>();
39        tokio::spawn(async move {
40            while let Some(b) = w_rx.recv().await {
41                if wr.write_all(&b).await.is_err() {
42                    break;
43                }
44            }
45            let _ = wr.shutdown().await;
46        });
47        let reader = read_stream(rd);
48        Ok(OutgoingConnection { reader, writer: w_tx, remote_addr })
49    }
50
51    pub async fn bind(addr: SocketAddr) -> io::Result<Source<io::Result<IncomingConnection>>> {
52        let listener = TcpListener::bind(addr).await?;
53        let local = listener.local_addr()?;
54        let s = futures::stream::unfold(AcceptState { listener, local }, |state| async move {
55            match state.listener.accept().await {
56                Ok((stream, remote)) => {
57                    let (rd, mut wr) = split(stream);
58                    let (w_tx, mut w_rx) = tokio::sync::mpsc::unbounded_channel::<Bytes>();
59                    tokio::spawn(async move {
60                        while let Some(b) = w_rx.recv().await {
61                            if wr.write_all(&b).await.is_err() {
62                                break;
63                            }
64                        }
65                        let _ = wr.shutdown().await;
66                    });
67                    let reader = read_stream(rd);
68                    let ic = IncomingConnection {
69                        reader,
70                        writer: w_tx,
71                        remote_addr: remote,
72                        local_addr: state.local,
73                    };
74                    Some((Ok(ic), state))
75                }
76                Err(e) => Some((Err(e), state)),
77            }
78        })
79        .boxed();
80        Ok(Source { inner: s })
81    }
82}
83
84struct AcceptState {
85    listener: TcpListener,
86    local: SocketAddr,
87}
88
89fn read_stream<R>(rd: R) -> Source<io::Result<Bytes>>
90where
91    R: tokio::io::AsyncRead + Unpin + Send + 'static,
92{
93    struct St<R> {
94        rd: R,
95        done: bool,
96    }
97    let s = futures::stream::unfold(St { rd, done: false }, |mut st| async move {
98        if st.done {
99            return None;
100        }
101        let mut buf = BytesMut::with_capacity(4096);
102        buf.resize(4096, 0);
103        match st.rd.read(&mut buf).await {
104            Ok(0) => None,
105            Ok(n) => {
106                buf.truncate(n);
107                Some((Ok(buf.freeze()), st))
108            }
109            Err(e) => {
110                st.done = true;
111                Some((Err(e), st))
112            }
113        }
114    })
115    .boxed();
116    Source { inner: s }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::sink::Sink;
123
124    #[tokio::test]
125    async fn tcp_roundtrip() {
126        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
127        let addr = listener.local_addr().unwrap();
128        drop(listener); // rebind using Tcp::bind
129
130        let incoming = Tcp::bind(addr).await.unwrap();
131        let (tx_done, mut rx_done) = tokio::sync::mpsc::unbounded_channel::<Vec<Bytes>>();
132        tokio::spawn(async move {
133            let mut stream = incoming.into_boxed();
134            if let Some(Ok(conn)) = stream.next().await {
135                let collected = Sink::collect(conn.reader).await;
136                let mut ok = Vec::new();
137                for b in collected.into_iter().flatten() {
138                    ok.push(b);
139                }
140                let _ = tx_done.send(ok);
141            }
142        });
143
144        let out = Tcp::outgoing_connection(addr).await.unwrap();
145        out.writer.send(Bytes::from_static(b"hello")).unwrap();
146        drop(out.writer);
147
148        let received = rx_done.recv().await.unwrap();
149        assert!(received.iter().any(|b| b.as_ref() == b"hello"));
150    }
151}