fluvio_async_tls/
connector.rs1use crate::common::tls_state::TlsState;
2
3use crate::client;
4
5use futures_io::{AsyncRead, AsyncWrite};
6use rustls::{ClientConfig, ClientConnection, OwnedTrustAnchor, RootCertStore, ServerName};
7use std::io;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::{convert::TryFrom, future::Future};
12
13#[derive(Clone)]
38pub struct TlsConnector {
39 inner: Arc<ClientConfig>,
40 #[cfg(feature = "early-data")]
41 early_data: bool,
42}
43
44impl From<Arc<ClientConfig>> for TlsConnector {
45 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
46 TlsConnector {
47 inner,
48 #[cfg(feature = "early-data")]
49 early_data: false,
50 }
51 }
52}
53
54impl From<ClientConfig> for TlsConnector {
55 fn from(inner: ClientConfig) -> TlsConnector {
56 TlsConnector {
57 inner: Arc::new(inner),
58 #[cfg(feature = "early-data")]
59 early_data: false,
60 }
61 }
62}
63
64impl Default for TlsConnector {
65 fn default() -> Self {
66 let mut root_store = RootCertStore::empty();
67 root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
68 OwnedTrustAnchor::from_subject_spki_name_constraints(
69 ta.subject,
70 ta.spki,
71 ta.name_constraints,
72 )
73 }));
74
75 let config = rustls::ClientConfig::builder()
76 .with_safe_defaults()
77 .with_root_certificates(root_store)
78 .with_no_client_auth();
79
80 Arc::new(config).into()
81 }
82}
83
84impl TlsConnector {
85 pub fn new() -> Self {
89 Default::default()
90 }
91
92 #[cfg(feature = "early-data")]
96 pub fn early_data(mut self, flag: bool) -> TlsConnector {
97 self.early_data = flag;
98 self
99 }
100
101 #[inline]
107 pub fn connect<IO>(&self, domain: impl AsRef<str>, stream: IO) -> Connect<IO>
108 where
109 IO: AsyncRead + AsyncWrite + Unpin,
110 {
111 self.connect_with(domain, stream, |_| ())
112 }
113
114 fn connect_with<IO, F>(&self, domain: impl AsRef<str>, stream: IO, f: F) -> Connect<IO>
117 where
118 IO: AsyncRead + AsyncWrite + Unpin,
119 F: FnOnce(&mut ClientConnection),
120 {
121 let domain = match ServerName::try_from(domain.as_ref()) {
122 Ok(domain) => domain,
123 Err(_) => {
124 return Connect(ConnectInner::Error(Some(io::Error::new(
125 io::ErrorKind::InvalidInput,
126 "invalid domain",
127 ))))
128 }
129 };
130
131 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
132 Ok(conn) => conn,
133 Err(_) => {
134 return Connect(ConnectInner::Error(Some(io::Error::new(
135 io::ErrorKind::Other,
136 "failed to create client connection",
137 ))))
138 }
139 };
140
141 f(&mut session);
142
143 #[cfg(not(feature = "early-data"))]
144 {
145 Connect(ConnectInner::Handshake(client::MidHandshake::Handshaking(
146 client::TlsStream {
147 session,
148 io: stream,
149 state: TlsState::Stream,
150 },
151 )))
152 }
153
154 #[cfg(feature = "early-data")]
155 {
156 Connect(ConnectInner::Handshake(if self.early_data {
157 client::MidHandshake::EarlyData(client::TlsStream {
158 session,
159 io: stream,
160 state: TlsState::EarlyData,
161 early_data: (0, Vec::new()),
162 })
163 } else {
164 client::MidHandshake::Handshaking(client::TlsStream {
165 session,
166 io: stream,
167 state: TlsState::Stream,
168 early_data: (0, Vec::new()),
169 })
170 }))
171 }
172 }
173}
174
175pub struct Connect<IO>(ConnectInner<IO>);
178
179enum ConnectInner<IO> {
180 Error(Option<io::Error>),
181 Handshake(client::MidHandshake<IO>),
182}
183
184impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
185 type Output = io::Result<client::TlsStream<IO>>;
186
187 #[inline]
188 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189 match self.0 {
190 ConnectInner::Error(ref mut err) => {
191 Poll::Ready(Err(err.take().expect("Polled twice after being Ready")))
192 }
193 ConnectInner::Handshake(ref mut handshake) => Pin::new(handshake).poll(cx),
194 }
195 }
196}