use async_trait::async_trait;
use serde_json::Value;
#[derive(Debug, Clone)]
pub struct HttpRequest {
pub method: String,
pub url: String,
pub headers: Vec<(String, String)>,
pub form: Vec<(String, String)>,
pub json: Option<Value>,
pub basic_auth: Option<(String, Option<String>)>,
pub bearer_auth: Option<String>,
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub status: u16,
pub body: Value,
}
#[async_trait]
pub trait HttpClient: Send + Sync {
async fn execute(&self, req: HttpRequest) -> Result<HttpResponse, crate::error::ConnectError>;
}
pub trait HttpClientExt {
fn get(&self, url: impl Into<String>) -> RequestBuilder<'_>;
fn post(&self, url: impl Into<String>) -> RequestBuilder<'_>;
}
impl HttpClientExt for dyn HttpClient + '_ {
fn get(&self, url: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder::new(self, "GET".to_string(), url.into())
}
fn post(&self, url: impl Into<String>) -> RequestBuilder<'_> {
RequestBuilder::new(self, "POST".to_string(), url.into())
}
}
pub struct RequestBuilder<'a> {
client: &'a dyn HttpClient,
req: HttpRequest,
}
impl<'a> RequestBuilder<'a> {
pub fn new(client: &'a dyn HttpClient, method: String, url: String) -> Self {
Self {
client,
req: HttpRequest {
method,
url,
headers: vec![],
form: vec![],
json: None,
basic_auth: None,
bearer_auth: None,
},
}
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.req.headers.push((key.into(), value.into()));
self
}
pub fn bearer_auth(mut self, token: &str) -> Self {
self.req.bearer_auth = Some(token.to_string());
self
}
pub fn basic_auth(mut self, username: &str, password: Option<&str>) -> Self {
self.req.basic_auth = Some((username.to_string(), password.map(|s| s.to_string())));
self
}
pub fn json(mut self, value: &Value) -> Self {
self.req.json = Some(value.clone());
self
}
pub fn form<K, V>(mut self, form: &[(K, V)]) -> Self
where
K: AsRef<str>,
V: AsRef<str>,
{
for (k, v) in form {
self.req
.form
.push((k.as_ref().to_string(), v.as_ref().to_string()));
}
self
}
pub async fn send(self) -> Result<ResponseWrapper, crate::error::ConnectError> {
let res = self.client.execute(self.req).await?;
Ok(ResponseWrapper { res })
}
}
pub struct ResponseWrapper {
res: HttpResponse,
}
impl ResponseWrapper {
pub fn error_for_status(self) -> Result<Self, crate::error::ConnectError> {
if self.res.status >= 400 {
tracing::error!("HTTP status {} received", self.res.status);
let mut code = format!("HTTP_{}", self.res.status);
let mut message = "Unknown error".to_string();
if let Some(obj) = self.res.body.as_object() {
if let Some(err) = obj.get("error").and_then(|v| v.as_str()) {
code = err.to_string();
}
if let Some(desc) = obj.get("error_description").and_then(|v| v.as_str()) {
message = desc.to_string();
} else if let Some(msg) = obj.get("message").and_then(|v| v.as_str()) {
message = msg.to_string();
} else {
message = self.res.body.to_string();
}
} else if let Some(s) = self.res.body.as_str() {
message = s.to_string();
}
Err(crate::error::ConnectError::ProviderApiError { code, message })
} else {
Ok(self)
}
}
pub async fn json<T>(self) -> Result<T, crate::error::ConnectError>
where
T: serde::de::DeserializeOwned,
{
let t = serde_json::from_value(self.res.body)?;
Ok(t)
}
}
pub struct ReqwestClient {
#[cfg(not(feature = "retry"))]
client: reqwest::Client,
#[cfg(feature = "retry")]
client: reqwest_middleware::ClientWithMiddleware,
}
impl ReqwestClient {
pub fn new() -> Self {
let reqwest_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
#[cfg(feature = "retry")]
{
let retry_policy =
reqwest_retry::policies::ExponentialBackoff::builder().build_with_max_retries(3);
let client = reqwest_middleware::ClientBuilder::new(reqwest_client)
.with(reqwest_retry::RetryTransientMiddleware::new_with_policy(
retry_policy,
))
.build();
Self { client }
}
#[cfg(not(feature = "retry"))]
Self {
client: reqwest_client,
}
}
#[cfg(feature = "retry")]
pub fn new_with_retry(max_retries: u32) -> Self {
let reqwest_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.build()
.unwrap_or_else(|_| reqwest::Client::new());
let retry_policy = reqwest_retry::policies::ExponentialBackoff::builder()
.build_with_max_retries(max_retries);
let client = reqwest_middleware::ClientBuilder::new(reqwest_client)
.with(reqwest_retry::RetryTransientMiddleware::new_with_policy(
retry_policy,
))
.build();
Self { client }
}
}
impl Default for ReqwestClient {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl HttpClient for ReqwestClient {
#[tracing::instrument(skip(self, req), fields(method = %req.method, url = %req.url))]
async fn execute(&self, req: HttpRequest) -> Result<HttpResponse, crate::error::ConnectError> {
tracing::debug!("Executing HTTP request");
let method = match req.method.as_str() {
"GET" => reqwest::Method::GET,
"POST" => reqwest::Method::POST,
_ => reqwest::Method::GET,
};
#[cfg(not(feature = "retry"))]
let res = {
let mut builder = self.client.request(method, &req.url);
for (k, v) in &req.headers {
builder = builder.header(k, v);
}
if let Some(token) = &req.bearer_auth {
builder = builder.bearer_auth(token);
}
if let Some((user, pass)) = &req.basic_auth {
builder = builder.basic_auth(user, pass.as_deref());
}
if !req.form.is_empty() {
builder = builder.form(&req.form);
} else if let Some(j) = &req.json {
builder = builder.json(j);
}
builder
.send()
.await
.map_err(crate::error::ConnectError::Reqwest)?
};
#[cfg(feature = "retry")]
let res = {
let mut builder = self.client.request(method, &req.url);
for (k, v) in &req.headers {
builder = builder.header(k, v);
}
if let Some(token) = &req.bearer_auth {
builder = builder.bearer_auth(token);
}
if let Some((user, pass)) = &req.basic_auth {
builder = builder.basic_auth(user, pass.as_deref());
}
if !req.form.is_empty() {
let body = serde_urlencoded::to_string(&req.form).unwrap_or_default();
builder = builder.body(body).header(
reqwest::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
);
} else if let Some(j) = &req.json {
let body = serde_json::to_string(j).unwrap_or_default();
builder = builder
.body(body)
.header(reqwest::header::CONTENT_TYPE, "application/json");
}
builder.send().await.map_err(|e| {
if let reqwest_middleware::Error::Reqwest(err) = e {
crate::error::ConnectError::Reqwest(err)
} else {
crate::error::ConnectError::Provider(e.to_string())
}
})?
};
let status = res.status().as_u16();
tracing::debug!(status = %status, "Received HTTP response");
let text = res
.text()
.await
.map_err(crate::error::ConnectError::Reqwest)?;
let body = serde_json::from_str(&text).unwrap_or(Value::String(text));
Ok(HttpResponse { status, body })
}
}