Skip to main content

nex_socket/tcp/
sync_impl.rs

1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{PollFd, PollFlags, PollTimeout, poll};
13
14/// Low level synchronous TCP socket.
15#[derive(Debug)]
16pub struct TcpSocket {
17    socket: Socket,
18    nonblocking: bool,
19}
20
21impl TcpSocket {
22    /// Build a socket according to `TcpSocketConfig`.
23    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
24        config.validate()?;
25
26        let socket = Socket::new(
27            config.socket_family.to_domain(),
28            config.socket_type.to_sock_type(),
29            Some(Protocol::TCP),
30        )?;
31
32        socket.set_nonblocking(config.nonblocking)?;
33
34        // Set socket options based on configuration
35        if let Some(flag) = config.reuseaddr {
36            socket.set_reuse_address(flag)?;
37        }
38        #[cfg(any(
39            target_os = "android",
40            target_os = "dragonfly",
41            target_os = "freebsd",
42            target_os = "fuchsia",
43            target_os = "ios",
44            target_os = "linux",
45            target_os = "macos",
46            target_os = "netbsd",
47            target_os = "openbsd",
48            target_os = "tvos",
49            target_os = "visionos",
50            target_os = "watchos"
51        ))]
52        if let Some(flag) = config.reuseport {
53            socket.set_reuse_port(flag)?;
54        }
55        if let Some(flag) = config.nodelay {
56            socket.set_nodelay(flag)?;
57        }
58        if let Some(dur) = config.linger {
59            socket.set_linger(Some(dur))?;
60        }
61        if let Some(ttl) = config.ttl {
62            socket.set_ttl(ttl)?;
63        }
64        if let Some(hoplimit) = config.hoplimit {
65            socket.set_unicast_hops_v6(hoplimit)?;
66        }
67        if let Some(keepalive) = config.keepalive {
68            socket.set_keepalive(keepalive)?;
69        }
70        if let Some(timeout) = config.read_timeout {
71            socket.set_read_timeout(Some(timeout))?;
72        }
73        if let Some(timeout) = config.write_timeout {
74            socket.set_write_timeout(Some(timeout))?;
75        }
76        if let Some(size) = config.recv_buffer_size {
77            socket.set_recv_buffer_size(size)?;
78        }
79        if let Some(size) = config.send_buffer_size {
80            socket.set_send_buffer_size(size)?;
81        }
82        if let Some(tos) = config.tos {
83            socket.set_tos(tos)?;
84        }
85        #[cfg(any(
86            target_os = "android",
87            target_os = "dragonfly",
88            target_os = "freebsd",
89            target_os = "fuchsia",
90            target_os = "ios",
91            target_os = "linux",
92            target_os = "macos",
93            target_os = "netbsd",
94            target_os = "openbsd",
95            target_os = "tvos",
96            target_os = "visionos",
97            target_os = "watchos"
98        ))]
99        if let Some(tclass) = config.tclass_v6 {
100            socket.set_tclass_v6(tclass)?;
101        }
102        if let Some(only_v6) = config.only_v6 {
103            socket.set_only_v6(only_v6)?;
104        }
105
106        // Linux: optional interface name
107        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
108        if let Some(iface) = &config.bind_device {
109            socket.bind_device(Some(iface.as_bytes()))?;
110        }
111
112        // bind to the specified address if provided
113        if let Some(addr) = config.bind_addr {
114            socket.bind(&addr.into())?;
115        }
116
117        Ok(Self {
118            socket,
119            nonblocking: config.nonblocking,
120        })
121    }
122
123    /// Create a socket of arbitrary type (STREAM or RAW).
124    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
125        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
126        socket.set_nonblocking(false)?;
127        Ok(Self {
128            socket,
129            nonblocking: false,
130        })
131    }
132
133    /// Convenience constructor for an IPv4 STREAM socket.
134    pub fn v4_stream() -> io::Result<Self> {
135        Self::new(Domain::IPV4, SockType::STREAM)
136    }
137
138    /// Convenience constructor for an IPv6 STREAM socket.
139    pub fn v6_stream() -> io::Result<Self> {
140        Self::new(Domain::IPV6, SockType::STREAM)
141    }
142
143    /// IPv4 RAW TCP. Requires administrator privileges.
144    pub fn raw_v4() -> io::Result<Self> {
145        Self::new(Domain::IPV4, SockType::RAW)
146    }
147
148    /// IPv6 RAW TCP. Requires administrator privileges.
149    pub fn raw_v6() -> io::Result<Self> {
150        Self::new(Domain::IPV6, SockType::RAW)
151    }
152
153    /// Bind the socket to a specific address.
154    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
155        self.socket.bind(&addr.into())
156    }
157
158    /// Connect to a remote address.
159    pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
160        self.socket.connect(&addr.into())
161    }
162
163    /// Connect to the target address with a timeout and return the connected stream.
164    ///
165    /// The returned `TcpStream` must be used for subsequent I/O.
166    #[cfg(unix)]
167    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
168        let socket = self.socket.try_clone()?;
169        socket.set_nonblocking(true)?;
170        let raw_fd = socket.as_raw_fd();
171
172        // Try to connect first
173        match socket.connect(&target.into()) {
174            Ok(_) => { /* succeeded immediately */ }
175            Err(err)
176                if err.kind() == io::ErrorKind::WouldBlock
177                    || err.raw_os_error() == Some(libc::EINPROGRESS) =>
178            {
179                // Continue waiting
180            }
181            Err(e) => return Err(e),
182        }
183
184        // Wait for the connection using poll
185        use std::os::unix::io::BorrowedFd;
186        // Safety: raw_fd is valid for the lifetime of this scope
187        let mut fds = [PollFd::new(
188            unsafe { BorrowedFd::borrow_raw(raw_fd) },
189            PollFlags::POLLOUT,
190        )];
191        let poll_timeout = PollTimeout::try_from(timeout).unwrap_or(PollTimeout::MAX);
192        let n = poll(&mut fds, poll_timeout)?;
193
194        if n == 0 {
195            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
196        }
197
198        // Check the result with `SO_ERROR`
199        let err: i32 = socket
200            .take_error()?
201            .map(|e| e.raw_os_error().unwrap_or(0))
202            .unwrap_or(0);
203        if err != 0 {
204            return Err(io::Error::from_raw_os_error(err));
205        }
206
207        socket.set_nonblocking(self.nonblocking)?;
208
209        match socket.try_clone() {
210            Ok(cloned_socket) => {
211                // Convert the socket into a `std::net::TcpStream`
212                let std_stream: TcpStream = cloned_socket.into();
213                Ok(std_stream)
214            }
215            Err(e) => Err(e),
216        }
217    }
218
219    /// Connect to the target address with a timeout and return the connected stream.
220    ///
221    /// The returned `TcpStream` must be used for subsequent I/O.
222    #[cfg(windows)]
223    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
224        use std::mem::size_of;
225        use std::os::windows::io::AsRawSocket;
226        use windows_sys::Win32::Networking::WinSock::{
227            POLLWRNORM, SO_ERROR, SOCKET, SOCKET_ERROR, SOL_SOCKET, WSAPOLLFD, WSAPoll, getsockopt,
228        };
229
230        let socket = self.socket.try_clone()?;
231        socket.set_nonblocking(true)?;
232        let sock = socket.as_raw_socket() as SOCKET;
233
234        // Start connect
235        match socket.connect(&target.into()) {
236            Ok(_) => { /* connection succeeded immediately */ }
237            Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) /* WSAEWOULDBLOCK */ => {}
238            Err(e) => return Err(e),
239        }
240
241        // Wait using WSAPoll until writable
242        let mut fds = [WSAPOLLFD {
243            fd: sock,
244            events: POLLWRNORM,
245            revents: 0,
246        }];
247
248        let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
249        let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
250        if result == SOCKET_ERROR {
251            return Err(io::Error::last_os_error());
252        } else if result == 0 {
253            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
254        }
255
256        // Check for errors via `SO_ERROR`
257        let mut so_error: i32 = 0;
258        let mut optlen = size_of::<i32>() as i32;
259        let ret = unsafe {
260            getsockopt(
261                sock,
262                SOL_SOCKET as i32,
263                SO_ERROR as i32,
264                &mut so_error as *mut _ as *mut _,
265                &mut optlen,
266            )
267        };
268
269        if ret == SOCKET_ERROR || so_error != 0 {
270            return Err(io::Error::from_raw_os_error(so_error));
271        }
272
273        socket.set_nonblocking(self.nonblocking)?;
274
275        let std_stream: TcpStream = socket.into();
276        Ok(std_stream)
277    }
278
279    /// Start listening for incoming connections.
280    pub fn listen(&self, backlog: i32) -> io::Result<()> {
281        self.socket.listen(backlog)
282    }
283
284    /// Accept an incoming connection.
285    pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
286        let (stream, addr) = self.socket.accept()?;
287        Ok((stream.into(), addr.as_socket().unwrap()))
288    }
289
290    /// Convert the socket into a `TcpStream`.
291    pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
292        Ok(self.socket.into())
293    }
294
295    /// Convert the socket into a `TcpListener`.
296    pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
297        Ok(self.socket.into())
298    }
299
300    /// Send a raw packet (for RAW TCP use).
301    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
302        self.socket.send_to(buf, &target.into())
303    }
304
305    /// Receive a raw packet (for RAW TCP use).
306    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
307        // Safety: `MaybeUninit<u8>` is layout-compatible with `u8`.
308        let buf_maybe = unsafe {
309            std::slice::from_raw_parts_mut(
310                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
311                buf.len(),
312            )
313        };
314
315        let (n, addr) = self.socket.recv_from(buf_maybe)?;
316        let addr = addr
317            .as_socket()
318            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
319
320        Ok((n, addr))
321    }
322
323    /// Shutdown the socket.
324    pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
325        self.socket.shutdown(how)
326    }
327
328    /// Set the socket to reuse the address.
329    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
330        self.socket.set_reuse_address(on)
331    }
332
333    /// Get the socket address reuse option.
334    pub fn reuseaddr(&self) -> io::Result<bool> {
335        self.socket.reuse_address()
336    }
337
338    /// Set the socket port reuse option where supported.
339    #[cfg(any(
340        target_os = "android",
341        target_os = "dragonfly",
342        target_os = "freebsd",
343        target_os = "fuchsia",
344        target_os = "ios",
345        target_os = "linux",
346        target_os = "macos",
347        target_os = "netbsd",
348        target_os = "openbsd",
349        target_os = "tvos",
350        target_os = "visionos",
351        target_os = "watchos"
352    ))]
353    pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
354        self.socket.set_reuse_port(on)
355    }
356
357    /// Get the socket port reuse option where supported.
358    #[cfg(any(
359        target_os = "android",
360        target_os = "dragonfly",
361        target_os = "freebsd",
362        target_os = "fuchsia",
363        target_os = "ios",
364        target_os = "linux",
365        target_os = "macos",
366        target_os = "netbsd",
367        target_os = "openbsd",
368        target_os = "tvos",
369        target_os = "visionos",
370        target_os = "watchos"
371    ))]
372    pub fn reuseport(&self) -> io::Result<bool> {
373        self.socket.reuse_port()
374    }
375
376    /// Set the socket to not delay packets.
377    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
378        self.socket.set_nodelay(on)
379    }
380
381    /// Get the no delay option.
382    pub fn nodelay(&self) -> io::Result<bool> {
383        self.socket.nodelay()
384    }
385
386    /// Set the linger option for the socket.
387    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
388        self.socket.set_linger(dur)
389    }
390
391    /// Set the time-to-live for IPv4 packets.
392    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
393        self.socket.set_ttl(ttl)
394    }
395
396    /// Get the time-to-live for IPv4 packets.
397    pub fn ttl(&self) -> io::Result<u32> {
398        self.socket.ttl()
399    }
400
401    /// Set the hop limit for IPv6 packets.
402    pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
403        self.socket.set_unicast_hops_v6(hops)
404    }
405
406    /// Get the hop limit for IPv6 packets.
407    pub fn hoplimit(&self) -> io::Result<u32> {
408        self.socket.unicast_hops_v6()
409    }
410
411    /// Set the keepalive option for the socket.
412    pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
413        self.socket.set_keepalive(on)
414    }
415
416    /// Get the keepalive option.
417    pub fn keepalive(&self) -> io::Result<bool> {
418        self.socket.keepalive()
419    }
420
421    /// Set the receive buffer size.
422    pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
423        self.socket.set_recv_buffer_size(size)
424    }
425
426    /// Get the receive buffer size.
427    pub fn recv_buffer_size(&self) -> io::Result<usize> {
428        self.socket.recv_buffer_size()
429    }
430
431    /// Set the send buffer size.
432    pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
433        self.socket.set_send_buffer_size(size)
434    }
435
436    /// Get the send buffer size.
437    pub fn send_buffer_size(&self) -> io::Result<usize> {
438        self.socket.send_buffer_size()
439    }
440
441    /// Set IPv4 TOS / DSCP.
442    pub fn set_tos(&self, tos: u32) -> io::Result<()> {
443        self.socket.set_tos(tos)
444    }
445
446    /// Get IPv4 TOS / DSCP.
447    pub fn tos(&self) -> io::Result<u32> {
448        self.socket.tos()
449    }
450
451    /// Set IPv6 traffic class where supported.
452    #[cfg(any(
453        target_os = "android",
454        target_os = "dragonfly",
455        target_os = "freebsd",
456        target_os = "fuchsia",
457        target_os = "ios",
458        target_os = "linux",
459        target_os = "macos",
460        target_os = "netbsd",
461        target_os = "openbsd",
462        target_os = "tvos",
463        target_os = "visionos",
464        target_os = "watchos"
465    ))]
466    pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
467        self.socket.set_tclass_v6(tclass)
468    }
469
470    /// Get IPv6 traffic class where supported.
471    #[cfg(any(
472        target_os = "android",
473        target_os = "dragonfly",
474        target_os = "freebsd",
475        target_os = "fuchsia",
476        target_os = "ios",
477        target_os = "linux",
478        target_os = "macos",
479        target_os = "netbsd",
480        target_os = "openbsd",
481        target_os = "tvos",
482        target_os = "visionos",
483        target_os = "watchos"
484    ))]
485    pub fn tclass_v6(&self) -> io::Result<u32> {
486        self.socket.tclass_v6()
487    }
488
489    /// Set whether this socket is IPv6 only.
490    pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
491        self.socket.set_only_v6(only_v6)
492    }
493
494    /// Get whether this socket is IPv6 only.
495    pub fn only_v6(&self) -> io::Result<bool> {
496        self.socket.only_v6()
497    }
498
499    /// Set the bind device for the socket (Linux specific).
500    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
501        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
502        return self.socket.bind_device(Some(iface.as_bytes()));
503
504        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
505        {
506            let _ = iface;
507            Err(io::Error::new(
508                io::ErrorKind::Unsupported,
509                "bind_device is not supported on this platform",
510            ))
511        }
512    }
513
514    /// Retrieve the local address of the socket.
515    pub fn local_addr(&self) -> io::Result<SocketAddr> {
516        self.socket
517            .local_addr()?
518            .as_socket()
519            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
520    }
521
522    /// Extract the RAW file descriptor for Unix.
523    #[cfg(unix)]
524    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
525        use std::os::fd::AsRawFd;
526        self.socket.as_raw_fd()
527    }
528
529    /// Extract the RAW socket handle for Windows.
530    #[cfg(windows)]
531    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
532        use std::os::windows::io::AsRawSocket;
533        self.socket.as_raw_socket()
534    }
535
536    /// Construct from a raw `socket2::Socket`.
537    pub fn from_socket(socket: Socket) -> Self {
538        Self {
539            socket,
540            // `socket2::Socket` does not expose a portable getter for the current
541            // blocking mode, so externally supplied sockets default to blocking
542            // expectations in this synchronous wrapper.
543            nonblocking: false,
544        }
545    }
546
547    /// Borrow the inner `socket2::Socket`.
548    pub fn socket(&self) -> &Socket {
549        &self.socket
550    }
551
552    /// Consume and return the inner `socket2::Socket`.
553    pub fn into_socket(self) -> Socket {
554        self.socket
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    #[cfg(unix)]
561    use super::*;
562    #[cfg(unix)]
563    use libc::{F_GETFL, O_NONBLOCK, fcntl};
564    #[cfg(unix)]
565    use std::net::TcpListener as StdTcpListener;
566
567    #[cfg(unix)]
568    fn socket_is_nonblocking(socket: &Socket) -> bool {
569        let flags = unsafe { fcntl(socket.as_raw_fd(), F_GETFL) };
570        assert!(flags >= 0, "F_GETFL failed: {}", io::Error::last_os_error());
571        (flags & O_NONBLOCK) != 0
572    }
573
574    #[cfg(unix)]
575    #[test]
576    fn connect_timeout_does_not_mutate_original_nonblocking_state_after_invalid_input() {
577        let sock = TcpSocket::v4_stream().expect("socket");
578        sock.socket.set_nonblocking(true).expect("set nonblocking");
579
580        let result = sock.connect_timeout("[::1]:80".parse().unwrap(), Duration::from_secs(1));
581        assert!(result.is_err());
582        assert!(socket_is_nonblocking(&sock.socket));
583    }
584
585    #[cfg(unix)]
586    #[test]
587    fn connect_timeout_does_not_mutate_original_blocking_state_after_success() {
588        let listener = StdTcpListener::bind("127.0.0.1:0").expect("listener");
589        let addr = listener.local_addr().expect("local addr");
590        let handle = std::thread::spawn(move || listener.accept().expect("accept"));
591
592        let sock = TcpSocket::v4_stream().expect("socket");
593        sock.socket.set_nonblocking(false).expect("set blocking");
594        let _stream = sock
595            .connect_timeout(addr, Duration::from_secs(1))
596            .expect("connect");
597
598        assert!(!socket_is_nonblocking(&sock.socket));
599        let _ = handle.join();
600    }
601}