nex_socket/tcp/
sync_impl.rs1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{poll, PollFd, PollFlags};
13
14#[derive(Debug)]
16pub struct TcpSocket {
17 socket: Socket,
18}
19
20impl TcpSocket {
21 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
23 let socket = Socket::new(
24 config.socket_family.to_domain(),
25 config.socket_type.to_sock_type(),
26 Some(Protocol::TCP),
27 )?;
28
29 socket.set_nonblocking(config.nonblocking)?;
30
31 if let Some(flag) = config.reuseaddr {
33 socket.set_reuse_address(flag)?;
34 }
35 if let Some(flag) = config.nodelay {
36 socket.set_nodelay(flag)?;
37 }
38 if let Some(dur) = config.linger {
39 socket.set_linger(Some(dur))?;
40 }
41 if let Some(ttl) = config.ttl {
42 socket.set_ttl(ttl)?;
43 }
44 if let Some(hoplimit) = config.hoplimit {
45 socket.set_unicast_hops_v6(hoplimit)?;
46 }
47 if let Some(keepalive) = config.keepalive {
48 socket.set_keepalive(keepalive)?;
49 }
50 if let Some(timeout) = config.read_timeout {
51 socket.set_read_timeout(Some(timeout))?;
52 }
53 if let Some(timeout) = config.write_timeout {
54 socket.set_write_timeout(Some(timeout))?;
55 }
56
57 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
59 if let Some(iface) = &config.bind_device {
60 socket.bind_device(Some(iface.as_bytes()))?;
61 }
62
63 if let Some(addr) = config.bind_addr {
65 socket.bind(&addr.into())?;
66 }
67
68 Ok(Self { socket })
69 }
70
71 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
73 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
74 socket.set_nonblocking(false)?;
75 Ok(Self { socket })
76 }
77
78 pub fn v4_stream() -> io::Result<Self> {
80 Self::new(Domain::IPV4, SockType::STREAM)
81 }
82
83 pub fn v6_stream() -> io::Result<Self> {
85 Self::new(Domain::IPV6, SockType::STREAM)
86 }
87
88 pub fn raw_v4() -> io::Result<Self> {
90 Self::new(Domain::IPV4, SockType::RAW)
91 }
92
93 pub fn raw_v6() -> io::Result<Self> {
95 Self::new(Domain::IPV6, SockType::RAW)
96 }
97
98 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
100 self.socket.bind(&addr.into())
101 }
102
103 pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
105 self.socket.connect(&addr.into())
106 }
107
108 #[cfg(unix)]
110 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
111 let raw_fd = self.socket.as_raw_fd();
112 self.socket.set_nonblocking(true)?;
113
114 match self.socket.connect(&target.into()) {
116 Ok(_) => { }
117 Err(err)
118 if err.kind() == io::ErrorKind::WouldBlock
119 || err.raw_os_error() == Some(libc::EINPROGRESS) =>
120 {
121 }
123 Err(e) => return Err(e),
124 }
125
126 let timeout_ms = timeout.as_millis() as i32;
128 use std::os::unix::io::BorrowedFd;
129 let mut fds = [PollFd::new(
131 unsafe { BorrowedFd::borrow_raw(raw_fd) },
132 PollFlags::POLLOUT,
133 )];
134 let n = poll(&mut fds, Some(timeout_ms as u16))?;
135
136 if n == 0 {
137 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
138 }
139
140 let err: i32 = self
142 .socket
143 .take_error()?
144 .map(|e| e.raw_os_error().unwrap_or(0))
145 .unwrap_or(0);
146 if err != 0 {
147 return Err(io::Error::from_raw_os_error(err));
148 }
149
150 self.socket.set_nonblocking(false)?;
151
152 match self.socket.try_clone() {
153 Ok(cloned_socket) => {
154 let std_stream: TcpStream = cloned_socket.into();
156 Ok(std_stream)
157 }
158 Err(e) => Err(e),
159 }
160 }
161
162 #[cfg(windows)]
163 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
164 use std::mem::size_of;
165 use std::os::windows::io::AsRawSocket;
166 use windows_sys::Win32::Networking::WinSock::{
167 getsockopt, WSAPoll, POLLWRNORM, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, WSAPOLLFD,
168 };
169
170 let sock = self.socket.as_raw_socket() as SOCKET;
171 self.socket.set_nonblocking(true)?;
172
173 match self.socket.connect(&target.into()) {
175 Ok(_) => { }
176 Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) => {}
177 Err(e) => return Err(e),
178 }
179
180 let mut fds = [WSAPOLLFD {
182 fd: sock,
183 events: POLLWRNORM,
184 revents: 0,
185 }];
186
187 let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
188 let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
189 if result == SOCKET_ERROR {
190 return Err(io::Error::last_os_error());
191 } else if result == 0 {
192 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
193 }
194
195 let mut so_error: i32 = 0;
197 let mut optlen = size_of::<i32>() as i32;
198 let ret = unsafe {
199 getsockopt(
200 sock,
201 SOL_SOCKET as i32,
202 SO_ERROR as i32,
203 &mut so_error as *mut _ as *mut _,
204 &mut optlen,
205 )
206 };
207
208 if ret == SOCKET_ERROR || so_error != 0 {
209 return Err(io::Error::from_raw_os_error(so_error));
210 }
211
212 self.socket.set_nonblocking(false)?;
213
214 let std_stream: TcpStream = self.socket.try_clone()?.into();
215 Ok(std_stream)
216 }
217
218 pub fn listen(&self, backlog: i32) -> io::Result<()> {
220 self.socket.listen(backlog)
221 }
222
223 pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
225 let (stream, addr) = self.socket.accept()?;
226 Ok((stream.into(), addr.as_socket().unwrap()))
227 }
228
229 pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
231 Ok(self.socket.into())
232 }
233
234 pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
236 Ok(self.socket.into())
237 }
238
239 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
241 self.socket.send_to(buf, &target.into())
242 }
243
244 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
246 let buf_maybe = unsafe {
248 std::slice::from_raw_parts_mut(
249 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
250 buf.len(),
251 )
252 };
253
254 let (n, addr) = self.socket.recv_from(buf_maybe)?;
255 let addr = addr
256 .as_socket()
257 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
258
259 Ok((n, addr))
260 }
261
262 pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
264 self.socket.shutdown(how)
265 }
266
267 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
269 self.socket.set_reuse_address(on)
270 }
271
272 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
274 self.socket.set_nodelay(on)
275 }
276
277 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
279 self.socket.set_linger(dur)
280 }
281
282 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
284 self.socket.set_ttl(ttl)
285 }
286
287 pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
289 self.socket.set_unicast_hops_v6(hops)
290 }
291
292 pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
294 self.socket.set_keepalive(on)
295 }
296
297 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
299 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
300 return self.socket.bind_device(Some(iface.as_bytes()));
301
302 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
303 {
304 let _ = iface;
305 Err(io::Error::new(
306 io::ErrorKind::Unsupported,
307 "bind_device not supported on this OS",
308 ))
309 }
310 }
311
312 pub fn local_addr(&self) -> io::Result<SocketAddr> {
314 self.socket
315 .local_addr()?
316 .as_socket()
317 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to retrieve local address"))
318 }
319
320 #[cfg(unix)]
322 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
323 use std::os::fd::AsRawFd;
324 self.socket.as_raw_fd()
325 }
326
327 #[cfg(windows)]
329 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
330 use std::os::windows::io::AsRawSocket;
331 self.socket.as_raw_socket()
332 }
333}