Skip to main content

libdd_common/connector/
mod.rs

1// Copyright 2021-Present Datadog, Inc. https://www.datadoghq.com/
2// SPDX-License-Identifier: Apache-2.0
3
4use futures::future::BoxFuture;
5use futures::{future, FutureExt};
6use hyper_util::client::legacy::connect;
7
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::LazyLock;
11use std::task::{Context, Poll};
12
13#[cfg(unix)]
14pub mod uds;
15
16pub mod named_pipe;
17
18pub mod errors;
19
20mod conn_stream;
21use conn_stream::{ConnStream, ConnStreamError};
22
23#[derive(Clone)]
24pub enum Connector {
25    Http(connect::HttpConnector),
26    #[cfg(feature = "https")]
27    Https(hyper_rustls::HttpsConnector<connect::HttpConnector>),
28}
29
30static DEFAULT_CONNECTOR: LazyLock<Connector> = LazyLock::new(Connector::new);
31
32impl Default for Connector {
33    fn default() -> Self {
34        DEFAULT_CONNECTOR.clone()
35    }
36}
37
38impl Connector {
39    /// Make sure this function is not called frequently. Fetching the root certificates is an
40    /// expensive operation. Access the globally cached connector via Connector::default().
41    fn new() -> Self {
42        #[cfg(feature = "https")]
43        {
44            #[cfg(feature = "use_webpki_roots")]
45            let https_connector_fn = https::build_https_connector_with_webpki_roots;
46            #[cfg(not(feature = "use_webpki_roots"))]
47            let https_connector_fn = https::build_https_connector;
48
49            match https_connector_fn() {
50                Ok(connector) => Connector::Https(connector),
51                Err(_) => Connector::Http(connect::HttpConnector::new()),
52            }
53        }
54        #[cfg(not(feature = "https"))]
55        {
56            Connector::Http(connect::HttpConnector::new())
57        }
58    }
59
60    fn build_conn_stream(
61        &mut self,
62        uri: hyper::Uri,
63        require_tls: bool,
64    ) -> BoxFuture<'static, Result<ConnStream, ConnStreamError>> {
65        match self {
66            Self::Http(c) => {
67                if require_tls {
68                    future::err::<ConnStream, ConnStreamError>(
69                        errors::Error::CannotEstablishTlsConnection.into(),
70                    )
71                    .boxed()
72                } else {
73                    ConnStream::from_http_connector_with_uri(c, uri).boxed()
74                }
75            }
76            #[cfg(feature = "https")]
77            Self::Https(c) => {
78                ConnStream::from_https_connector_with_uri(c, uri, require_tls).boxed()
79            }
80        }
81    }
82}
83
84#[cfg(feature = "https")]
85mod https {
86    #[cfg(feature = "use_webpki_roots")]
87    use hyper_rustls::ConfigBuilderExt;
88
89    use rustls::ClientConfig;
90
91    /// When using aws-lc-rs, rustls needs to be initialized with the default CryptoProvider;
92    /// sometimes this is done as a side-effect of other operations, but we need to ensure it
93    /// happens here.  On non-unix platforms, ddcommon uses `ring` instead, which handles this
94    /// at rustls initialization.
95    /// In fips mode we expect someone to have done this already.
96    #[cfg(any(not(feature = "fips"), coverage))]
97    fn ensure_crypto_provider_initialized() {
98        use std::sync::Once;
99
100        static INIT_CRYPTO_PROVIDER: Once = Once::new();
101
102        INIT_CRYPTO_PROVIDER.call_once(|| {
103            #[cfg(unix)]
104            let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
105            #[cfg(not(unix))]
106            let _ = rustls::crypto::ring::default_provider().install_default();
107        });
108    }
109
110    // This actually needs to be done by the user somewhere in their own main. This will only
111    // be active on Unix platforms
112    #[cfg(all(feature = "fips", not(coverage)))]
113    fn ensure_crypto_provider_initialized() {}
114
115    #[cfg(feature = "use_webpki_roots")]
116    pub(super) fn build_https_connector_with_webpki_roots() -> anyhow::Result<
117        hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
118    > {
119        ensure_crypto_provider_initialized(); // One-time initialization of a crypto provider if needed
120
121        let client_config = ClientConfig::builder()
122            .with_webpki_roots()
123            .with_no_client_auth();
124        Ok(hyper_rustls::HttpsConnectorBuilder::new()
125            .with_tls_config(client_config)
126            .https_or_http()
127            .enable_http1()
128            .build())
129    }
130
131    #[cfg(not(feature = "use_webpki_roots"))]
132    pub(super) fn build_https_connector() -> anyhow::Result<
133        hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
134    > {
135        ensure_crypto_provider_initialized(); // One-time initialization of a crypto provider if needed
136
137        let certs = load_root_certs()?;
138        let client_config = ClientConfig::builder()
139            .with_root_certificates(certs)
140            .with_no_client_auth();
141        Ok(hyper_rustls::HttpsConnectorBuilder::new()
142            .with_tls_config(client_config)
143            .https_or_http()
144            .enable_http1()
145            .build())
146    }
147
148    #[cfg(not(feature = "use_webpki_roots"))]
149    fn load_root_certs() -> anyhow::Result<rustls::RootCertStore> {
150        use super::errors;
151
152        let mut roots = rustls::RootCertStore::empty();
153
154        let cert_result = rustls_native_certs::load_native_certs();
155        if cert_result.certs.is_empty() {
156            if let Some(err) = cert_result.errors.into_iter().next() {
157                return Err(err.into());
158            }
159        }
160        // TODO(paullgdfc): log errors even if there are valid certs, instead of ignoring them
161
162        for cert in cert_result.certs {
163            //TODO: log when invalid cert is loaded
164            roots.add(cert).ok();
165        }
166        if roots.is_empty() {
167            return Err(errors::Error::NoValidCertifacteRootsFound.into());
168        }
169        Ok(roots)
170    }
171}
172
173impl tower_service::Service<hyper::Uri> for Connector {
174    type Response = ConnStream;
175    type Error = ConnStreamError;
176
177    // This lint gets lifted in this place in a newer version, see:
178    // https://github.com/rust-lang/rust-clippy/pull/8030
179    #[allow(clippy::type_complexity)]
180    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
181
182    fn call(&mut self, uri: hyper::Uri) -> Self::Future {
183        match uri.scheme_str() {
184            Some("unix") => conn_stream::ConnStream::from_uds_uri(uri).boxed(),
185            Some("windows") => conn_stream::ConnStream::from_named_pipe_uri(uri).boxed(),
186            Some("https") => self.build_conn_stream(uri, true),
187            _ => self.build_conn_stream(uri, false),
188        }
189    }
190
191    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192        match self {
193            Connector::Http(c) => c.poll_ready(cx).map_err(|e| e.into()),
194            #[cfg(feature = "https")]
195            Connector::Https(c) => c.poll_ready(cx),
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use crate::http_common;
203    use std::env;
204    use tower_service::Service;
205
206    use super::*;
207
208    #[test]
209    #[cfg_attr(miri, ignore)]
210    #[cfg(not(feature = "use_webpki_roots"))]
211    /// Verify that the Connector type implements the correct bound Connect + Clone
212    /// to be able to use the hyper::Client
213    fn test_hyper_client_from_connector() {
214        let _ = http_common::new_default_client();
215    }
216
217    #[test]
218    #[cfg_attr(miri, ignore)]
219    #[cfg(feature = "use_webpki_roots")]
220    fn test_hyper_client_from_connector_with_webpki_roots() {
221        let _ = http_common::new_default_client();
222    }
223
224    #[tokio::test]
225    #[cfg_attr(miri, ignore)]
226    #[cfg(not(feature = "use_webpki_roots"))]
227    /// Verify that Connector will only allow non tls connections if root certificates
228    /// are not found
229    async fn test_missing_root_certificates_only_allow_http_connections() {
230        const ENV_SSL_CERT_FILE: &str = "SSL_CERT_FILE";
231        const ENV_SSL_CERT_DIR: &str = "SSL_CERT_DIR";
232        let old_value = env::var(ENV_SSL_CERT_FILE).unwrap_or_default();
233        let old_dir_value = env::var(ENV_SSL_CERT_DIR).unwrap_or_default();
234
235        env::set_var(ENV_SSL_CERT_FILE, "this/folder/does/not/exist");
236        env::set_var(ENV_SSL_CERT_DIR, "this/folder/does/not/exist");
237        let mut connector = Connector::new();
238
239        assert!(matches!(connector, Connector::Http(_)));
240
241        let stream = connector
242            .call(hyper::Uri::from_static("https://example.com"))
243            .await
244            .unwrap_err();
245
246        assert_eq!(
247            *stream.downcast::<errors::Error>().unwrap(),
248            errors::Error::CannotEstablishTlsConnection
249        );
250
251        env::set_var(ENV_SSL_CERT_FILE, old_value);
252        env::set_var(ENV_SSL_CERT_DIR, old_dir_value);
253    }
254
255    #[tokio::test]
256    #[cfg_attr(miri, ignore)]
257    #[cfg(feature = "use_webpki_roots")]
258    #[cfg(feature = "https")]
259    /// Verify that Connector will allow tls connections if root certificates
260    /// are not found but can use webpki certificates
261    async fn test_missing_root_certificates_use_webpki_certificates() {
262        const ENV_SSL_CERT_FILE: &str = "SSL_CERT_FILE";
263        let old_value = env::var(ENV_SSL_CERT_FILE).unwrap_or_default();
264
265        env::set_var(ENV_SSL_CERT_FILE, "this/folder/does/not/exist");
266        let mut connector = Connector::new();
267        assert!(matches!(connector, Connector::Https(_)));
268
269        let stream = connector
270            .call(hyper::Uri::from_static("https://example.com"))
271            .await;
272
273        assert!(stream.is_ok());
274
275        env::set_var(ENV_SSL_CERT_FILE, old_value);
276    }
277}