1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8#[derive(Debug)]
10pub struct AsyncTcpSocket {
11 socket: Socket,
12}
13
14impl AsyncTcpSocket {
15 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17 let socket = Socket::new(
18 config.socket_family.to_domain(),
19 config.socket_type.to_sock_type(),
20 Some(Protocol::TCP),
21 )?;
22
23 socket.set_nonblocking(true)?;
24
25 if let Some(flag) = config.reuseaddr {
27 socket.set_reuse_address(flag)?;
28 }
29 #[cfg(any(
30 target_os = "android",
31 target_os = "dragonfly",
32 target_os = "freebsd",
33 target_os = "fuchsia",
34 target_os = "ios",
35 target_os = "linux",
36 target_os = "macos",
37 target_os = "netbsd",
38 target_os = "openbsd",
39 target_os = "tvos",
40 target_os = "visionos",
41 target_os = "watchos"
42 ))]
43 if let Some(flag) = config.reuseport {
44 socket.set_reuse_port(flag)?;
45 }
46 if let Some(flag) = config.nodelay {
47 socket.set_nodelay(flag)?;
48 }
49 if let Some(ttl) = config.ttl {
50 socket.set_ttl(ttl)?;
51 }
52 if let Some(hoplimit) = config.hoplimit {
53 socket.set_unicast_hops_v6(hoplimit)?;
54 }
55 if let Some(keepalive) = config.keepalive {
56 socket.set_keepalive(keepalive)?;
57 }
58 if let Some(timeout) = config.read_timeout {
59 socket.set_read_timeout(Some(timeout))?;
60 }
61 if let Some(timeout) = config.write_timeout {
62 socket.set_write_timeout(Some(timeout))?;
63 }
64 if let Some(size) = config.recv_buffer_size {
65 socket.set_recv_buffer_size(size)?;
66 }
67 if let Some(size) = config.send_buffer_size {
68 socket.set_send_buffer_size(size)?;
69 }
70 if let Some(tos) = config.tos {
71 socket.set_tos(tos)?;
72 }
73 #[cfg(any(
74 target_os = "android",
75 target_os = "dragonfly",
76 target_os = "freebsd",
77 target_os = "fuchsia",
78 target_os = "ios",
79 target_os = "linux",
80 target_os = "macos",
81 target_os = "netbsd",
82 target_os = "openbsd",
83 target_os = "tvos",
84 target_os = "visionos",
85 target_os = "watchos"
86 ))]
87 if let Some(tclass) = config.tclass_v6 {
88 socket.set_tclass_v6(tclass)?;
89 }
90 if let Some(only_v6) = config.only_v6 {
91 socket.set_only_v6(only_v6)?;
92 }
93
94 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
96 if let Some(iface) = &config.bind_device {
97 socket.bind_device(Some(iface.as_bytes()))?;
98 }
99
100 if let Some(addr) = config.bind_addr {
102 socket.bind(&addr.into())?;
103 }
104
105 Ok(Self { socket })
106 }
107
108 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
110 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
111 socket.set_nonblocking(true)?;
112 Ok(Self { socket })
113 }
114
115 pub fn v4_stream() -> io::Result<Self> {
117 Self::new(Domain::IPV4, SockType::STREAM)
118 }
119
120 pub fn v6_stream() -> io::Result<Self> {
122 Self::new(Domain::IPV6, SockType::STREAM)
123 }
124
125 pub fn raw_v4() -> io::Result<Self> {
127 Self::new(Domain::IPV4, SockType::RAW)
128 }
129
130 pub fn raw_v6() -> io::Result<Self> {
132 Self::new(Domain::IPV6, SockType::RAW)
133 }
134
135 pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
137 match self.socket.connect(&target.into()) {
139 Ok(_) => {
140 let std_stream: StdTcpStream = self.socket.into();
142 return TcpStream::from_std(std_stream);
143 }
144 Err(e)
145 if e.kind() == io::ErrorKind::WouldBlock
146 || e.raw_os_error() == Some(libc::EINPROGRESS) =>
147 {
148 let std_stream: StdTcpStream = self.socket.into();
150 let stream = TcpStream::from_std(std_stream)?;
151 stream.writable().await?;
152
153 if let Some(err) = stream.take_error()? {
155 return Err(err);
156 }
157
158 return Ok(stream);
159 }
160 Err(e) => {
161 return Err(e);
162 }
163 }
164 }
165
166 pub async fn connect_timeout(
168 self,
169 target: SocketAddr,
170 timeout: Duration,
171 ) -> io::Result<TcpStream> {
172 match tokio::time::timeout(timeout, self.connect(target)).await {
173 Ok(result) => result,
174 Err(_) => Err(io::Error::new(
175 io::ErrorKind::TimedOut,
176 "connection timed out",
177 )),
178 }
179 }
180
181 pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
183 self.socket.listen(backlog)?;
184
185 let std_listener: StdTcpListener = self.socket.into();
186 TcpListener::from_std(std_listener)
187 }
188
189 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
191 self.socket.send_to(buf, &target.into())
192 }
193
194 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
196 let buf_maybe = unsafe {
198 std::slice::from_raw_parts_mut(
199 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
200 buf.len(),
201 )
202 };
203
204 let (n, addr) = self.socket.recv_from(buf_maybe)?;
205 let addr = addr
206 .as_socket()
207 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
208
209 Ok((n, addr))
210 }
211
212 pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
214 self.socket.shutdown(how)
215 }
216
217 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
219 self.socket.set_reuse_address(on)
220 }
221
222 pub fn reuseaddr(&self) -> io::Result<bool> {
224 self.socket.reuse_address()
225 }
226
227 #[cfg(any(
229 target_os = "android",
230 target_os = "dragonfly",
231 target_os = "freebsd",
232 target_os = "fuchsia",
233 target_os = "ios",
234 target_os = "linux",
235 target_os = "macos",
236 target_os = "netbsd",
237 target_os = "openbsd",
238 target_os = "tvos",
239 target_os = "visionos",
240 target_os = "watchos"
241 ))]
242 pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
243 self.socket.set_reuse_port(on)
244 }
245
246 #[cfg(any(
248 target_os = "android",
249 target_os = "dragonfly",
250 target_os = "freebsd",
251 target_os = "fuchsia",
252 target_os = "ios",
253 target_os = "linux",
254 target_os = "macos",
255 target_os = "netbsd",
256 target_os = "openbsd",
257 target_os = "tvos",
258 target_os = "visionos",
259 target_os = "watchos"
260 ))]
261 pub fn reuseport(&self) -> io::Result<bool> {
262 self.socket.reuse_port()
263 }
264
265 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
267 self.socket.set_nodelay(on)
268 }
269
270 pub fn nodelay(&self) -> io::Result<bool> {
272 self.socket.nodelay()
273 }
274
275 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
277 self.socket.set_linger(dur)
278 }
279
280 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
282 self.socket.set_ttl(ttl)
283 }
284
285 pub fn ttl(&self) -> io::Result<u32> {
287 self.socket.ttl()
288 }
289
290 pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
292 self.socket.set_unicast_hops_v6(hops)
293 }
294
295 pub fn hoplimit(&self) -> io::Result<u32> {
297 self.socket.unicast_hops_v6()
298 }
299
300 pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
302 self.socket.set_keepalive(on)
303 }
304
305 pub fn keepalive(&self) -> io::Result<bool> {
307 self.socket.keepalive()
308 }
309
310 pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
312 self.socket.set_recv_buffer_size(size)
313 }
314
315 pub fn recv_buffer_size(&self) -> io::Result<usize> {
317 self.socket.recv_buffer_size()
318 }
319
320 pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
322 self.socket.set_send_buffer_size(size)
323 }
324
325 pub fn send_buffer_size(&self) -> io::Result<usize> {
327 self.socket.send_buffer_size()
328 }
329
330 pub fn set_tos(&self, tos: u32) -> io::Result<()> {
332 self.socket.set_tos(tos)
333 }
334
335 pub fn tos(&self) -> io::Result<u32> {
337 self.socket.tos()
338 }
339
340 #[cfg(any(
342 target_os = "android",
343 target_os = "dragonfly",
344 target_os = "freebsd",
345 target_os = "fuchsia",
346 target_os = "ios",
347 target_os = "linux",
348 target_os = "macos",
349 target_os = "netbsd",
350 target_os = "openbsd",
351 target_os = "tvos",
352 target_os = "visionos",
353 target_os = "watchos"
354 ))]
355 pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
356 self.socket.set_tclass_v6(tclass)
357 }
358
359 #[cfg(any(
361 target_os = "android",
362 target_os = "dragonfly",
363 target_os = "freebsd",
364 target_os = "fuchsia",
365 target_os = "ios",
366 target_os = "linux",
367 target_os = "macos",
368 target_os = "netbsd",
369 target_os = "openbsd",
370 target_os = "tvos",
371 target_os = "visionos",
372 target_os = "watchos"
373 ))]
374 pub fn tclass_v6(&self) -> io::Result<u32> {
375 self.socket.tclass_v6()
376 }
377
378 pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
380 self.socket.set_only_v6(only_v6)
381 }
382
383 pub fn only_v6(&self) -> io::Result<bool> {
385 self.socket.only_v6()
386 }
387
388 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
390 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
391 return self.socket.bind_device(Some(iface.as_bytes()));
392
393 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
394 {
395 let _ = iface;
396 Err(io::Error::new(
397 io::ErrorKind::Unsupported,
398 "bind_device is not supported on this platform",
399 ))
400 }
401 }
402
403 pub fn local_addr(&self) -> io::Result<SocketAddr> {
405 self.socket
406 .local_addr()?
407 .as_socket()
408 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
409 }
410
411 pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
413 let std_stream: StdTcpStream = self.socket.into();
414 TcpStream::from_std(std_stream)
415 }
416
417 pub fn from_socket(socket: Socket) -> Self {
419 Self { socket }
420 }
421
422 pub fn socket(&self) -> &Socket {
424 &self.socket
425 }
426
427 pub fn into_socket(self) -> Socket {
429 self.socket
430 }
431
432 #[cfg(unix)]
434 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
435 use std::os::fd::AsRawFd;
436 self.socket.as_raw_fd()
437 }
438
439 #[cfg(windows)]
441 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
442 use std::os::windows::io::AsRawSocket;
443 self.socket.as_raw_socket()
444 }
445}