1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{PollFd, PollFlags, poll};
13
14#[derive(Debug)]
16pub struct TcpSocket {
17 socket: Socket,
18}
19
20impl TcpSocket {
21 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
23 let socket = Socket::new(
24 config.socket_family.to_domain(),
25 config.socket_type.to_sock_type(),
26 Some(Protocol::TCP),
27 )?;
28
29 socket.set_nonblocking(config.nonblocking)?;
30
31 if let Some(flag) = config.reuseaddr {
33 socket.set_reuse_address(flag)?;
34 }
35 #[cfg(any(
36 target_os = "android",
37 target_os = "dragonfly",
38 target_os = "freebsd",
39 target_os = "fuchsia",
40 target_os = "ios",
41 target_os = "linux",
42 target_os = "macos",
43 target_os = "netbsd",
44 target_os = "openbsd",
45 target_os = "tvos",
46 target_os = "visionos",
47 target_os = "watchos"
48 ))]
49 if let Some(flag) = config.reuseport {
50 socket.set_reuse_port(flag)?;
51 }
52 if let Some(flag) = config.nodelay {
53 socket.set_nodelay(flag)?;
54 }
55 if let Some(dur) = config.linger {
56 socket.set_linger(Some(dur))?;
57 }
58 if let Some(ttl) = config.ttl {
59 socket.set_ttl(ttl)?;
60 }
61 if let Some(hoplimit) = config.hoplimit {
62 socket.set_unicast_hops_v6(hoplimit)?;
63 }
64 if let Some(keepalive) = config.keepalive {
65 socket.set_keepalive(keepalive)?;
66 }
67 if let Some(timeout) = config.read_timeout {
68 socket.set_read_timeout(Some(timeout))?;
69 }
70 if let Some(timeout) = config.write_timeout {
71 socket.set_write_timeout(Some(timeout))?;
72 }
73 if let Some(size) = config.recv_buffer_size {
74 socket.set_recv_buffer_size(size)?;
75 }
76 if let Some(size) = config.send_buffer_size {
77 socket.set_send_buffer_size(size)?;
78 }
79 if let Some(tos) = config.tos {
80 socket.set_tos(tos)?;
81 }
82 #[cfg(any(
83 target_os = "android",
84 target_os = "dragonfly",
85 target_os = "freebsd",
86 target_os = "fuchsia",
87 target_os = "ios",
88 target_os = "linux",
89 target_os = "macos",
90 target_os = "netbsd",
91 target_os = "openbsd",
92 target_os = "tvos",
93 target_os = "visionos",
94 target_os = "watchos"
95 ))]
96 if let Some(tclass) = config.tclass_v6 {
97 socket.set_tclass_v6(tclass)?;
98 }
99 if let Some(only_v6) = config.only_v6 {
100 socket.set_only_v6(only_v6)?;
101 }
102
103 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
105 if let Some(iface) = &config.bind_device {
106 socket.bind_device(Some(iface.as_bytes()))?;
107 }
108
109 if let Some(addr) = config.bind_addr {
111 socket.bind(&addr.into())?;
112 }
113
114 Ok(Self { socket })
115 }
116
117 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
119 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
120 socket.set_nonblocking(false)?;
121 Ok(Self { socket })
122 }
123
124 pub fn v4_stream() -> io::Result<Self> {
126 Self::new(Domain::IPV4, SockType::STREAM)
127 }
128
129 pub fn v6_stream() -> io::Result<Self> {
131 Self::new(Domain::IPV6, SockType::STREAM)
132 }
133
134 pub fn raw_v4() -> io::Result<Self> {
136 Self::new(Domain::IPV4, SockType::RAW)
137 }
138
139 pub fn raw_v6() -> io::Result<Self> {
141 Self::new(Domain::IPV6, SockType::RAW)
142 }
143
144 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
146 self.socket.bind(&addr.into())
147 }
148
149 pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
151 self.socket.connect(&addr.into())
152 }
153
154 #[cfg(unix)]
156 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
157 let raw_fd = self.socket.as_raw_fd();
158 self.socket.set_nonblocking(true)?;
159
160 match self.socket.connect(&target.into()) {
162 Ok(_) => { }
163 Err(err)
164 if err.kind() == io::ErrorKind::WouldBlock
165 || err.raw_os_error() == Some(libc::EINPROGRESS) =>
166 {
167 }
169 Err(e) => return Err(e),
170 }
171
172 let timeout_ms = timeout.as_millis() as i32;
174 use std::os::unix::io::BorrowedFd;
175 let mut fds = [PollFd::new(
177 unsafe { BorrowedFd::borrow_raw(raw_fd) },
178 PollFlags::POLLOUT,
179 )];
180 let n = poll(&mut fds, Some(timeout_ms as u16))?;
181
182 if n == 0 {
183 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
184 }
185
186 let err: i32 = self
188 .socket
189 .take_error()?
190 .map(|e| e.raw_os_error().unwrap_or(0))
191 .unwrap_or(0);
192 if err != 0 {
193 return Err(io::Error::from_raw_os_error(err));
194 }
195
196 self.socket.set_nonblocking(false)?;
197
198 match self.socket.try_clone() {
199 Ok(cloned_socket) => {
200 let std_stream: TcpStream = cloned_socket.into();
202 Ok(std_stream)
203 }
204 Err(e) => Err(e),
205 }
206 }
207
208 #[cfg(windows)]
209 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
210 use std::mem::size_of;
211 use std::os::windows::io::AsRawSocket;
212 use windows_sys::Win32::Networking::WinSock::{
213 POLLWRNORM, SO_ERROR, SOCKET, SOCKET_ERROR, SOL_SOCKET, WSAPOLLFD, WSAPoll, getsockopt,
214 };
215
216 let sock = self.socket.as_raw_socket() as SOCKET;
217 self.socket.set_nonblocking(true)?;
218
219 match self.socket.connect(&target.into()) {
221 Ok(_) => { }
222 Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) => {}
223 Err(e) => return Err(e),
224 }
225
226 let mut fds = [WSAPOLLFD {
228 fd: sock,
229 events: POLLWRNORM,
230 revents: 0,
231 }];
232
233 let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
234 let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
235 if result == SOCKET_ERROR {
236 return Err(io::Error::last_os_error());
237 } else if result == 0 {
238 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
239 }
240
241 let mut so_error: i32 = 0;
243 let mut optlen = size_of::<i32>() as i32;
244 let ret = unsafe {
245 getsockopt(
246 sock,
247 SOL_SOCKET as i32,
248 SO_ERROR as i32,
249 &mut so_error as *mut _ as *mut _,
250 &mut optlen,
251 )
252 };
253
254 if ret == SOCKET_ERROR || so_error != 0 {
255 return Err(io::Error::from_raw_os_error(so_error));
256 }
257
258 self.socket.set_nonblocking(false)?;
259
260 let std_stream: TcpStream = self.socket.try_clone()?.into();
261 Ok(std_stream)
262 }
263
264 pub fn listen(&self, backlog: i32) -> io::Result<()> {
266 self.socket.listen(backlog)
267 }
268
269 pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
271 let (stream, addr) = self.socket.accept()?;
272 Ok((stream.into(), addr.as_socket().unwrap()))
273 }
274
275 pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
277 Ok(self.socket.into())
278 }
279
280 pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
282 Ok(self.socket.into())
283 }
284
285 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
287 self.socket.send_to(buf, &target.into())
288 }
289
290 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
292 let buf_maybe = unsafe {
294 std::slice::from_raw_parts_mut(
295 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
296 buf.len(),
297 )
298 };
299
300 let (n, addr) = self.socket.recv_from(buf_maybe)?;
301 let addr = addr
302 .as_socket()
303 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
304
305 Ok((n, addr))
306 }
307
308 pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
310 self.socket.shutdown(how)
311 }
312
313 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
315 self.socket.set_reuse_address(on)
316 }
317
318 pub fn reuseaddr(&self) -> io::Result<bool> {
320 self.socket.reuse_address()
321 }
322
323 #[cfg(any(
325 target_os = "android",
326 target_os = "dragonfly",
327 target_os = "freebsd",
328 target_os = "fuchsia",
329 target_os = "ios",
330 target_os = "linux",
331 target_os = "macos",
332 target_os = "netbsd",
333 target_os = "openbsd",
334 target_os = "tvos",
335 target_os = "visionos",
336 target_os = "watchos"
337 ))]
338 pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
339 self.socket.set_reuse_port(on)
340 }
341
342 #[cfg(any(
344 target_os = "android",
345 target_os = "dragonfly",
346 target_os = "freebsd",
347 target_os = "fuchsia",
348 target_os = "ios",
349 target_os = "linux",
350 target_os = "macos",
351 target_os = "netbsd",
352 target_os = "openbsd",
353 target_os = "tvos",
354 target_os = "visionos",
355 target_os = "watchos"
356 ))]
357 pub fn reuseport(&self) -> io::Result<bool> {
358 self.socket.reuse_port()
359 }
360
361 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
363 self.socket.set_nodelay(on)
364 }
365
366 pub fn nodelay(&self) -> io::Result<bool> {
368 self.socket.nodelay()
369 }
370
371 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
373 self.socket.set_linger(dur)
374 }
375
376 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
378 self.socket.set_ttl(ttl)
379 }
380
381 pub fn ttl(&self) -> io::Result<u32> {
383 self.socket.ttl()
384 }
385
386 pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
388 self.socket.set_unicast_hops_v6(hops)
389 }
390
391 pub fn hoplimit(&self) -> io::Result<u32> {
393 self.socket.unicast_hops_v6()
394 }
395
396 pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
398 self.socket.set_keepalive(on)
399 }
400
401 pub fn keepalive(&self) -> io::Result<bool> {
403 self.socket.keepalive()
404 }
405
406 pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
408 self.socket.set_recv_buffer_size(size)
409 }
410
411 pub fn recv_buffer_size(&self) -> io::Result<usize> {
413 self.socket.recv_buffer_size()
414 }
415
416 pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
418 self.socket.set_send_buffer_size(size)
419 }
420
421 pub fn send_buffer_size(&self) -> io::Result<usize> {
423 self.socket.send_buffer_size()
424 }
425
426 pub fn set_tos(&self, tos: u32) -> io::Result<()> {
428 self.socket.set_tos(tos)
429 }
430
431 pub fn tos(&self) -> io::Result<u32> {
433 self.socket.tos()
434 }
435
436 #[cfg(any(
438 target_os = "android",
439 target_os = "dragonfly",
440 target_os = "freebsd",
441 target_os = "fuchsia",
442 target_os = "ios",
443 target_os = "linux",
444 target_os = "macos",
445 target_os = "netbsd",
446 target_os = "openbsd",
447 target_os = "tvos",
448 target_os = "visionos",
449 target_os = "watchos"
450 ))]
451 pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
452 self.socket.set_tclass_v6(tclass)
453 }
454
455 #[cfg(any(
457 target_os = "android",
458 target_os = "dragonfly",
459 target_os = "freebsd",
460 target_os = "fuchsia",
461 target_os = "ios",
462 target_os = "linux",
463 target_os = "macos",
464 target_os = "netbsd",
465 target_os = "openbsd",
466 target_os = "tvos",
467 target_os = "visionos",
468 target_os = "watchos"
469 ))]
470 pub fn tclass_v6(&self) -> io::Result<u32> {
471 self.socket.tclass_v6()
472 }
473
474 pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
476 self.socket.set_only_v6(only_v6)
477 }
478
479 pub fn only_v6(&self) -> io::Result<bool> {
481 self.socket.only_v6()
482 }
483
484 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
486 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
487 return self.socket.bind_device(Some(iface.as_bytes()));
488
489 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
490 {
491 let _ = iface;
492 Err(io::Error::new(
493 io::ErrorKind::Unsupported,
494 "bind_device is not supported on this platform",
495 ))
496 }
497 }
498
499 pub fn local_addr(&self) -> io::Result<SocketAddr> {
501 self.socket
502 .local_addr()?
503 .as_socket()
504 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
505 }
506
507 #[cfg(unix)]
509 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
510 use std::os::fd::AsRawFd;
511 self.socket.as_raw_fd()
512 }
513
514 #[cfg(windows)]
516 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
517 use std::os::windows::io::AsRawSocket;
518 self.socket.as_raw_socket()
519 }
520
521 pub fn from_socket(socket: Socket) -> Self {
523 Self { socket }
524 }
525
526 pub fn socket(&self) -> &Socket {
528 &self.socket
529 }
530
531 pub fn into_socket(self) -> Socket {
533 self.socket
534 }
535}