Skip to main content

nex_socket/tcp/
sync_impl.rs

1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{PollFd, PollFlags, poll};
13
14/// Low level synchronous TCP socket.
15#[derive(Debug)]
16pub struct TcpSocket {
17    socket: Socket,
18}
19
20impl TcpSocket {
21    /// Build a socket according to `TcpSocketConfig`.
22    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
23        let socket = Socket::new(
24            config.socket_family.to_domain(),
25            config.socket_type.to_sock_type(),
26            Some(Protocol::TCP),
27        )?;
28
29        socket.set_nonblocking(config.nonblocking)?;
30
31        // Set socket options based on configuration
32        if let Some(flag) = config.reuseaddr {
33            socket.set_reuse_address(flag)?;
34        }
35        #[cfg(any(
36            target_os = "android",
37            target_os = "dragonfly",
38            target_os = "freebsd",
39            target_os = "fuchsia",
40            target_os = "ios",
41            target_os = "linux",
42            target_os = "macos",
43            target_os = "netbsd",
44            target_os = "openbsd",
45            target_os = "tvos",
46            target_os = "visionos",
47            target_os = "watchos"
48        ))]
49        if let Some(flag) = config.reuseport {
50            socket.set_reuse_port(flag)?;
51        }
52        if let Some(flag) = config.nodelay {
53            socket.set_nodelay(flag)?;
54        }
55        if let Some(dur) = config.linger {
56            socket.set_linger(Some(dur))?;
57        }
58        if let Some(ttl) = config.ttl {
59            socket.set_ttl(ttl)?;
60        }
61        if let Some(hoplimit) = config.hoplimit {
62            socket.set_unicast_hops_v6(hoplimit)?;
63        }
64        if let Some(keepalive) = config.keepalive {
65            socket.set_keepalive(keepalive)?;
66        }
67        if let Some(timeout) = config.read_timeout {
68            socket.set_read_timeout(Some(timeout))?;
69        }
70        if let Some(timeout) = config.write_timeout {
71            socket.set_write_timeout(Some(timeout))?;
72        }
73        if let Some(size) = config.recv_buffer_size {
74            socket.set_recv_buffer_size(size)?;
75        }
76        if let Some(size) = config.send_buffer_size {
77            socket.set_send_buffer_size(size)?;
78        }
79        if let Some(tos) = config.tos {
80            socket.set_tos(tos)?;
81        }
82        #[cfg(any(
83            target_os = "android",
84            target_os = "dragonfly",
85            target_os = "freebsd",
86            target_os = "fuchsia",
87            target_os = "ios",
88            target_os = "linux",
89            target_os = "macos",
90            target_os = "netbsd",
91            target_os = "openbsd",
92            target_os = "tvos",
93            target_os = "visionos",
94            target_os = "watchos"
95        ))]
96        if let Some(tclass) = config.tclass_v6 {
97            socket.set_tclass_v6(tclass)?;
98        }
99        if let Some(only_v6) = config.only_v6 {
100            socket.set_only_v6(only_v6)?;
101        }
102
103        // Linux: optional interface name
104        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
105        if let Some(iface) = &config.bind_device {
106            socket.bind_device(Some(iface.as_bytes()))?;
107        }
108
109        // bind to the specified address if provided
110        if let Some(addr) = config.bind_addr {
111            socket.bind(&addr.into())?;
112        }
113
114        Ok(Self { socket })
115    }
116
117    /// Create a socket of arbitrary type (STREAM or RAW).
118    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
119        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
120        socket.set_nonblocking(false)?;
121        Ok(Self { socket })
122    }
123
124    /// Convenience constructor for an IPv4 STREAM socket.
125    pub fn v4_stream() -> io::Result<Self> {
126        Self::new(Domain::IPV4, SockType::STREAM)
127    }
128
129    /// Convenience constructor for an IPv6 STREAM socket.
130    pub fn v6_stream() -> io::Result<Self> {
131        Self::new(Domain::IPV6, SockType::STREAM)
132    }
133
134    /// IPv4 RAW TCP. Requires administrator privileges.
135    pub fn raw_v4() -> io::Result<Self> {
136        Self::new(Domain::IPV4, SockType::RAW)
137    }
138
139    /// IPv6 RAW TCP. Requires administrator privileges.
140    pub fn raw_v6() -> io::Result<Self> {
141        Self::new(Domain::IPV6, SockType::RAW)
142    }
143
144    /// Bind the socket to a specific address.
145    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
146        self.socket.bind(&addr.into())
147    }
148
149    /// Connect to a remote address.
150    pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
151        self.socket.connect(&addr.into())
152    }
153
154    /// Connect to the target address with a timeout.
155    #[cfg(unix)]
156    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
157        let raw_fd = self.socket.as_raw_fd();
158        self.socket.set_nonblocking(true)?;
159
160        // Try to connect first
161        match self.socket.connect(&target.into()) {
162            Ok(_) => { /* succeeded immediately */ }
163            Err(err)
164                if err.kind() == io::ErrorKind::WouldBlock
165                    || err.raw_os_error() == Some(libc::EINPROGRESS) =>
166            {
167                // Continue waiting
168            }
169            Err(e) => return Err(e),
170        }
171
172        // Wait for the connection using poll
173        let timeout_ms = timeout.as_millis() as i32;
174        use std::os::unix::io::BorrowedFd;
175        // Safety: raw_fd is valid for the lifetime of this scope
176        let mut fds = [PollFd::new(
177            unsafe { BorrowedFd::borrow_raw(raw_fd) },
178            PollFlags::POLLOUT,
179        )];
180        let n = poll(&mut fds, Some(timeout_ms as u16))?;
181
182        if n == 0 {
183            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
184        }
185
186        // Check the result with `SO_ERROR`
187        let err: i32 = self
188            .socket
189            .take_error()?
190            .map(|e| e.raw_os_error().unwrap_or(0))
191            .unwrap_or(0);
192        if err != 0 {
193            return Err(io::Error::from_raw_os_error(err));
194        }
195
196        self.socket.set_nonblocking(false)?;
197
198        match self.socket.try_clone() {
199            Ok(cloned_socket) => {
200                // Convert the socket into a `std::net::TcpStream`
201                let std_stream: TcpStream = cloned_socket.into();
202                Ok(std_stream)
203            }
204            Err(e) => Err(e),
205        }
206    }
207
208    #[cfg(windows)]
209    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
210        use std::mem::size_of;
211        use std::os::windows::io::AsRawSocket;
212        use windows_sys::Win32::Networking::WinSock::{
213            POLLWRNORM, SO_ERROR, SOCKET, SOCKET_ERROR, SOL_SOCKET, WSAPOLLFD, WSAPoll, getsockopt,
214        };
215
216        let sock = self.socket.as_raw_socket() as SOCKET;
217        self.socket.set_nonblocking(true)?;
218
219        // Start connect
220        match self.socket.connect(&target.into()) {
221            Ok(_) => { /* connection succeeded immediately */ }
222            Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) /* WSAEWOULDBLOCK */ => {}
223            Err(e) => return Err(e),
224        }
225
226        // Wait using WSAPoll until writable
227        let mut fds = [WSAPOLLFD {
228            fd: sock,
229            events: POLLWRNORM,
230            revents: 0,
231        }];
232
233        let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
234        let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
235        if result == SOCKET_ERROR {
236            return Err(io::Error::last_os_error());
237        } else if result == 0 {
238            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
239        }
240
241        // Check for errors via `SO_ERROR`
242        let mut so_error: i32 = 0;
243        let mut optlen = size_of::<i32>() as i32;
244        let ret = unsafe {
245            getsockopt(
246                sock,
247                SOL_SOCKET as i32,
248                SO_ERROR as i32,
249                &mut so_error as *mut _ as *mut _,
250                &mut optlen,
251            )
252        };
253
254        if ret == SOCKET_ERROR || so_error != 0 {
255            return Err(io::Error::from_raw_os_error(so_error));
256        }
257
258        self.socket.set_nonblocking(false)?;
259
260        let std_stream: TcpStream = self.socket.try_clone()?.into();
261        Ok(std_stream)
262    }
263
264    /// Start listening for incoming connections.
265    pub fn listen(&self, backlog: i32) -> io::Result<()> {
266        self.socket.listen(backlog)
267    }
268
269    /// Accept an incoming connection.
270    pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
271        let (stream, addr) = self.socket.accept()?;
272        Ok((stream.into(), addr.as_socket().unwrap()))
273    }
274
275    /// Convert the socket into a `TcpStream`.
276    pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
277        Ok(self.socket.into())
278    }
279
280    /// Convert the socket into a `TcpListener`.
281    pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
282        Ok(self.socket.into())
283    }
284
285    /// Send a raw packet (for RAW TCP use).
286    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
287        self.socket.send_to(buf, &target.into())
288    }
289
290    /// Receive a raw packet (for RAW TCP use).
291    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
292        // Safety: `MaybeUninit<u8>` is layout-compatible with `u8`.
293        let buf_maybe = unsafe {
294            std::slice::from_raw_parts_mut(
295                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
296                buf.len(),
297            )
298        };
299
300        let (n, addr) = self.socket.recv_from(buf_maybe)?;
301        let addr = addr
302            .as_socket()
303            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
304
305        Ok((n, addr))
306    }
307
308    /// Shutdown the socket.
309    pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
310        self.socket.shutdown(how)
311    }
312
313    /// Set the socket to reuse the address.
314    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
315        self.socket.set_reuse_address(on)
316    }
317
318    /// Get the socket address reuse option.
319    pub fn reuseaddr(&self) -> io::Result<bool> {
320        self.socket.reuse_address()
321    }
322
323    /// Set the socket port reuse option where supported.
324    #[cfg(any(
325        target_os = "android",
326        target_os = "dragonfly",
327        target_os = "freebsd",
328        target_os = "fuchsia",
329        target_os = "ios",
330        target_os = "linux",
331        target_os = "macos",
332        target_os = "netbsd",
333        target_os = "openbsd",
334        target_os = "tvos",
335        target_os = "visionos",
336        target_os = "watchos"
337    ))]
338    pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
339        self.socket.set_reuse_port(on)
340    }
341
342    /// Get the socket port reuse option where supported.
343    #[cfg(any(
344        target_os = "android",
345        target_os = "dragonfly",
346        target_os = "freebsd",
347        target_os = "fuchsia",
348        target_os = "ios",
349        target_os = "linux",
350        target_os = "macos",
351        target_os = "netbsd",
352        target_os = "openbsd",
353        target_os = "tvos",
354        target_os = "visionos",
355        target_os = "watchos"
356    ))]
357    pub fn reuseport(&self) -> io::Result<bool> {
358        self.socket.reuse_port()
359    }
360
361    /// Set the socket to not delay packets.
362    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
363        self.socket.set_nodelay(on)
364    }
365
366    /// Get the no delay option.
367    pub fn nodelay(&self) -> io::Result<bool> {
368        self.socket.nodelay()
369    }
370
371    /// Set the linger option for the socket.
372    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
373        self.socket.set_linger(dur)
374    }
375
376    /// Set the time-to-live for IPv4 packets.
377    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
378        self.socket.set_ttl(ttl)
379    }
380
381    /// Get the time-to-live for IPv4 packets.
382    pub fn ttl(&self) -> io::Result<u32> {
383        self.socket.ttl()
384    }
385
386    /// Set the hop limit for IPv6 packets.
387    pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
388        self.socket.set_unicast_hops_v6(hops)
389    }
390
391    /// Get the hop limit for IPv6 packets.
392    pub fn hoplimit(&self) -> io::Result<u32> {
393        self.socket.unicast_hops_v6()
394    }
395
396    /// Set the keepalive option for the socket.
397    pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
398        self.socket.set_keepalive(on)
399    }
400
401    /// Get the keepalive option.
402    pub fn keepalive(&self) -> io::Result<bool> {
403        self.socket.keepalive()
404    }
405
406    /// Set the receive buffer size.
407    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
408        self.socket.set_recv_buffer_size(size)
409    }
410
411    /// Get the receive buffer size.
412    pub fn recv_buffer_size(&self) -> io::Result<usize> {
413        self.socket.recv_buffer_size()
414    }
415
416    /// Set the send buffer size.
417    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
418        self.socket.set_send_buffer_size(size)
419    }
420
421    /// Get the send buffer size.
422    pub fn send_buffer_size(&self) -> io::Result<usize> {
423        self.socket.send_buffer_size()
424    }
425
426    /// Set IPv4 TOS / DSCP.
427    pub fn set_tos(&self, tos: u32) -> io::Result<()> {
428        self.socket.set_tos(tos)
429    }
430
431    /// Get IPv4 TOS / DSCP.
432    pub fn tos(&self) -> io::Result<u32> {
433        self.socket.tos()
434    }
435
436    /// Set IPv6 traffic class where supported.
437    #[cfg(any(
438        target_os = "android",
439        target_os = "dragonfly",
440        target_os = "freebsd",
441        target_os = "fuchsia",
442        target_os = "ios",
443        target_os = "linux",
444        target_os = "macos",
445        target_os = "netbsd",
446        target_os = "openbsd",
447        target_os = "tvos",
448        target_os = "visionos",
449        target_os = "watchos"
450    ))]
451    pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
452        self.socket.set_tclass_v6(tclass)
453    }
454
455    /// Get IPv6 traffic class where supported.
456    #[cfg(any(
457        target_os = "android",
458        target_os = "dragonfly",
459        target_os = "freebsd",
460        target_os = "fuchsia",
461        target_os = "ios",
462        target_os = "linux",
463        target_os = "macos",
464        target_os = "netbsd",
465        target_os = "openbsd",
466        target_os = "tvos",
467        target_os = "visionos",
468        target_os = "watchos"
469    ))]
470    pub fn tclass_v6(&self) -> io::Result<u32> {
471        self.socket.tclass_v6()
472    }
473
474    /// Set whether this socket is IPv6 only.
475    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
476        self.socket.set_only_v6(only_v6)
477    }
478
479    /// Get whether this socket is IPv6 only.
480    pub fn only_v6(&self) -> io::Result<bool> {
481        self.socket.only_v6()
482    }
483
484    /// Set the bind device for the socket (Linux specific).
485    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
486        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
487        return self.socket.bind_device(Some(iface.as_bytes()));
488
489        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
490        {
491            let _ = iface;
492            Err(io::Error::new(
493                io::ErrorKind::Unsupported,
494                "bind_device is not supported on this platform",
495            ))
496        }
497    }
498
499    /// Retrieve the local address of the socket.
500    pub fn local_addr(&self) -> io::Result<SocketAddr> {
501        self.socket
502            .local_addr()?
503            .as_socket()
504            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
505    }
506
507    /// Extract the RAW file descriptor for Unix.
508    #[cfg(unix)]
509    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
510        use std::os::fd::AsRawFd;
511        self.socket.as_raw_fd()
512    }
513
514    /// Extract the RAW socket handle for Windows.
515    #[cfg(windows)]
516    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
517        use std::os::windows::io::AsRawSocket;
518        self.socket.as_raw_socket()
519    }
520
521    /// Construct from a raw `socket2::Socket`.
522    pub fn from_socket(socket: Socket) -> Self {
523        Self { socket }
524    }
525
526    /// Borrow the inner `socket2::Socket`.
527    pub fn socket(&self) -> &Socket {
528        &self.socket
529    }
530
531    /// Consume and return the inner `socket2::Socket`.
532    pub fn into_socket(self) -> Socket {
533        self.socket
534    }
535}