1use std::net::SocketAddr;
2
3use crate::{maybe_fut_constructor_result, maybe_fut_method, maybe_fut_method_sync};
4
5#[derive(Debug, Unwrap, Read, Write)]
12#[io(feature("tokio-net"))]
13#[unwrap_types(
14 std(std::net::TcpStream),
15 tokio(tokio::net::TcpStream),
16 tokio_gated("tokio-net")
17)]
18pub struct TcpStream(TcpStreamInner);
19
20#[derive(Debug)]
21enum TcpStreamInner {
22 Std(std::net::TcpStream),
23 #[cfg(tokio_net)]
24 #[cfg_attr(docsrs, doc(cfg(feature = "tokio-net")))]
25 Tokio(tokio::net::TcpStream),
26}
27
28impl From<std::net::TcpStream> for TcpStream {
29 fn from(stream: std::net::TcpStream) -> Self {
30 Self(TcpStreamInner::Std(stream))
31 }
32}
33
34#[cfg(tokio_net)]
35#[cfg_attr(docsrs, doc(cfg(feature = "tokio-net")))]
36impl From<tokio::net::TcpStream> for TcpStream {
37 fn from(stream: tokio::net::TcpStream) -> Self {
38 Self(TcpStreamInner::Tokio(stream))
39 }
40}
41
42#[cfg(unix)]
43impl std::os::fd::AsFd for TcpStream {
44 fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
45 match &self.0 {
46 TcpStreamInner::Std(file) => file.as_fd(),
47 #[cfg(tokio_net)]
48 TcpStreamInner::Tokio(file) => file.as_fd(),
49 }
50 }
51}
52
53#[cfg(unix)]
54impl std::os::fd::AsRawFd for TcpStream {
55 fn as_raw_fd(&self) -> std::os::fd::RawFd {
56 match &self.0 {
57 TcpStreamInner::Std(file) => file.as_raw_fd(),
58 #[cfg(tokio_net)]
59 TcpStreamInner::Tokio(file) => file.as_raw_fd(),
60 }
61 }
62}
63
64#[cfg(windows)]
65impl std::os::windows::io::AsSocket for TcpStream {
66 fn as_socket(&self) -> std::os::windows::io::BorrowedSocket<'_> {
67 match &self.0 {
68 TcpStreamInner::Std(file) => file.as_socket(),
69 #[cfg(tokio_net)]
70 TcpStreamInner::Tokio(file) => file.as_socket(),
71 }
72 }
73}
74
75#[cfg(windows)]
76impl std::os::windows::io::AsRawSocket for TcpStream {
77 fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
78 match &self.0 {
79 TcpStreamInner::Std(file) => file.as_raw_socket(),
80 #[cfg(tokio_net)]
81 TcpStreamInner::Tokio(file) => file.as_raw_socket(),
82 }
83 }
84}
85
86impl TcpStream {
87 maybe_fut_constructor_result!(
88 connect(addr: SocketAddr) -> std::io::Result<TcpStream>,
90 std::net::TcpStream::connect,
91 tokio::net::TcpStream::connect,
92 tokio_net
93 );
94
95 maybe_fut_method_sync!(
96 local_addr() -> std::io::Result<SocketAddr>,
98 TcpStreamInner::Std,
99 TcpStreamInner::Tokio,
100 tokio_net
101 );
102
103 maybe_fut_method_sync!(
104 take_error() -> std::io::Result<Option<std::io::Error>>,
106 TcpStreamInner::Std,
107 TcpStreamInner::Tokio,
108 tokio_net
109 );
110
111 maybe_fut_method_sync!(
112 peer_addr() -> std::io::Result<SocketAddr>,
114 TcpStreamInner::Std,
115 TcpStreamInner::Tokio,
116 tokio_net
117 );
118
119 maybe_fut_method_sync!(
120 nodelay() -> std::io::Result<bool>,
122 TcpStreamInner::Std,
123 TcpStreamInner::Tokio,
124 tokio_net
125 );
126
127 maybe_fut_method_sync!(
128 set_nodelay(nodelay: bool) -> std::io::Result<()>,
130 TcpStreamInner::Std,
131 TcpStreamInner::Tokio,
132 tokio_net
133 );
134
135 maybe_fut_method!(
136 peek(buf: &mut [u8]) -> std::io::Result<usize>,
139 TcpStreamInner::Std,
140 TcpStreamInner::Tokio,
141 tokio_net
142 );
143
144 maybe_fut_method_sync!(
145 ttl() -> std::io::Result<u32>,
147 TcpStreamInner::Std,
148 TcpStreamInner::Tokio,
149 tokio_net
150 );
151
152 maybe_fut_method_sync!(
153 set_ttl(ttl: u32) -> std::io::Result<()>,
155 TcpStreamInner::Std,
156 TcpStreamInner::Tokio,
157 tokio_net
158 );
159}
160
161#[cfg(test)]
162mod test {
163
164 use std::io::{Read as _, Write as _};
165 use std::net::TcpListener;
166 use std::sync::Arc;
167 use std::sync::atomic::AtomicBool;
168 use std::thread::JoinHandle;
169
170 use super::*;
171 use crate::block_on;
172 use crate::io::{Read as _, Write};
173
174 #[test]
175 #[serial_test::serial]
176 fn test_should_connect_std() {
177 let (_join, peer_addr, exit) = ping_server();
178 assert!(block_on(TcpStream::connect(peer_addr)).is_ok());
179
180 exit.store(true, std::sync::atomic::Ordering::Relaxed);
181 }
183
184 #[cfg(tokio_net)]
185 #[tokio::test]
186 #[serial_test::serial]
187 async fn test_should_connect_tokio() {
188 let (_join, peer_addr, exit) = ping_server();
189 assert!(TcpStream::connect(peer_addr).await.is_ok());
190
191 exit.store(true, std::sync::atomic::Ordering::Relaxed);
192 }
194
195 #[test]
196 #[serial_test::serial]
197 fn test_should_get_local_and_peer_addr() {
198 let (_join, peer_addr, exit) = ping_server();
199 let stream = block_on(TcpStream::connect(peer_addr)).unwrap();
200
201 assert!(stream.local_addr().is_ok());
202 assert_eq!(stream.peer_addr().unwrap(), peer_addr);
203
204 exit.store(true, std::sync::atomic::Ordering::Relaxed);
205 }
207
208 #[cfg(tokio_net)]
209 #[tokio::test]
210 #[serial_test::serial]
211 async fn test_should_get_local_and_peer_addr_tokio() {
212 let (_join, peer_addr, exit) = ping_server();
213 let stream = TcpStream::connect(peer_addr).await.unwrap();
214 assert!(stream.local_addr().is_ok());
215 assert_eq!(stream.peer_addr().unwrap(), peer_addr);
216
217 exit.store(true, std::sync::atomic::Ordering::Relaxed);
218 }
220
221 #[test]
222 #[serial_test::serial]
223 fn test_should_get_nodelay() {
224 let (_join, peer_addr, exit) = ping_server();
225 let stream = block_on(TcpStream::connect(peer_addr)).unwrap();
226 assert!(stream.nodelay().is_ok());
227 assert!(stream.set_nodelay(true).is_ok());
228 assert!(stream.nodelay().unwrap());
229 assert!(stream.set_nodelay(false).is_ok());
230 assert!(!stream.nodelay().unwrap());
231
232 exit.store(true, std::sync::atomic::Ordering::Relaxed);
233 }
235
236 #[cfg(tokio_net)]
237 #[tokio::test]
238 #[serial_test::serial]
239 async fn test_should_get_nodelay_tokio() {
240 let (_join, peer_addr, exit) = ping_server();
241 let stream = TcpStream::connect(peer_addr).await.unwrap();
242 assert!(stream.nodelay().is_ok());
243 assert!(stream.set_nodelay(true).is_ok());
244 assert!(stream.nodelay().unwrap());
245 assert!(stream.set_nodelay(false).is_ok());
246 assert!(!stream.nodelay().unwrap());
247
248 exit.store(true, std::sync::atomic::Ordering::Relaxed);
249 }
251
252 #[test]
253 #[serial_test::serial]
254 fn test_should_get_ttl() {
255 let (_join, peer_addr, exit) = ping_server();
256 let stream = block_on(TcpStream::connect(peer_addr)).unwrap();
257 assert!(stream.ttl().is_ok());
258 assert!(stream.set_ttl(64).is_ok());
259 assert_eq!(stream.ttl().unwrap(), 64);
260
261 exit.store(true, std::sync::atomic::Ordering::Relaxed);
262 }
264
265 #[cfg(tokio_net)]
266 #[tokio::test]
267 async fn test_should_get_ttl_tokio() {
268 let (_join, peer_addr, exit) = ping_server();
269 let stream = TcpStream::connect(peer_addr).await.unwrap();
270 assert!(stream.ttl().is_ok());
271 assert!(stream.set_ttl(64).is_ok());
272 assert_eq!(stream.ttl().unwrap(), 64);
273
274 exit.store(true, std::sync::atomic::Ordering::Relaxed);
275 }
277
278 #[test]
279 #[serial_test::serial]
280 fn test_should_read_and_write_from_tcp_stream_std() {
281 let (_join, peer_addr, exit) = ping_server();
282
283 let mut stream = block_on(TcpStream::connect(peer_addr)).unwrap();
284 block_on(stream.write_all(b"Ping")).expect("Failed to write to stream");
285 let mut buf = [0; 1024];
286 let size = block_on(stream.read(&mut buf)).expect("Failed to read from stream");
287 assert_eq!(size, 4);
288 assert_eq!(&buf[..size], b"Pong");
289 exit.store(true, std::sync::atomic::Ordering::Relaxed);
290
291 }
293
294 #[cfg(tokio_net)]
295 #[tokio::test]
296 #[serial_test::serial]
297 async fn test_should_read_and_write_from_tcp_stream_tokio() {
298 let (_join, peer_addr, exit) = ping_server();
299
300 let mut stream = TcpStream::connect(peer_addr).await.unwrap();
301 stream
302 .write_all(b"Ping")
303 .await
304 .expect("Failed to write to stream");
305 let mut buf = [0; 1024];
306 let size = stream
307 .read(&mut buf)
308 .await
309 .expect("Failed to read from stream");
310 assert_eq!(size, 4);
311 assert_eq!(&buf[..size], b"Pong");
312 exit.store(true, std::sync::atomic::Ordering::Relaxed);
313
314 }
316
317 fn ping_server() -> (JoinHandle<()>, SocketAddr, Arc<AtomicBool>) {
318 std::thread::sleep(std::time::Duration::from_millis(
320 rand::random::<u64>() % 1000,
321 ));
322
323 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
324 listener
325 .set_nonblocking(true)
326 .expect("Failed to set listener to non-blocking");
327 let addr = listener.local_addr().unwrap();
328
329 let exit = Arc::new(AtomicBool::new(false));
330 let exit_clone = exit.clone();
331
332 let join = std::thread::spawn(move || {
333 while !exit_clone.load(std::sync::atomic::Ordering::Relaxed) {
334 match listener.accept() {
335 Ok((mut stream, _)) => {
336 println!("Accepted connection from {}", stream.peer_addr().unwrap());
337
338 let mut buf = [0; 1024];
340 if let Ok(size) = stream.read(&mut buf) {
341 if size > 0 {
342 println!("Received: {}", String::from_utf8_lossy(&buf[..size]));
343 }
344 }
345 if let Err(e) = stream.write_all(b"Pong") {
347 eprintln!("Failed to write to stream: {}", e);
348 }
349 }
350 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
351 std::thread::sleep(std::time::Duration::from_millis(100));
353 }
354 Err(e) => {
355 eprintln!("Failed to accept connection: {}", e);
356 break;
357 }
358 }
359 }
360 });
361
362 (join, addr, exit)
363 }
364}