1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{PollFd, PollFlags, PollTimeout, poll};
13
14#[derive(Debug)]
16pub struct TcpSocket {
17 socket: Socket,
18 nonblocking: bool,
19}
20
21impl TcpSocket {
22 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
24 config.validate()?;
25
26 let socket = Socket::new(
27 config.socket_family.to_domain(),
28 config.socket_type.to_sock_type(),
29 Some(Protocol::TCP),
30 )?;
31
32 socket.set_nonblocking(config.nonblocking)?;
33
34 if let Some(flag) = config.reuseaddr {
36 socket.set_reuse_address(flag)?;
37 }
38 #[cfg(any(
39 target_os = "android",
40 target_os = "dragonfly",
41 target_os = "freebsd",
42 target_os = "fuchsia",
43 target_os = "ios",
44 target_os = "linux",
45 target_os = "macos",
46 target_os = "netbsd",
47 target_os = "openbsd",
48 target_os = "tvos",
49 target_os = "visionos",
50 target_os = "watchos"
51 ))]
52 if let Some(flag) = config.reuseport {
53 socket.set_reuse_port(flag)?;
54 }
55 if let Some(flag) = config.nodelay {
56 socket.set_nodelay(flag)?;
57 }
58 if let Some(dur) = config.linger {
59 socket.set_linger(Some(dur))?;
60 }
61 if let Some(ttl) = config.ttl {
62 socket.set_ttl(ttl)?;
63 }
64 if let Some(hoplimit) = config.hoplimit {
65 socket.set_unicast_hops_v6(hoplimit)?;
66 }
67 if let Some(keepalive) = config.keepalive {
68 socket.set_keepalive(keepalive)?;
69 }
70 if let Some(timeout) = config.read_timeout {
71 socket.set_read_timeout(Some(timeout))?;
72 }
73 if let Some(timeout) = config.write_timeout {
74 socket.set_write_timeout(Some(timeout))?;
75 }
76 if let Some(size) = config.recv_buffer_size {
77 socket.set_recv_buffer_size(size)?;
78 }
79 if let Some(size) = config.send_buffer_size {
80 socket.set_send_buffer_size(size)?;
81 }
82 if let Some(tos) = config.tos {
83 socket.set_tos(tos)?;
84 }
85 #[cfg(any(
86 target_os = "android",
87 target_os = "dragonfly",
88 target_os = "freebsd",
89 target_os = "fuchsia",
90 target_os = "ios",
91 target_os = "linux",
92 target_os = "macos",
93 target_os = "netbsd",
94 target_os = "openbsd",
95 target_os = "tvos",
96 target_os = "visionos",
97 target_os = "watchos"
98 ))]
99 if let Some(tclass) = config.tclass_v6 {
100 socket.set_tclass_v6(tclass)?;
101 }
102 if let Some(only_v6) = config.only_v6 {
103 socket.set_only_v6(only_v6)?;
104 }
105
106 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
108 if let Some(iface) = &config.bind_device {
109 socket.bind_device(Some(iface.as_bytes()))?;
110 }
111
112 if let Some(addr) = config.bind_addr {
114 socket.bind(&addr.into())?;
115 }
116
117 Ok(Self {
118 socket,
119 nonblocking: config.nonblocking,
120 })
121 }
122
123 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
125 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
126 socket.set_nonblocking(false)?;
127 Ok(Self {
128 socket,
129 nonblocking: false,
130 })
131 }
132
133 pub fn v4_stream() -> io::Result<Self> {
135 Self::new(Domain::IPV4, SockType::STREAM)
136 }
137
138 pub fn v6_stream() -> io::Result<Self> {
140 Self::new(Domain::IPV6, SockType::STREAM)
141 }
142
143 pub fn raw_v4() -> io::Result<Self> {
145 Self::new(Domain::IPV4, SockType::RAW)
146 }
147
148 pub fn raw_v6() -> io::Result<Self> {
150 Self::new(Domain::IPV6, SockType::RAW)
151 }
152
153 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
155 self.socket.bind(&addr.into())
156 }
157
158 pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
160 self.socket.connect(&addr.into())
161 }
162
163 #[cfg(unix)]
167 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
168 let socket = self.socket.try_clone()?;
169 socket.set_nonblocking(true)?;
170 let raw_fd = socket.as_raw_fd();
171
172 match socket.connect(&target.into()) {
174 Ok(_) => { }
175 Err(err)
176 if err.kind() == io::ErrorKind::WouldBlock
177 || err.raw_os_error() == Some(libc::EINPROGRESS) =>
178 {
179 }
181 Err(e) => return Err(e),
182 }
183
184 use std::os::unix::io::BorrowedFd;
186 let mut fds = [PollFd::new(
188 unsafe { BorrowedFd::borrow_raw(raw_fd) },
189 PollFlags::POLLOUT,
190 )];
191 let poll_timeout = PollTimeout::try_from(timeout).unwrap_or(PollTimeout::MAX);
192 let n = poll(&mut fds, poll_timeout)?;
193
194 if n == 0 {
195 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
196 }
197
198 let err: i32 = socket
200 .take_error()?
201 .map(|e| e.raw_os_error().unwrap_or(0))
202 .unwrap_or(0);
203 if err != 0 {
204 return Err(io::Error::from_raw_os_error(err));
205 }
206
207 socket.set_nonblocking(self.nonblocking)?;
208
209 match socket.try_clone() {
210 Ok(cloned_socket) => {
211 let std_stream: TcpStream = cloned_socket.into();
213 Ok(std_stream)
214 }
215 Err(e) => Err(e),
216 }
217 }
218
219 #[cfg(windows)]
223 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
224 use std::mem::size_of;
225 use std::os::windows::io::AsRawSocket;
226 use windows_sys::Win32::Networking::WinSock::{
227 POLLWRNORM, SO_ERROR, SOCKET, SOCKET_ERROR, SOL_SOCKET, WSAPOLLFD, WSAPoll, getsockopt,
228 };
229
230 let socket = self.socket.try_clone()?;
231 socket.set_nonblocking(true)?;
232 let sock = socket.as_raw_socket() as SOCKET;
233
234 match socket.connect(&target.into()) {
236 Ok(_) => { }
237 Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) => {}
238 Err(e) => return Err(e),
239 }
240
241 let mut fds = [WSAPOLLFD {
243 fd: sock,
244 events: POLLWRNORM,
245 revents: 0,
246 }];
247
248 let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
249 let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
250 if result == SOCKET_ERROR {
251 return Err(io::Error::last_os_error());
252 } else if result == 0 {
253 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
254 }
255
256 let mut so_error: i32 = 0;
258 let mut optlen = size_of::<i32>() as i32;
259 let ret = unsafe {
260 getsockopt(
261 sock,
262 SOL_SOCKET as i32,
263 SO_ERROR as i32,
264 &mut so_error as *mut _ as *mut _,
265 &mut optlen,
266 )
267 };
268
269 if ret == SOCKET_ERROR || so_error != 0 {
270 return Err(io::Error::from_raw_os_error(so_error));
271 }
272
273 socket.set_nonblocking(self.nonblocking)?;
274
275 let std_stream: TcpStream = socket.into();
276 Ok(std_stream)
277 }
278
279 pub fn listen(&self, backlog: i32) -> io::Result<()> {
281 self.socket.listen(backlog)
282 }
283
284 pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
286 let (stream, addr) = self.socket.accept()?;
287 Ok((stream.into(), addr.as_socket().unwrap()))
288 }
289
290 pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
292 Ok(self.socket.into())
293 }
294
295 pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
297 Ok(self.socket.into())
298 }
299
300 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
302 self.socket.send_to(buf, &target.into())
303 }
304
305 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
307 let buf_maybe = unsafe {
309 std::slice::from_raw_parts_mut(
310 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
311 buf.len(),
312 )
313 };
314
315 let (n, addr) = self.socket.recv_from(buf_maybe)?;
316 let addr = addr
317 .as_socket()
318 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
319
320 Ok((n, addr))
321 }
322
323 pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
325 self.socket.shutdown(how)
326 }
327
328 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
330 self.socket.set_reuse_address(on)
331 }
332
333 pub fn reuseaddr(&self) -> io::Result<bool> {
335 self.socket.reuse_address()
336 }
337
338 #[cfg(any(
340 target_os = "android",
341 target_os = "dragonfly",
342 target_os = "freebsd",
343 target_os = "fuchsia",
344 target_os = "ios",
345 target_os = "linux",
346 target_os = "macos",
347 target_os = "netbsd",
348 target_os = "openbsd",
349 target_os = "tvos",
350 target_os = "visionos",
351 target_os = "watchos"
352 ))]
353 pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
354 self.socket.set_reuse_port(on)
355 }
356
357 #[cfg(any(
359 target_os = "android",
360 target_os = "dragonfly",
361 target_os = "freebsd",
362 target_os = "fuchsia",
363 target_os = "ios",
364 target_os = "linux",
365 target_os = "macos",
366 target_os = "netbsd",
367 target_os = "openbsd",
368 target_os = "tvos",
369 target_os = "visionos",
370 target_os = "watchos"
371 ))]
372 pub fn reuseport(&self) -> io::Result<bool> {
373 self.socket.reuse_port()
374 }
375
376 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
378 self.socket.set_nodelay(on)
379 }
380
381 pub fn nodelay(&self) -> io::Result<bool> {
383 self.socket.nodelay()
384 }
385
386 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
388 self.socket.set_linger(dur)
389 }
390
391 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
393 self.socket.set_ttl(ttl)
394 }
395
396 pub fn ttl(&self) -> io::Result<u32> {
398 self.socket.ttl()
399 }
400
401 pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
403 self.socket.set_unicast_hops_v6(hops)
404 }
405
406 pub fn hoplimit(&self) -> io::Result<u32> {
408 self.socket.unicast_hops_v6()
409 }
410
411 pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
413 self.socket.set_keepalive(on)
414 }
415
416 pub fn keepalive(&self) -> io::Result<bool> {
418 self.socket.keepalive()
419 }
420
421 pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
423 self.socket.set_recv_buffer_size(size)
424 }
425
426 pub fn recv_buffer_size(&self) -> io::Result<usize> {
428 self.socket.recv_buffer_size()
429 }
430
431 pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
433 self.socket.set_send_buffer_size(size)
434 }
435
436 pub fn send_buffer_size(&self) -> io::Result<usize> {
438 self.socket.send_buffer_size()
439 }
440
441 pub fn set_tos(&self, tos: u32) -> io::Result<()> {
443 self.socket.set_tos(tos)
444 }
445
446 pub fn tos(&self) -> io::Result<u32> {
448 self.socket.tos()
449 }
450
451 #[cfg(any(
453 target_os = "android",
454 target_os = "dragonfly",
455 target_os = "freebsd",
456 target_os = "fuchsia",
457 target_os = "ios",
458 target_os = "linux",
459 target_os = "macos",
460 target_os = "netbsd",
461 target_os = "openbsd",
462 target_os = "tvos",
463 target_os = "visionos",
464 target_os = "watchos"
465 ))]
466 pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
467 self.socket.set_tclass_v6(tclass)
468 }
469
470 #[cfg(any(
472 target_os = "android",
473 target_os = "dragonfly",
474 target_os = "freebsd",
475 target_os = "fuchsia",
476 target_os = "ios",
477 target_os = "linux",
478 target_os = "macos",
479 target_os = "netbsd",
480 target_os = "openbsd",
481 target_os = "tvos",
482 target_os = "visionos",
483 target_os = "watchos"
484 ))]
485 pub fn tclass_v6(&self) -> io::Result<u32> {
486 self.socket.tclass_v6()
487 }
488
489 pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
491 self.socket.set_only_v6(only_v6)
492 }
493
494 pub fn only_v6(&self) -> io::Result<bool> {
496 self.socket.only_v6()
497 }
498
499 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
501 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
502 return self.socket.bind_device(Some(iface.as_bytes()));
503
504 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
505 {
506 let _ = iface;
507 Err(io::Error::new(
508 io::ErrorKind::Unsupported,
509 "bind_device is not supported on this platform",
510 ))
511 }
512 }
513
514 pub fn local_addr(&self) -> io::Result<SocketAddr> {
516 self.socket
517 .local_addr()?
518 .as_socket()
519 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
520 }
521
522 #[cfg(unix)]
524 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
525 use std::os::fd::AsRawFd;
526 self.socket.as_raw_fd()
527 }
528
529 #[cfg(windows)]
531 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
532 use std::os::windows::io::AsRawSocket;
533 self.socket.as_raw_socket()
534 }
535
536 pub fn from_socket(socket: Socket) -> Self {
538 Self {
539 socket,
540 nonblocking: false,
544 }
545 }
546
547 pub fn socket(&self) -> &Socket {
549 &self.socket
550 }
551
552 pub fn into_socket(self) -> Socket {
554 self.socket
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 #[cfg(unix)]
561 use super::*;
562 #[cfg(unix)]
563 use libc::{F_GETFL, O_NONBLOCK, fcntl};
564 #[cfg(unix)]
565 use std::net::TcpListener as StdTcpListener;
566
567 #[cfg(unix)]
568 fn socket_is_nonblocking(socket: &Socket) -> bool {
569 let flags = unsafe { fcntl(socket.as_raw_fd(), F_GETFL) };
570 assert!(flags >= 0, "F_GETFL failed: {}", io::Error::last_os_error());
571 (flags & O_NONBLOCK) != 0
572 }
573
574 #[cfg(unix)]
575 #[test]
576 fn connect_timeout_does_not_mutate_original_nonblocking_state_after_invalid_input() {
577 let sock = TcpSocket::v4_stream().expect("socket");
578 sock.socket.set_nonblocking(true).expect("set nonblocking");
579
580 let result = sock.connect_timeout("[::1]:80".parse().unwrap(), Duration::from_secs(1));
581 assert!(result.is_err());
582 assert!(socket_is_nonblocking(&sock.socket));
583 }
584
585 #[cfg(unix)]
586 #[test]
587 fn connect_timeout_does_not_mutate_original_blocking_state_after_success() {
588 let listener = StdTcpListener::bind("127.0.0.1:0").expect("listener");
589 let addr = listener.local_addr().expect("local addr");
590 let handle = std::thread::spawn(move || listener.accept().expect("accept"));
591
592 let sock = TcpSocket::v4_stream().expect("socket");
593 sock.socket.set_nonblocking(false).expect("set blocking");
594 let _stream = sock
595 .connect_timeout(addr, Duration::from_secs(1))
596 .expect("connect");
597
598 assert!(!socket_is_nonblocking(&sock.socket));
599 let _ = handle.join();
600 }
601}