use core::fmt::{self, Debug};
use std::{
error::{self, Error},
future::Future,
};
use http::{HeaderMap, HeaderName, HeaderValue};
use reqwest::{Body, Client, Method, RequestBuilder, Response};
use serde::de::DeserializeOwned;
#[derive(Clone)]
pub enum HttpRequestOption {
Header(HeaderName, HeaderValue),
BaseUrl(String),
NoBeforeAfter,
Anonymous,
}
pub type ClientError = Box<dyn Error + Send + Sync>;
#[derive(Debug)]
pub struct FailedRequestError {
response: Response,
}
impl From<Response> for FailedRequestError {
fn from(value: Response) -> Self {
Self { response: value }
}
}
impl fmt::Display for FailedRequestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"http request failed with status code {}",
self.response.status()
)
}
}
impl error::Error for FailedRequestError {}
pub trait HttpClient {
fn request(
&self,
method: Method,
path: &str,
body: impl Into<Body> + Send,
options: Vec<HttpRequestOption>,
) -> impl Future<Output = Result<Response, ClientError>> + Send
where
Self: Sync,
{
async move {
let mut base_url: String = self.get_base_url().to_string();
let mut header_map = HeaderMap::new();
let mut no_before_after = false;
for option in &options {
match option {
HttpRequestOption::Header(key, value) => {
header_map.insert(key.clone(), value.clone());
}
HttpRequestOption::BaseUrl(url) => {
base_url = url.clone();
}
HttpRequestOption::NoBeforeAfter => {
no_before_after = true;
}
_ => {}
}
}
let url = format!("{}{}", base_url, path);
let client = self.get_http_client();
let req_builder = client.request(method, url).headers(header_map).body(body);
let req_builder = if !no_before_after {
match self.before_request(req_builder, path, options).await {
Ok(req_builder) => req_builder,
Err(err) => return Err(err),
}
} else {
req_builder
};
let req = req_builder.build()?;
match client.execute(req).await {
Ok(response) => {
let response = if !no_before_after {
match self.after_request(response).await {
Ok(response) => response,
Err(err) => return Err(err),
}
} else {
response
};
Ok(response)
}
Err(err) => Err(Box::new(err)),
}
}
}
fn request_json<T>(
&self,
method: Method,
path: &str,
body: impl Into<Body> + Send,
options: Vec<HttpRequestOption>,
) -> impl Future<Output = Result<T, ClientError>> + Send
where
Self: Sync,
T: DeserializeOwned + Debug,
{
async move {
let mut options = options;
if !options.iter().any(|opt| match opt {
HttpRequestOption::Header(name, _) => name == "content-type",
_ => false,
}) {
options.push(HttpRequestOption::Header(
HeaderName::from_static("content-type"),
HeaderValue::from_static("application/json"),
));
}
match self.request(method, path, body, options).await {
Ok(response) => {
if response.status().is_success() {
let bytes = response.bytes().await?;
let payload = serde_json::from_slice(&bytes)?;
Ok(payload)
} else {
Err(Box::new(FailedRequestError::from(response))
as Box<dyn Error + Send + Sync>)
}
}
Err(err) => Err(err),
}
}
}
fn before_request(
&self,
req_builder: RequestBuilder,
_path: &str,
_options: Vec<HttpRequestOption>,
) -> impl Future<Output = Result<RequestBuilder, ClientError>> + Send {
async { Ok(req_builder) }
}
fn after_request(
&self,
response: Response,
) -> impl Future<Output = Result<Response, ClientError>> + Send {
async { Ok(response) }
}
fn get_http_client(&self) -> &Client;
fn get_base_url(&self) -> &str;
}