Skip to main content

tcp_stream/
futures.rs

1use crate::TLSConfig;
2
3use async_rs::traits::*;
4use cfg_if::cfg_if;
5use futures_io::{AsyncRead, AsyncWrite};
6use std::{
7    fmt,
8    io::{self, IoSlice, IoSliceMut},
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13#[cfg(feature = "native-tls-futures")]
14use crate::{NativeTlsAsyncStream, NativeTlsConnectorBuilder};
15#[cfg(feature = "openssl-futures")]
16use crate::{OpensslAsyncStream, OpensslConnector};
17#[cfg(feature = "rustls-futures")]
18use crate::{RustlsAsyncStream, RustlsConnector};
19
20/// Wrapper around plain or TLS async TCP streams
21#[non_exhaustive]
22pub enum AsyncTcpStream<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> {
23    /// Wrapper around plain async TCP stream
24    Plain(S),
25    #[cfg(feature = "native-tls-futures")]
26    /// Wrapper around a TLS async stream handled by native-tls
27    NativeTls(NativeTlsAsyncStream<S>),
28    #[cfg(feature = "openssl-futures")]
29    /// Wrapper around a TLS async stream handled by openssl
30    Openssl(OpensslAsyncStream<S>),
31    #[cfg(feature = "rustls-futures")]
32    /// Wrapper around a TLS async stream handled by rustls
33    Rustls(RustlsAsyncStream<S>),
34}
35
36impl<S: AsyncRead + AsyncWrite + fmt::Debug + Send + Unpin + 'static> fmt::Debug
37    for AsyncTcpStream<S>
38{
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        f.debug_struct("AsyncTcpStream").finish_non_exhaustive()
41    }
42}
43
44impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncTcpStream<S> {
45    /// Wrapper around `reactor_trait::TcpReactor::connect`
46    pub async fn connect<R: Reactor<TcpStream = S> + Sync, A: AsyncToSocketAddrs + Send>(
47        reactor: &R,
48        addr: A,
49    ) -> io::Result<Self> {
50        Ok(Self::Plain(reactor.tcp_connect(addr).await?))
51    }
52
53    /// Upgrade this plain stream to TLS using the feature-selected backend (rustls by default).
54    ///
55    /// Unlike the synchronous [`TcpStream::into_tls`](crate::TcpStream::into_tls), this completes the full handshake
56    /// asynchronously and returns `io::Result<Self>` directly — there is no
57    /// mid-handshake state to retry.
58    pub async fn into_tls(self, domain: &str, config: TLSConfig<'_, '_, '_>) -> io::Result<Self> {
59        into_tls_impl(self, domain, config).await
60    }
61
62    #[cfg(feature = "native-tls-futures")]
63    /// Enable TLS using native-tls
64    pub async fn into_native_tls(
65        self,
66        connector: NativeTlsConnectorBuilder,
67        domain: &str,
68    ) -> io::Result<Self> {
69        Ok(Self::NativeTls(
70            async_native_tls::TlsConnector::from(connector)
71                .connect(domain, self.into_plain()?)
72                .await
73                .map_err(io::Error::other)?,
74        ))
75    }
76
77    #[cfg(feature = "openssl-futures")]
78    /// Enable TLS using openssl
79    pub async fn into_openssl(
80        self,
81        connector: &OpensslConnector,
82        domain: &str,
83    ) -> io::Result<Self> {
84        let mut stream = async_openssl::SslStream::new(
85            connector.configure()?.into_ssl(domain)?,
86            self.into_plain()?,
87        )?;
88        Pin::new(&mut stream)
89            .connect()
90            .await
91            .map_err(io::Error::other)?;
92        Ok(Self::Openssl(stream))
93    }
94
95    #[cfg(feature = "rustls-futures")]
96    /// Enable TLS using rustls
97    pub async fn into_rustls(self, connector: &RustlsConnector, domain: &str) -> io::Result<Self> {
98        Ok(Self::Rustls(
99            connector.connect_async(domain, self.into_plain()?).await?,
100        ))
101    }
102
103    #[allow(irrefutable_let_patterns, dead_code)]
104    fn into_plain(self) -> io::Result<S> {
105        if let Self::Plain(plain) = self {
106            Ok(plain)
107        } else {
108            Err(io::Error::new(
109                io::ErrorKind::AlreadyExists,
110                "already a TLS stream",
111            ))
112        }
113    }
114}
115
116async fn into_tls_impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
117    s: AsyncTcpStream<S>,
118    domain: &str,
119    config: TLSConfig<'_, '_, '_>,
120) -> io::Result<AsyncTcpStream<S>> {
121    cfg_if! {
122        if #[cfg(feature = "rustls-futures")] {
123            crate::into_rustls_impl_async(s, domain, config).await
124        } else if #[cfg(feature = "openssl-futures")] {
125            crate::into_openssl_impl_async(s, domain, config).await
126        } else if #[cfg(feature = "native-tls-futures")] {
127            crate::into_native_tls_impl_async(s, domain, config).await
128        } else {
129            let _ = (domain, config);
130            Ok(AsyncTcpStream::Plain(s.into_plain()?))
131        }
132    }
133}
134
135impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for AsyncTcpStream<S> {
136    fn poll_read(
137        self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        buf: &mut [u8],
140    ) -> Poll<io::Result<usize>> {
141        fwd_pin_impl!(self, poll_read, cx, buf)
142    }
143
144    fn poll_read_vectored(
145        self: Pin<&mut Self>,
146        cx: &mut Context<'_>,
147        bufs: &mut [IoSliceMut<'_>],
148    ) -> Poll<io::Result<usize>> {
149        fwd_pin_impl!(self, poll_read_vectored, cx, bufs)
150    }
151}
152
153impl<S: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for AsyncTcpStream<S> {
154    fn poll_write(
155        self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        buf: &[u8],
158    ) -> Poll<io::Result<usize>> {
159        fwd_pin_impl!(self, poll_write, cx, buf)
160    }
161
162    fn poll_write_vectored(
163        self: Pin<&mut Self>,
164        cx: &mut Context<'_>,
165        bufs: &[IoSlice<'_>],
166    ) -> Poll<io::Result<usize>> {
167        fwd_pin_impl!(self, poll_write_vectored, cx, bufs)
168    }
169
170    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
171        fwd_pin_impl!(self, poll_flush, cx)
172    }
173
174    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175        fwd_pin_impl!(self, poll_close, cx)
176    }
177}