nex_socket/tcp/
sync_impl.rs1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{poll, PollFd, PollFlags};
13
14#[derive(Debug)]
16pub struct TcpSocket {
17 socket: Socket,
18}
19
20impl TcpSocket {
21 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
23 let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?;
24
25 if let Some(flag) = config.reuseaddr {
27 socket.set_reuse_address(flag)?;
28 }
29 if let Some(flag) = config.nodelay {
30 socket.set_nodelay(flag)?;
31 }
32 if let Some(dur) = config.linger {
33 socket.set_linger(Some(dur))?;
34 }
35 if let Some(ttl) = config.ttl {
36 socket.set_ttl(ttl)?;
37 }
38
39 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
40 if let Some(iface) = &config.bind_device {
41 socket.bind_device(Some(iface.as_bytes()))?;
42 }
43
44 if let Some(addr) = config.bind_addr {
46 socket.bind(&addr.into())?;
47 }
48
49 socket.set_nonblocking(config.nonblocking)?;
51
52 Ok(Self { socket })
53 }
54
55 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
57 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
58 socket.set_nonblocking(false)?;
59 Ok(Self { socket })
60 }
61
62 pub fn v4_stream() -> io::Result<Self> {
64 Self::new(Domain::IPV4, SockType::STREAM)
65 }
66
67 pub fn v6_stream() -> io::Result<Self> {
69 Self::new(Domain::IPV6, SockType::STREAM)
70 }
71
72 pub fn raw_v4() -> io::Result<Self> {
74 Self::new(Domain::IPV4, SockType::RAW)
75 }
76
77 pub fn raw_v6() -> io::Result<Self> {
79 Self::new(Domain::IPV6, SockType::RAW)
80 }
81
82 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
85 self.socket.bind(&addr.into())
86 }
87
88 pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
89 self.socket.connect(&addr.into())
90 }
91
92 #[cfg(unix)]
93 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
94 let raw_fd = self.socket.as_raw_fd();
95 self.socket.set_nonblocking(true)?;
96
97 match self.socket.connect(&target.into()) {
99 Ok(_) => { }
100 Err(err)
101 if err.kind() == io::ErrorKind::WouldBlock
102 || err.raw_os_error() == Some(libc::EINPROGRESS) =>
103 {
104 }
106 Err(e) => return Err(e),
107 }
108
109 let timeout_ms = timeout.as_millis() as i32;
111 use std::os::unix::io::BorrowedFd;
112 let mut fds = [PollFd::new(
114 unsafe { BorrowedFd::borrow_raw(raw_fd) },
115 PollFlags::POLLOUT,
116 )];
117 let n = poll(&mut fds, Some(timeout_ms as u16))?;
118
119 if n == 0 {
120 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
121 }
122
123 let err: i32 = self
125 .socket
126 .take_error()?
127 .map(|e| e.raw_os_error().unwrap_or(0))
128 .unwrap_or(0);
129 if err != 0 {
130 return Err(io::Error::from_raw_os_error(err));
131 }
132
133 self.socket.set_nonblocking(false)?;
134
135 match self.socket.try_clone() {
136 Ok(cloned_socket) => {
137 let std_stream: TcpStream = cloned_socket.into();
139 Ok(std_stream)
140 }
141 Err(e) => Err(e),
142 }
143 }
144
145 #[cfg(windows)]
146 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
147 use std::mem::size_of;
148 use std::os::windows::io::AsRawSocket;
149 use windows_sys::Win32::Networking::WinSock::{
150 getsockopt, WSAPoll, POLLWRNORM, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, WSAPOLLFD,
151 };
152
153 let sock = self.socket.as_raw_socket() as SOCKET;
154 self.socket.set_nonblocking(true)?;
155
156 match self.socket.connect(&target.into()) {
158 Ok(_) => { }
159 Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) => {}
160 Err(e) => return Err(e),
161 }
162
163 let mut fds = [WSAPOLLFD {
165 fd: sock,
166 events: POLLWRNORM,
167 revents: 0,
168 }];
169
170 let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
171 let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
172 if result == SOCKET_ERROR {
173 return Err(io::Error::last_os_error());
174 } else if result == 0 {
175 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
176 }
177
178 let mut so_error: i32 = 0;
180 let mut optlen = size_of::<i32>() as i32;
181 let ret = unsafe {
182 getsockopt(
183 sock,
184 SOL_SOCKET as i32,
185 SO_ERROR as i32,
186 &mut so_error as *mut _ as *mut _,
187 &mut optlen,
188 )
189 };
190
191 if ret == SOCKET_ERROR || so_error != 0 {
192 return Err(io::Error::from_raw_os_error(so_error));
193 }
194
195 self.socket.set_nonblocking(false)?;
196
197 let std_stream: TcpStream = self.socket.try_clone()?.into();
198 Ok(std_stream)
199 }
200
201 pub fn listen(&self, backlog: i32) -> io::Result<()> {
202 self.socket.listen(backlog)
203 }
204
205 pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
206 let (stream, addr) = self.socket.accept()?;
207 Ok((stream.into(), addr.as_socket().unwrap()))
208 }
209
210 pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
211 Ok(self.socket.into())
212 }
213
214 pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
215 Ok(self.socket.into())
216 }
217
218 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
220 self.socket.send_to(buf, &target.into())
221 }
222
223 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
225 let buf_maybe = unsafe {
227 std::slice::from_raw_parts_mut(
228 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
229 buf.len(),
230 )
231 };
232
233 let (n, addr) = self.socket.recv_from(buf_maybe)?;
234 let addr = addr
235 .as_socket()
236 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
237
238 Ok((n, addr))
239 }
240
241 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
244 self.socket.set_reuse_address(on)
245 }
246
247 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
248 self.socket.set_nodelay(on)
249 }
250
251 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
252 self.socket.set_linger(dur)
253 }
254
255 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
256 self.socket.set_ttl(ttl)
257 }
258
259 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
260 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
261 return self.socket.bind_device(Some(iface.as_bytes()));
262
263 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
264 {
265 let _ = iface;
266 Err(io::Error::new(
267 io::ErrorKind::Unsupported,
268 "bind_device not supported on this OS",
269 ))
270 }
271 }
272
273 pub fn local_addr(&self) -> io::Result<SocketAddr> {
276 self.socket
277 .local_addr()?
278 .as_socket()
279 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to retrieve local address"))
280 }
281
282 #[cfg(unix)]
283 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
284 use std::os::fd::AsRawFd;
285 self.socket.as_raw_fd()
286 }
287
288 #[cfg(windows)]
289 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
290 use std::os::windows::io::AsRawSocket;
291 self.socket.as_raw_socket()
292 }
293}