use async_trait::async_trait;
use reqwest::Client;
use reqwest::RequestBuilder;
use serde::de::DeserializeOwned;
use thiserror::Error;
type None = ();
impl Headers for None {
fn write_headers(&self, req: RequestBuilder) -> RequestBuilder {
req
}
}
impl Parameters for None {
fn write_parameters(&self, req: RequestBuilder) -> RequestBuilder {
req
}
}
impl Body for None {
fn write_body(&self, req: RequestBuilder) -> RequestBuilder {
req
}
}
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct FailureStatus<S>
where
S: DeserializeOwned + std::fmt::Display + std::fmt::Debug + 'static,
{
pub error: Option<String>,
#[serde(bound(deserialize = "S: DeserializeOwned"))]
pub status: S,
pub message: String,
}
impl<S> std::fmt::Display for FailureStatus<S>
where
S: DeserializeOwned + std::fmt::Display + std::fmt::Debug + 'static,
{
fn fmt(&self, w: &mut std::fmt::Formatter) -> std::fmt::Result {
if let Some(error) = &self.error {
write!(
w,
"Encountered error with code {}, error {}, and message {}",
self.status, error, self.message
)
} else {
write!(
w,
"Encountered error with code {}, and message {}",
self.status, self.message
)
}
}
}
impl<S> std::error::Error for FailureStatus<S> where
S: DeserializeOwned + std::fmt::Display + std::fmt::Debug + 'static
{
}
impl<E: ErrorCodes> From<FailureStatus<u16>> for RequestError<E> {
fn from(failure: FailureStatus<u16>) -> Self {
match E::from_status(failure) {
Ok(known) => RequestError::KnownErrorStatus(known),
Err(unkn) => RequestError::UnkownErrorStatus(unkn),
}
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum PossibleResponse<R>
where
R: Response + 'static,
{
#[serde(bound(deserialize = "R: DeserializeOwned"))]
Response(R),
Failure(FailureStatus<u16>),
}
impl<R> PossibleResponse<R>
where
R: Response + 'static,
{
fn into_result(self) -> Result<R, FailureStatus<u16>> {
match self {
Self::Response(r) => Ok(r),
Self::Failure(f) => Err(f),
}
}
}
#[derive(Debug, Error)]
pub enum RequestError<C: ErrorCodes + 'static> {
#[error("You must provide valid authorization to this endpoint")]
MissingAuth,
#[error("Request Malformed with message: {0}")]
MalformedRequest(String),
#[error("Did not have user scopes required {0:?}")]
ScopesError(Vec<String>),
#[error("Known Error enountered: {0}")]
KnownErrorStatus(FailureStatus<C>),
#[error("Unknown Error enountered: {0}")]
UnkownErrorStatus(FailureStatus<u16>),
#[error("Reqwest encountered an error: {0}")]
ReqwestError(#[from] reqwest::Error),
#[error("Unknown Error encountered {0:?}")]
UnknownError(#[from] Box<dyn std::error::Error>),
}
pub trait ErrorCodes: std::error::Error + Sized + DeserializeOwned + Copy {
fn from_status(codes: FailureStatus<u16>) -> Result<FailureStatus<Self>, FailureStatus<u16>>;
}
#[derive(Debug, Clone, Copy, Error, Deserialize)]
pub enum CommonResponseCodes {
#[error("400: Malformed Request")]
BadRequestCode,
#[error("401: Authorization Error")]
AuthErrorCode,
#[error("500: Server Error")]
ServerErrorCode,
}
#[macro_export]
macro_rules! response_codes {
($for:ty : [$($val:expr => $item:path),+]) => {
impl ErrorCodes for $for {
fn from_status(codes: FailureStatus<u16>) -> Result<FailureStatus<Self>, FailureStatus<u16>> {
match codes.status {
$(
$val => Ok(FailureStatus::<Self> {
error: codes.error,
status: $item,
message: codes.message
}),
)*
_ => Err(codes),
}
}
}
}
}
response_codes!(
CommonResponseCodes: [
400 => CommonResponseCodes::BadRequestCode,
401 => CommonResponseCodes::AuthErrorCode,
500 => CommonResponseCodes::ServerErrorCode
]);
pub trait Headers {
fn write_headers(&self, req: RequestBuilder) -> RequestBuilder;
}
pub trait HeadersExt {
fn as_ref<'a>(&'a self) -> &'a [(&'a str, &'a str)];
}
impl<T: HeadersExt> Headers for T {
fn write_headers<'a>(&'a self, mut req: RequestBuilder) -> RequestBuilder {
for (a, b) in self.as_ref() {
req = req.header(*a, *b);
}
req
}
}
pub trait Parameters {
fn write_parameters(&self, req: RequestBuilder) -> RequestBuilder;
}
pub trait ParametersExt: serde::Serialize {}
impl<T: ParametersExt> Parameters for T {
fn write_parameters(&self, req: RequestBuilder) -> RequestBuilder {
req.query(self)
}
}
pub trait Body {
fn write_body(&self, req: RequestBuilder) -> RequestBuilder;
}
pub trait BodyExt: serde::Serialize {}
impl<T: BodyExt> Body for T {
fn write_body(&self, req: RequestBuilder) -> RequestBuilder {
req.json(self)
}
}
#[async_trait]
#[cfg_attr(feature = "nightly", doc(spotlight))]
pub trait Request {
const ENDPOINT: &'static str;
type Headers: Headers;
type Parameters: Parameters;
type Body: Body;
type Response: Response + 'static;
type ErrorCodes: ErrorCodes + 'static;
const METHOD: reqwest::Method;
fn builder() -> Self;
fn headers(&self) -> &Self::Headers;
fn parameters(&self) -> &Self::Parameters;
fn body(&self) -> &Self::Body;
fn ready(&self) -> Result<(), RequestError<Self::ErrorCodes>>;
async fn make_request<C>(
&self,
client: C,
) -> Result<Self::Response, RequestError<Self::ErrorCodes>>
where
C: std::borrow::Borrow<Client> + Send,
{
self.ready()?;
let mut req = client.borrow().request(Self::METHOD, Self::ENDPOINT);
req = self.headers().write_headers(req);
req = self.parameters().write_parameters(req);
req = self.body().write_body(req);
log::info!("Making request {:#?}", req);
let resp = req.send().await?;
log::info!("Got response {:#?}", resp);
resp.json::<PossibleResponse<Self::Response>>()
.await?
.into_result()
.map_err(FailureStatus::into)
}
}
pub trait Response: DeserializeOwned + Sized {}
impl<T: DeserializeOwned> Response for T {}