async_tls/
connector.rs

1use crate::common::tls_state::TlsState;
2
3use crate::client;
4
5use futures_io::{AsyncRead, AsyncWrite};
6use rustls::{ClientConfig, ClientConnection, OwnedTrustAnchor, RootCertStore, ServerName};
7use std::convert::TryFrom;
8use std::future::Future;
9use std::io;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14/// The TLS connecting part. The acceptor drives
15/// the client side of the TLS handshake process. It works
16/// on any asynchronous stream.
17///
18/// It provides a simple interface (`connect`), returning a future
19/// that will resolve when the handshake process completed. On
20/// success, it will hand you an async `TlsStream`.
21///
22/// To create a `TlsConnector` with a non-default configuation, create
23/// a `rusttls::ClientConfig` and call `.into()` on it.
24///
25/// ## Example
26///
27/// ```rust
28/// use async_tls::TlsConnector;
29///
30/// async_std::task::block_on(async {
31///     let connector = TlsConnector::default();
32///     let tcp_stream = async_std::net::TcpStream::connect("example.com").await?;
33///     let encrypted_stream = connector.connect("example.com", tcp_stream).await?;
34///
35///     Ok(()) as async_std::io::Result<()>
36/// });
37/// ```
38#[derive(Clone)]
39pub struct TlsConnector {
40    inner: Arc<ClientConfig>,
41    #[cfg(feature = "early-data")]
42    early_data: bool,
43}
44
45impl From<Arc<ClientConfig>> for TlsConnector {
46    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
47        TlsConnector {
48            inner,
49            #[cfg(feature = "early-data")]
50            early_data: false,
51        }
52    }
53}
54
55impl From<ClientConfig> for TlsConnector {
56    fn from(inner: ClientConfig) -> TlsConnector {
57        TlsConnector {
58            inner: Arc::new(inner),
59            #[cfg(feature = "early-data")]
60            early_data: false,
61        }
62    }
63}
64
65impl Default for TlsConnector {
66    fn default() -> Self {
67        let mut root_certs = RootCertStore::empty();
68        root_certs.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
69            OwnedTrustAnchor::from_subject_spki_name_constraints(
70                ta.subject,
71                ta.spki,
72                ta.name_constraints,
73            )
74        }));
75        let config = ClientConfig::builder()
76            .with_safe_defaults()
77            .with_root_certificates(root_certs)
78            .with_no_client_auth();
79        Arc::new(config).into()
80    }
81}
82
83impl TlsConnector {
84    /// Create a new TlsConnector with default configuration.
85    ///
86    /// This is the same as calling `TlsConnector::default()`.
87    pub fn new() -> Self {
88        Default::default()
89    }
90
91    /// Enable 0-RTT.
92    ///
93    /// You must also set `enable_early_data` to `true` in `ClientConfig`.
94    #[cfg(feature = "early-data")]
95    pub fn early_data(mut self, flag: bool) -> TlsConnector {
96        self.early_data = flag;
97        self
98    }
99
100    /// Connect to a server. `stream` can be any type implementing `AsyncRead` and `AsyncWrite`,
101    /// such as TcpStreams or Unix domain sockets.
102    ///
103    /// The function will return a `Connect` Future, representing the connecting part of a Tls
104    /// handshake. It will resolve when the handshake is over.
105    #[inline]
106    pub fn connect<'a, IO>(&self, domain: impl AsRef<str>, stream: IO) -> Connect<IO>
107    where
108        IO: AsyncRead + AsyncWrite + Unpin,
109    {
110        self.connect_with(domain, stream, |_| ())
111    }
112
113    // NOTE: Currently private, exposing ClientConnection exposes rusttls
114    // Early data should be exposed differently
115    fn connect_with<'a, IO, F>(&self, domain: impl AsRef<str>, stream: IO, f: F) -> Connect<IO>
116    where
117        IO: AsyncRead + AsyncWrite + Unpin,
118        F: FnOnce(&mut ClientConnection),
119    {
120        let domain = match ServerName::try_from(domain.as_ref()) {
121            Ok(domain) => domain,
122            Err(_) => {
123                return Connect(ConnectInner::Error(Some(io::Error::new(
124                    io::ErrorKind::InvalidInput,
125                    "invalid domain",
126                ))))
127            }
128        };
129
130        let mut session = match ClientConnection::new(self.inner.clone(), domain) {
131            Ok(session) => session,
132            Err(_) => {
133                return Connect(ConnectInner::Error(Some(io::Error::new(
134                    io::ErrorKind::Other,
135                    "invalid connection",
136                ))))
137            }
138        };
139
140        f(&mut session);
141
142        #[cfg(not(feature = "early-data"))]
143        {
144            Connect(ConnectInner::Handshake(client::MidHandshake::Handshaking(
145                client::TlsStream {
146                    session,
147                    io: stream,
148                    state: TlsState::Stream,
149                },
150            )))
151        }
152
153        #[cfg(feature = "early-data")]
154        {
155            Connect(ConnectInner::Handshake(if self.early_data {
156                client::MidHandshake::EarlyData(client::TlsStream {
157                    session,
158                    io: stream,
159                    state: TlsState::EarlyData,
160                    early_data: (0, Vec::new()),
161                })
162            } else {
163                client::MidHandshake::Handshaking(client::TlsStream {
164                    session,
165                    io: stream,
166                    state: TlsState::Stream,
167                    early_data: (0, Vec::new()),
168                })
169            }))
170        }
171    }
172}
173
174/// Future returned from `TlsConnector::connect` which will resolve
175/// once the connection handshake has finished.
176pub struct Connect<IO>(ConnectInner<IO>);
177
178enum ConnectInner<IO> {
179    Error(Option<io::Error>),
180    Handshake(client::MidHandshake<IO>),
181}
182
183impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
184    type Output = io::Result<client::TlsStream<IO>>;
185
186    #[inline]
187    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188        match self.0 {
189            ConnectInner::Error(ref mut err) => {
190                Poll::Ready(Err(err.take().expect("Polled twice after being Ready")))
191            }
192            ConnectInner::Handshake(ref mut handshake) => Pin::new(handshake).poll(cx),
193        }
194    }
195}