1use crate::socket::to_socket_protocol;
2use crate::socket::{IpVersion, SocketOption};
3use async_io::{Async, Timer};
4use futures_lite::future::FutureExt;
5use socket2::{SockAddr, Socket as SystemSocket};
6use std::io::{self, Read, Write};
7use std::mem::MaybeUninit;
8use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream, UdpSocket};
9use std::sync::Arc;
10use std::time::Duration;
11
12#[derive(Clone, Debug)]
14pub struct AsyncSocket {
15 inner: Arc<Async<SystemSocket>>,
16}
17
18impl AsyncSocket {
19 pub fn new(socket_option: SocketOption) -> io::Result<AsyncSocket> {
21 let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
22 SystemSocket::new(
23 socket_option.ip_version.to_domain(),
24 socket_option.socket_type.to_type(),
25 Some(to_socket_protocol(protocol)),
26 )?
27 } else {
28 SystemSocket::new(
29 socket_option.ip_version.to_domain(),
30 socket_option.socket_type.to_type(),
31 None,
32 )?
33 };
34 socket.set_nonblocking(true)?;
35 Ok(AsyncSocket {
36 inner: Arc::new(Async::new(socket)?),
37 })
38 }
39 pub async fn new_with_async_connect(addr: &SocketAddr) -> io::Result<AsyncSocket> {
41 let stream = Async::<TcpStream>::connect(*addr).await?;
42 let socket = SystemSocket::from(stream.into_inner()?);
45 socket.set_nonblocking(true)?;
46 Ok(AsyncSocket {
47 inner: Arc::new(Async::new(socket)?),
48 })
49 }
50 pub async fn new_with_async_connect_timeout(
52 addr: &SocketAddr,
53 timeout: Duration,
54 ) -> io::Result<AsyncSocket> {
55 let stream = Async::<TcpStream>::connect(*addr)
56 .or(async {
57 Timer::after(timeout).await;
58 Err(io::ErrorKind::TimedOut.into())
59 })
60 .await?;
61 let socket = SystemSocket::from(stream.into_inner()?);
64 socket.set_nonblocking(true)?;
65 Ok(AsyncSocket {
66 inner: Arc::new(Async::new(socket)?),
67 })
68 }
69 pub fn new_with_connect(
72 socket_option: SocketOption,
73 addr: &SocketAddr,
74 ) -> io::Result<AsyncSocket> {
75 let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
76 SystemSocket::new(
77 socket_option.ip_version.to_domain(),
78 socket_option.socket_type.to_type(),
79 Some(to_socket_protocol(protocol)),
80 )?
81 } else {
82 SystemSocket::new(
83 socket_option.ip_version.to_domain(),
84 socket_option.socket_type.to_type(),
85 None,
86 )?
87 };
88 let addr: SockAddr = SockAddr::from(*addr);
89 socket.connect(&addr)?;
90 socket.set_nonblocking(true)?;
91 Ok(AsyncSocket {
92 inner: Arc::new(Async::new(socket)?),
93 })
94 }
95 pub fn new_with_connect_timeout(
98 socket_option: SocketOption,
99 addr: &SocketAddr,
100 timeout: Duration,
101 ) -> io::Result<AsyncSocket> {
102 let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
103 SystemSocket::new(
104 socket_option.ip_version.to_domain(),
105 socket_option.socket_type.to_type(),
106 Some(to_socket_protocol(protocol)),
107 )?
108 } else {
109 SystemSocket::new(
110 socket_option.ip_version.to_domain(),
111 socket_option.socket_type.to_type(),
112 None,
113 )?
114 };
115 let addr: SockAddr = SockAddr::from(*addr);
116 socket.connect_timeout(&addr, timeout)?;
117 socket.set_nonblocking(true)?;
118 Ok(AsyncSocket {
119 inner: Arc::new(Async::new(socket)?),
120 })
121 }
122 pub fn new_with_listener(
124 socket_option: SocketOption,
125 addr: &SocketAddr,
126 ) -> io::Result<AsyncSocket> {
127 let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
128 SystemSocket::new(
129 socket_option.ip_version.to_domain(),
130 socket_option.socket_type.to_type(),
131 Some(to_socket_protocol(protocol)),
132 )?
133 } else {
134 SystemSocket::new(
135 socket_option.ip_version.to_domain(),
136 socket_option.socket_type.to_type(),
137 None,
138 )?
139 };
140 socket.set_nonblocking(true)?;
141 let addr: SockAddr = SockAddr::from(*addr);
142 socket.bind(&addr)?;
143 socket.listen(1024)?;
144 Ok(AsyncSocket {
145 inner: Arc::new(Async::new(socket)?),
146 })
147 }
148 pub fn new_with_bind(
150 socket_option: SocketOption,
151 addr: &SocketAddr,
152 ) -> io::Result<AsyncSocket> {
153 let socket: SystemSocket = if let Some(protocol) = socket_option.protocol {
154 SystemSocket::new(
155 socket_option.ip_version.to_domain(),
156 socket_option.socket_type.to_type(),
157 Some(to_socket_protocol(protocol)),
158 )?
159 } else {
160 SystemSocket::new(
161 socket_option.ip_version.to_domain(),
162 socket_option.socket_type.to_type(),
163 None,
164 )?
165 };
166 socket.set_nonblocking(true)?;
167 let addr: SockAddr = SockAddr::from(*addr);
168 socket.bind(&addr)?;
169 Ok(AsyncSocket {
170 inner: Arc::new(Async::new(socket)?),
171 })
172 }
173 pub fn from_tcp_stream(tcp_stream: TcpStream) -> io::Result<AsyncSocket> {
176 let socket = SystemSocket::from(tcp_stream);
177 socket.set_nonblocking(true)?;
178 Ok(AsyncSocket {
179 inner: Arc::new(Async::new(socket)?),
180 })
181 }
182 pub fn from_tcp_listener(tcp_listener: TcpListener) -> io::Result<AsyncSocket> {
184 let socket = SystemSocket::from(tcp_listener);
185 socket.set_nonblocking(true)?;
186 Ok(AsyncSocket {
187 inner: Arc::new(Async::new(socket)?),
188 })
189 }
190 pub fn from_udp_socket(udp_socket: UdpSocket) -> io::Result<AsyncSocket> {
192 let socket = SystemSocket::from(udp_socket);
193 socket.set_nonblocking(true)?;
194 Ok(AsyncSocket {
195 inner: Arc::new(Async::new(socket)?),
196 })
197 }
198 pub async fn bind(&self, addr: SocketAddr) -> io::Result<()> {
200 let addr: SockAddr = SockAddr::from(addr);
201 self.inner.write_with(|inner| inner.bind(&addr)).await
203 }
204 pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
206 loop {
207 self.inner.writable().await?;
208 match self.inner.write_with(|inner| inner.send(buf)).await {
209 Ok(n) => return Ok(n),
210 Err(_) => continue,
211 }
212 }
213 }
214 pub async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
216 let target: SockAddr = SockAddr::from(target);
217 loop {
218 self.inner.writable().await?;
219 match self
220 .inner
221 .write_with(|inner| inner.send_to(buf, &target))
222 .await
223 {
224 Ok(n) => return Ok(n),
225 Err(_) => continue,
226 }
227 }
228 }
229 pub async fn receive(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
231 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
232 loop {
233 self.inner.readable().await?;
234 match self.inner.read_with(|inner| inner.recv(recv_buf)).await {
235 Ok(result) => return Ok(result),
236 Err(_) => continue,
237 }
238 }
239 }
240 pub async fn receive_from(&self, buf: &mut Vec<u8>) -> io::Result<(usize, SocketAddr)> {
242 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
243 loop {
244 self.inner.readable().await?;
245 match self
246 .inner
247 .read_with(|inner| inner.recv_from(recv_buf))
248 .await
249 {
250 Ok(result) => {
251 let (n, addr) = result;
252 match addr.as_socket() {
253 Some(addr) => return Ok((n, addr)),
254 None => continue,
255 }
256 }
257 Err(_) => continue,
258 }
259 }
260 }
261 pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
264 loop {
265 self.inner.writable().await?;
266 match self.inner.write_with(|inner| inner.send(buf)).await {
267 Ok(n) => return Ok(n),
268 Err(_) => continue,
269 }
270 }
271 }
272 pub async fn write_timeout(&self, buf: &[u8], timeout: Duration) -> io::Result<usize> {
275 loop {
276 self.inner.writable().await?;
277 match self
278 .inner
279 .write_with(|inner| {
280 match inner.set_write_timeout(Some(timeout)) {
281 Ok(_) => {}
282 Err(e) => return Err(e),
283 }
284 inner.send(buf)
285 })
286 .await
287 {
288 Ok(n) => return Ok(n),
289 Err(_) => continue,
290 }
291 }
292 }
293 pub async fn read(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
296 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
297 loop {
298 self.inner.readable().await?;
299 match self.inner.read_with(|inner| inner.recv(recv_buf)).await {
300 Ok(result) => return Ok(result),
301 Err(_) => continue,
302 }
303 }
304 }
305 pub async fn read_timeout(&self, buf: &mut Vec<u8>, timeout: Duration) -> io::Result<usize> {
308 let recv_buf = unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [MaybeUninit<u8>]) };
309 loop {
310 self.inner.readable().await?;
311 match self
312 .inner
313 .read_with(|inner| {
314 match inner.set_read_timeout(Some(timeout)) {
315 Ok(_) => {}
316 Err(e) => return Err(e),
317 }
318 inner.recv(recv_buf)
319 })
320 .await
321 {
322 Ok(result) => return Ok(result),
323 Err(_) => continue,
324 }
325 }
326 }
327 pub async fn ttl(&self, ip_version: IpVersion) -> io::Result<u32> {
329 match ip_version {
330 IpVersion::V4 => self.inner.read_with(|inner| inner.ttl()).await,
331 IpVersion::V6 => self.inner.read_with(|inner| inner.unicast_hops_v6()).await,
332 }
333 }
334 pub async fn set_ttl(&self, ttl: u32, ip_version: IpVersion) -> io::Result<()> {
336 match ip_version {
337 IpVersion::V4 => self.inner.write_with(|inner| inner.set_ttl(ttl)).await,
338 IpVersion::V6 => {
339 self.inner
340 .write_with(|inner| inner.set_unicast_hops_v6(ttl))
341 .await
342 }
343 }
344 }
345 pub async fn tos(&self) -> io::Result<u32> {
347 self.inner.read_with(|inner| inner.tos()).await
348 }
349 pub async fn set_tos(&self, tos: u32) -> io::Result<()> {
351 self.inner.write_with(|inner| inner.set_tos(tos)).await
352 }
353 pub async fn receive_tos(&self) -> io::Result<bool> {
355 self.inner.read_with(|inner| inner.recv_tos()).await
356 }
357 pub async fn set_receive_tos(&self, receive_tos: bool) -> io::Result<()> {
359 self.inner
360 .write_with(|inner| inner.set_recv_tos(receive_tos))
361 .await
362 }
363 pub async fn connect(&mut self, addr: &SocketAddr) -> io::Result<()> {
365 let addr: SockAddr = SockAddr::from(*addr);
366 self.inner.write_with(|inner| inner.connect(&addr)).await
367 }
368 pub async fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
371 let addr: SockAddr = SockAddr::from(*addr);
372 self.inner
373 .write_with(|inner| inner.connect_timeout(&addr, timeout))
374 .await
375 }
376 pub async fn listen(&self, backlog: i32) -> io::Result<()> {
378 self.inner.write_with(|inner| inner.listen(backlog)).await
379 }
380 pub async fn accept(&self) -> io::Result<(AsyncSocket, SocketAddr)> {
382 match self.inner.read_with(|inner| inner.accept()).await {
383 Ok((socket, addr)) => {
384 let socket = AsyncSocket {
385 inner: Arc::new(Async::new(socket)?),
386 };
387 Ok((socket, addr.as_socket().unwrap()))
388 }
389 Err(e) => Err(e),
390 }
391 }
392 pub async fn local_addr(&self) -> io::Result<SocketAddr> {
394 match self.inner.read_with(|inner| inner.local_addr()).await {
395 Ok(addr) => Ok(addr.as_socket().unwrap()),
396 Err(e) => Err(e),
397 }
398 }
399 pub async fn peer_addr(&self) -> io::Result<SocketAddr> {
401 match self.inner.read_with(|inner| inner.peer_addr()).await {
402 Ok(addr) => Ok(addr.as_socket().unwrap()),
403 Err(e) => Err(e),
404 }
405 }
406 pub async fn socket_type(&self) -> io::Result<crate::socket::SocketType> {
408 match self.inner.read_with(|inner| inner.r#type()).await {
409 Ok(socktype) => Ok(crate::socket::SocketType::from_type(socktype)),
410 Err(e) => Err(e),
411 }
412 }
413 pub async fn try_clone(&self) -> io::Result<AsyncSocket> {
415 match self.inner.read_with(|inner| inner.try_clone()).await {
416 Ok(socket) => Ok(AsyncSocket {
417 inner: Arc::new(Async::new(socket)?),
418 }),
419 Err(e) => Err(e),
420 }
421 }
422
423 #[cfg(not(target_os = "windows"))]
425 pub async fn is_nonblocking(&self) -> io::Result<bool> {
426 self.inner.read_with(|inner| inner.nonblocking()).await
427 }
428 pub async fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
430 self.inner
431 .write_with(|inner| inner.set_nonblocking(nonblocking))
432 .await
433 }
434 pub async fn shutdown(&self, how: Shutdown) -> io::Result<()> {
436 self.inner.write_with(|inner| inner.shutdown(how)).await
437 }
438 pub async fn is_broadcast(&self) -> io::Result<bool> {
440 self.inner.read_with(|inner| inner.broadcast()).await
441 }
442 pub async fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
446 self.inner
447 .write_with(|inner| inner.set_broadcast(broadcast))
448 .await
449 }
450 pub async fn get_error(&self) -> io::Result<Option<io::Error>> {
452 self.inner.read_with(|inner| inner.take_error()).await
453 }
454 pub async fn is_keepalive(&self) -> io::Result<bool> {
456 self.inner.read_with(|inner| inner.keepalive()).await
457 }
458 pub async fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
462 self.inner
463 .write_with(|inner| inner.set_keepalive(keepalive))
464 .await
465 }
466 pub async fn linger(&self) -> io::Result<Option<Duration>> {
468 self.inner.read_with(|inner| inner.linger()).await
469 }
470 pub async fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
472 self.inner.write_with(|inner| inner.set_linger(dur)).await
473 }
474 pub async fn receive_buffer_size(&self) -> io::Result<usize> {
476 self.inner.read_with(|inner| inner.recv_buffer_size()).await
477 }
478 pub async fn set_receive_buffer_size(&self, size: usize) -> io::Result<()> {
482 self.inner
483 .write_with(|inner| inner.set_recv_buffer_size(size))
484 .await
485 }
486 pub async fn receive_timeout(&self) -> io::Result<Option<Duration>> {
488 self.inner.read_with(|inner| inner.read_timeout()).await
489 }
490 pub async fn set_receive_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
492 self.inner
493 .write_with(|inner| inner.set_read_timeout(duration))
494 .await
495 }
496 pub async fn reuse_address(&self) -> io::Result<bool> {
498 self.inner.read_with(|inner| inner.reuse_address()).await
499 }
500 pub async fn set_reuse_address(&self, reuse: bool) -> io::Result<()> {
504 self.inner
505 .write_with(|inner| inner.set_reuse_address(reuse))
506 .await
507 }
508 pub async fn send_buffer_size(&self) -> io::Result<usize> {
510 self.inner.read_with(|inner| inner.send_buffer_size()).await
511 }
512 pub async fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
516 self.inner
517 .write_with(|inner| inner.set_send_buffer_size(size))
518 .await
519 }
520 pub async fn send_timeout(&self) -> io::Result<Option<Duration>> {
522 self.inner.read_with(|inner| inner.write_timeout()).await
523 }
524 pub async fn set_send_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
528 self.inner
529 .write_with(|inner| inner.set_write_timeout(duration))
530 .await
531 }
532 pub async fn is_ip_header_included(&self) -> io::Result<bool> {
534 self.inner.read_with(|inner| inner.header_included()).await
535 }
536 pub async fn set_ip_header_included(&self, include: bool) -> io::Result<()> {
538 self.inner
539 .write_with(|inner| inner.set_header_included(include))
540 .await
541 }
542 pub async fn is_nodelay(&self) -> io::Result<bool> {
544 self.inner.read_with(|inner| inner.nodelay()).await
545 }
546 pub async fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
550 self.inner
551 .write_with(|inner| inner.set_nodelay(nodelay))
552 .await
553 }
554 pub fn into_tcp_stream(&self) -> io::Result<TcpStream> {
557 let socket = Arc::try_unwrap(self.inner.clone())
558 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Failed to unwrap Arc"))?
559 .into_inner()?;
560 let tcp_stream = TcpStream::from(socket);
561 Ok(tcp_stream)
562 }
563 pub fn into_tcp_listener(&self) -> io::Result<TcpListener> {
566 let socket = Arc::try_unwrap(self.inner.clone())
567 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Failed to unwrap Arc"))?
568 .into_inner()?;
569 let tcp_listener = TcpListener::from(socket);
570 Ok(tcp_listener)
571 }
572 pub fn into_udp_socket(&self) -> io::Result<UdpSocket> {
575 let socket = Arc::try_unwrap(self.inner.clone())
576 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Failed to unwrap Arc"))?
577 .into_inner()?;
578 let udp_socket = UdpSocket::from(socket);
579 Ok(udp_socket)
580 }
581}
582
583#[derive(Clone, Debug)]
585pub struct AsyncTcpStream {
586 inner: Arc<Async<TcpStream>>,
587}
588
589impl AsyncTcpStream {
590 pub async fn connect(addr: SocketAddr) -> io::Result<Self> {
592 let stream = Async::<TcpStream>::connect(addr).await?;
593 Ok(AsyncTcpStream {
594 inner: Arc::new(stream),
595 })
596 }
597
598 pub async fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<Self> {
600 let stream = Async::<TcpStream>::connect(*addr)
601 .or(async {
602 Timer::after(timeout).await;
603 Err(std::io::ErrorKind::TimedOut.into())
604 })
605 .await?;
606 Ok(AsyncTcpStream {
607 inner: Arc::new(stream),
608 })
609 }
610
611 pub async fn local_addr(&self) -> io::Result<SocketAddr> {
613 self.inner.read_with(|inner| inner.local_addr()).await
614 }
615
616 pub async fn peer_addr(&self) -> io::Result<SocketAddr> {
618 self.inner.read_with(|inner| inner.peer_addr()).await
619 }
620
621 pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
623 self.inner.write_with(|mut inner| inner.write(buf)).await
624 }
625
626 pub async fn write_all(&self, buf: &[u8]) -> io::Result<()> {
628 self.inner
629 .write_with(|mut inner| inner.write_all(buf))
630 .await
631 }
632
633 pub async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
635 self.inner.read_with(|mut inner| inner.read(buf)).await
636 }
637
638 pub async fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
640 self.inner
641 .read_with(|mut inner| inner.read_to_end(buf))
642 .await
643 }
644
645 pub async fn read_to_end_timeout(
649 &self,
650 buf: &mut Vec<u8>,
651 timeout: Duration,
652 ) -> io::Result<usize> {
653 let mut io_error: io::Error = io::Error::new(io::ErrorKind::Other, "No response");
654 match self
655 .read_to_end(buf)
656 .or(async {
657 Timer::after(timeout).await;
658 Err(std::io::ErrorKind::TimedOut.into())
659 })
660 .await
661 {
662 Ok(_) => {}
663 Err(e) => {
664 io_error = e;
665 }
666 }
667 if buf.is_empty() {
668 Err(io_error)
669 } else {
670 Ok(buf.len())
671 }
672 }
673
674 pub async fn shutdown(&self, how: Shutdown) -> io::Result<()> {
676 self.inner.write_with(|inner| inner.shutdown(how)).await
677 }
678
679 pub async fn take_error(&self) -> io::Result<Option<io::Error>> {
681 self.inner.read_with(|inner| inner.take_error()).await
682 }
683 pub async fn try_clone(&self) -> io::Result<Self> {
685 let stream = self.inner.read_with(|inner| inner.try_clone()).await?;
686 Ok(AsyncTcpStream {
687 inner: Arc::new(Async::new(stream)?),
688 })
689 }
690
691 pub async fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
693 self.inner
694 .write_with(|inner| inner.set_read_timeout(dur))
695 .await
696 }
697
698 pub async fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
700 self.inner
701 .write_with(|inner| inner.set_write_timeout(dur))
702 .await
703 }
704
705 pub async fn read_timeout(&self) -> io::Result<Option<Duration>> {
707 self.inner.read_with(|inner| inner.read_timeout()).await
708 }
709
710 pub async fn write_timeout(&self) -> io::Result<Option<Duration>> {
712 self.inner.read_with(|inner| inner.write_timeout()).await
713 }
714
715 pub async fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
717 self.inner
718 .write_with(|inner| inner.set_nodelay(nodelay))
719 .await
720 }
721
722 pub async fn nodelay(&self) -> io::Result<bool> {
724 self.inner.read_with(|inner| inner.nodelay()).await
725 }
726
727 pub async fn set_ttl(&self, ttl: u32) -> io::Result<()> {
729 self.inner.write_with(|inner| inner.set_ttl(ttl)).await
730 }
731
732 pub async fn ttl(&self) -> io::Result<u32> {
734 self.inner.read_with(|inner| inner.ttl()).await
735 }
736
737 pub async fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
739 self.inner
740 .write_with(|inner| inner.set_nonblocking(nonblocking))
741 .await
742 }
743}