fluvio_async_tls/
connector.rs

1use crate::common::tls_state::TlsState;
2
3use crate::client;
4
5use futures_io::{AsyncRead, AsyncWrite};
6use rustls::{ClientConfig, ClientConnection, OwnedTrustAnchor, RootCertStore, ServerName};
7use std::io;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::{convert::TryFrom, future::Future};
12
13/// The TLS connecting part. The acceptor drives
14/// the client side of the TLS handshake process. It works
15/// on any asynchronous stream.
16///
17/// It provides a simple interface (`connect`), returning a future
18/// that will resolve when the handshake process completed. On
19/// success, it will hand you an async `TlsStream`.
20///
21/// To create a `TlsConnector` with a non-default configuation, create
22/// a `rusttls::ClientConfig` and call `.into()` on it.
23///
24/// ## Example
25///
26/// ```rust
27/// use fluvio_async_tls::TlsConnector;
28///
29/// async_std::task::block_on(async {
30///     let connector = TlsConnector::default();
31///     let tcp_stream = async_std::net::TcpStream::connect("example.com").await?;
32///     let encrypted_stream = connector.connect("example.com", tcp_stream).await?;
33///
34///     Ok(()) as async_std::io::Result<()>
35/// });
36/// ```
37#[derive(Clone)]
38pub struct TlsConnector {
39    inner: Arc<ClientConfig>,
40    #[cfg(feature = "early-data")]
41    early_data: bool,
42}
43
44impl From<Arc<ClientConfig>> for TlsConnector {
45    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
46        TlsConnector {
47            inner,
48            #[cfg(feature = "early-data")]
49            early_data: false,
50        }
51    }
52}
53
54impl From<ClientConfig> for TlsConnector {
55    fn from(inner: ClientConfig) -> TlsConnector {
56        TlsConnector {
57            inner: Arc::new(inner),
58            #[cfg(feature = "early-data")]
59            early_data: false,
60        }
61    }
62}
63
64impl Default for TlsConnector {
65    fn default() -> Self {
66        let mut root_store = RootCertStore::empty();
67        root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
68            OwnedTrustAnchor::from_subject_spki_name_constraints(
69                ta.subject,
70                ta.spki,
71                ta.name_constraints,
72            )
73        }));
74
75        let config = rustls::ClientConfig::builder()
76            .with_safe_defaults()
77            .with_root_certificates(root_store)
78            .with_no_client_auth();
79
80        Arc::new(config).into()
81    }
82}
83
84impl TlsConnector {
85    /// Create a new TlsConnector with default configuration.
86    ///
87    /// This is the same as calling `TlsConnector::default()`.
88    pub fn new() -> Self {
89        Default::default()
90    }
91
92    /// Enable 0-RTT.
93    ///
94    /// You must also set `enable_early_data` to `true` in `ClientConfig`.
95    #[cfg(feature = "early-data")]
96    pub fn early_data(mut self, flag: bool) -> TlsConnector {
97        self.early_data = flag;
98        self
99    }
100
101    /// Connect to a server. `stream` can be any type implementing `AsyncRead` and `AsyncWrite`,
102    /// such as TcpStreams or Unix domain sockets.
103    ///
104    /// The function will return a `Connect` Future, representing the connecting part of a Tls
105    /// handshake. It will resolve when the handshake is over.
106    #[inline]
107    pub fn connect<IO>(&self, domain: impl AsRef<str>, stream: IO) -> Connect<IO>
108    where
109        IO: AsyncRead + AsyncWrite + Unpin,
110    {
111        self.connect_with(domain, stream, |_| ())
112    }
113
114    // NOTE: Currently private, exposing ClientSession exposes rusttls
115    // Early data should be exposed differently
116    fn connect_with<IO, F>(&self, domain: impl AsRef<str>, stream: IO, f: F) -> Connect<IO>
117    where
118        IO: AsyncRead + AsyncWrite + Unpin,
119        F: FnOnce(&mut ClientConnection),
120    {
121        let domain = match ServerName::try_from(domain.as_ref()) {
122            Ok(domain) => domain,
123            Err(_) => {
124                return Connect(ConnectInner::Error(Some(io::Error::new(
125                    io::ErrorKind::InvalidInput,
126                    "invalid domain",
127                ))))
128            }
129        };
130
131        let mut session = match ClientConnection::new(self.inner.clone(), domain) {
132            Ok(conn) => conn,
133            Err(_) => {
134                return Connect(ConnectInner::Error(Some(io::Error::new(
135                    io::ErrorKind::Other,
136                    "failed to create client connection",
137                ))))
138            }
139        };
140
141        f(&mut session);
142
143        #[cfg(not(feature = "early-data"))]
144        {
145            Connect(ConnectInner::Handshake(client::MidHandshake::Handshaking(
146                client::TlsStream {
147                    session,
148                    io: stream,
149                    state: TlsState::Stream,
150                },
151            )))
152        }
153
154        #[cfg(feature = "early-data")]
155        {
156            Connect(ConnectInner::Handshake(if self.early_data {
157                client::MidHandshake::EarlyData(client::TlsStream {
158                    session,
159                    io: stream,
160                    state: TlsState::EarlyData,
161                    early_data: (0, Vec::new()),
162                })
163            } else {
164                client::MidHandshake::Handshaking(client::TlsStream {
165                    session,
166                    io: stream,
167                    state: TlsState::Stream,
168                    early_data: (0, Vec::new()),
169                })
170            }))
171        }
172    }
173}
174
175/// Future returned from `TlsConnector::connect` which will resolve
176/// once the connection handshake has finished.
177pub struct Connect<IO>(ConnectInner<IO>);
178
179enum ConnectInner<IO> {
180    Error(Option<io::Error>),
181    Handshake(client::MidHandshake<IO>),
182}
183
184impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
185    type Output = io::Result<client::TlsStream<IO>>;
186
187    #[inline]
188    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189        match self.0 {
190            ConnectInner::Error(ref mut err) => {
191                Poll::Ready(Err(err.take().expect("Polled twice after being Ready")))
192            }
193            ConnectInner::Handshake(ref mut handshake) => Pin::new(handshake).poll(cx),
194        }
195    }
196}