Skip to main content

aws_smithy_http_client/client/tls/
rustls_provider.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5use crate::client::tls::Provider;
6use rustls::crypto::CryptoProvider;
7
8/// Choice of underlying cryptography library (this only applies to rustls)
9#[derive(Debug, Clone)]
10#[non_exhaustive]
11pub enum CryptoMode {
12    /// Crypto based on [ring](https://github.com/briansmith/ring)
13    #[cfg(feature = "rustls-ring")]
14    Ring,
15    /// Crypto based on [aws-lc](https://github.com/aws/aws-lc-rs)
16    #[cfg(feature = "rustls-aws-lc")]
17    AwsLc,
18    /// FIPS compliant variant of [aws-lc](https://github.com/aws/aws-lc-rs)
19    #[cfg(feature = "rustls-aws-lc-fips")]
20    AwsLcFips,
21    /// Use a caller-supplied [`CryptoProvider`].
22    ///
23    /// Unlike the built-in modes, the cipher-suite restriction normally
24    /// applied by smithy-rs is skipped -- the caller is expected
25    /// to select the applicable cipher suites via the supplied provider.
26    ///
27    /// This variant is provided behind an `aws_sdk_unstable` cfg flag,
28    /// because the version of rustls may change in the future,
29    #[cfg(all(aws_sdk_unstable, feature = "__rustls"))]
30    Custom(CryptoProvider),
31}
32
33impl std::cmp::PartialEq for CryptoMode {
34    fn eq(&self, other: &CryptoMode) -> bool {
35        match (self, other) {
36            #[cfg(feature = "rustls-ring")]
37            (Self::Ring, Self::Ring) => true,
38            #[cfg(feature = "rustls-aws-lc")]
39            (Self::AwsLc, Self::AwsLc) => true,
40            #[cfg(feature = "rustls-aws-lc-fips")]
41            (Self::AwsLcFips, Self::AwsLcFips) => true,
42            // `CryptoProvider` does not implement PartialEq, so any
43            // `CryptoMode::Custom` value will always compare not equal to
44            // any other.
45            #[allow(unreachable_patterns)]
46            _ => false,
47        }
48    }
49}
50
51#[cfg(not(all(aws_sdk_unstable, feature = "__rustls")))]
52impl Eq for CryptoMode {}
53
54impl CryptoMode {
55    fn provider(self) -> CryptoProvider {
56        match self {
57            #[cfg(feature = "rustls-aws-lc")]
58            CryptoMode::AwsLc => rustls::crypto::aws_lc_rs::default_provider(),
59
60            #[cfg(feature = "rustls-ring")]
61            CryptoMode::Ring => rustls::crypto::ring::default_provider(),
62
63            #[cfg(feature = "rustls-aws-lc-fips")]
64            CryptoMode::AwsLcFips => {
65                let provider = rustls::crypto::default_fips_provider();
66                assert!(
67                    provider.fips(),
68                    "FIPS was requested but the provider did not support FIPS"
69                );
70                provider
71            }
72            #[cfg(all(aws_sdk_unstable, feature = "__rustls"))]
73            CryptoMode::Custom(provider) => provider,
74        }
75    }
76
77    #[cfg(all(aws_sdk_unstable, feature = "__rustls"))]
78    fn is_custom(&self) -> bool {
79        matches!(self, Self::Custom(_))
80    }
81
82    #[cfg(not(all(aws_sdk_unstable, feature = "__rustls")))]
83    fn is_custom(&self) -> bool {
84        false
85    }
86}
87
88impl Provider {
89    /// Create a TLS provider based on [rustls](https://github.com/rustls/rustls)
90    /// and the given [`CryptoMode`]
91    pub fn rustls(mode: CryptoMode) -> Provider {
92        Provider::Rustls(mode)
93    }
94}
95
96pub(crate) mod build_connector {
97    use crate::client::tls::rustls_provider::CryptoMode;
98    use crate::tls::TlsContext;
99    use client::connect::HttpConnector;
100    use hyper_util::client::legacy as client;
101    use rustls::crypto::CryptoProvider;
102    use rustls_native_certs::CertificateResult;
103    use rustls_pki_types::pem::PemObject;
104    use rustls_pki_types::CertificateDer;
105    use std::sync::Arc;
106    use std::sync::LazyLock;
107
108    /// Cached native certificates
109    ///
110    /// Creating a `with_native_roots()` hyper_rustls client re-loads system certs
111    /// each invocation (which can take 300ms on OSx). Cache the loaded certs
112    /// to avoid repeatedly incurring that cost.
113    pub(crate) static NATIVE_ROOTS: LazyLock<Vec<CertificateDer<'static>>> = LazyLock::new(|| {
114        let CertificateResult { certs, errors, .. } = rustls_native_certs::load_native_certs();
115        if !errors.is_empty() {
116            tracing::warn!("native root CA certificate loading errors: {errors:?}")
117        }
118
119        if certs.is_empty() {
120            tracing::warn!("no native root CA certificates found!");
121        }
122
123        // NOTE: unlike hyper-rustls::with_native_roots we don't validate here, we'll do that later
124        // for now we have a collection of certs that may or may not be valid.
125        certs
126    });
127
128    pub(crate) fn restrict_ciphers(base: CryptoProvider) -> CryptoProvider {
129        let suites = &[
130            rustls::CipherSuite::TLS13_AES_256_GCM_SHA384,
131            rustls::CipherSuite::TLS13_AES_128_GCM_SHA256,
132            // TLS1.2 suites
133            rustls::CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
134            rustls::CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
135            rustls::CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
136            rustls::CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
137            rustls::CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
138        ];
139        let supported_suites = suites
140            .iter()
141            .flat_map(|suite| {
142                base.cipher_suites
143                    .iter()
144                    .find(|s| &s.suite() == suite)
145                    .cloned()
146            })
147            .collect::<Vec<_>>();
148        CryptoProvider {
149            cipher_suites: supported_suites,
150            ..base
151        }
152    }
153
154    impl TlsContext {
155        pub(crate) fn rustls_root_certs(&self) -> rustls::RootCertStore {
156            let mut roots = rustls::RootCertStore::empty();
157            if self.trust_store.enable_native_roots {
158                let (valid, _invalid) = roots.add_parsable_certificates(NATIVE_ROOTS.clone());
159                debug_assert!(valid > 0, "TrustStore configured to enable native roots but no valid root certificates parsed!");
160            }
161
162            for pem_cert in &self.trust_store.custom_certs {
163                let ders = CertificateDer::pem_slice_iter(&pem_cert.0)
164                    .collect::<Result<Vec<_>, _>>()
165                    .expect("valid PEM certificate");
166                for cert in ders {
167                    roots.add(cert).expect("cert parsable")
168                }
169            }
170
171            roots
172        }
173    }
174
175    /// Create a rustls ClientConfig with smithy-rs defaults
176    ///
177    /// This centralizes the rustls ClientConfig creation logic to ensure
178    /// consistency between the main HTTPS connector and tunnel handlers.
179    pub(crate) fn create_rustls_client_config(
180        crypto_mode: CryptoMode,
181        tls_context: &TlsContext,
182    ) -> rustls::ClientConfig {
183        let skip_restrict = crypto_mode.is_custom();
184        let provider = if skip_restrict {
185            crypto_mode.provider()
186        } else {
187            restrict_ciphers(crypto_mode.provider())
188        };
189        let root_certs = tls_context.rustls_root_certs();
190        rustls::ClientConfig::builder_with_provider(Arc::new(provider))
191            .with_safe_default_protocol_versions()
192            .expect("Error with the TLS configuration. Please file a bug report under https://github.com/smithy-lang/smithy-rs/issues.")
193            .with_root_certificates(root_certs)
194            .with_no_client_auth()
195    }
196
197    pub(crate) fn wrap_connector<R>(
198        mut conn: HttpConnector<R>,
199        crypto_mode: CryptoMode,
200        tls_context: &TlsContext,
201        proxy_config: crate::client::proxy::ProxyConfig,
202    ) -> super::connect::RustTlsConnector<R> {
203        let client_config = create_rustls_client_config(crypto_mode, tls_context);
204        conn.enforce_http(false);
205        let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
206            .with_tls_config(client_config.clone())
207            .https_or_http()
208            .enable_http1()
209            .enable_http2()
210            .wrap_connector(conn);
211
212        super::connect::RustTlsConnector::new(https_connector, client_config, proxy_config)
213    }
214}
215
216pub(crate) mod connect {
217    use crate::client::connect::{Conn, Connecting};
218    use crate::client::proxy::ProxyConfig;
219    use aws_smithy_runtime_api::box_error::BoxError;
220    use http_1x::uri::Scheme;
221    use http_1x::Uri;
222    use hyper::rt::{Read, ReadBufCursor, Write};
223    use hyper_rustls::MaybeHttpsStream;
224    use hyper_util::client::legacy::connect::{Connected, Connection, HttpConnector};
225    use hyper_util::client::proxy::matcher::Matcher;
226    use hyper_util::rt::TokioIo;
227    use pin_project_lite::pin_project;
228    use std::error::Error;
229    use std::sync::Arc;
230    use std::{
231        io::{self, IoSlice},
232        pin::Pin,
233        task::{Context, Poll},
234    };
235    use tokio::io::{AsyncRead, AsyncWrite};
236    use tokio::net::TcpStream;
237    use tokio_rustls::client::TlsStream;
238    use tower::Service;
239
240    #[derive(Debug, Clone)]
241    pub(crate) struct RustTlsConnector<R> {
242        https: hyper_rustls::HttpsConnector<HttpConnector<R>>,
243        tls_config: Arc<rustls::ClientConfig>,
244        proxy_matcher: Option<Arc<Matcher>>, // Pre-computed for performance
245    }
246
247    impl<R> RustTlsConnector<R> {
248        pub(super) fn new(
249            https: hyper_rustls::HttpsConnector<HttpConnector<R>>,
250            tls_config: rustls::ClientConfig,
251            proxy_config: ProxyConfig,
252        ) -> Self {
253            // Pre-compute the proxy matcher once during construction
254            let proxy_matcher = if proxy_config.is_disabled() {
255                None
256            } else {
257                Some(Arc::new(proxy_config.into_hyper_util_matcher()))
258            };
259
260            Self {
261                https,
262                tls_config: Arc::new(tls_config),
263                proxy_matcher,
264            }
265        }
266    }
267
268    impl<R> Service<Uri> for RustTlsConnector<R>
269    where
270        R: Clone + Send + Sync + 'static,
271        R: Service<hyper_util::client::legacy::connect::dns::Name>,
272        R::Response: Iterator<Item = std::net::SocketAddr>,
273        R::Future: Send,
274        R::Error: Into<Box<dyn Error + Send + Sync>>,
275    {
276        type Response = Conn;
277        type Error = BoxError;
278        type Future = Connecting;
279
280        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
281            self.https.poll_ready(cx).map_err(Into::into)
282        }
283
284        fn call(&mut self, dst: Uri) -> Self::Future {
285            // Check if this request should be proxied using pre-computed matcher
286            let proxy_intercept = if let Some(ref matcher) = self.proxy_matcher {
287                matcher.intercept(&dst)
288            } else {
289                None
290            };
291
292            if let Some(intercept) = proxy_intercept {
293                if dst.scheme() == Some(&Scheme::HTTPS) {
294                    // HTTPS through HTTP proxy: Use CONNECT tunneling + manual TLS
295                    self.handle_https_through_proxy(dst, intercept)
296                } else {
297                    // HTTP through proxy: Direct connection to proxy
298                    self.handle_http_through_proxy(dst, intercept)
299                }
300            } else {
301                // Direct connection: Use the existing HTTPS connector
302                self.handle_direct_connection(dst)
303            }
304        }
305    }
306
307    impl<R> RustTlsConnector<R>
308    where
309        R: Clone + Send + Sync + 'static,
310        R: Service<hyper_util::client::legacy::connect::dns::Name>,
311        R::Response: Iterator<Item = std::net::SocketAddr>,
312        R::Future: Send,
313        R::Error: Into<Box<dyn Error + Send + Sync>>,
314    {
315        fn handle_direct_connection(&mut self, dst: Uri) -> Connecting {
316            let fut = self.https.call(dst);
317            Box::pin(async move {
318                let conn = fut.await?;
319                Ok(Conn {
320                    inner: Box::new(conn),
321                    is_proxy: false,
322                })
323            })
324        }
325
326        fn handle_http_through_proxy(
327            &mut self,
328            _dst: Uri,
329            intercept: hyper_util::client::proxy::matcher::Intercept,
330        ) -> Connecting {
331            // For HTTP through proxy, connect to the proxy and let it handle the request
332            let proxy_uri = intercept.uri().clone();
333            let fut = self.https.call(proxy_uri);
334            Box::pin(async move {
335                let conn = fut.await?;
336                Ok(Conn {
337                    inner: Box::new(conn),
338                    is_proxy: true,
339                })
340            })
341        }
342
343        fn handle_https_through_proxy(
344            &mut self,
345            dst: Uri,
346            intercept: hyper_util::client::proxy::matcher::Intercept,
347        ) -> Connecting {
348            use rustls_pki_types::ServerName;
349            // For HTTPS through HTTP proxy, we need to:
350            // 1. Establish CONNECT tunnel using the HTTPS connector
351            // 2. Perform manual TLS handshake over the tunneled stream
352
353            let tunnel = hyper_util::client::legacy::connect::proxy::Tunnel::new(
354                intercept.uri().clone(),
355                self.https.clone(),
356            );
357
358            // Configure tunnel with authentication if present
359            let mut tunnel = if let Some(auth) = intercept.basic_auth() {
360                tunnel.with_auth(auth.clone())
361            } else {
362                tunnel
363            };
364
365            let tls_config = self.tls_config.clone();
366            let dst_clone = dst.clone();
367
368            Box::pin(async move {
369                // Establish CONNECT tunnel
370                tracing::trace!("tunneling HTTPS over proxy");
371                let tunneled = tunnel
372                    .call(dst_clone.clone())
373                    .await
374                    .map_err(|e| BoxError::from(format!("CONNECT tunnel failed: {e}")))?;
375
376                // Stage 2: Manual TLS handshake over tunneled stream
377                let host = dst_clone
378                    .host()
379                    .ok_or("missing host in URI for TLS handshake")?;
380
381                let server_name = ServerName::try_from(host.to_owned()).map_err(|e| {
382                    BoxError::from(format!("invalid server name for TLS handshake: {e}"))
383                })?;
384
385                let tls_connector = tokio_rustls::TlsConnector::from(tls_config)
386                    .connect(server_name, TokioIo::new(tunneled))
387                    .await?;
388
389                Ok(Conn {
390                    inner: Box::new(RustTlsConn {
391                        inner: TokioIo::new(tls_connector),
392                    }),
393                    is_proxy: true,
394                })
395            })
396        }
397    }
398
399    pin_project! {
400        pub(crate) struct RustTlsConn<T> {
401            #[pin] pub(super) inner: TokioIo<TlsStream<T>>
402        }
403    }
404
405    impl Connection for RustTlsConn<TokioIo<TokioIo<TcpStream>>> {
406        fn connected(&self) -> Connected {
407            if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
408                self.inner
409                    .inner()
410                    .get_ref()
411                    .0
412                    .inner()
413                    .connected()
414                    .negotiated_h2()
415            } else {
416                self.inner.inner().get_ref().0.inner().connected()
417            }
418        }
419    }
420
421    impl Connection for RustTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
422        fn connected(&self) -> Connected {
423            if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
424                self.inner
425                    .inner()
426                    .get_ref()
427                    .0
428                    .inner()
429                    .connected()
430                    .negotiated_h2()
431            } else {
432                self.inner.inner().get_ref().0.inner().connected()
433            }
434        }
435    }
436    impl<T: AsyncRead + AsyncWrite + Unpin> Read for RustTlsConn<T> {
437        fn poll_read(
438            self: Pin<&mut Self>,
439            cx: &mut Context<'_>,
440            buf: ReadBufCursor<'_>,
441        ) -> Poll<tokio::io::Result<()>> {
442            let this = self.project();
443            Read::poll_read(this.inner, cx, buf)
444        }
445    }
446
447    impl<T: AsyncRead + AsyncWrite + Unpin> Write for RustTlsConn<T> {
448        fn poll_write(
449            self: Pin<&mut Self>,
450            cx: &mut Context<'_>,
451            buf: &[u8],
452        ) -> Poll<Result<usize, tokio::io::Error>> {
453            let this = self.project();
454            Write::poll_write(this.inner, cx, buf)
455        }
456
457        fn poll_write_vectored(
458            self: Pin<&mut Self>,
459            cx: &mut Context<'_>,
460            bufs: &[IoSlice<'_>],
461        ) -> Poll<Result<usize, io::Error>> {
462            let this = self.project();
463            Write::poll_write_vectored(this.inner, cx, bufs)
464        }
465
466        fn is_write_vectored(&self) -> bool {
467            self.inner.is_write_vectored()
468        }
469
470        fn poll_flush(
471            self: Pin<&mut Self>,
472            cx: &mut Context<'_>,
473        ) -> Poll<Result<(), tokio::io::Error>> {
474            let this = self.project();
475            Write::poll_flush(this.inner, cx)
476        }
477
478        fn poll_shutdown(
479            self: Pin<&mut Self>,
480            cx: &mut Context<'_>,
481        ) -> Poll<Result<(), tokio::io::Error>> {
482            let this = self.project();
483            Write::poll_shutdown(this.inner, cx)
484        }
485    }
486}