1use crate::tcp::TcpConfig;
2use socket2::{Domain, Protocol, Socket, Type as SockType};
3use std::io;
4use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
5use std::time::Duration;
6use tokio::net::{TcpListener, TcpStream};
7
8#[derive(Debug)]
10pub struct AsyncTcpSocket {
11 socket: Socket,
12}
13
14impl AsyncTcpSocket {
15 pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
17 config.validate()?;
18
19 let socket = Socket::new(
20 config.socket_family.to_domain(),
21 config.socket_type.to_sock_type(),
22 Some(Protocol::TCP),
23 )?;
24
25 socket.set_nonblocking(true)?;
26
27 if let Some(flag) = config.reuseaddr {
29 socket.set_reuse_address(flag)?;
30 }
31 #[cfg(any(
32 target_os = "android",
33 target_os = "dragonfly",
34 target_os = "freebsd",
35 target_os = "fuchsia",
36 target_os = "ios",
37 target_os = "linux",
38 target_os = "macos",
39 target_os = "netbsd",
40 target_os = "openbsd",
41 target_os = "tvos",
42 target_os = "visionos",
43 target_os = "watchos"
44 ))]
45 if let Some(flag) = config.reuseport {
46 socket.set_reuse_port(flag)?;
47 }
48 if let Some(flag) = config.nodelay {
49 socket.set_nodelay(flag)?;
50 }
51 if let Some(ttl) = config.ttl {
52 socket.set_ttl(ttl)?;
53 }
54 if let Some(hoplimit) = config.hoplimit {
55 socket.set_unicast_hops_v6(hoplimit)?;
56 }
57 if let Some(keepalive) = config.keepalive {
58 socket.set_keepalive(keepalive)?;
59 }
60 if let Some(timeout) = config.read_timeout {
61 socket.set_read_timeout(Some(timeout))?;
62 }
63 if let Some(timeout) = config.write_timeout {
64 socket.set_write_timeout(Some(timeout))?;
65 }
66 if let Some(size) = config.recv_buffer_size {
67 socket.set_recv_buffer_size(size)?;
68 }
69 if let Some(size) = config.send_buffer_size {
70 socket.set_send_buffer_size(size)?;
71 }
72 if let Some(tos) = config.tos {
73 socket.set_tos(tos)?;
74 }
75 #[cfg(any(
76 target_os = "android",
77 target_os = "dragonfly",
78 target_os = "freebsd",
79 target_os = "fuchsia",
80 target_os = "ios",
81 target_os = "linux",
82 target_os = "macos",
83 target_os = "netbsd",
84 target_os = "openbsd",
85 target_os = "tvos",
86 target_os = "visionos",
87 target_os = "watchos"
88 ))]
89 if let Some(tclass) = config.tclass_v6 {
90 socket.set_tclass_v6(tclass)?;
91 }
92 if let Some(only_v6) = config.only_v6 {
93 socket.set_only_v6(only_v6)?;
94 }
95
96 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
98 if let Some(iface) = &config.bind_device {
99 socket.bind_device(Some(iface.as_bytes()))?;
100 }
101
102 if let Some(addr) = config.bind_addr {
104 socket.bind(&addr.into())?;
105 }
106
107 Ok(Self { socket })
108 }
109
110 pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
112 let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
113 socket.set_nonblocking(true)?;
114 Ok(Self { socket })
115 }
116
117 pub fn v4_stream() -> io::Result<Self> {
119 Self::new(Domain::IPV4, SockType::STREAM)
120 }
121
122 pub fn v6_stream() -> io::Result<Self> {
124 Self::new(Domain::IPV6, SockType::STREAM)
125 }
126
127 pub fn raw_v4() -> io::Result<Self> {
129 Self::new(Domain::IPV4, SockType::RAW)
130 }
131
132 pub fn raw_v6() -> io::Result<Self> {
134 Self::new(Domain::IPV6, SockType::RAW)
135 }
136
137 pub async fn connect(self, target: SocketAddr) -> io::Result<TcpStream> {
139 match self.socket.connect(&target.into()) {
141 Ok(_) => {
142 let std_stream: StdTcpStream = self.socket.into();
144 return TcpStream::from_std(std_stream);
145 }
146 Err(e)
147 if e.kind() == io::ErrorKind::WouldBlock
148 || e.raw_os_error() == Some(libc::EINPROGRESS) =>
149 {
150 let std_stream: StdTcpStream = self.socket.into();
152 let stream = TcpStream::from_std(std_stream)?;
153 stream.writable().await?;
154
155 if let Some(err) = stream.take_error()? {
157 return Err(err);
158 }
159
160 return Ok(stream);
161 }
162 Err(e) => {
163 return Err(e);
164 }
165 }
166 }
167
168 pub async fn connect_timeout(
170 self,
171 target: SocketAddr,
172 timeout: Duration,
173 ) -> io::Result<TcpStream> {
174 match tokio::time::timeout(timeout, self.connect(target)).await {
175 Ok(result) => result,
176 Err(_) => Err(io::Error::new(
177 io::ErrorKind::TimedOut,
178 "connection timed out",
179 )),
180 }
181 }
182
183 pub fn listen(self, backlog: i32) -> io::Result<TcpListener> {
185 self.socket.listen(backlog)?;
186
187 let std_listener: StdTcpListener = self.socket.into();
188 TcpListener::from_std(std_listener)
189 }
190
191 pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
193 self.socket.send_to(buf, &target.into())
194 }
195
196 pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
198 let buf_maybe = unsafe {
200 std::slice::from_raw_parts_mut(
201 buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
202 buf.len(),
203 )
204 };
205
206 let (n, addr) = self.socket.recv_from(buf_maybe)?;
207 let addr = addr
208 .as_socket()
209 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
210
211 Ok((n, addr))
212 }
213
214 pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
216 self.socket.shutdown(how)
217 }
218
219 pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
221 self.socket.set_reuse_address(on)
222 }
223
224 pub fn reuseaddr(&self) -> io::Result<bool> {
226 self.socket.reuse_address()
227 }
228
229 #[cfg(any(
231 target_os = "android",
232 target_os = "dragonfly",
233 target_os = "freebsd",
234 target_os = "fuchsia",
235 target_os = "ios",
236 target_os = "linux",
237 target_os = "macos",
238 target_os = "netbsd",
239 target_os = "openbsd",
240 target_os = "tvos",
241 target_os = "visionos",
242 target_os = "watchos"
243 ))]
244 pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
245 self.socket.set_reuse_port(on)
246 }
247
248 #[cfg(any(
250 target_os = "android",
251 target_os = "dragonfly",
252 target_os = "freebsd",
253 target_os = "fuchsia",
254 target_os = "ios",
255 target_os = "linux",
256 target_os = "macos",
257 target_os = "netbsd",
258 target_os = "openbsd",
259 target_os = "tvos",
260 target_os = "visionos",
261 target_os = "watchos"
262 ))]
263 pub fn reuseport(&self) -> io::Result<bool> {
264 self.socket.reuse_port()
265 }
266
267 pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
269 self.socket.set_nodelay(on)
270 }
271
272 pub fn nodelay(&self) -> io::Result<bool> {
274 self.socket.nodelay()
275 }
276
277 pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
279 self.socket.set_linger(dur)
280 }
281
282 pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
284 self.socket.set_ttl(ttl)
285 }
286
287 pub fn ttl(&self) -> io::Result<u32> {
289 self.socket.ttl()
290 }
291
292 pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
294 self.socket.set_unicast_hops_v6(hops)
295 }
296
297 pub fn hoplimit(&self) -> io::Result<u32> {
299 self.socket.unicast_hops_v6()
300 }
301
302 pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
304 self.socket.set_keepalive(on)
305 }
306
307 pub fn keepalive(&self) -> io::Result<bool> {
309 self.socket.keepalive()
310 }
311
312 pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
314 self.socket.set_recv_buffer_size(size)
315 }
316
317 pub fn recv_buffer_size(&self) -> io::Result<usize> {
319 self.socket.recv_buffer_size()
320 }
321
322 pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
324 self.socket.set_send_buffer_size(size)
325 }
326
327 pub fn send_buffer_size(&self) -> io::Result<usize> {
329 self.socket.send_buffer_size()
330 }
331
332 pub fn set_tos(&self, tos: u32) -> io::Result<()> {
334 self.socket.set_tos(tos)
335 }
336
337 pub fn tos(&self) -> io::Result<u32> {
339 self.socket.tos()
340 }
341
342 #[cfg(any(
344 target_os = "android",
345 target_os = "dragonfly",
346 target_os = "freebsd",
347 target_os = "fuchsia",
348 target_os = "ios",
349 target_os = "linux",
350 target_os = "macos",
351 target_os = "netbsd",
352 target_os = "openbsd",
353 target_os = "tvos",
354 target_os = "visionos",
355 target_os = "watchos"
356 ))]
357 pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
358 self.socket.set_tclass_v6(tclass)
359 }
360
361 #[cfg(any(
363 target_os = "android",
364 target_os = "dragonfly",
365 target_os = "freebsd",
366 target_os = "fuchsia",
367 target_os = "ios",
368 target_os = "linux",
369 target_os = "macos",
370 target_os = "netbsd",
371 target_os = "openbsd",
372 target_os = "tvos",
373 target_os = "visionos",
374 target_os = "watchos"
375 ))]
376 pub fn tclass_v6(&self) -> io::Result<u32> {
377 self.socket.tclass_v6()
378 }
379
380 pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
382 self.socket.set_only_v6(only_v6)
383 }
384
385 pub fn only_v6(&self) -> io::Result<bool> {
387 self.socket.only_v6()
388 }
389
390 pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
392 #[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
393 return self.socket.bind_device(Some(iface.as_bytes()));
394
395 #[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
396 {
397 let _ = iface;
398 Err(io::Error::new(
399 io::ErrorKind::Unsupported,
400 "bind_device is not supported on this platform",
401 ))
402 }
403 }
404
405 pub fn local_addr(&self) -> io::Result<SocketAddr> {
407 self.socket
408 .local_addr()?
409 .as_socket()
410 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
411 }
412
413 pub fn into_tokio_stream(self) -> io::Result<TcpStream> {
415 let std_stream: StdTcpStream = self.socket.into();
416 TcpStream::from_std(std_stream)
417 }
418
419 pub fn from_socket(socket: Socket) -> Self {
421 Self { socket }
422 }
423
424 pub fn socket(&self) -> &Socket {
426 &self.socket
427 }
428
429 pub fn into_socket(self) -> Socket {
431 self.socket
432 }
433
434 #[cfg(unix)]
436 pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
437 use std::os::fd::AsRawFd;
438 self.socket.as_raw_fd()
439 }
440
441 #[cfg(windows)]
443 pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
444 use std::os::windows::io::AsRawSocket;
445 self.socket.as_raw_socket()
446 }
447}