tcp_stream/
futures.rs

1use crate::TLSConfig;
2
3use cfg_if::cfg_if;
4use futures_io::{AsyncRead, AsyncWrite};
5use reactor_trait::{AsyncIOHandle, AsyncToSocketAddrs, TcpReactor};
6use std::{
7    io::{self, IoSlice, IoSliceMut},
8    pin::{Pin, pin},
9    task::{Context, Poll},
10};
11
12#[cfg(feature = "native-tls-futures")]
13use crate::NativeTlsConnectorBuilder;
14#[cfg(feature = "openssl-futures")]
15use crate::OpenSslConnector;
16#[cfg(feature = "rustls-futures")]
17use crate::{RustlsConnector, RustlsConnectorConfig};
18
19type AsyncStream = Pin<Box<dyn AsyncIOHandle + Send>>;
20
21/// Wrapper around plain or TLS async TCP streams
22pub enum AsyncTcpStream {
23    /// Wrapper around plain async TCP stream
24    Plain(AsyncStream),
25    /// Wrapper around a TLS async TCP stream
26    TLS(AsyncStream),
27}
28
29impl AsyncTcpStream {
30    /// Wrapper around `reactor_trait::TcpReactor::connect`
31    pub async fn connect<R: TcpReactor, A: AsyncToSocketAddrs>(
32        _reactor: R,
33        addr: A,
34    ) -> io::Result<Self> {
35        let addrs = addr.to_socket_addrs().await?;
36        let mut err = None;
37        for addr in addrs {
38            match R::connect(addr).await {
39                Ok(stream) => return Ok(Self::Plain(stream.into())),
40                Err(e) => err = Some(e),
41            }
42        }
43        Err(err.unwrap_or_else(|| {
44            io::Error::new(io::ErrorKind::AddrNotAvailable, "couldn't resolve host")
45        }))
46    }
47
48    /// Enable TLS
49    pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
50        into_tls_impl(self, domain, config).await
51    }
52
53    #[cfg(feature = "native-tls-futures")]
54    /// Enable TLS using native-tls
55    pub async fn into_native_tls(
56        self,
57        connector: NativeTlsConnectorBuilder,
58        domain: &str,
59    ) -> io::Result<Self> {
60        Ok(Self::TLS(Box::pin(
61            async_native_tls::TlsConnector::from(connector)
62                .connect(domain, self.into_plain()?)
63                .await
64                .map_err(io::Error::other)?,
65        )))
66    }
67
68    #[cfg(feature = "openssl-futures")]
69    /// Enable TLS using openssl
70    pub async fn into_openssl(
71        self,
72        connector: &OpenSslConnector,
73        domain: &str,
74    ) -> io::Result<Self> {
75        let mut stream = async_openssl::SslStream::new(
76            connector.configure()?.into_ssl(domain)?,
77            self.into_plain()?,
78        )?;
79        Pin::new(&mut stream)
80            .connect()
81            .await
82            .map_err(io::Error::other)?;
83        Ok(Self::TLS(Box::pin(stream)))
84    }
85
86    #[cfg(feature = "rustls-futures")]
87    /// Enable TLS using rustls
88    pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
89        Ok(Self::TLS(Box::pin(
90            connector.connect_async(domain, self.into_plain()?).await?,
91        )))
92    }
93
94    #[allow(irrefutable_let_patterns, dead_code)]
95    fn into_plain(self) -> io::Result<AsyncStream> {
96        if let AsyncTcpStream::Plain(plain) = self {
97            Ok(plain)
98        } else {
99            Err(io::Error::new(
100                io::ErrorKind::AlreadyExists,
101                "already a TLS stream",
102            ))
103        }
104    }
105}
106
107cfg_if! {
108    if #[cfg(all(feature = "rustls-futures", feature = "rustls-native-certs"))] {
109        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
110            crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config).await
111        }
112    } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-webpki-roots-certs"))] {
113        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
114            crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config).await
115        }
116    } else if #[cfg(feature = "rustls-futures")] {
117        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
118            crate::into_rustls_impl_async(s, RustlsConnectorConfig::default(), domain, config).await
119        }
120    } else if #[cfg(feature = "openssl-futures")] {
121        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
122            crate::into_openssl_impl_async(s, domain, config).await
123        }
124    } else if #[cfg(feature = "native-tls-futures")] {
125        async fn into_tls_impl(s: AsyncTcpStream, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
126            crate::into_native_tls_impl_async(s, domain, config).await
127        }
128    } else {
129        async fn into_tls_impl(s: AsyncTcpStream, _domain: &str, _: TLSConfig<'_, '_, '_>) -> io::Result<AsyncTcpStream> {
130            Ok(AsyncTcpStream::Plain(s.into_plain()?))
131        }
132    }
133}
134
135impl AsyncRead for AsyncTcpStream {
136    fn poll_read(
137        self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        buf: &mut [u8],
140    ) -> Poll<io::Result<usize>> {
141        match self.get_mut() {
142            Self::Plain(plain) => pin!(plain).poll_read(cx, buf),
143            Self::TLS(tls) => pin!(tls).poll_read(cx, buf),
144        }
145    }
146
147    fn poll_read_vectored(
148        self: Pin<&mut Self>,
149        cx: &mut Context<'_>,
150        bufs: &mut [IoSliceMut<'_>],
151    ) -> Poll<io::Result<usize>> {
152        match self.get_mut() {
153            Self::Plain(plain) => pin!(plain).poll_read_vectored(cx, bufs),
154            Self::TLS(tls) => pin!(tls).poll_read_vectored(cx, bufs),
155        }
156    }
157}
158
159impl AsyncWrite for AsyncTcpStream {
160    fn poll_write(
161        self: Pin<&mut Self>,
162        cx: &mut Context<'_>,
163        buf: &[u8],
164    ) -> Poll<io::Result<usize>> {
165        match self.get_mut() {
166            Self::Plain(plain) => pin!(plain).poll_write(cx, buf),
167            Self::TLS(tls) => pin!(tls).poll_write(cx, buf),
168        }
169    }
170
171    fn poll_write_vectored(
172        self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174        bufs: &[IoSlice<'_>],
175    ) -> Poll<io::Result<usize>> {
176        match self.get_mut() {
177            Self::Plain(plain) => pin!(plain).poll_write_vectored(cx, bufs),
178            Self::TLS(tls) => pin!(tls).poll_write_vectored(cx, bufs),
179        }
180    }
181
182    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
183        match self.get_mut() {
184            Self::Plain(plain) => pin!(plain).poll_flush(cx),
185            Self::TLS(tls) => pin!(tls).poll_flush(cx),
186        }
187    }
188
189    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
190        match self.get_mut() {
191            Self::Plain(plain) => pin!(plain).poll_close(cx),
192            Self::TLS(tls) => pin!(tls).poll_close(cx),
193        }
194    }
195}