use std::borrow::Cow;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::fmt::Result as FmtResult;
use std::future::Future;
use std::str::from_utf8;
use http::request::Builder as HttpRequestBuilder;
use http::HeaderMap;
use http::HeaderValue;
use http::Request;
use http::Response;
use http_endpoint::Endpoint;
use hyper::body::to_bytes;
use hyper::body::Bytes;
use hyper::client::Builder as HttpClientBuilder;
use hyper::client::HttpConnector;
use hyper::Body;
use hyper::Client as HttpClient;
use hyper::Error as HyperError;
use hyper_tls::HttpsConnector;
use tracing::debug;
use tracing::field::debug;
use tracing::field::DebugValue;
use tracing::instrument;
use tracing::span;
use tracing::trace;
use tracing::Level;
use tracing_futures::Instrument;
use url::Url;
use crate::api::HDR_KEY_ID;
use crate::api::HDR_SECRET;
use crate::api_info::ApiInfo;
use crate::error::RequestError;
use crate::subscribable::Subscribable;
use crate::Error;
struct DebugHeaders<'h> {
headers: &'h HeaderMap<HeaderValue>,
}
impl<'h> Debug for DebugHeaders<'h> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
static MASKED: HeaderValue = HeaderValue::from_static("<masked>");
f.debug_map()
.entries(self.headers.iter().map(|(k, v)| {
if k == HDR_KEY_ID || k == HDR_SECRET {
(k, &MASKED)
} else {
(k, v)
}
}))
.finish()
}
}
struct DebugRequest<'r> {
request: &'r Request<Body>,
}
impl<'r> Debug for DebugRequest<'r> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_struct("Request")
.field("version", &self.request.version())
.field(
"headers",
&DebugHeaders {
headers: self.request.headers(),
},
)
.field("body", self.request.body())
.finish()
}
}
fn debug_request(request: &Request<Body>) -> DebugValue<DebugRequest<'_>> {
debug(DebugRequest { request })
}
#[derive(Debug)]
pub struct Builder {
builder: HttpClientBuilder,
}
impl Builder {
#[inline]
pub fn max_idle_per_host(&mut self, max_idle: usize) -> &mut Self {
let _ = self.builder.pool_max_idle_per_host(max_idle);
self
}
pub fn build(&self, api_info: ApiInfo) -> Client {
let https = HttpsConnector::new();
let client = self.builder.build(https);
Client { api_info, client }
}
}
impl Default for Builder {
#[cfg(test)]
fn default() -> Self {
let mut builder = HttpClient::builder();
let _ = builder.pool_max_idle_per_host(0);
Self { builder }
}
#[cfg(not(test))]
#[inline]
fn default() -> Self {
Self {
builder: HttpClient::builder(),
}
}
}
#[derive(Debug)]
pub struct Client {
api_info: ApiInfo,
client: HttpClient<HttpsConnector<HttpConnector>, Body>,
}
impl Client {
#[inline]
pub fn builder() -> Builder {
Builder::default()
}
#[inline]
pub fn new(api_info: ApiInfo) -> Self {
Builder::default().build(api_info)
}
#[cfg(feature = "gzip")]
fn maybe_add_gzip_header(request: &mut Request<Body>) {
use http::header::ACCEPT_ENCODING;
let _ = request
.headers_mut()
.insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip"));
}
#[cfg(not(feature = "gzip"))]
fn maybe_add_gzip_header(_request: &mut Request<Body>) {}
fn request<R>(&self, input: &R::Input) -> Result<Request<Body>, R::Error>
where
R: Endpoint,
{
let mut url = R::base_url()
.map(|url| Url::parse(url.as_ref()).expect("endpoint definition contains invalid URL"))
.unwrap_or_else(|| self.api_info.api_base_url.clone());
url.set_path(&R::path(input));
url.set_query(R::query(input)?.as_ref().map(AsRef::as_ref));
let mut request = HttpRequestBuilder::new()
.method(R::method())
.uri(url.as_str())
.header(HDR_KEY_ID, self.api_info.key_id.as_str())
.header(HDR_SECRET, self.api_info.secret.as_str())
.body(Body::from(
R::body(input)?.unwrap_or(Cow::Borrowed(&[0; 0])),
))?;
Self::maybe_add_gzip_header(&mut request);
Ok(request)
}
async fn retrieve_raw_body(response: Body) -> Result<Bytes, HyperError> {
to_bytes(response).await
}
#[cfg(feature = "gzip")]
async fn retrieve_body<E>(response: Response<Body>) -> Result<Bytes, RequestError<E>> {
use async_compression::futures::bufread::GzipDecoder;
use futures::AsyncReadExt as _;
use http::header::CONTENT_ENCODING;
let (parts, body) = response.into_parts();
let encoding = parts.headers.get(CONTENT_ENCODING);
let bytes = Self::retrieve_raw_body(body).await?;
let bytes = match encoding {
Some(value) if value == HeaderValue::from_static("gzip") => {
let mut buffer = Vec::new();
let _count = GzipDecoder::new(&*bytes).read_to_end(&mut buffer).await?;
buffer.into()
},
_ => bytes,
};
Ok(bytes)
}
#[cfg(not(feature = "gzip"))]
async fn retrieve_body<E>(response: Response<Body>) -> Result<Bytes, RequestError<E>> {
let bytes = Self::retrieve_raw_body(response.into_body()).await?;
Ok(bytes)
}
pub fn issue<R>(
&self,
input: &R::Input,
) -> impl Future<Output = Result<R::Output, RequestError<R::Error>>> + '_
where
R: Endpoint,
{
let result = self.request::<R>(input);
async move {
let request = result.map_err(RequestError::Endpoint)?;
let span = span!(
Level::INFO,
"issue",
method = display(request.method()),
uri = display(request.uri())
);
self.issue_::<R>(request).instrument(span).await
}
}
#[allow(clippy::cognitive_complexity)]
async fn issue_<R>(&self, request: Request<Body>) -> Result<R::Output, RequestError<R::Error>>
where
R: Endpoint,
{
debug!("requesting");
trace!(request = debug_request(&request));
let result = self.client.request(request).await?;
let status = result.status();
debug!(status = debug(&status));
trace!(response = debug(&result));
let bytes = Self::retrieve_body::<R::Error>(result).await?;
let body = bytes.as_ref();
match from_utf8(body) {
Ok(s) => trace!(body = display(&s)),
Err(b) => trace!(body = display(&b)),
}
R::evaluate(status, body).map_err(RequestError::Endpoint)
}
#[instrument(level = "debug", skip(self))]
pub async fn subscribe<S>(&self) -> Result<(S::Stream, S::Subscription), Error>
where
S: Subscribable<Input = ApiInfo>,
{
S::connect(&self.api_info).await
}
#[inline]
pub fn api_info(&self) -> &ApiInfo {
&self.api_info
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::StatusCode;
use test_log::test;
use crate::endpoint::ApiError;
use crate::Str;
Endpoint! {
GetNotFound(()),
Ok => (), [],
Err => GetNotFoundError, []
fn path(_input: &Self::Input) -> Str {
"/v2/foobarbaz".into()
}
}
#[test(tokio::test)]
async fn unexpected_status_code_return() {
let api_info = ApiInfo::from_env().unwrap();
let client = Client::builder().max_idle_per_host(0).build(api_info);
let result = client.issue::<GetNotFound>(&()).await;
let err = result.unwrap_err();
match err {
RequestError::Endpoint(GetNotFoundError::UnexpectedStatus(status, message)) => {
let expected = ApiError {
code: 40410000,
message: "endpoint not found".to_string(),
};
assert_eq!(message, Ok(expected));
assert_eq!(status, StatusCode::NOT_FOUND);
},
_ => panic!("Received unexpected error: {err:?}"),
};
}
}