use std::marker::PhantomData;
use reqwest::header::CONTENT_TYPE;
pub trait Request {
type Response;
fn method(&self) -> HttpMethod;
fn path(&self) -> String;
}
#[derive(Debug, Clone)]
pub struct Client<RequestGroup = All> {
base_url: String,
_p: PhantomData<RequestGroup>,
}
impl<RequestGroup> Client<RequestGroup> {
pub fn new(base_url: String) -> Self {
Self {
base_url,
_p: PhantomData,
}
}
pub async fn send_to<Req>(base_url: &str, request: Req) -> Result<Req::Response, Error>
where
Req: Request + serde::Serialize + InRequestGroup<RequestGroup>,
Req::Response: for<'a> serde::Deserialize<'a>,
{
send(base_url, request).await
}
pub async fn send<Req>(&self, request: Req) -> Result<Req::Response, Error>
where
Req: Request + serde::Serialize + InRequestGroup<RequestGroup>,
Req::Response: for<'a> serde::Deserialize<'a>,
{
send(&self.base_url, request).await
}
}
pub async fn send<Req>(base_url: &str, request: Req) -> Result<Req::Response, Error>
where
Req: Request + serde::Serialize,
Req::Response: for<'a> serde::Deserialize<'a>,
{
let url = join_url(base_url, request.path());
send_custom(&url, request.method(), request).await
}
pub async fn send_custom<Req, Res>(
url: &str,
method: HttpMethod,
request: Req,
) -> Result<Res, Error>
where
Req: serde::Serialize,
Res: for<'a> serde::Deserialize<'a>,
{
let response = reqwest::Client::new()
.request(method.into(), url)
.body(
serde_json::to_string(&request)
.map_err(Error::SerializationError)?
.into_bytes(),
)
.header(CONTENT_TYPE, "application/json")
.send()
.await?;
let status = response.status();
if status.is_success() {
let body = response.bytes().await?;
serde_json::from_slice(&body).map_err(|error| Error::DeserializationError {
error,
response_body: body_bytes_to_str(&body),
})
} else {
let message = match response.bytes().await {
Ok(bytes) => body_bytes_to_str(&bytes),
Err(e) => format!("failed to get body: {e:?}"),
};
Err(Error::InvalidStatusCode(status.into(), message))
}
}
fn body_bytes_to_str(bytes: &[u8]) -> String {
match std::str::from_utf8(bytes) {
Ok(message) => message.to_owned(),
Err(e) => format!("could not read message body as a string: {e:?}"),
}
}
#[macro_export]
macro_rules! request_group {
($viz:vis $Name:ident { $($Request:ident),*$(,)? }) => {
$viz struct $Name;
$(impl $crate::InRequestGroup<$Name> for $Request {})*
};
}
pub trait InRequestGroup<Group> {}
pub struct All;
impl<T> InRequestGroup<All> for T {}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("reqwest error: {0}")]
ClientError(#[from] reqwest::Error),
#[error("serde serialization error: {0}")]
SerializationError(serde_json::error::Error),
#[error("serde deserialization error `{error}` while parsing response body: {response_body}")]
DeserializationError {
error: serde_json::error::Error,
response_body: String,
},
#[error("invalid status code {0} with response body: `{1}`")]
InvalidStatusCode(u16, String),
}
#[derive(Debug, Clone, Copy)]
pub enum HttpMethod {
Options,
Get,
Post,
Put,
Delete,
Head,
Trace,
Connect,
Patch,
}
impl From<HttpMethod> for reqwest::Method {
fn from(value: HttpMethod) -> Self {
match value {
HttpMethod::Options => reqwest::Method::OPTIONS,
HttpMethod::Get => reqwest::Method::GET,
HttpMethod::Post => reqwest::Method::POST,
HttpMethod::Put => reqwest::Method::PUT,
HttpMethod::Delete => reqwest::Method::DELETE,
HttpMethod::Head => reqwest::Method::HEAD,
HttpMethod::Trace => reqwest::Method::TRACE,
HttpMethod::Connect => reqwest::Method::CONNECT,
HttpMethod::Patch => reqwest::Method::PATCH,
}
}
}
fn join_url(base_url: &str, path: String) -> String {
if base_url.chars().last().map(|c| c == '/').unwrap_or(true)
|| path.chars().next().map(|c| c == '/').unwrap_or(true)
{
format!("{base_url}{}", path)
} else {
format!("{base_url}/{}", path)
}
}