distant_net/common/transport/
tcp.rs

1use std::net::IpAddr;
2use std::{fmt, io};
3
4use async_trait::async_trait;
5use tokio::net::{TcpStream, ToSocketAddrs};
6
7use super::{Interest, Ready, Reconnectable, Transport};
8
9/// Represents a [`Transport`] that leverages a TCP stream
10pub struct TcpTransport {
11    pub(crate) addr: IpAddr,
12    pub(crate) port: u16,
13    pub(crate) inner: TcpStream,
14}
15
16impl TcpTransport {
17    /// Creates a new stream by connecting to a remote machine at the specified
18    /// IP address and port
19    pub async fn connect(addrs: impl ToSocketAddrs) -> io::Result<Self> {
20        let stream = TcpStream::connect(addrs).await?;
21        let addr = stream.peer_addr()?;
22        Ok(Self {
23            addr: addr.ip(),
24            port: addr.port(),
25            inner: stream,
26        })
27    }
28
29    /// Returns the IP address that the stream is connected to
30    pub fn ip_addr(&self) -> IpAddr {
31        self.addr
32    }
33
34    /// Returns the port that the stream is connected to
35    pub fn port(&self) -> u16 {
36        self.port
37    }
38}
39
40impl fmt::Debug for TcpTransport {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.debug_struct("TcpTransport")
43            .field("addr", &self.addr)
44            .field("port", &self.port)
45            .finish()
46    }
47}
48
49#[async_trait]
50impl Reconnectable for TcpTransport {
51    async fn reconnect(&mut self) -> io::Result<()> {
52        self.inner = TcpStream::connect((self.addr, self.port)).await?;
53        Ok(())
54    }
55}
56
57#[async_trait]
58impl Transport for TcpTransport {
59    fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
60        self.inner.try_read(buf)
61    }
62
63    fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
64        self.inner.try_write(buf)
65    }
66
67    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
68        self.inner.ready(interest).await
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use std::net::{Ipv6Addr, SocketAddr};
75
76    use test_log::test;
77    use tokio::net::TcpListener;
78    use tokio::sync::oneshot;
79    use tokio::task::JoinHandle;
80
81    use super::*;
82    use crate::common::TransportExt;
83
84    async fn find_ephemeral_addr() -> SocketAddr {
85        // Start a listener on a distinct port, get its port, and kill it
86        // NOTE: This is a race condition as something else could bind to
87        //       this port inbetween us killing it and us attempting to
88        //       connect to it. We're willing to take that chance
89        let addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
90
91        let listener = TcpListener::bind((addr, 0))
92            .await
93            .expect("Failed to bind on an ephemeral port");
94
95        let port = listener
96            .local_addr()
97            .expect("Failed to look up ephemeral port")
98            .port();
99
100        SocketAddr::from((addr, port))
101    }
102
103    async fn start_and_run_server(tx: oneshot::Sender<SocketAddr>) -> io::Result<()> {
104        let addr = find_ephemeral_addr().await;
105
106        // Start listening at the distinct address
107        let listener = TcpListener::bind(addr).await?;
108
109        // Send the address back to our main test thread
110        tx.send(addr)
111            .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string()))?;
112
113        run_server(listener).await
114    }
115
116    async fn run_server(listener: TcpListener) -> io::Result<()> {
117        use tokio::io::{AsyncReadExt, AsyncWriteExt};
118
119        // Get the connection
120        let (mut conn, _) = listener.accept().await?;
121
122        // Send some data to the connection (10 bytes)
123        conn.write_all(b"hello conn").await?;
124
125        // Receive some data from the connection (12 bytes)
126        let mut buf: [u8; 12] = [0; 12];
127        let _ = conn.read_exact(&mut buf).await?;
128        assert_eq!(&buf, b"hello server");
129
130        Ok(())
131    }
132
133    #[test(tokio::test)]
134    async fn should_fail_to_connect_if_nothing_listening() {
135        let addr = find_ephemeral_addr().await;
136
137        // Now this should fail as we've stopped what was listening
138        TcpTransport::connect(addr).await.expect_err(&format!(
139            "Unexpectedly succeeded in connecting to ghost address: {}",
140            addr
141        ));
142    }
143
144    #[test(tokio::test)]
145    async fn should_be_able_to_read_and_write_data() {
146        let (tx, rx) = oneshot::channel();
147
148        // Spawn a task that will wait for a connection, send data,
149        // and receive data that it will return in the task
150        let task: JoinHandle<io::Result<()>> = tokio::spawn(start_and_run_server(tx));
151
152        // Wait for the server to be ready
153        let addr = rx.await.expect("Failed to get server server address");
154
155        // Connect to the socket, send some bytes, and get some bytes
156        let mut buf: [u8; 10] = [0; 10];
157
158        let conn = TcpTransport::connect(&addr)
159            .await
160            .expect("Conn failed to connect");
161
162        // Continually read until we get all of the data
163        conn.read_exact(&mut buf)
164            .await
165            .expect("Conn failed to read");
166        assert_eq!(&buf, b"hello conn");
167
168        conn.write_all(b"hello server")
169            .await
170            .expect("Conn failed to write");
171
172        // Verify that the task has completed by waiting on it
173        let _ = task.await.expect("Server task failed unexpectedly");
174    }
175
176    #[test(tokio::test)]
177    async fn should_be_able_to_reconnect() {
178        let (tx, rx) = oneshot::channel();
179
180        // Spawn a task that will wait for a connection, send data,
181        // and receive data that it will return in the task
182        let task: JoinHandle<io::Result<()>> = tokio::spawn(start_and_run_server(tx));
183
184        // Wait for the server to be ready
185        let addr = rx.await.expect("Failed to get server server address");
186
187        // Connect to the server
188        let mut conn = TcpTransport::connect(&addr)
189            .await
190            .expect("Conn failed to connect");
191
192        // Kill the server to make the connection fail
193        task.abort();
194
195        // Verify the connection fails by trying to read from it (should get connection reset)
196        conn.readable()
197            .await
198            .expect("Failed to wait for conn to be readable");
199        let res = conn.read_exact(&mut [0; 10]).await;
200        assert!(
201            matches!(res, Ok(0) | Err(_)),
202            "Unexpected read result: {res:?}"
203        );
204
205        // Restart the server
206        let task: JoinHandle<io::Result<()>> = tokio::spawn(run_server(
207            TcpListener::bind(addr)
208                .await
209                .expect("Failed to rebind server"),
210        ));
211
212        // Reconnect to the socket, send some bytes, and get some bytes
213        let mut buf: [u8; 10] = [0; 10];
214        conn.reconnect().await.expect("Conn failed to reconnect");
215
216        // Continually read until we get all of the data
217        conn.read_exact(&mut buf)
218            .await
219            .expect("Conn failed to read");
220        assert_eq!(&buf, b"hello conn");
221
222        conn.write_all(b"hello server")
223            .await
224            .expect("Conn failed to write");
225
226        // Verify that the task has completed by waiting on it
227        let _ = task.await.expect("Server task failed unexpectedly");
228    }
229}