aws_smithy_http_client/client/tls/
rustls_provider.rs1use crate::client::tls::Provider;
6use rustls::crypto::CryptoProvider;
7
8#[derive(Debug, Clone)]
10#[non_exhaustive]
11pub enum CryptoMode {
12 #[cfg(feature = "rustls-ring")]
14 Ring,
15 #[cfg(feature = "rustls-aws-lc")]
17 AwsLc,
18 #[cfg(feature = "rustls-aws-lc-fips")]
20 AwsLcFips,
21 #[cfg(all(aws_sdk_unstable, feature = "__rustls"))]
30 Custom(CryptoProvider),
31}
32
33impl std::cmp::PartialEq for CryptoMode {
34 fn eq(&self, other: &CryptoMode) -> bool {
35 match (self, other) {
36 #[cfg(feature = "rustls-ring")]
37 (Self::Ring, Self::Ring) => true,
38 #[cfg(feature = "rustls-aws-lc")]
39 (Self::AwsLc, Self::AwsLc) => true,
40 #[cfg(feature = "rustls-aws-lc-fips")]
41 (Self::AwsLcFips, Self::AwsLcFips) => true,
42 #[allow(unreachable_patterns)]
46 _ => false,
47 }
48 }
49}
50
51#[cfg(not(all(aws_sdk_unstable, feature = "__rustls")))]
52impl Eq for CryptoMode {}
53
54impl CryptoMode {
55 fn provider(self) -> CryptoProvider {
56 match self {
57 #[cfg(feature = "rustls-aws-lc")]
58 CryptoMode::AwsLc => rustls::crypto::aws_lc_rs::default_provider(),
59
60 #[cfg(feature = "rustls-ring")]
61 CryptoMode::Ring => rustls::crypto::ring::default_provider(),
62
63 #[cfg(feature = "rustls-aws-lc-fips")]
64 CryptoMode::AwsLcFips => {
65 let provider = rustls::crypto::default_fips_provider();
66 assert!(
67 provider.fips(),
68 "FIPS was requested but the provider did not support FIPS"
69 );
70 provider
71 }
72 #[cfg(all(aws_sdk_unstable, feature = "__rustls"))]
73 CryptoMode::Custom(provider) => provider,
74 }
75 }
76
77 #[cfg(all(aws_sdk_unstable, feature = "__rustls"))]
78 fn is_custom(&self) -> bool {
79 matches!(self, Self::Custom(_))
80 }
81
82 #[cfg(not(all(aws_sdk_unstable, feature = "__rustls")))]
83 fn is_custom(&self) -> bool {
84 false
85 }
86}
87
88impl Provider {
89 pub fn rustls(mode: CryptoMode) -> Provider {
92 Provider::Rustls(mode)
93 }
94}
95
96pub(crate) mod build_connector {
97 use crate::client::tls::rustls_provider::CryptoMode;
98 use crate::tls::TlsContext;
99 use client::connect::HttpConnector;
100 use hyper_util::client::legacy as client;
101 use rustls::crypto::CryptoProvider;
102 use rustls_native_certs::CertificateResult;
103 use rustls_pki_types::pem::PemObject;
104 use rustls_pki_types::CertificateDer;
105 use std::sync::Arc;
106 use std::sync::LazyLock;
107
108 pub(crate) static NATIVE_ROOTS: LazyLock<Vec<CertificateDer<'static>>> = LazyLock::new(|| {
114 let CertificateResult { certs, errors, .. } = rustls_native_certs::load_native_certs();
115 if !errors.is_empty() {
116 tracing::warn!("native root CA certificate loading errors: {errors:?}")
117 }
118
119 if certs.is_empty() {
120 tracing::warn!("no native root CA certificates found!");
121 }
122
123 certs
126 });
127
128 pub(crate) fn restrict_ciphers(base: CryptoProvider) -> CryptoProvider {
129 let suites = &[
130 rustls::CipherSuite::TLS13_AES_256_GCM_SHA384,
131 rustls::CipherSuite::TLS13_AES_128_GCM_SHA256,
132 rustls::CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
134 rustls::CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
135 rustls::CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
136 rustls::CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
137 rustls::CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
138 ];
139 let supported_suites = suites
140 .iter()
141 .flat_map(|suite| {
142 base.cipher_suites
143 .iter()
144 .find(|s| &s.suite() == suite)
145 .cloned()
146 })
147 .collect::<Vec<_>>();
148 CryptoProvider {
149 cipher_suites: supported_suites,
150 ..base
151 }
152 }
153
154 impl TlsContext {
155 pub(crate) fn rustls_root_certs(&self) -> rustls::RootCertStore {
156 let mut roots = rustls::RootCertStore::empty();
157 if self.trust_store.enable_native_roots {
158 let (valid, _invalid) = roots.add_parsable_certificates(NATIVE_ROOTS.clone());
159 debug_assert!(valid > 0, "TrustStore configured to enable native roots but no valid root certificates parsed!");
160 }
161
162 for pem_cert in &self.trust_store.custom_certs {
163 let ders = CertificateDer::pem_slice_iter(&pem_cert.0)
164 .collect::<Result<Vec<_>, _>>()
165 .expect("valid PEM certificate");
166 for cert in ders {
167 roots.add(cert).expect("cert parsable")
168 }
169 }
170
171 roots
172 }
173 }
174
175 pub(crate) fn create_rustls_client_config(
180 crypto_mode: CryptoMode,
181 tls_context: &TlsContext,
182 ) -> rustls::ClientConfig {
183 let skip_restrict = crypto_mode.is_custom();
184 let provider = if skip_restrict {
185 crypto_mode.provider()
186 } else {
187 restrict_ciphers(crypto_mode.provider())
188 };
189 let root_certs = tls_context.rustls_root_certs();
190 rustls::ClientConfig::builder_with_provider(Arc::new(provider))
191 .with_safe_default_protocol_versions()
192 .expect("Error with the TLS configuration. Please file a bug report under https://github.com/smithy-lang/smithy-rs/issues.")
193 .with_root_certificates(root_certs)
194 .with_no_client_auth()
195 }
196
197 pub(crate) fn wrap_connector<R>(
198 mut conn: HttpConnector<R>,
199 crypto_mode: CryptoMode,
200 tls_context: &TlsContext,
201 proxy_config: crate::client::proxy::ProxyConfig,
202 ) -> super::connect::RustTlsConnector<R> {
203 let client_config = create_rustls_client_config(crypto_mode, tls_context);
204 conn.enforce_http(false);
205 let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
206 .with_tls_config(client_config.clone())
207 .https_or_http()
208 .enable_http1()
209 .enable_http2()
210 .wrap_connector(conn);
211
212 super::connect::RustTlsConnector::new(https_connector, client_config, proxy_config)
213 }
214}
215
216pub(crate) mod connect {
217 use crate::client::connect::{Conn, Connecting};
218 use crate::client::proxy::ProxyConfig;
219 use aws_smithy_runtime_api::box_error::BoxError;
220 use http_1x::uri::Scheme;
221 use http_1x::Uri;
222 use hyper::rt::{Read, ReadBufCursor, Write};
223 use hyper_rustls::MaybeHttpsStream;
224 use hyper_util::client::legacy::connect::{Connected, Connection, HttpConnector};
225 use hyper_util::client::proxy::matcher::Matcher;
226 use hyper_util::rt::TokioIo;
227 use pin_project_lite::pin_project;
228 use std::error::Error;
229 use std::sync::Arc;
230 use std::{
231 io::{self, IoSlice},
232 pin::Pin,
233 task::{Context, Poll},
234 };
235 use tokio::io::{AsyncRead, AsyncWrite};
236 use tokio::net::TcpStream;
237 use tokio_rustls::client::TlsStream;
238 use tower::Service;
239
240 #[derive(Debug, Clone)]
241 pub(crate) struct RustTlsConnector<R> {
242 https: hyper_rustls::HttpsConnector<HttpConnector<R>>,
243 tls_config: Arc<rustls::ClientConfig>,
244 proxy_matcher: Option<Arc<Matcher>>, }
246
247 impl<R> RustTlsConnector<R> {
248 pub(super) fn new(
249 https: hyper_rustls::HttpsConnector<HttpConnector<R>>,
250 tls_config: rustls::ClientConfig,
251 proxy_config: ProxyConfig,
252 ) -> Self {
253 let proxy_matcher = if proxy_config.is_disabled() {
255 None
256 } else {
257 Some(Arc::new(proxy_config.into_hyper_util_matcher()))
258 };
259
260 Self {
261 https,
262 tls_config: Arc::new(tls_config),
263 proxy_matcher,
264 }
265 }
266 }
267
268 impl<R> Service<Uri> for RustTlsConnector<R>
269 where
270 R: Clone + Send + Sync + 'static,
271 R: Service<hyper_util::client::legacy::connect::dns::Name>,
272 R::Response: Iterator<Item = std::net::SocketAddr>,
273 R::Future: Send,
274 R::Error: Into<Box<dyn Error + Send + Sync>>,
275 {
276 type Response = Conn;
277 type Error = BoxError;
278 type Future = Connecting;
279
280 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
281 self.https.poll_ready(cx).map_err(Into::into)
282 }
283
284 fn call(&mut self, dst: Uri) -> Self::Future {
285 let proxy_intercept = if let Some(ref matcher) = self.proxy_matcher {
287 matcher.intercept(&dst)
288 } else {
289 None
290 };
291
292 if let Some(intercept) = proxy_intercept {
293 if dst.scheme() == Some(&Scheme::HTTPS) {
294 self.handle_https_through_proxy(dst, intercept)
296 } else {
297 self.handle_http_through_proxy(dst, intercept)
299 }
300 } else {
301 self.handle_direct_connection(dst)
303 }
304 }
305 }
306
307 impl<R> RustTlsConnector<R>
308 where
309 R: Clone + Send + Sync + 'static,
310 R: Service<hyper_util::client::legacy::connect::dns::Name>,
311 R::Response: Iterator<Item = std::net::SocketAddr>,
312 R::Future: Send,
313 R::Error: Into<Box<dyn Error + Send + Sync>>,
314 {
315 fn handle_direct_connection(&mut self, dst: Uri) -> Connecting {
316 let fut = self.https.call(dst);
317 Box::pin(async move {
318 let conn = fut.await?;
319 Ok(Conn {
320 inner: Box::new(conn),
321 is_proxy: false,
322 })
323 })
324 }
325
326 fn handle_http_through_proxy(
327 &mut self,
328 _dst: Uri,
329 intercept: hyper_util::client::proxy::matcher::Intercept,
330 ) -> Connecting {
331 let proxy_uri = intercept.uri().clone();
333 let fut = self.https.call(proxy_uri);
334 Box::pin(async move {
335 let conn = fut.await?;
336 Ok(Conn {
337 inner: Box::new(conn),
338 is_proxy: true,
339 })
340 })
341 }
342
343 fn handle_https_through_proxy(
344 &mut self,
345 dst: Uri,
346 intercept: hyper_util::client::proxy::matcher::Intercept,
347 ) -> Connecting {
348 use rustls_pki_types::ServerName;
349 let tunnel = hyper_util::client::legacy::connect::proxy::Tunnel::new(
354 intercept.uri().clone(),
355 self.https.clone(),
356 );
357
358 let mut tunnel = if let Some(auth) = intercept.basic_auth() {
360 tunnel.with_auth(auth.clone())
361 } else {
362 tunnel
363 };
364
365 let tls_config = self.tls_config.clone();
366 let dst_clone = dst.clone();
367
368 Box::pin(async move {
369 tracing::trace!("tunneling HTTPS over proxy");
371 let tunneled = tunnel
372 .call(dst_clone.clone())
373 .await
374 .map_err(|e| BoxError::from(format!("CONNECT tunnel failed: {e}")))?;
375
376 let host = dst_clone
378 .host()
379 .ok_or("missing host in URI for TLS handshake")?;
380
381 let server_name = ServerName::try_from(host.to_owned()).map_err(|e| {
382 BoxError::from(format!("invalid server name for TLS handshake: {e}"))
383 })?;
384
385 let tls_connector = tokio_rustls::TlsConnector::from(tls_config)
386 .connect(server_name, TokioIo::new(tunneled))
387 .await?;
388
389 Ok(Conn {
390 inner: Box::new(RustTlsConn {
391 inner: TokioIo::new(tls_connector),
392 }),
393 is_proxy: true,
394 })
395 })
396 }
397 }
398
399 pin_project! {
400 pub(crate) struct RustTlsConn<T> {
401 #[pin] pub(super) inner: TokioIo<TlsStream<T>>
402 }
403 }
404
405 impl Connection for RustTlsConn<TokioIo<TokioIo<TcpStream>>> {
406 fn connected(&self) -> Connected {
407 if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
408 self.inner
409 .inner()
410 .get_ref()
411 .0
412 .inner()
413 .connected()
414 .negotiated_h2()
415 } else {
416 self.inner.inner().get_ref().0.inner().connected()
417 }
418 }
419 }
420
421 impl Connection for RustTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
422 fn connected(&self) -> Connected {
423 if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
424 self.inner
425 .inner()
426 .get_ref()
427 .0
428 .inner()
429 .connected()
430 .negotiated_h2()
431 } else {
432 self.inner.inner().get_ref().0.inner().connected()
433 }
434 }
435 }
436 impl<T: AsyncRead + AsyncWrite + Unpin> Read for RustTlsConn<T> {
437 fn poll_read(
438 self: Pin<&mut Self>,
439 cx: &mut Context<'_>,
440 buf: ReadBufCursor<'_>,
441 ) -> Poll<tokio::io::Result<()>> {
442 let this = self.project();
443 Read::poll_read(this.inner, cx, buf)
444 }
445 }
446
447 impl<T: AsyncRead + AsyncWrite + Unpin> Write for RustTlsConn<T> {
448 fn poll_write(
449 self: Pin<&mut Self>,
450 cx: &mut Context<'_>,
451 buf: &[u8],
452 ) -> Poll<Result<usize, tokio::io::Error>> {
453 let this = self.project();
454 Write::poll_write(this.inner, cx, buf)
455 }
456
457 fn poll_write_vectored(
458 self: Pin<&mut Self>,
459 cx: &mut Context<'_>,
460 bufs: &[IoSlice<'_>],
461 ) -> Poll<Result<usize, io::Error>> {
462 let this = self.project();
463 Write::poll_write_vectored(this.inner, cx, bufs)
464 }
465
466 fn is_write_vectored(&self) -> bool {
467 self.inner.is_write_vectored()
468 }
469
470 fn poll_flush(
471 self: Pin<&mut Self>,
472 cx: &mut Context<'_>,
473 ) -> Poll<Result<(), tokio::io::Error>> {
474 let this = self.project();
475 Write::poll_flush(this.inner, cx)
476 }
477
478 fn poll_shutdown(
479 self: Pin<&mut Self>,
480 cx: &mut Context<'_>,
481 ) -> Poll<Result<(), tokio::io::Error>> {
482 let this = self.project();
483 Write::poll_shutdown(this.inner, cx)
484 }
485 }
486}