#![allow(clippy::type_complexity)]
use super::body::BodyStreamExt;
pub use hyper::{body, service::Service, Body, Request, Response, Uri};
use std::{sync::Arc, time::Duration};
use tokio::sync::Mutex;
use tower::{util::BoxCloneService, Layer, ServiceExt};
#[cfg(feature = "tower-trace")]
use opentelemetry::global;
#[cfg(feature = "tower-trace")]
use opentelemetry_http::HeaderInjector;
#[cfg(all(feature = "tower-client-rls", not(feature = "tower-client-tls")))]
use rustls::{
client::{ServerCertVerified, ServerCertVerifier},
Certificate, Error as TLSError,
};
use tower_http::map_response_body::MapResponseBodyLayer;
#[cfg(feature = "tower-trace")]
use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer};
#[cfg(feature = "tower-trace")]
use tracing::Span;
#[cfg(feature = "tower-trace")]
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Tower Service Error
pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
/// `ConfigurationBuilder` that can be used to build a `Configuration`.
#[derive(Clone)]
pub struct ConfigurationBuilder {
/// Timeout for each HTTP Request.
timeout: Option<std::time::Duration>,
/// Bearer Access Token for bearer-configured routes.
bearer_token: Option<String>,
/// OpenTel and Tracing layer.
#[cfg(feature = "tower-trace")]
tracing_layer: bool,
certificate: Option<Vec<u8>>,
concurrency_limit: Option<usize>,
}
impl Default for ConfigurationBuilder {
fn default() -> Self {
Self {
timeout: Some(std::time::Duration::from_secs(5)),
bearer_token: None,
#[cfg(feature = "tower-trace")]
tracing_layer: true,
certificate: None,
concurrency_limit: None,
}
}
}
impl ConfigurationBuilder {
/// Return a new `Self`.
pub fn new() -> Self {
Self::default()
}
/// Enable/Disable a request timeout layer with the given request timeout.
pub fn with_timeout<O: Into<Option<Duration>>>(mut self, timeout: O) -> Self {
self.timeout = timeout.into();
self
}
/// Enable/Disable the given request bearer token.
pub fn with_bearer_token(mut self, bearer_token: Option<String>) -> Self {
self.bearer_token = bearer_token;
self
}
/// Add a request concurrency limit.
pub fn with_concurrency_limit(mut self, limit: Option<usize>) -> Self {
self.concurrency_limit = limit;
self
}
/// Add a PEM-format certificate file.
pub fn with_certificate(mut self, certificate: &[u8]) -> Self {
self.certificate = Some(certificate.to_vec());
self
}
/// Enable/Disable the telemetry and tracing layer.
#[cfg(feature = "tower-trace")]
pub fn with_tracing(mut self, tracing_layer: bool) -> Self {
self.tracing_layer = tracing_layer;
self
}
/// Build a `Configuration` from the Self parameters.
pub fn build(self, uri: hyper::Uri) -> Result<Configuration, Error> {
Configuration::new(
uri.to_string().parse().map_err(Error::UriToUrl)?,
self.timeout.unwrap(),
self.bearer_token,
self.certificate.as_ref().map(|c| &c[..]),
self.tracing_layer,
self.concurrency_limit,
)
}
/// Build a `Configuration` from the Self parameters.
pub fn build_url(self, url: url::Url) -> Result<Configuration, Error> {
Configuration::new(
url,
self.timeout.unwrap_or_else(|| Duration::from_secs(5)),
self.bearer_token,
self.certificate.as_ref().map(|c| &c[..]),
self.tracing_layer,
self.concurrency_limit,
)
}
/// Build a `Configuration` from the Self parameters.
pub fn build_with_svc<S>(
self,
uri: hyper::Uri,
client_service: S,
) -> Result<Configuration, Error>
where
S: Service<Request<Body>, Response = Response<Body>> + Sync + Send + Clone + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxedError> + std::fmt::Debug,
{
#[cfg(feature = "tower-trace")]
let tracing_layer = self.tracing_layer;
#[cfg(not(feature = "tower-trace"))]
let tracing_layer = false;
Configuration::new_with_client(
uri,
client_service,
self.timeout,
self.bearer_token,
tracing_layer,
self.concurrency_limit,
)
}
}
/// Configuration used by the `ApiClient`.
#[derive(Clone)]
pub struct Configuration {
pub base_path: hyper::Uri,
pub user_agent: Option<String>,
pub client_service: Arc<Mutex<BoxCloneService<Request<Body>, Response<Body>, BoxedError>>>,
pub basic_auth: Option<BasicAuth>,
pub oauth_access_token: Option<String>,
pub bearer_access_token: Option<String>,
pub api_key: Option<ApiKey>,
}
/// Basic authentication.
pub type BasicAuth = (String, Option<String>);
/// ApiKey used for ApiKey authentication.
#[derive(Debug, Clone)]
pub struct ApiKey {
pub prefix: Option<String>,
pub key: String,
}
/// Configuration creation Error.
#[derive(Debug)]
pub enum Error {
Certificate,
TlsConnector,
NoTracingFeature,
UrlToUri(hyper::http::uri::InvalidUri),
UriToUrl(url::ParseError),
AddingVersionPath(hyper::http::uri::InvalidUri),
}
impl Configuration {
/// Return a new `ConfigurationBuilder`.
pub fn builder() -> ConfigurationBuilder {
ConfigurationBuilder::new()
}
/// New `Self` with a provided client.
pub fn new_with_client<S>(
mut url: hyper::Uri,
client_service: S,
timeout: Option<std::time::Duration>,
bearer_access_token: Option<String>,
trace_requests: bool,
concurrency_limit: Option<usize>,
) -> Result<Self, Error>
where
S: Service<Request<Body>, Response = Response<Body>> + Sync + Send + Clone + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxedError> + std::fmt::Debug,
{
#[cfg(feature = "tower-trace")]
let tracing_layer = tower::ServiceBuilder::new()
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<Body>| {
tracing::info_span!(
"HTTP",
http.method = %request.method(),
http.url = %request.uri(),
http.status_code = tracing::field::Empty,
otel.name = %format!("{} {}", request.method(), request.uri()),
otel.kind = "client",
otel.status_code = tracing::field::Empty,
)
})
// to silence the default trace
.on_request(|request: &Request<Body>, _span: &Span| {
tracing::trace!("started {} {}", request.method(), request.uri().path())
})
.on_response(
|response: &Response<Body>, _latency: std::time::Duration, span: &Span| {
let status = response.status();
span.record("http.status_code", status.as_u16());
if status.is_client_error() || status.is_server_error() {
span.record("otel.status_code", "ERROR");
}
},
)
.on_body_chunk(())
.on_failure(
|ec: ServerErrorsFailureClass,
_latency: std::time::Duration,
span: &Span| {
span.record("otel.status_code", "ERROR");
match ec {
ServerErrorsFailureClass::StatusCode(status) => {
span.record("http.status_code", status.as_u16());
tracing::debug!(status=%status, "failed to issue request")
}
ServerErrorsFailureClass::Error(err) => {
tracing::debug!(error=%err, "failed to issue request")
}
}
},
),
)
// injects the telemetry context in the http headers
.layer(OpenTelContext::new())
.into_inner();
url = format!("{}/v0", url.to_string().trim_end_matches('/'))
.parse()
.map_err(Error::AddingVersionPath)?;
let backend_service = tower::ServiceBuilder::new()
.option_layer(timeout.map(tower::timeout::TimeoutLayer::new))
// .option_layer(
// bearer_access_token.map(|b| tower_http::auth::AddAuthorizationLayer::bearer(&b)),
// )
.service(client_service);
let service_builder = tower::ServiceBuilder::new()
.option_layer(concurrency_limit.map(tower::limit::ConcurrencyLimitLayer::new));
match trace_requests {
false => Ok(Self::new_with_client_inner(
url,
service_builder.service(backend_service),
bearer_access_token,
)),
true => {
#[cfg(feature = "tower-trace")]
let result = Ok(Self::new_with_client_inner(
url,
service_builder
.layer(tracing_layer)
.service(backend_service),
bearer_access_token,
));
#[cfg(not(feature = "tower-trace"))]
let result = Err(Error::NoTracingFeature {});
result
}
}
}
/// New `Self`.
pub fn new(
mut url: url::Url,
timeout: std::time::Duration,
bearer_access_token: Option<String>,
certificate: Option<&[u8]>,
trace_requests: bool,
concurrency_limit: Option<usize>,
) -> Result<Self, Error> {
#[cfg(all(not(feature = "tower-client-tls"), feature = "tower-client-rls"))]
let client = {
match certificate {
None => {
let mut http = hyper::client::HttpConnector::new();
let tls = match url.scheme() == "https" {
true => {
http.enforce_http(false);
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(std::sync::Arc::new(
DisableServerCertVerifier {},
))
.with_no_client_auth()
}
false => rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth(),
};
let connector =
hyper_rustls::HttpsConnector::from((http, std::sync::Arc::new(tls)));
hyper::Client::builder().build(connector)
}
Some(bytes) => {
let mut cert_file = std::io::BufReader::new(bytes);
let mut root_store = rustls::RootCertStore::empty();
root_store.add_parsable_certificates(
&rustls_pemfile::certs(&mut cert_file).map_err(|_| Error::Certificate)?,
);
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let mut http = hyper::client::HttpConnector::new();
http.enforce_http(false);
let connector =
hyper_rustls::HttpsConnector::from((http, std::sync::Arc::new(config)));
url.set_scheme("https").ok();
hyper::Client::builder().build(connector)
}
}
};
#[cfg(feature = "tower-client-tls")]
let client = {
match certificate {
None => {
let mut http = hyper_tls::HttpsConnector::new();
if url.scheme() == "https" {
http.https_only(true);
}
let tls = hyper_tls::native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(true)
.build()
.map_err(|_| Error::TlsConnector)?;
let tls = tokio_native_tls::TlsConnector::from(tls);
let connector = hyper_tls::HttpsConnector::from((http, tls));
hyper::Client::builder().build(connector)
}
Some(bytes) => {
let certificate = hyper_tls::native_tls::Certificate::from_pem(bytes)
.map_err(|_| Error::Certificate)?;
let tls = hyper_tls::native_tls::TlsConnector::builder()
.add_root_certificate(certificate)
.danger_accept_invalid_hostnames(true)
.disable_built_in_roots(true)
.build()
.map_err(|_| Error::TlsConnector)?;
let tls = tokio_native_tls::TlsConnector::from(tls);
let mut http = hyper_tls::HttpsConnector::new();
http.https_only(true);
let connector = hyper_tls::HttpsConnector::from((http, tls));
url.set_scheme("https").ok();
hyper::Client::builder().build(connector)
}
}
};
let uri = url.to_string().parse().map_err(Error::UrlToUri)?;
Self::new_with_client(
uri,
client,
Some(timeout),
bearer_access_token,
trace_requests,
concurrency_limit,
)
}
/// New `Self` with a provided client.
pub fn new_with_client_inner<S, B>(
url: hyper::Uri,
client_service: S,
bearer_access_token: Option<String>,
) -> Self
where
S: Service<Request<Body>, Response = Response<B>> + Sync + Send + Clone + 'static,
S::Future: Send + 'static,
S::Error: Into<BoxedError> + std::fmt::Debug,
B: http_body::Body<Data = hyper::body::Bytes> + Send + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
// Transform response body to `hyper::Body` and use type erased error to avoid type
// parameters.
let client_service = MapResponseBodyLayer::new(|b: B| Body::wrap_stream(b.into_stream()))
.layer(client_service)
.map_err(|e| e.into());
let client_service = Arc::new(Mutex::new(BoxCloneService::new(client_service)));
Self {
base_path: url,
user_agent: None,
client_service,
basic_auth: None,
oauth_access_token: None,
bearer_access_token,
api_key: None,
}
}
}
/// Add OpenTelemetry Span to the Http Headers.
#[cfg(feature = "tower-trace")]
pub struct OpenTelContext {}
#[cfg(feature = "tower-trace")]
impl OpenTelContext {
fn new() -> Self {
Self {}
}
}
#[cfg(feature = "tower-trace")]
impl<S> Layer<S> for OpenTelContext {
type Service = OpenTelContextService<S>;
fn layer(&self, service: S) -> Self::Service {
OpenTelContextService::new(service)
}
}
/// OpenTelemetry Service that injects the current span into the Http Headers.
#[cfg(feature = "tower-trace")]
#[derive(Clone)]
pub struct OpenTelContextService<S> {
service: S,
}
#[cfg(feature = "tower-trace")]
impl<S> OpenTelContextService<S> {
fn new(service: S) -> Self {
Self { service }
}
}
#[cfg(feature = "tower-trace")]
impl<S> Service<hyper::Request<Body>> for OpenTelContextService<S>
where
S: Service<hyper::Request<Body>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, mut request: hyper::Request<Body>) -> Self::Future {
let cx = tracing::Span::current().context();
global::get_text_map_propagator(|propagator| {
propagator.inject_context(&cx, &mut HeaderInjector(request.headers_mut()))
});
self.service.call(request)
}
}
#[cfg(all(feature = "tower-client-rls", not(feature = "tower-client-tls")))]
struct DisableServerCertVerifier {}
#[cfg(all(feature = "tower-client-rls", not(feature = "tower-client-tls")))]
impl ServerCertVerifier for DisableServerCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &rustls::ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: std::time::SystemTime,
) -> Result<ServerCertVerified, TLSError> {
Ok(ServerCertVerified::assertion())
}
}