use self::RejectionKind::*;
use super::{FullResponse, Response, StatusCode};
use crate::{
error::Error,
request::{Context, RequestContext},
trace::TraceContext,
validation::Validation,
warn, SharedString,
};
#[derive(Debug)]
pub struct Rejection {
kind: RejectionKind,
context: Option<Context>,
trace_context: Option<TraceContext>,
}
#[derive(Debug)]
#[non_exhaustive]
enum RejectionKind {
BadRequest(Validation),
Unauthorized(Error),
Forbidden(Error),
NotFound(Error),
MethodNotAllowed(Error),
Conflict(Error),
InternalServerError(Error),
ServiceUnavailable(Error),
}
impl Rejection {
#[inline]
pub fn bad_request(validation: Validation) -> Self {
Self {
kind: BadRequest(validation),
context: None,
trace_context: None,
}
}
#[inline]
pub fn unauthorized(err: impl Into<Error>) -> Self {
Self {
kind: Unauthorized(err.into()),
context: None,
trace_context: None,
}
}
#[inline]
pub fn forbidden(err: impl Into<Error>) -> Self {
Self {
kind: Forbidden(err.into()),
context: None,
trace_context: None,
}
}
#[inline]
pub fn not_found(err: impl Into<Error>) -> Self {
Self {
kind: NotFound(err.into()),
context: None,
trace_context: None,
}
}
#[inline]
pub fn method_not_allowed(err: impl Into<Error>) -> Self {
Self {
kind: MethodNotAllowed(err.into()),
context: None,
trace_context: None,
}
}
#[inline]
pub fn conflict(err: impl Into<Error>) -> Self {
Self {
kind: Conflict(err.into()),
context: None,
trace_context: None,
}
}
#[inline]
pub fn internal_server_error(err: impl Into<Error>) -> Self {
Self {
kind: InternalServerError(err.into()),
context: None,
trace_context: None,
}
}
#[inline]
pub fn service_unavailable(err: impl Into<Error>) -> Self {
Self {
kind: ServiceUnavailable(err.into()),
context: None,
trace_context: None,
}
}
#[inline]
pub fn from_validation_entry(key: impl Into<SharedString>, err: impl Into<Error>) -> Self {
let validation = Validation::from_entry(key, err);
Self::bad_request(validation)
}
pub fn from_error(err: impl Into<Error>) -> Self {
let err = err.into();
let message = err.message();
if message.starts_with("401 Unauthorized") {
Self::unauthorized(err)
} else if message.starts_with("403 Forbidden") {
Self::forbidden(err)
} else if message.starts_with("404 Not Found") {
Self::not_found(err)
} else if message.starts_with("405 Method Not Allowed") {
Self::method_not_allowed(err)
} else if message.starts_with("409 Conflict") {
Self::conflict(err)
} else if message.starts_with("503 Service Unavailable") {
Self::service_unavailable(err)
} else {
Self::internal_server_error(err)
}
}
#[inline]
pub fn with_message(message: impl Into<SharedString>) -> Self {
Self::from_error(Error::new(message))
}
#[inline]
pub fn context<T: RequestContext + ?Sized>(mut self, ctx: &T) -> Self {
self.context = ctx.get_context();
self.trace_context = Some(ctx.new_trace_context());
self
}
#[inline]
pub fn status_code(&self) -> u16 {
match &self.kind {
BadRequest(_) => 400,
Unauthorized(_) => 401,
Forbidden(_) => 403,
NotFound(_) => 404,
MethodNotAllowed(_) => 405,
Conflict(_) => 409,
InternalServerError(_) => 500,
ServiceUnavailable(_) => 503,
}
}
}
impl From<Rejection> for Response<StatusCode> {
fn from(rejection: Rejection) -> Self {
let mut res = match rejection.kind {
BadRequest(validation) => {
let mut res = Response::new(StatusCode::BAD_REQUEST);
res.set_validation_data(validation);
res
}
Unauthorized(err) => {
let mut res = Response::new(StatusCode::UNAUTHORIZED);
res.set_error_message(err);
res
}
Forbidden(err) => {
let mut res = Response::new(StatusCode::FORBIDDEN);
res.set_error_message(err);
res
}
NotFound(err) => {
let mut res = Response::new(StatusCode::NOT_FOUND);
res.set_error_message(err);
res
}
MethodNotAllowed(err) => {
let mut res = Response::new(StatusCode::METHOD_NOT_ALLOWED);
res.set_error_message(err);
res
}
Conflict(err) => {
let mut res = Response::new(StatusCode::CONFLICT);
res.set_error_message(err);
res
}
InternalServerError(err) => {
let mut res = Response::new(StatusCode::INTERNAL_SERVER_ERROR);
res.set_error_message(err);
res
}
ServiceUnavailable(err) => {
let mut res = Response::new(StatusCode::SERVICE_UNAVAILABLE);
res.set_error_message(err);
res
}
};
if let Some(ctx) = rejection.context {
res.set_instance(ctx.instance().to_owned());
res.set_start_time(ctx.start_time());
res.set_request_id(ctx.request_id());
}
res.set_trace_context(rejection.trace_context);
res
}
}
impl From<Rejection> for FullResponse {
#[inline]
fn from(rejection: Rejection) -> Self {
Response::from(rejection).into()
}
}
pub trait ExtractRejection<T> {
fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection>;
}
impl<T> ExtractRejection<T> for Option<T> {
#[inline]
fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
self.ok_or_else(|| Rejection::not_found(warn!("resource does not exist")).context(ctx))
}
}
impl<T> ExtractRejection<T> for Result<T, Validation> {
#[inline]
fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
self.map_err(|err| Rejection::bad_request(err).context(ctx))
}
}
impl<T, E: Into<Error>> ExtractRejection<T> for Result<T, E> {
#[inline]
fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
self.map_err(|err| Rejection::from_error(err).context(ctx))
}
}
impl<T, E: Into<Error>> ExtractRejection<T> for Result<Option<T>, E> {
#[inline]
fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
self.map_err(|err| Rejection::from_error(err).context(ctx))?
.ok_or_else(|| Rejection::not_found(warn!("resource does not exist")).context(ctx))
}
}
#[macro_export]
macro_rules! reject {
($ctx:ident, $validation:expr $(,)?) => {{
return Err(Rejection::bad_request($validation).context(&$ctx).into());
}};
($ctx:ident, $key:literal, $message:literal $(,)?) => {{
let err = Error::new($message);
warn!("invalid value for `{}`: {}", $key, $message);
return Err(Rejection::from_validation_entry($key, err).context(&$ctx).into());
}};
($ctx:ident, $key:literal, $err:expr $(,)?) => {{
return Err(Rejection::from_validation_entry($key, $err).context(&$ctx).into());
}};
($ctx:ident, $kind:ident, $message:literal $(,)?) => {{
let err = warn!($message);
return Err(Rejection::$kind(err).context(&$ctx).into());
}};
($ctx:ident, $kind:ident, $err:expr $(,)?) => {{
return Err(Rejection::$kind($err).context(&$ctx).into());
}};
($ctx:ident, $kind:ident, $fmt:expr, $($arg:tt)+) => {{
let err = warn!($fmt, $($arg)+);
return Err(Rejection::$kind(err).context(&$ctx).into());
}};
}