1use crate::TLSConfig;
2
3use async_rs::traits::*;
4use cfg_if::cfg_if;
5use futures_io::{AsyncRead, AsyncWrite};
6use std::{
7 io::{self, IoSlice, IoSliceMut},
8 pin::Pin,
9 task::{Context, Poll},
10};
11
12#[cfg(feature = "native-tls-futures")]
13use crate::{NativeTlsAsyncStream, NativeTlsConnectorBuilder};
14#[cfg(feature = "openssl-futures")]
15use crate::{OpensslAsyncStream, OpensslConnector};
16#[cfg(feature = "rustls-futures")]
17use crate::{RustlsAsyncStream, RustlsConnector, RustlsConnectorConfig};
18
19#[non_exhaustive]
21pub enum AsyncTcpStream<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
22 Plain(S),
24 #[cfg(feature = "native-tls-futures")]
25 NativeTls(NativeTlsAsyncStream<S>),
27 #[cfg(feature = "openssl-futures")]
28 Openssl(OpensslAsyncStream<S>),
30 #[cfg(feature = "rustls-futures")]
31 Rustls(RustlsAsyncStream<S>),
33}
34
35impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncTcpStream<S> {
36 pub async fn connect<R: Reactor<TcpStream = S> + Sync, A: AsyncToSocketAddrs + Send>(
38 reactor: &R,
39 addr: A,
40 ) -> io::Result<Self> {
41 Ok(Self::Plain(reactor.tcp_connect(addr).await?))
42 }
43
44 pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
46 into_tls_impl(self, domain, config).await
47 }
48
49 #[cfg(feature = "native-tls-futures")]
50 pub async fn into_native_tls(
52 self,
53 connector: NativeTlsConnectorBuilder,
54 domain: &str,
55 ) -> io::Result<Self> {
56 Ok(Self::NativeTls(
57 async_native_tls::TlsConnector::from(connector)
58 .connect(domain, self.into_plain()?)
59 .await
60 .map_err(io::Error::other)?,
61 ))
62 }
63
64 #[cfg(feature = "openssl-futures")]
65 pub async fn into_openssl(
67 self,
68 connector: &OpensslConnector,
69 domain: &str,
70 ) -> io::Result<Self> {
71 let mut stream = async_openssl::SslStream::new(
72 connector.configure()?.into_ssl(domain)?,
73 self.into_plain()?,
74 )?;
75 Pin::new(&mut stream)
76 .connect()
77 .await
78 .map_err(io::Error::other)?;
79 Ok(Self::Openssl(stream))
80 }
81
82 #[cfg(feature = "rustls-futures")]
83 pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
85 Ok(Self::Rustls(
86 connector.connect_async(domain, self.into_plain()?).await?,
87 ))
88 }
89
90 #[allow(irrefutable_let_patterns, dead_code)]
91 fn into_plain(self) -> io::Result<S> {
92 if let Self::Plain(plain) = self {
93 Ok(plain)
94 } else {
95 Err(io::Error::new(
96 io::ErrorKind::AlreadyExists,
97 "already a TLS stream",
98 ))
99 }
100 }
101}
102
103async fn into_tls_impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
104 s: AsyncTcpStream<S>,
105 domain: &str,
106 config: TLSConfig<'_, '_, '_>,
107) -> io::Result<AsyncTcpStream<S>> {
108 cfg_if! {
109 if #[cfg(all(feature = "rustls-futures", feature = "rustls-platform-verifier"))] {
110 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_platform_verifier(), domain, config).await
111 } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-native-certs"))] {
112 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_native_certs()?, domain, config).await
113 } else if #[cfg(all(feature = "rustls-futures", feature = "rustls-webpki-roots-certs"))] {
114 crate::into_rustls_impl_async(s, RustlsConnectorConfig::new_with_webpki_roots_certs(), domain, config).await
115 } else if #[cfg(feature = "rustls-futures")] {
116 crate::into_rustls_impl_async(s, RustlsConnectorConfig::default(), domain, config).await
117 } else if #[cfg(feature = "openssl-futures")] {
118 crate::into_openssl_impl_async(s, domain, config).await
119 } else if #[cfg(feature = "native-tls-futures")] {
120 crate::into_native_tls_impl_async(s, domain, config).await
121 } else {
122 let _ = (domain, config);
123 Ok(AsyncTcpStream::Plain(s.into_plain()?))
124 }
125 }
126}
127
128macro_rules! fwd_impl {
129 ($self:ident, $method:ident, $($args:expr),*) => {
130 match $self.get_mut() {
131 Self::Plain(plain) => Pin::new(plain).$method($($args),*),
132 #[cfg(feature = "native-tls-futures")]
133 Self::NativeTls(tls) => Pin::new(tls).$method($($args),*),
134 #[cfg(feature = "openssl-futures")]
135 Self::Openssl(tls) => Pin::new(tls).$method($($args),*),
136 #[cfg(feature = "rustls-futures")]
137 Self::Rustls(tls) => Pin::new(tls).$method($($args),*),
138 }
139 };
140}
141
142impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for AsyncTcpStream<S> {
143 fn poll_read(
144 self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 buf: &mut [u8],
147 ) -> Poll<io::Result<usize>> {
148 fwd_impl!(self, poll_read, cx, buf)
149 }
150
151 fn poll_read_vectored(
152 self: Pin<&mut Self>,
153 cx: &mut Context<'_>,
154 bufs: &mut [IoSliceMut<'_>],
155 ) -> Poll<io::Result<usize>> {
156 fwd_impl!(self, poll_read_vectored, cx, bufs)
157 }
158}
159
160impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for AsyncTcpStream<S> {
161 fn poll_write(
162 self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 buf: &[u8],
165 ) -> Poll<io::Result<usize>> {
166 fwd_impl!(self, poll_write, cx, buf)
167 }
168
169 fn poll_write_vectored(
170 self: Pin<&mut Self>,
171 cx: &mut Context<'_>,
172 bufs: &[IoSlice<'_>],
173 ) -> Poll<io::Result<usize>> {
174 fwd_impl!(self, poll_write_vectored, cx, bufs)
175 }
176
177 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178 fwd_impl!(self, poll_flush, cx)
179 }
180
181 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
182 fwd_impl!(self, poll_close, cx)
183 }
184}