distant_net/common/transport/
tcp.rs1use std::net::IpAddr;
2use std::{fmt, io};
3
4use async_trait::async_trait;
5use tokio::net::{TcpStream, ToSocketAddrs};
6
7use super::{Interest, Ready, Reconnectable, Transport};
8
9pub struct TcpTransport {
11 pub(crate) addr: IpAddr,
12 pub(crate) port: u16,
13 pub(crate) inner: TcpStream,
14}
15
16impl TcpTransport {
17 pub async fn connect(addrs: impl ToSocketAddrs) -> io::Result<Self> {
20 let stream = TcpStream::connect(addrs).await?;
21 let addr = stream.peer_addr()?;
22 Ok(Self {
23 addr: addr.ip(),
24 port: addr.port(),
25 inner: stream,
26 })
27 }
28
29 pub fn ip_addr(&self) -> IpAddr {
31 self.addr
32 }
33
34 pub fn port(&self) -> u16 {
36 self.port
37 }
38}
39
40impl fmt::Debug for TcpTransport {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.debug_struct("TcpTransport")
43 .field("addr", &self.addr)
44 .field("port", &self.port)
45 .finish()
46 }
47}
48
49#[async_trait]
50impl Reconnectable for TcpTransport {
51 async fn reconnect(&mut self) -> io::Result<()> {
52 self.inner = TcpStream::connect((self.addr, self.port)).await?;
53 Ok(())
54 }
55}
56
57#[async_trait]
58impl Transport for TcpTransport {
59 fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
60 self.inner.try_read(buf)
61 }
62
63 fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
64 self.inner.try_write(buf)
65 }
66
67 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
68 self.inner.ready(interest).await
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use std::net::{Ipv6Addr, SocketAddr};
75
76 use test_log::test;
77 use tokio::net::TcpListener;
78 use tokio::sync::oneshot;
79 use tokio::task::JoinHandle;
80
81 use super::*;
82 use crate::common::TransportExt;
83
84 async fn find_ephemeral_addr() -> SocketAddr {
85 let addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
90
91 let listener = TcpListener::bind((addr, 0))
92 .await
93 .expect("Failed to bind on an ephemeral port");
94
95 let port = listener
96 .local_addr()
97 .expect("Failed to look up ephemeral port")
98 .port();
99
100 SocketAddr::from((addr, port))
101 }
102
103 async fn start_and_run_server(tx: oneshot::Sender<SocketAddr>) -> io::Result<()> {
104 let addr = find_ephemeral_addr().await;
105
106 let listener = TcpListener::bind(addr).await?;
108
109 tx.send(addr)
111 .map_err(|x| io::Error::new(io::ErrorKind::Other, x.to_string()))?;
112
113 run_server(listener).await
114 }
115
116 async fn run_server(listener: TcpListener) -> io::Result<()> {
117 use tokio::io::{AsyncReadExt, AsyncWriteExt};
118
119 let (mut conn, _) = listener.accept().await?;
121
122 conn.write_all(b"hello conn").await?;
124
125 let mut buf: [u8; 12] = [0; 12];
127 let _ = conn.read_exact(&mut buf).await?;
128 assert_eq!(&buf, b"hello server");
129
130 Ok(())
131 }
132
133 #[test(tokio::test)]
134 async fn should_fail_to_connect_if_nothing_listening() {
135 let addr = find_ephemeral_addr().await;
136
137 TcpTransport::connect(addr).await.expect_err(&format!(
139 "Unexpectedly succeeded in connecting to ghost address: {}",
140 addr
141 ));
142 }
143
144 #[test(tokio::test)]
145 async fn should_be_able_to_read_and_write_data() {
146 let (tx, rx) = oneshot::channel();
147
148 let task: JoinHandle<io::Result<()>> = tokio::spawn(start_and_run_server(tx));
151
152 let addr = rx.await.expect("Failed to get server server address");
154
155 let mut buf: [u8; 10] = [0; 10];
157
158 let conn = TcpTransport::connect(&addr)
159 .await
160 .expect("Conn failed to connect");
161
162 conn.read_exact(&mut buf)
164 .await
165 .expect("Conn failed to read");
166 assert_eq!(&buf, b"hello conn");
167
168 conn.write_all(b"hello server")
169 .await
170 .expect("Conn failed to write");
171
172 let _ = task.await.expect("Server task failed unexpectedly");
174 }
175
176 #[test(tokio::test)]
177 async fn should_be_able_to_reconnect() {
178 let (tx, rx) = oneshot::channel();
179
180 let task: JoinHandle<io::Result<()>> = tokio::spawn(start_and_run_server(tx));
183
184 let addr = rx.await.expect("Failed to get server server address");
186
187 let mut conn = TcpTransport::connect(&addr)
189 .await
190 .expect("Conn failed to connect");
191
192 task.abort();
194
195 conn.readable()
197 .await
198 .expect("Failed to wait for conn to be readable");
199 let res = conn.read_exact(&mut [0; 10]).await;
200 assert!(
201 matches!(res, Ok(0) | Err(_)),
202 "Unexpected read result: {res:?}"
203 );
204
205 let task: JoinHandle<io::Result<()>> = tokio::spawn(run_server(
207 TcpListener::bind(addr)
208 .await
209 .expect("Failed to rebind server"),
210 ));
211
212 let mut buf: [u8; 10] = [0; 10];
214 conn.reconnect().await.expect("Conn failed to reconnect");
215
216 conn.read_exact(&mut buf)
218 .await
219 .expect("Conn failed to read");
220 assert_eq!(&buf, b"hello conn");
221
222 conn.write_all(b"hello server")
223 .await
224 .expect("Conn failed to write");
225
226 let _ = task.await.expect("Server task failed unexpectedly");
228 }
229}