nex_socket/tcp/
sync_impl.rs

1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{poll, PollFd, PollFlags};
13
14/// Low level synchronous TCP socket.
15#[derive(Debug)]
16pub struct TcpSocket {
17    socket: Socket,
18}
19
20impl TcpSocket {
21    /// Build a socket according to `TcpSocketConfig`.
22    pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
23        let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?;
24
25        // Apply all configuration options
26        if let Some(flag) = config.reuseaddr {
27            socket.set_reuse_address(flag)?;
28        }
29        if let Some(flag) = config.nodelay {
30            socket.set_nodelay(flag)?;
31        }
32        if let Some(dur) = config.linger {
33            socket.set_linger(Some(dur))?;
34        }
35        if let Some(ttl) = config.ttl {
36            socket.set_ttl(ttl)?;
37        }
38
39        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
40        if let Some(iface) = &config.bind_device {
41            socket.bind_device(Some(iface.as_bytes()))?;
42        }
43
44        // Bind to the specified address if provided
45        if let Some(addr) = config.bind_addr {
46            socket.bind(&addr.into())?;
47        }
48
49        // Set non blocking mode
50        socket.set_nonblocking(config.nonblocking)?;
51
52        Ok(Self { socket })
53    }
54
55    /// Create a socket of arbitrary type (STREAM or RAW).
56    pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
57        let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
58        socket.set_nonblocking(false)?;
59        Ok(Self { socket })
60    }
61
62    /// Convenience constructor for an IPv4 STREAM socket.
63    pub fn v4_stream() -> io::Result<Self> {
64        Self::new(Domain::IPV4, SockType::STREAM)
65    }
66
67    /// Convenience constructor for an IPv6 STREAM socket.
68    pub fn v6_stream() -> io::Result<Self> {
69        Self::new(Domain::IPV6, SockType::STREAM)
70    }
71
72    /// IPv4 RAW TCP. Requires administrator privileges.
73    pub fn raw_v4() -> io::Result<Self> {
74        Self::new(Domain::IPV4, SockType::RAW)
75    }
76
77    /// IPv6 RAW TCP. Requires administrator privileges.
78    pub fn raw_v6() -> io::Result<Self> {
79        Self::new(Domain::IPV6, SockType::RAW)
80    }
81
82    // --- socket operations ---
83
84    pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
85        self.socket.bind(&addr.into())
86    }
87
88    pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
89        self.socket.connect(&addr.into())
90    }
91
92    #[cfg(unix)]
93    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
94        let raw_fd = self.socket.as_raw_fd();
95        self.socket.set_nonblocking(true)?;
96
97        // Try to connect first
98        match self.socket.connect(&target.into()) {
99            Ok(_) => { /* succeeded immediately */ }
100            Err(err)
101                if err.kind() == io::ErrorKind::WouldBlock
102                    || err.raw_os_error() == Some(libc::EINPROGRESS) =>
103            {
104                // Continue waiting
105            }
106            Err(e) => return Err(e),
107        }
108
109        // Wait for the connection using poll
110        let timeout_ms = timeout.as_millis() as i32;
111        use std::os::unix::io::BorrowedFd;
112        // Safety: raw_fd is valid for the lifetime of this scope
113        let mut fds = [PollFd::new(
114            unsafe { BorrowedFd::borrow_raw(raw_fd) },
115            PollFlags::POLLOUT,
116        )];
117        let n = poll(&mut fds, Some(timeout_ms as u16))?;
118
119        if n == 0 {
120            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
121        }
122
123        // Check the result with `SO_ERROR`
124        let err: i32 = self
125            .socket
126            .take_error()?
127            .map(|e| e.raw_os_error().unwrap_or(0))
128            .unwrap_or(0);
129        if err != 0 {
130            return Err(io::Error::from_raw_os_error(err));
131        }
132
133        self.socket.set_nonblocking(false)?;
134
135        match self.socket.try_clone() {
136            Ok(cloned_socket) => {
137                // Convert the socket into a `std::net::TcpStream`
138                let std_stream: TcpStream = cloned_socket.into();
139                Ok(std_stream)
140            }
141            Err(e) => Err(e),
142        }
143    }
144
145    #[cfg(windows)]
146    pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
147        use std::mem::size_of;
148        use std::os::windows::io::AsRawSocket;
149        use windows_sys::Win32::Networking::WinSock::{
150            getsockopt, WSAPoll, POLLWRNORM, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, WSAPOLLFD,
151        };
152
153        let sock = self.socket.as_raw_socket() as SOCKET;
154        self.socket.set_nonblocking(true)?;
155
156        // Start connect
157        match self.socket.connect(&target.into()) {
158            Ok(_) => { /* connection succeeded immediately */ }
159            Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) /* WSAEWOULDBLOCK */ => {}
160            Err(e) => return Err(e),
161        }
162
163        // Wait using WSAPoll until writable
164        let mut fds = [WSAPOLLFD {
165            fd: sock,
166            events: POLLWRNORM,
167            revents: 0,
168        }];
169
170        let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
171        let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
172        if result == SOCKET_ERROR {
173            return Err(io::Error::last_os_error());
174        } else if result == 0 {
175            return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
176        }
177
178        // Check for errors via `SO_ERROR`
179        let mut so_error: i32 = 0;
180        let mut optlen = size_of::<i32>() as i32;
181        let ret = unsafe {
182            getsockopt(
183                sock,
184                SOL_SOCKET as i32,
185                SO_ERROR as i32,
186                &mut so_error as *mut _ as *mut _,
187                &mut optlen,
188            )
189        };
190
191        if ret == SOCKET_ERROR || so_error != 0 {
192            return Err(io::Error::from_raw_os_error(so_error));
193        }
194
195        self.socket.set_nonblocking(false)?;
196
197        let std_stream: TcpStream = self.socket.try_clone()?.into();
198        Ok(std_stream)
199    }
200
201    pub fn listen(&self, backlog: i32) -> io::Result<()> {
202        self.socket.listen(backlog)
203    }
204
205    pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
206        let (stream, addr) = self.socket.accept()?;
207        Ok((stream.into(), addr.as_socket().unwrap()))
208    }
209
210    pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
211        Ok(self.socket.into())
212    }
213
214    pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
215        Ok(self.socket.into())
216    }
217
218    /// Send a raw packet (for RAW TCP use).
219    pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
220        self.socket.send_to(buf, &target.into())
221    }
222
223    /// Receive a raw packet (for RAW TCP use).
224    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
225        // Safety: `MaybeUninit<u8>` is layout-compatible with `u8`.
226        let buf_maybe = unsafe {
227            std::slice::from_raw_parts_mut(
228                buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
229                buf.len(),
230            )
231        };
232
233        let (n, addr) = self.socket.recv_from(buf_maybe)?;
234        let addr = addr
235            .as_socket()
236            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
237
238        Ok((n, addr))
239    }
240
241    // --- option helpers ---
242
243    pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
244        self.socket.set_reuse_address(on)
245    }
246
247    pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
248        self.socket.set_nodelay(on)
249    }
250
251    pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
252        self.socket.set_linger(dur)
253    }
254
255    pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
256        self.socket.set_ttl(ttl)
257    }
258
259    pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
260        #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
261        return self.socket.bind_device(Some(iface.as_bytes()));
262
263        #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
264        {
265            let _ = iface;
266            Err(io::Error::new(
267                io::ErrorKind::Unsupported,
268                "bind_device not supported on this OS",
269            ))
270        }
271    }
272
273    // --- information helpers ---
274
275    pub fn local_addr(&self) -> io::Result<SocketAddr> {
276        self.socket
277            .local_addr()?
278            .as_socket()
279            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to retrieve local address"))
280    }
281
282    #[cfg(unix)]
283    pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
284        use std::os::fd::AsRawFd;
285        self.socket.as_raw_fd()
286    }
287
288    #[cfg(windows)]
289    pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
290        use std::os::windows::io::AsRawSocket;
291        self.socket.as_raw_socket()
292    }
293}