nex_socket/tcp/
sync_impl.rs

1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{poll, PollFd, PollFlags};
13
14/// Low level synchronous TCP socket.
15#[derive(Debug)]
16pub struct TcpSocket {
17    socket: Socket,
18}
19
20impl TcpSocket {
21    /// Build a socket according to `TcpSocketConfig`.
22    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
23        let socket = Socket::new(
24            config.socket_family.to_domain(),
25            config.socket_type.to_sock_type(),
26            Some(Protocol::TCP),
27        )?;
28
29        socket.set_nonblocking(config.nonblocking)?;
30
31        // Set socket options based on configuration
32        if let Some(flag) = config.reuseaddr {
33            socket.set_reuse_address(flag)?;
34        }
35        if let Some(flag) = config.nodelay {
36            socket.set_nodelay(flag)?;
37        }
38        if let Some(dur) = config.linger {
39            socket.set_linger(Some(dur))?;
40        }
41        if let Some(ttl) = config.ttl {
42            socket.set_ttl(ttl)?;
43        }
44        if let Some(hoplimit) = config.hoplimit {
45            socket.set_unicast_hops_v6(hoplimit)?;
46        }
47        if let Some(keepalive) = config.keepalive {
48            socket.set_keepalive(keepalive)?;
49        }
50        if let Some(timeout) = config.read_timeout {
51            socket.set_read_timeout(Some(timeout))?;
52        }
53        if let Some(timeout) = config.write_timeout {
54            socket.set_write_timeout(Some(timeout))?;
55        }
56
57        // Linux: optional interface name
58        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
59        if let Some(iface) = &config.bind_device {
60            socket.bind_device(Some(iface.as_bytes()))?;
61        }
62
63        // bind to the specified address if provided
64        if let Some(addr) = config.bind_addr {
65            socket.bind(&addr.into())?;
66        }
67
68        Ok(Self { socket })
69    }
70
71    /// Create a socket of arbitrary type (STREAM or RAW).
72    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
73        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
74        socket.set_nonblocking(false)?;
75        Ok(Self { socket })
76    }
77
78    /// Convenience constructor for an IPv4 STREAM socket.
79    pub fn v4_stream() -> io::Result<Self> {
80        Self::new(Domain::IPV4, SockType::STREAM)
81    }
82
83    /// Convenience constructor for an IPv6 STREAM socket.
84    pub fn v6_stream() -> io::Result<Self> {
85        Self::new(Domain::IPV6, SockType::STREAM)
86    }
87
88    /// IPv4 RAW TCP. Requires administrator privileges.
89    pub fn raw_v4() -> io::Result<Self> {
90        Self::new(Domain::IPV4, SockType::RAW)
91    }
92
93    /// IPv6 RAW TCP. Requires administrator privileges.
94    pub fn raw_v6() -> io::Result<Self> {
95        Self::new(Domain::IPV6, SockType::RAW)
96    }
97
98    /// Bind the socket to a specific address.
99    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
100        self.socket.bind(&addr.into())
101    }
102
103    /// Connect to a remote address.
104    pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
105        self.socket.connect(&addr.into())
106    }
107
108    /// Connect to the target address with a timeout.
109    #[cfg(unix)]
110    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
111        let raw_fd = self.socket.as_raw_fd();
112        self.socket.set_nonblocking(true)?;
113
114        // Try to connect first
115        match self.socket.connect(&target.into()) {
116            Ok(_) => { /* succeeded immediately */ }
117            Err(err)
118                if err.kind() == io::ErrorKind::WouldBlock
119                    || err.raw_os_error() == Some(libc::EINPROGRESS) =>
120            {
121                // Continue waiting
122            }
123            Err(e) => return Err(e),
124        }
125
126        // Wait for the connection using poll
127        let timeout_ms = timeout.as_millis() as i32;
128        use std::os::unix::io::BorrowedFd;
129        // Safety: raw_fd is valid for the lifetime of this scope
130        let mut fds = [PollFd::new(
131            unsafe { BorrowedFd::borrow_raw(raw_fd) },
132            PollFlags::POLLOUT,
133        )];
134        let n = poll(&mut fds, Some(timeout_ms as u16))?;
135
136        if n == 0 {
137            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
138        }
139
140        // Check the result with `SO_ERROR`
141        let err: i32 = self
142            .socket
143            .take_error()?
144            .map(|e| e.raw_os_error().unwrap_or(0))
145            .unwrap_or(0);
146        if err != 0 {
147            return Err(io::Error::from_raw_os_error(err));
148        }
149
150        self.socket.set_nonblocking(false)?;
151
152        match self.socket.try_clone() {
153            Ok(cloned_socket) => {
154                // Convert the socket into a `std::net::TcpStream`
155                let std_stream: TcpStream = cloned_socket.into();
156                Ok(std_stream)
157            }
158            Err(e) => Err(e),
159        }
160    }
161
162    #[cfg(windows)]
163    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
164        use std::mem::size_of;
165        use std::os::windows::io::AsRawSocket;
166        use windows_sys::Win32::Networking::WinSock::{
167            getsockopt, WSAPoll, POLLWRNORM, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, WSAPOLLFD,
168        };
169
170        let sock = self.socket.as_raw_socket() as SOCKET;
171        self.socket.set_nonblocking(true)?;
172
173        // Start connect
174        match self.socket.connect(&target.into()) {
175            Ok(_) => { /* connection succeeded immediately */ }
176            Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) /* WSAEWOULDBLOCK */ => {}
177            Err(e) => return Err(e),
178        }
179
180        // Wait using WSAPoll until writable
181        let mut fds = [WSAPOLLFD {
182            fd: sock,
183            events: POLLWRNORM,
184            revents: 0,
185        }];
186
187        let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
188        let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
189        if result == SOCKET_ERROR {
190            return Err(io::Error::last_os_error());
191        } else if result == 0 {
192            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
193        }
194
195        // Check for errors via `SO_ERROR`
196        let mut so_error: i32 = 0;
197        let mut optlen = size_of::<i32>() as i32;
198        let ret = unsafe {
199            getsockopt(
200                sock,
201                SOL_SOCKET as i32,
202                SO_ERROR as i32,
203                &mut so_error as *mut _ as *mut _,
204                &mut optlen,
205            )
206        };
207
208        if ret == SOCKET_ERROR || so_error != 0 {
209            return Err(io::Error::from_raw_os_error(so_error));
210        }
211
212        self.socket.set_nonblocking(false)?;
213
214        let std_stream: TcpStream = self.socket.try_clone()?.into();
215        Ok(std_stream)
216    }
217
218    /// Start listening for incoming connections.
219    pub fn listen(&self, backlog: i32) -> io::Result<()> {
220        self.socket.listen(backlog)
221    }
222
223    /// Accept an incoming connection.
224    pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
225        let (stream, addr) = self.socket.accept()?;
226        Ok((stream.into(), addr.as_socket().unwrap()))
227    }
228
229    /// Convert the socket into a `TcpStream`.
230    pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
231        Ok(self.socket.into())
232    }
233
234    /// Convert the socket into a `TcpListener`.
235    pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
236        Ok(self.socket.into())
237    }
238
239    /// Send a raw packet (for RAW TCP use).
240    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
241        self.socket.send_to(buf, &target.into())
242    }
243
244    /// Receive a raw packet (for RAW TCP use).
245    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
246        // Safety: `MaybeUninit<u8>` is layout-compatible with `u8`.
247        let buf_maybe = unsafe {
248            std::slice::from_raw_parts_mut(
249                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
250                buf.len(),
251            )
252        };
253
254        let (n, addr) = self.socket.recv_from(buf_maybe)?;
255        let addr = addr
256            .as_socket()
257            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
258
259        Ok((n, addr))
260    }
261
262    /// Shutdown the socket.
263    pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
264        self.socket.shutdown(how)
265    }
266
267    /// Set the socket to reuse the address.
268    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
269        self.socket.set_reuse_address(on)
270    }
271
272    /// Set the socket to not delay packets.
273    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
274        self.socket.set_nodelay(on)
275    }
276
277    /// Set the linger option for the socket.
278    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
279        self.socket.set_linger(dur)
280    }
281
282    /// Set the time-to-live for IPv4 packets.
283    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
284        self.socket.set_ttl(ttl)
285    }
286
287    /// Set the hop limit for IPv6 packets.
288    pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
289        self.socket.set_unicast_hops_v6(hops)
290    }
291
292    /// Set the keepalive option for the socket.
293    pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
294        self.socket.set_keepalive(on)
295    }
296
297    /// Set the bind device for the socket (Linux specific).
298    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
299        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
300        return self.socket.bind_device(Some(iface.as_bytes()));
301
302        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
303        {
304            let _ = iface;
305            Err(io::Error::new(
306                io::ErrorKind::Unsupported,
307                "bind_device not supported on this OS",
308            ))
309        }
310    }
311
312    /// Retrieve the local address of the socket.
313    pub fn local_addr(&self) -> io::Result<SocketAddr> {
314        self.socket
315            .local_addr()?
316            .as_socket()
317            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to retrieve local address"))
318    }
319
320    /// Extract the RAW file descriptor for Unix.
321    #[cfg(unix)]
322    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
323        use std::os::fd::AsRawFd;
324        self.socket.as_raw_fd()
325    }
326
327    /// Extract the RAW socket handle for Windows.
328    #[cfg(windows)]
329    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
330        use std::os::windows::io::AsRawSocket;
331        self.socket.as_raw_socket()
332    }
333}