conjure_runtime/raw/
default.rs

1// Copyright 2020 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::raw::{BuildRawClient, RawBody, Service};
15use crate::service::proxy::connector::ProxyConnectorLayer;
16use crate::service::proxy::{ProxyConfig, ProxyConnectorService};
17use crate::service::timeout::{TimeoutLayer, TimeoutService};
18use crate::service::tls_metrics::{TlsMetricsLayer, TlsMetricsService};
19use crate::{builder, Builder};
20use bytes::Bytes;
21use conjure_error::Error;
22use http::{Request, Response};
23use http_body::{Body, Frame, SizeHint};
24use hyper::body::Incoming;
25use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
26use hyper_util::client::legacy::connect::HttpConnector;
27use hyper_util::client::legacy::Client;
28use hyper_util::rt::{TokioExecutor, TokioTimer};
29use pin_project::pin_project;
30use rustls::crypto::ring;
31use rustls::pki_types::{CertificateDer, PrivateKeyDer};
32use rustls::{ClientConfig, RootCertStore};
33use rustls_pemfile::Item;
34use std::error;
35use std::fmt;
36use std::fs::File;
37use std::io::BufReader;
38use std::marker::PhantomPinned;
39use std::path::Path;
40use std::pin::Pin;
41use std::sync::Arc;
42use std::task::{Context, Poll};
43use std::time::Duration;
44use tower_layer::Layer;
45use webpki_roots::TLS_SERVER_ROOTS;
46
47// This is pretty arbitrary - I just grabbed it from some Cloudflare blog post.
48const TCP_KEEPALIVE: Duration = Duration::from_secs(3 * 60);
49// Most servers time out idle connections after 60 seconds, so we'll set the client timeout a bit below that.
50const HTTP_KEEPALIVE: Duration = Duration::from_secs(55);
51
52type ConjureConnector =
53    TlsMetricsService<HttpsConnector<ProxyConnectorService<TimeoutService<HttpConnector>>>>;
54
55/// The default raw client builder used by `conjure_runtime`.
56#[derive(Copy, Clone)]
57pub struct DefaultRawClientBuilder;
58
59impl BuildRawClient for DefaultRawClientBuilder {
60    type RawClient = DefaultRawClient;
61
62    fn build_raw_client(
63        &self,
64        builder: &Builder<builder::Complete<Self>>,
65    ) -> Result<Self::RawClient, Error> {
66        let mut connector = HttpConnector::new();
67        connector.enforce_http(false);
68        connector.set_nodelay(true);
69        connector.set_keepalive(Some(TCP_KEEPALIVE));
70        connector.set_connect_timeout(Some(builder.get_connect_timeout()));
71
72        let proxy = ProxyConfig::from_config(builder.get_proxy())?;
73
74        let connector = TimeoutLayer::new(builder).layer(connector);
75        let connector = ProxyConnectorLayer::new(&proxy).layer(connector);
76
77        let mut roots = RootCertStore::empty();
78        roots.extend(TLS_SERVER_ROOTS.iter().cloned());
79
80        if let Some(ca_file) = builder.get_security().ca_file() {
81            let certs = load_certs_file(ca_file)?;
82            roots.add_parsable_certificates(certs);
83        }
84        let client_config = ClientConfig::builder_with_provider(Arc::new(ring::default_provider()))
85            .with_safe_default_protocol_versions()
86            .map_err(Error::internal_safe)?
87            .with_root_certificates(roots);
88
89        let client_config = match (
90            builder.get_security().cert_file(),
91            builder.get_security().key_file(),
92        ) {
93            (Some(cert_file), Some(key_file)) => {
94                let cert_chain = load_certs_file(cert_file)?;
95                let private_key = load_private_key(key_file)?;
96
97                client_config
98                    .with_client_auth_cert(cert_chain, private_key)
99                    .map_err(Error::internal_safe)?
100            }
101            (None, None) => client_config.with_no_client_auth(),
102            _ => {
103                return Err(Error::internal_safe(
104                    "neither or both of key-file and cert-file must be set in the client \
105                    security config",
106                ));
107            }
108        };
109
110        let connector = HttpsConnectorBuilder::new()
111            .with_tls_config(client_config)
112            .https_or_http()
113            .enable_all_versions()
114            .wrap_connector(connector);
115        let connector = TlsMetricsLayer::new(builder).layer(connector);
116
117        let client = Client::builder(TokioExecutor::new())
118            .pool_idle_timeout(HTTP_KEEPALIVE)
119            .pool_timer(TokioTimer::new())
120            .timer(TokioTimer::new())
121            .build(connector);
122
123        Ok(DefaultRawClient(client))
124    }
125}
126
127fn load_certs_file(path: &Path) -> Result<Vec<CertificateDer<'static>>, Error> {
128    let file = File::open(path).map_err(Error::internal_safe)?;
129    let mut file = BufReader::new(file);
130    rustls_pemfile::certs(&mut file)
131        .collect::<Result<Vec<_>, _>>()
132        .map_err(Error::internal_safe)
133}
134
135fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, Error> {
136    let file = File::open(path).map_err(Error::internal_safe)?;
137    let mut reader = BufReader::new(file);
138
139    let mut items = rustls_pemfile::read_all(&mut reader)
140        .collect::<Result<Vec<_>, _>>()
141        .map_err(Error::internal_safe)?;
142
143    if items.len() != 1 {
144        return Err(Error::internal_safe(
145            "expected exactly one private key in key file",
146        ));
147    }
148
149    match items.pop().unwrap() {
150        Item::Pkcs1Key(key) => Ok(key.into()),
151        Item::Pkcs8Key(key) => Ok(key.into()),
152        Item::Sec1Key(key) => Ok(key.into()),
153        _ => Err(Error::internal_safe(
154            "expected a PKCS#1, PKCS#8, or Sec1 private key",
155        )),
156    }
157}
158
159/// The default raw client implementation used by `conjure_runtime`.
160///
161/// This is currently implemented with `hyper` and `rustls`, but that is subject to change at any time.
162pub struct DefaultRawClient(Client<ConjureConnector, RawBody>);
163
164impl Service<Request<RawBody>> for DefaultRawClient {
165    type Response = Response<DefaultRawBody>;
166    type Error = DefaultRawError;
167
168    async fn call(&self, req: Request<RawBody>) -> Result<Self::Response, Self::Error> {
169        self.0
170            .request(req)
171            .await
172            .map(|r| {
173                r.map(|inner| DefaultRawBody {
174                    inner,
175                    _p: PhantomPinned,
176                })
177            })
178            .map_err(DefaultRawError::new)
179    }
180}
181
182/// The body type used by `DefaultRawClient`.
183#[pin_project]
184pub struct DefaultRawBody {
185    #[pin]
186    inner: Incoming,
187    #[pin]
188    _p: PhantomPinned,
189}
190
191impl Body for DefaultRawBody {
192    type Data = Bytes;
193    type Error = DefaultRawError;
194
195    fn poll_frame(
196        self: Pin<&mut Self>,
197        cx: &mut Context<'_>,
198    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
199        self.project()
200            .inner
201            .poll_frame(cx)
202            .map(|o| o.map(|r| r.map_err(DefaultRawError::new)))
203    }
204
205    fn is_end_stream(&self) -> bool {
206        self.inner.is_end_stream()
207    }
208
209    fn size_hint(&self) -> SizeHint {
210        self.inner.size_hint()
211    }
212}
213
214/// The error type used by `DefaultRawClient`.
215#[derive(Debug)]
216pub struct DefaultRawError(Box<dyn error::Error + Sync + Send>);
217
218impl DefaultRawError {
219    fn new<T>(e: T) -> Self
220    where
221        T: Into<Box<dyn error::Error + Sync + Send>>,
222    {
223        DefaultRawError(e.into())
224    }
225}
226
227impl fmt::Display for DefaultRawError {
228    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
229        fmt::Display::fmt(&self.0, fmt)
230    }
231}
232
233impl error::Error for DefaultRawError {
234    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
235        error::Error::source(&*self.0)
236    }
237}