arta_async_std/net/
tcp_stream.rs

1use crate::AsyncStdGlobalRuntime;
2use arta::net::RuntimeTcpStream;
3use cfg_if::cfg_if;
4use futures::{prelude::Future, AsyncRead, AsyncWrite, TryFutureExt};
5use socket2::SockRef;
6use std::{
7    net::SocketAddr,
8    pin::Pin,
9    task::{Context, Poll},
10    time::Duration,
11};
12
13cfg_if! {
14    if #[cfg(windows)] {
15        impl std::os::windows::io::AsRawSocket for AsyncStdTcpStream {
16            fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
17                self.inner.as_raw_socket()
18            }
19        }
20
21        impl std::os::windows::io::AsSocket for AsyncStdTcpStream {
22            fn as_socket(&self) -> std::os::windows::io::BorrowedSocket<'_> {
23                let raw_socket = std::os::windows::io::AsRawSocket::as_raw_socket(self);
24                unsafe { std::os::windows::io::BorrowedSocket::borrow_raw(raw_socket) }
25            }
26        }
27
28        impl From<std::os::windows::io::OwnedSocket> for AsyncStdTcpStream {
29            fn from(socket: std::os::windows::io::OwnedSocket) -> Self {
30                Self {
31                    inner: async_std::net::TcpStream::from(std::net::TcpStream::from(socket))
32                }
33            }
34        }
35    } else if #[cfg(any(unix, target_os = "wasi"))] {
36        impl std::os::fd::AsRawFd for AsyncStdTcpStream {
37            fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
38                self.inner.as_raw_fd()
39            }
40        }
41
42        impl std::os::fd::AsFd for AsyncStdTcpStream {
43            fn as_fd(&self) -> std::os::unix::prelude::BorrowedFd<'_> {
44                let raw_fd = std::os::fd::AsRawFd::as_raw_fd(self);
45                unsafe { std::os::fd::BorrowedFd::borrow_raw(raw_fd) }
46            }
47        }
48
49        impl From<std::os::fd::OwnedFd> for AsyncStdTcpStream {
50            fn from(fd: std::os::fd::OwnedFd) -> Self {
51                Self {
52                    inner: async_std::net::TcpStream::from(std::net::TcpStream::from(fd))
53                }
54            }
55        }
56    }
57}
58
59/// Async-std specific [`RuntimeTcpStream`] implementation.
60pub struct AsyncStdTcpStream {
61    pub(super) inner: async_std::net::TcpStream,
62}
63
64impl RuntimeTcpStream for AsyncStdTcpStream {
65    type Runtime = AsyncStdGlobalRuntime;
66
67    fn connect(
68        runtime: &Self::Runtime,
69        addr: impl arta::net::ToSocketAddrs<Self::Runtime>,
70    ) -> impl Future<Output = std::io::Result<Self>> + Send
71    where
72        Self: Sized,
73    {
74        addr.for_each_resolved_addr_until_success(runtime, |addr| {
75            async_std::net::TcpStream::connect(addr).map_ok(|stream| Self { inner: stream })
76        })
77    }
78
79    fn local_addr(&self) -> std::io::Result<SocketAddr> {
80        self.inner.local_addr()
81    }
82
83    fn peer_addr(&self) -> std::io::Result<SocketAddr> {
84        self.inner.peer_addr()
85    }
86
87    #[cfg(not(target_os = "wasi"))]
88    fn linger(&self) -> std::io::Result<Option<Duration>> {
89        SockRef::from(self).linger()
90    }
91
92    #[cfg(not(target_os = "wasi"))]
93    fn set_linger(&self, linger: Option<Duration>) -> std::io::Result<()> {
94        SockRef::from(self).set_linger(linger)
95    }
96
97    fn nodelay(&self) -> std::io::Result<bool> {
98        self.inner.nodelay()
99    }
100
101    fn set_nodelay(&self, is_enabled: bool) -> std::io::Result<()> {
102        self.inner.set_nodelay(is_enabled)
103    }
104
105    fn ttl(&self) -> std::io::Result<u32> {
106        self.inner.ttl()
107    }
108
109    fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
110        self.inner.set_ttl(ttl)
111    }
112
113    fn peek(&self, buf: &mut [u8]) -> impl Future<Output = std::io::Result<usize>> + Send {
114        self.inner.peek(buf)
115    }
116
117    fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
118        SockRef::from(self).take_error()
119    }
120}
121
122impl AsyncRead for AsyncStdTcpStream {
123    fn poll_read(
124        mut self: Pin<&mut Self>,
125        cx: &mut Context<'_>,
126        buf: &mut [u8],
127    ) -> Poll<std::io::Result<usize>> {
128        async_std::io::Read::poll_read(Pin::new(&mut self.inner), cx, buf)
129    }
130
131    fn poll_read_vectored(
132        mut self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134        bufs: &mut [std::io::IoSliceMut<'_>],
135    ) -> Poll<std::io::Result<usize>> {
136        async_std::io::Read::poll_read_vectored(Pin::new(&mut self.inner), cx, bufs)
137    }
138}
139
140impl AsyncWrite for AsyncStdTcpStream {
141    fn poll_write(
142        mut self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144        buf: &[u8],
145    ) -> Poll<std::io::Result<usize>> {
146        async_std::io::Write::poll_write(Pin::new(&mut self.inner), cx, buf)
147    }
148
149    fn poll_write_vectored(
150        mut self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        bufs: &[std::io::IoSlice<'_>],
153    ) -> Poll<std::io::Result<usize>> {
154        async_std::io::Write::poll_write_vectored(Pin::new(&mut self.inner), cx, bufs)
155    }
156
157    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
158        async_std::io::Write::poll_flush(Pin::new(&mut self.inner), cx)
159    }
160
161    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
162        async_std::io::Write::poll_close(Pin::new(&mut self.inner), cx)
163    }
164}