use reqwest::header::{
HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE, USER_AGENT,
};
use reqwest::{Client, Method, StatusCode, Url};
use serde::Serialize;
use serde_json::Value;
use std::borrow::Cow;
use thiserror::Error;
#[derive(Debug, Clone)]
pub enum AuthStrategy {
None,
Bearer(String),
Header {
name: HeaderName,
value: HeaderValue,
},
}
#[derive(Debug, Clone)]
pub struct RequestFactory {
client: Client,
base_url: Url,
auth: AuthStrategy,
default_headers: HeaderMap,
}
#[derive(Debug, Clone)]
pub struct ResponseBytes {
pub content_type: Option<String>,
pub body: Vec<u8>,
}
#[derive(Debug, Error)]
pub enum HttpError {
#[error("{message}")]
Request {
message: String,
status: Option<StatusCode>,
body: Option<String>,
},
#[error("failed to build request: {0}")]
Build(String),
#[error("failed to parse response JSON: {0}")]
Decode(String),
}
impl HttpError {
pub fn request(
message: impl Into<String>,
status: Option<StatusCode>,
body: Option<String>,
) -> Self {
Self::Request {
message: message.into(),
status,
body,
}
}
}
impl RequestFactory {
pub fn new(base_url: impl AsRef<str>) -> Result<Self, HttpError> {
let client = Client::builder()
.user_agent("xbp")
.build()
.map_err(|error| HttpError::Build(error.to_string()))?;
let base_url =
Url::parse(base_url.as_ref()).map_err(|error| HttpError::Build(error.to_string()))?;
let mut default_headers = HeaderMap::new();
default_headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
default_headers.insert(USER_AGENT, HeaderValue::from_static("xbp"));
Ok(Self {
client,
base_url,
auth: AuthStrategy::None,
default_headers,
})
}
pub fn with_auth(mut self, auth: AuthStrategy) -> Self {
self.auth = auth;
self
}
pub fn with_default_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
self.default_headers.insert(name, value);
self
}
pub async fn get_json<T, Q>(&self, path: &str, query: Option<&Q>) -> Result<T, HttpError>
where
T: serde::de::DeserializeOwned,
Q: Serialize + ?Sized,
{
self.send_json(Method::GET, path, query, Option::<&Value>::None)
.await
}
pub async fn delete_json<T, Q>(&self, path: &str, query: Option<&Q>) -> Result<T, HttpError>
where
T: serde::de::DeserializeOwned,
Q: Serialize + ?Sized,
{
self.send_json(Method::DELETE, path, query, Option::<&Value>::None)
.await
}
pub async fn delete_json_with_body<T, Q, B>(
&self,
path: &str,
query: Option<&Q>,
body: &B,
) -> Result<T, HttpError>
where
T: serde::de::DeserializeOwned,
Q: Serialize + ?Sized,
B: Serialize + ?Sized,
{
self.send_json(Method::DELETE, path, query, Some(body))
.await
}
pub async fn post_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
where
T: serde::de::DeserializeOwned,
B: Serialize + ?Sized,
{
self.send_json(Method::POST, path, Option::<&Value>::None, Some(body))
.await
}
pub async fn put_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
where
T: serde::de::DeserializeOwned,
B: Serialize + ?Sized,
{
self.send_json(Method::PUT, path, Option::<&Value>::None, Some(body))
.await
}
pub async fn patch_json<T, B>(&self, path: &str, body: &B) -> Result<T, HttpError>
where
T: serde::de::DeserializeOwned,
B: Serialize + ?Sized,
{
self.send_json(Method::PATCH, path, Option::<&Value>::None, Some(body))
.await
}
pub async fn post_bytes(
&self,
path: &str,
bytes: Vec<u8>,
content_type: &'static str,
) -> Result<ResponseBytes, HttpError> {
let response = self
.request(Method::POST, path)?
.header(CONTENT_TYPE, content_type)
.body(bytes)
.send()
.await
.map_err(|error| HttpError::request(error.to_string(), None, None))?;
self.read_bytes_response(response).await
}
pub async fn get_bytes<Q>(
&self,
path: &str,
query: Option<&Q>,
) -> Result<ResponseBytes, HttpError>
where
Q: Serialize + ?Sized,
{
let mut request = self.request(Method::GET, path)?;
if let Some(query) = query {
request = request.query(query);
}
let response = request
.send()
.await
.map_err(|error| HttpError::request(error.to_string(), None, None))?;
self.read_bytes_response(response).await
}
async fn send_json<T, Q, B>(
&self,
method: Method,
path: &str,
query: Option<&Q>,
body: Option<&B>,
) -> Result<T, HttpError>
where
T: serde::de::DeserializeOwned,
Q: Serialize + ?Sized,
B: Serialize + ?Sized,
{
let mut request = self.request(method, path)?;
if let Some(query) = query {
request = request.query(query);
}
if let Some(body) = body {
request = request.json(body);
}
let response = request
.send()
.await
.map_err(|error| HttpError::request(error.to_string(), None, None))?;
let status = response.status();
let body = response
.text()
.await
.map_err(|error| HttpError::request(error.to_string(), Some(status), None))?;
if !status.is_success() {
let message = extract_cloudflare_error_message(&body)
.or_else(|| extract_github_error_message(&body))
.unwrap_or_else(|| format!("HTTP {}", status));
return Err(HttpError::request(message, Some(status), Some(body)));
}
serde_json::from_str(&body).map_err(|error| HttpError::Decode(error.to_string()))
}
fn request(&self, method: Method, path: &str) -> Result<reqwest::RequestBuilder, HttpError> {
let mut url = self
.base_url
.join(path)
.map_err(|error| HttpError::Build(error.to_string()))?;
if path.starts_with('/') {
let joined = format!("{}{}", self.base_url.as_str().trim_end_matches('/'), path);
url = Url::parse(&joined).map_err(|error| HttpError::Build(error.to_string()))?;
}
let mut builder = self.client.request(method, url);
builder = builder.headers(self.default_headers.clone());
match &self.auth {
AuthStrategy::None => {}
AuthStrategy::Bearer(token) => {
builder = builder.header(AUTHORIZATION, format!("Bearer {}", token));
}
AuthStrategy::Header { name, value } => {
builder = builder.header(name, value);
}
}
Ok(builder)
}
async fn read_bytes_response(
&self,
response: reqwest::Response,
) -> Result<ResponseBytes, HttpError> {
let status = response.status();
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(str::to_string);
let bytes = response
.bytes()
.await
.map_err(|error| HttpError::request(error.to_string(), Some(status), None))?;
if !status.is_success() {
let body = String::from_utf8_lossy(&bytes).to_string();
let message = extract_cloudflare_error_message(&body)
.or_else(|| extract_github_error_message(&body))
.unwrap_or_else(|| format!("HTTP {}", status));
return Err(HttpError::request(message, Some(status), Some(body)));
}
Ok(ResponseBytes {
content_type,
body: bytes.to_vec(),
})
}
}
pub fn extract_github_error_message(body: &str) -> Option<String> {
let parsed = serde_json::from_str::<Value>(body.trim()).ok()?;
parsed
.get("message")
.and_then(Value::as_str)
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
}
pub fn extract_cloudflare_error_message(body: &str) -> Option<String> {
let parsed = serde_json::from_str::<Value>(body.trim()).ok()?;
let errors = parsed.get("errors")?.as_array()?;
let messages = errors
.iter()
.filter_map(|entry| {
let code = entry.get("code").and_then(Value::as_i64);
let message = entry.get("message").and_then(Value::as_str)?.trim();
if message.is_empty() {
return None;
}
Some(match code {
Some(code) => Cow::Owned(format!("{} ({})", message, code)),
None => Cow::Borrowed(message),
})
})
.collect::<Vec<_>>();
if messages.is_empty() {
None
} else {
Some(
messages
.into_iter()
.map(|value| value.into_owned())
.collect::<Vec<_>>()
.join("; "),
)
}
}