nex_socket/tcp/
sync_impl.rs

1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpStream, TcpListener};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{poll, PollFd, PollFlags};
13
14/// Low level synchronous TCP socket.
15#[derive(Debug)]
16pub struct TcpSocket {
17    socket: Socket,
18}
19
20impl TcpSocket {
21    /// Build a socket according to `TcpSocketConfig`.
22    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
23        let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?;
24
25        // Apply all configuration options
26        if let Some(flag) = config.reuseaddr {
27            socket.set_reuse_address(flag)?;
28        }
29        if let Some(flag) = config.nodelay {
30            socket.set_nodelay(flag)?;
31        }
32        if let Some(dur) = config.linger {
33            socket.set_linger(Some(dur))?;
34        }
35        if let Some(ttl) = config.ttl {
36            socket.set_ttl(ttl)?;
37        }
38
39        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
40        if let Some(iface) = &config.bind_device {
41            socket.bind_device(Some(iface.as_bytes()))?;
42        }
43
44        // Bind to the specified address if provided
45        if let Some(addr) = config.bind_addr {
46            socket.bind(&addr.into())?;
47        }
48
49        // Set non blocking mode
50        socket.set_nonblocking(config.nonblocking)?;
51
52        Ok(Self { socket })
53    }
54
55    /// Create a socket of arbitrary type (STREAM or RAW).
56    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
57        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
58        socket.set_nonblocking(false)?;
59        Ok(Self { socket })
60    }
61
62    /// Convenience constructor for an IPv4 STREAM socket.
63    pub fn v4_stream() -> io::Result<Self> {
64        Self::new(Domain::IPV4, SockType::STREAM)
65    }
66
67    /// Convenience constructor for an IPv6 STREAM socket.
68    pub fn v6_stream() -> io::Result<Self> {
69        Self::new(Domain::IPV6, SockType::STREAM)
70    }
71
72    /// IPv4 RAW TCP. Requires administrator privileges.
73    pub fn raw_v4() -> io::Result<Self> {
74        Self::new(Domain::IPV4, SockType::RAW)
75    }
76
77    /// IPv6 RAW TCP. Requires administrator privileges.
78    pub fn raw_v6() -> io::Result<Self> {
79        Self::new(Domain::IPV6, SockType::RAW)
80    }
81
82    // --- socket operations ---
83
84    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
85        self.socket.bind(&addr.into())
86    }
87
88    pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
89        self.socket.connect(&addr.into())
90    }
91
92    #[cfg(unix)]
93    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
94        let raw_fd = self.socket.as_raw_fd();
95        self.socket.set_nonblocking(true)?;
96
97        // Try to connect first
98        match self.socket.connect(&target.into()) {
99            Ok(_) => { /* succeeded immediately */ }
100            Err(err) if err.kind() == io::ErrorKind::WouldBlock || err.raw_os_error() == Some(libc::EINPROGRESS) => {
101                // Continue waiting
102            }
103            Err(e) => return Err(e),
104        }
105
106        // Wait for the connection using poll
107        let timeout_ms = timeout.as_millis() as i32;
108        use std::os::unix::io::BorrowedFd;
109        // Safety: raw_fd is valid for the lifetime of this scope
110        let mut fds = [PollFd::new(unsafe { BorrowedFd::borrow_raw(raw_fd) }, PollFlags::POLLOUT)];
111        let n = poll(&mut fds, Some(timeout_ms as u16))?;
112
113        if n == 0 {
114            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
115        }
116
117        // Check the result with `SO_ERROR`
118        let err: i32 = self.socket.take_error()?.map(|e| e.raw_os_error().unwrap_or(0)).unwrap_or(0);
119        if err != 0 {
120            return Err(io::Error::from_raw_os_error(err));
121        }
122
123        self.socket.set_nonblocking(false)?;
124
125        match self.socket.try_clone() {
126            Ok(cloned_socket) => {
127                // Convert the socket into a `std::net::TcpStream`
128                let std_stream: TcpStream = cloned_socket.into();
129                Ok(std_stream)
130            }
131            Err(e) => Err(e),
132        }
133    }
134
135    #[cfg(windows)]
136    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
137        use std::os::windows::io::AsRawSocket;
138        use windows_sys::Win32::Networking::WinSock::{
139            WSAPOLLFD, WSAPoll, POLLWRNORM, SOCKET_ERROR, SO_ERROR, SOL_SOCKET,
140            getsockopt, SOCKET,
141        };
142        use std::mem::size_of;
143
144        let sock = self.socket.as_raw_socket() as SOCKET;
145        self.socket.set_nonblocking(true)?;
146
147        // Start connect
148        match self.socket.connect(&target.into()) {
149            Ok(_) => { /* connection succeeded immediately */ }
150            Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) /* WSAEWOULDBLOCK */ => {}
151            Err(e) => return Err(e),
152        }
153
154        // Wait using WSAPoll until writable
155        let mut fds = [WSAPOLLFD {
156            fd: sock,
157            events: POLLWRNORM,
158            revents: 0,
159        }];
160
161        let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
162        let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
163        if result == SOCKET_ERROR {
164            return Err(io::Error::last_os_error());
165        } else if result == 0 {
166            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
167        }
168
169        // Check for errors via `SO_ERROR`
170        let mut so_error: i32 = 0;
171        let mut optlen = size_of::<i32>() as i32;
172        let ret = unsafe {
173            getsockopt(
174                sock,
175                SOL_SOCKET as i32,
176                SO_ERROR as i32,
177                &mut so_error as *mut _ as *mut _,
178                &mut optlen,
179            )
180        };
181
182        if ret == SOCKET_ERROR || so_error != 0 {
183            return Err(io::Error::from_raw_os_error(so_error));
184        }
185
186        self.socket.set_nonblocking(false)?;
187
188        let std_stream: TcpStream = self.socket.try_clone()?.into();
189        Ok(std_stream)
190    }
191
192    pub fn listen(&self, backlog: i32) -> io::Result<()> {
193        self.socket.listen(backlog)
194    }
195
196    pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
197        let (stream, addr) = self.socket.accept()?;
198        Ok((stream.into(), addr.as_socket().unwrap()))
199    }
200
201    pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
202        Ok(self.socket.into())
203    }
204
205    pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
206        Ok(self.socket.into())
207    }
208
209    /// Send a raw packet (for RAW TCP use).
210    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
211        self.socket.send_to(buf, &target.into())
212    }
213
214    /// Receive a raw packet (for RAW TCP use).
215    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
216        // Safety: `MaybeUninit<u8>` is layout-compatible with `u8`.
217        let buf_maybe = unsafe {
218            std::slice::from_raw_parts_mut(
219                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
220                buf.len(),
221            )
222        };
223
224        let (n, addr) = self.socket.recv_from(buf_maybe)?;
225        let addr = addr.as_socket().ok_or_else(|| {
226            io::Error::new(io::ErrorKind::InvalidData, "invalid address format")
227        })?;
228
229        Ok((n, addr))
230    }
231
232    // --- option helpers ---
233
234    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
235        self.socket.set_reuse_address(on)
236    }
237
238    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
239        self.socket.set_nodelay(on)
240    }
241
242    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
243        self.socket.set_linger(dur)
244    }
245
246    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
247        self.socket.set_ttl(ttl)
248    }
249
250    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
251        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
252        return self.socket.bind_device(Some(iface.as_bytes()));
253
254        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
255        {
256            let _ = iface;
257            Err(io::Error::new(io::ErrorKind::Unsupported, "bind_device not supported on this OS"))
258        }
259    }
260
261    // --- information helpers ---
262
263    pub fn local_addr(&self) -> io::Result<SocketAddr> {
264        self.socket.local_addr()?.as_socket().ok_or_else(|| {
265            io::Error::new(io::ErrorKind::Other, "Failed to retrieve local address")
266        })
267    }
268
269    #[cfg(unix)]
270    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
271        use std::os::fd::AsRawFd;
272        self.socket.as_raw_fd()
273    }
274
275    #[cfg(windows)]
276    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
277        use std::os::windows::io::AsRawSocket;
278        self.socket.as_raw_socket()
279    }
280}