nex_socket/tcp/
sync_impl.rs1use socket2::{Domain, Protocol, Socket, Type as SockType};
2use std::io;
3use std::net::{SocketAddr, TcpStream, TcpListener};
4use std::time::Duration;
5
6use crate::tcp::TcpConfig;
7
8#[cfg(unix)]
9use std::os::fd::AsRawFd;
10
11#[cfg(unix)]
12use nix::poll::{poll, PollFd, PollFlags};
13
14#[derive(Debug)]
16pub struct TcpSocket {
17 socket: Socket,
18}
19
20impl TcpSocket {
21 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
23 let socket = Socket::new(config.domain, config.sock_type, Some(Protocol::TCP))?;
24
25 if let Some(flag) = config.reuseaddr {
27 socket.set_reuse_address(flag)?;
28 }
29 if let Some(flag) = config.nodelay {
30 socket.set_nodelay(flag)?;
31 }
32 if let Some(dur) = config.linger {
33 socket.set_linger(Some(dur))?;
34 }
35 if let Some(ttl) = config.ttl {
36 socket.set_ttl(ttl)?;
37 }
38
39 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
40 if let Some(iface) = &config.bind_device {
41 socket.bind_device(Some(iface.as_bytes()))?;
42 }
43
44 if let Some(addr) = config.bind_addr {
46 socket.bind(&addr.into())?;
47 }
48
49 socket.set_nonblocking(config.nonblocking)?;
51
52 Ok(Self { socket })
53 }
54
55 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
57 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
58 socket.set_nonblocking(false)?;
59 Ok(Self { socket })
60 }
61
62 pub fn v4_stream() -> io::Result<Self> {
64 Self::new(Domain::IPV4, SockType::STREAM)
65 }
66
67 pub fn v6_stream() -> io::Result<Self> {
69 Self::new(Domain::IPV6, SockType::STREAM)
70 }
71
72 pub fn raw_v4() -> io::Result<Self> {
74 Self::new(Domain::IPV4, SockType::RAW)
75 }
76
77 pub fn raw_v6() -> io::Result<Self> {
79 Self::new(Domain::IPV6, SockType::RAW)
80 }
81
82 pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
85 self.socket.bind(&addr.into())
86 }
87
88 pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
89 self.socket.connect(&addr.into())
90 }
91
92 #[cfg(unix)]
93 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
94 let raw_fd = self.socket.as_raw_fd();
95 self.socket.set_nonblocking(true)?;
96
97 match self.socket.connect(&target.into()) {
99 Ok(_) => { }
100 Err(err) if err.kind() == io::ErrorKind::WouldBlock || err.raw_os_error() == Some(libc::EINPROGRESS) => {
101 }
103 Err(e) => return Err(e),
104 }
105
106 let timeout_ms = timeout.as_millis() as i32;
108 use std::os::unix::io::BorrowedFd;
109 let mut fds = [PollFd::new(unsafe { BorrowedFd::borrow_raw(raw_fd) }, PollFlags::POLLOUT)];
111 let n = poll(&mut fds, Some(timeout_ms as u16))?;
112
113 if n == 0 {
114 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
115 }
116
117 let err: i32 = self.socket.take_error()?.map(|e| e.raw_os_error().unwrap_or(0)).unwrap_or(0);
119 if err != 0 {
120 return Err(io::Error::from_raw_os_error(err));
121 }
122
123 self.socket.set_nonblocking(false)?;
124
125 match self.socket.try_clone() {
126 Ok(cloned_socket) => {
127 let std_stream: TcpStream = cloned_socket.into();
129 Ok(std_stream)
130 }
131 Err(e) => Err(e),
132 }
133 }
134
135 #[cfg(windows)]
136 pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
137 use std::os::windows::io::AsRawSocket;
138 use windows_sys::Win32::Networking::WinSock::{
139 WSAPOLLFD, WSAPoll, POLLWRNORM, SOCKET_ERROR, SO_ERROR, SOL_SOCKET,
140 getsockopt, SOCKET,
141 };
142 use std::mem::size_of;
143
144 let sock = self.socket.as_raw_socket() as SOCKET;
145 self.socket.set_nonblocking(true)?;
146
147 match self.socket.connect(&target.into()) {
149 Ok(_) => { }
150 Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) => {}
151 Err(e) => return Err(e),
152 }
153
154 let mut fds = [WSAPOLLFD {
156 fd: sock,
157 events: POLLWRNORM,
158 revents: 0,
159 }];
160
161 let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
162 let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
163 if result == SOCKET_ERROR {
164 return Err(io::Error::last_os_error());
165 } else if result == 0 {
166 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
167 }
168
169 let mut so_error: i32 = 0;
171 let mut optlen = size_of::<i32>() as i32;
172 let ret = unsafe {
173 getsockopt(
174 sock,
175 SOL_SOCKET as i32,
176 SO_ERROR as i32,
177 &mut so_error as *mut _ as *mut _,
178 &mut optlen,
179 )
180 };
181
182 if ret == SOCKET_ERROR || so_error != 0 {
183 return Err(io::Error::from_raw_os_error(so_error));
184 }
185
186 self.socket.set_nonblocking(false)?;
187
188 let std_stream: TcpStream = self.socket.try_clone()?.into();
189 Ok(std_stream)
190 }
191
192 pub fn listen(&self, backlog: i32) -> io::Result<()> {
193 self.socket.listen(backlog)
194 }
195
196 pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
197 let (stream, addr) = self.socket.accept()?;
198 Ok((stream.into(), addr.as_socket().unwrap()))
199 }
200
201 pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
202 Ok(self.socket.into())
203 }
204
205 pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
206 Ok(self.socket.into())
207 }
208
209 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
211 self.socket.send_to(buf, &target.into())
212 }
213
214 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
216 let buf_maybe = unsafe {
218 std::slice::from_raw_parts_mut(
219 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
220 buf.len(),
221 )
222 };
223
224 let (n, addr) = self.socket.recv_from(buf_maybe)?;
225 let addr = addr.as_socket().ok_or_else(|| {
226 io::Error::new(io::ErrorKind::InvalidData, "invalid address format")
227 })?;
228
229 Ok((n, addr))
230 }
231
232 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
235 self.socket.set_reuse_address(on)
236 }
237
238 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
239 self.socket.set_nodelay(on)
240 }
241
242 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
243 self.socket.set_linger(dur)
244 }
245
246 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
247 self.socket.set_ttl(ttl)
248 }
249
250 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
251 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
252 return self.socket.bind_device(Some(iface.as_bytes()));
253
254 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
255 {
256 let _ = iface;
257 Err(io::Error::new(io::ErrorKind::Unsupported, "bind_device not supported on this OS"))
258 }
259 }
260
261 pub fn local_addr(&self) -> io::Result<SocketAddr> {
264 self.socket.local_addr()?.as_socket().ok_or_else(|| {
265 io::Error::new(io::ErrorKind::Other, "Failed to retrieve local address")
266 })
267 }
268
269 #[cfg(unix)]
270 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
271 use std::os::fd::AsRawFd;
272 self.socket.as_raw_fd()
273 }
274
275 #[cfg(windows)]
276 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
277 use std::os::windows::io::AsRawSocket;
278 self.socket.as_raw_socket()
279 }
280}