maybe_fut/api/net/
tcp_stream.rs

1use std::net::SocketAddr;
2
3use crate::{maybe_fut_constructor_result, maybe_fut_method, maybe_fut_method_sync};
4
5/// A TCP stream between a local and a remote socket.
6///
7/// A TCP Stream can either be created by connecting to an endpoint, via the [`TcpStream::connect`] method,
8/// or by [`super::TcpListener::accept`]ing a connection from a [`super::TcpListener`].
9///
10/// Reading and writing to a [`TcpStream`] is usually done by using the [`crate::io::Read`] and [`crate::io::Write`] traits.
11#[derive(Debug, Unwrap, Read, Write)]
12#[io(feature("tokio-net"))]
13#[unwrap_types(
14    std(std::net::TcpStream),
15    tokio(tokio::net::TcpStream),
16    tokio_gated("tokio-net")
17)]
18pub struct TcpStream(TcpStreamInner);
19
20#[derive(Debug)]
21enum TcpStreamInner {
22    Std(std::net::TcpStream),
23    #[cfg(tokio_net)]
24    #[cfg_attr(docsrs, doc(cfg(feature = "tokio-net")))]
25    Tokio(tokio::net::TcpStream),
26}
27
28impl From<std::net::TcpStream> for TcpStream {
29    fn from(stream: std::net::TcpStream) -> Self {
30        Self(TcpStreamInner::Std(stream))
31    }
32}
33
34#[cfg(tokio_net)]
35#[cfg_attr(docsrs, doc(cfg(feature = "tokio-net")))]
36impl From<tokio::net::TcpStream> for TcpStream {
37    fn from(stream: tokio::net::TcpStream) -> Self {
38        Self(TcpStreamInner::Tokio(stream))
39    }
40}
41
42#[cfg(unix)]
43impl std::os::fd::AsFd for TcpStream {
44    fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
45        match &self.0 {
46            TcpStreamInner::Std(file) => file.as_fd(),
47            #[cfg(tokio_net)]
48            TcpStreamInner::Tokio(file) => file.as_fd(),
49        }
50    }
51}
52
53#[cfg(unix)]
54impl std::os::fd::AsRawFd for TcpStream {
55    fn as_raw_fd(&self) -> std::os::fd::RawFd {
56        match &self.0 {
57            TcpStreamInner::Std(file) => file.as_raw_fd(),
58            #[cfg(tokio_net)]
59            TcpStreamInner::Tokio(file) => file.as_raw_fd(),
60        }
61    }
62}
63
64#[cfg(windows)]
65impl std::os::windows::io::AsSocket for TcpStream {
66    fn as_socket(&self) -> std::os::windows::io::BorrowedSocket<'_> {
67        match &self.0 {
68            TcpStreamInner::Std(file) => file.as_socket(),
69            #[cfg(tokio_net)]
70            TcpStreamInner::Tokio(file) => file.as_socket(),
71        }
72    }
73}
74
75#[cfg(windows)]
76impl std::os::windows::io::AsRawSocket for TcpStream {
77    fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
78        match &self.0 {
79            TcpStreamInner::Std(file) => file.as_raw_socket(),
80            #[cfg(tokio_net)]
81            TcpStreamInner::Tokio(file) => file.as_raw_socket(),
82        }
83    }
84}
85
86impl TcpStream {
87    maybe_fut_constructor_result!(
88        /// Opens a TCP connection to a remote host at the specified address.
89        connect(addr: SocketAddr) -> std::io::Result<TcpStream>,
90        std::net::TcpStream::connect,
91        tokio::net::TcpStream::connect,
92        tokio_net
93    );
94
95    maybe_fut_method_sync!(
96        /// Returns the local address that this stream is bound to.
97        local_addr() -> std::io::Result<SocketAddr>,
98        TcpStreamInner::Std,
99        TcpStreamInner::Tokio,
100        tokio_net
101    );
102
103    maybe_fut_method_sync!(
104        /// Returns the value of the `SO_ERROR` option.
105        take_error() -> std::io::Result<Option<std::io::Error>>,
106        TcpStreamInner::Std,
107        TcpStreamInner::Tokio,
108        tokio_net
109    );
110
111    maybe_fut_method_sync!(
112        /// Returns the remote address that this stream is connected to.
113        peer_addr() -> std::io::Result<SocketAddr>,
114        TcpStreamInner::Std,
115        TcpStreamInner::Tokio,
116        tokio_net
117    );
118
119    maybe_fut_method_sync!(
120        /// Gets the value of the `TCP_NODELAY` option on this socket.
121        nodelay() -> std::io::Result<bool>,
122        TcpStreamInner::Std,
123        TcpStreamInner::Tokio,
124        tokio_net
125    );
126
127    maybe_fut_method_sync!(
128        /// Sets the value of the `TCP_NODELAY` option on this socket.
129        set_nodelay(nodelay: bool) -> std::io::Result<()>,
130        TcpStreamInner::Std,
131        TcpStreamInner::Tokio,
132        tokio_net
133    );
134
135    maybe_fut_method!(
136        /// Receives data on the socket from the remote address to which it is connected, without removing that data from the queue.
137        /// On success, returns the number of bytes read.
138        peek(buf: &mut [u8]) -> std::io::Result<usize>,
139        TcpStreamInner::Std,
140        TcpStreamInner::Tokio,
141        tokio_net
142    );
143
144    maybe_fut_method_sync!(
145        /// Gets the value of the `IP_TTL` option on this socket.
146        ttl() -> std::io::Result<u32>,
147        TcpStreamInner::Std,
148        TcpStreamInner::Tokio,
149        tokio_net
150    );
151
152    maybe_fut_method_sync!(
153        /// Sets the value of the `IP_TTL` option on this socket.
154        set_ttl(ttl: u32) -> std::io::Result<()>,
155        TcpStreamInner::Std,
156        TcpStreamInner::Tokio,
157        tokio_net
158    );
159}
160
161#[cfg(test)]
162mod test {
163
164    use std::io::{Read as _, Write as _};
165    use std::net::TcpListener;
166    use std::sync::Arc;
167    use std::sync::atomic::AtomicBool;
168    use std::thread::JoinHandle;
169
170    use super::*;
171    use crate::block_on;
172    use crate::io::{Read as _, Write};
173
174    #[test]
175    #[serial_test::serial]
176    fn test_should_connect_std() {
177        let (_join, peer_addr, exit) = ping_server();
178        assert!(block_on(TcpStream::connect(peer_addr)).is_ok());
179
180        exit.store(true, std::sync::atomic::Ordering::Relaxed);
181        // join.join().expect("Failed to join server thread");
182    }
183
184    #[cfg(tokio_net)]
185    #[tokio::test]
186    #[serial_test::serial]
187    async fn test_should_connect_tokio() {
188        let (_join, peer_addr, exit) = ping_server();
189        assert!(TcpStream::connect(peer_addr).await.is_ok());
190
191        exit.store(true, std::sync::atomic::Ordering::Relaxed);
192        // join.join().expect("Failed to join server thread");
193    }
194
195    #[test]
196    #[serial_test::serial]
197    fn test_should_get_local_and_peer_addr() {
198        let (_join, peer_addr, exit) = ping_server();
199        let stream = block_on(TcpStream::connect(peer_addr)).unwrap();
200
201        assert!(stream.local_addr().is_ok());
202        assert_eq!(stream.peer_addr().unwrap(), peer_addr);
203
204        exit.store(true, std::sync::atomic::Ordering::Relaxed);
205        // join.join().expect("Failed to join server thread");
206    }
207
208    #[cfg(tokio_net)]
209    #[tokio::test]
210    #[serial_test::serial]
211    async fn test_should_get_local_and_peer_addr_tokio() {
212        let (_join, peer_addr, exit) = ping_server();
213        let stream = TcpStream::connect(peer_addr).await.unwrap();
214        assert!(stream.local_addr().is_ok());
215        assert_eq!(stream.peer_addr().unwrap(), peer_addr);
216
217        exit.store(true, std::sync::atomic::Ordering::Relaxed);
218        // join.join().expect("Failed to join server thread");
219    }
220
221    #[test]
222    #[serial_test::serial]
223    fn test_should_get_nodelay() {
224        let (_join, peer_addr, exit) = ping_server();
225        let stream = block_on(TcpStream::connect(peer_addr)).unwrap();
226        assert!(stream.nodelay().is_ok());
227        assert!(stream.set_nodelay(true).is_ok());
228        assert!(stream.nodelay().unwrap());
229        assert!(stream.set_nodelay(false).is_ok());
230        assert!(!stream.nodelay().unwrap());
231
232        exit.store(true, std::sync::atomic::Ordering::Relaxed);
233        // join.join().expect("Failed to join server thread");
234    }
235
236    #[cfg(tokio_net)]
237    #[tokio::test]
238    #[serial_test::serial]
239    async fn test_should_get_nodelay_tokio() {
240        let (_join, peer_addr, exit) = ping_server();
241        let stream = TcpStream::connect(peer_addr).await.unwrap();
242        assert!(stream.nodelay().is_ok());
243        assert!(stream.set_nodelay(true).is_ok());
244        assert!(stream.nodelay().unwrap());
245        assert!(stream.set_nodelay(false).is_ok());
246        assert!(!stream.nodelay().unwrap());
247
248        exit.store(true, std::sync::atomic::Ordering::Relaxed);
249        // join.join().expect("Failed to join server thread");
250    }
251
252    #[test]
253    #[serial_test::serial]
254    fn test_should_get_ttl() {
255        let (_join, peer_addr, exit) = ping_server();
256        let stream = block_on(TcpStream::connect(peer_addr)).unwrap();
257        assert!(stream.ttl().is_ok());
258        assert!(stream.set_ttl(64).is_ok());
259        assert_eq!(stream.ttl().unwrap(), 64);
260
261        exit.store(true, std::sync::atomic::Ordering::Relaxed);
262        // join.join().expect("Failed to join server thread");
263    }
264
265    #[cfg(tokio_net)]
266    #[tokio::test]
267    async fn test_should_get_ttl_tokio() {
268        let (_join, peer_addr, exit) = ping_server();
269        let stream = TcpStream::connect(peer_addr).await.unwrap();
270        assert!(stream.ttl().is_ok());
271        assert!(stream.set_ttl(64).is_ok());
272        assert_eq!(stream.ttl().unwrap(), 64);
273
274        exit.store(true, std::sync::atomic::Ordering::Relaxed);
275        // join.join().expect("Failed to join server thread");
276    }
277
278    #[test]
279    #[serial_test::serial]
280    fn test_should_read_and_write_from_tcp_stream_std() {
281        let (_join, peer_addr, exit) = ping_server();
282
283        let mut stream = block_on(TcpStream::connect(peer_addr)).unwrap();
284        block_on(stream.write_all(b"Ping")).expect("Failed to write to stream");
285        let mut buf = [0; 1024];
286        let size = block_on(stream.read(&mut buf)).expect("Failed to read from stream");
287        assert_eq!(size, 4);
288        assert_eq!(&buf[..size], b"Pong");
289        exit.store(true, std::sync::atomic::Ordering::Relaxed);
290
291        // join.join().expect("Failed to join server thread");
292    }
293
294    #[cfg(tokio_net)]
295    #[tokio::test]
296    #[serial_test::serial]
297    async fn test_should_read_and_write_from_tcp_stream_tokio() {
298        let (_join, peer_addr, exit) = ping_server();
299
300        let mut stream = TcpStream::connect(peer_addr).await.unwrap();
301        stream
302            .write_all(b"Ping")
303            .await
304            .expect("Failed to write to stream");
305        let mut buf = [0; 1024];
306        let size = stream
307            .read(&mut buf)
308            .await
309            .expect("Failed to read from stream");
310        assert_eq!(size, 4);
311        assert_eq!(&buf[..size], b"Pong");
312        exit.store(true, std::sync::atomic::Ordering::Relaxed);
313
314        // join.join().expect("Failed to join server thread");
315    }
316
317    fn ping_server() -> (JoinHandle<()>, SocketAddr, Arc<AtomicBool>) {
318        // sleep for a random amount of time
319        std::thread::sleep(std::time::Duration::from_millis(
320            rand::random::<u64>() % 1000,
321        ));
322
323        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
324        listener
325            .set_nonblocking(true)
326            .expect("Failed to set listener to non-blocking");
327        let addr = listener.local_addr().unwrap();
328
329        let exit = Arc::new(AtomicBool::new(false));
330        let exit_clone = exit.clone();
331
332        let join = std::thread::spawn(move || {
333            while !exit_clone.load(std::sync::atomic::Ordering::Relaxed) {
334                match listener.accept() {
335                    Ok((mut stream, _)) => {
336                        println!("Accepted connection from {}", stream.peer_addr().unwrap());
337
338                        // read
339                        let mut buf = [0; 1024];
340                        if let Ok(size) = stream.read(&mut buf) {
341                            if size > 0 {
342                                println!("Received: {}", String::from_utf8_lossy(&buf[..size]));
343                            }
344                        }
345                        // write
346                        if let Err(e) = stream.write_all(b"Pong") {
347                            eprintln!("Failed to write to stream: {}", e);
348                        }
349                    }
350                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
351                        // wait for next connection
352                        std::thread::sleep(std::time::Duration::from_millis(100));
353                    }
354                    Err(e) => {
355                        eprintln!("Failed to accept connection: {}", e);
356                        break;
357                    }
358                }
359            }
360        });
361
362        (join, addr, exit)
363    }
364}