use std::{
convert::Infallible,
error::Error as StdError,
fmt::{self, Debug, Display, Formatter},
string::FromUtf8Error,
};
use headers::{ContentRange, HeaderMapExt};
use http::Method;
use crate::{http::StatusCode, IntoResponse, Response};
macro_rules! define_http_error {
($($(#[$docs:meta])* ($name:ident, $status:ident);)*) => {
$(
$(#[$docs])*
#[allow(non_snake_case)]
#[inline]
pub fn $name(err: impl StdError + Send + Sync + 'static) -> Error {
Error::new(err, StatusCode::$status)
}
)*
};
}
pub trait ResponseError {
fn status(&self) -> StatusCode;
fn as_response(&self) -> Response
where
Self: StdError + Send + Sync + 'static,
{
Response::builder()
.status(self.status())
.body(self.to_string())
}
}
enum ErrorSource {
BoxedError(Box<dyn StdError + Send + Sync>),
#[cfg(feature = "anyhow")]
Anyhow(anyhow::Error),
}
impl Debug for ErrorSource {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
ErrorSource::BoxedError(err) => Debug::fmt(err, f),
#[cfg(feature = "anyhow")]
ErrorSource::Anyhow(err) => Debug::fmt(err, f),
}
}
}
type BoxAsResponseFn = Box<dyn Fn(&Error) -> Response + Send + Sync + 'static>;
enum AsResponse {
Status(StatusCode),
Fn(BoxAsResponseFn),
}
impl AsResponse {
#[inline]
fn from_status(status: StatusCode) -> Self {
AsResponse::Status(status)
}
fn from_type<T: ResponseError + StdError + Send + Sync + 'static>() -> Self {
AsResponse::Fn(Box::new(|err| {
let err = err.downcast_ref::<T>().expect("valid error");
err.as_response()
}))
}
fn as_response(&self, err: &Error) -> Response {
match self {
AsResponse::Status(status) => Response::builder().status(*status).body(err.to_string()),
AsResponse::Fn(f) => f(err),
}
}
}
pub struct Error {
as_response: AsResponse,
source: ErrorSource,
}
impl Debug for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Error")
.field("source", &self.source)
.finish()
}
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match &self.source {
ErrorSource::BoxedError(err) => Display::fmt(err, f),
#[cfg(feature = "anyhow")]
ErrorSource::Anyhow(err) => Display::fmt(err, f),
}
}
}
impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
unreachable!()
}
}
impl<T: ResponseError + StdError + Send + Sync + 'static> From<T> for Error {
fn from(err: T) -> Self {
Error {
as_response: AsResponse::from_type::<T>(),
source: ErrorSource::BoxedError(Box::new(err)),
}
}
}
impl From<Box<dyn StdError + Send + Sync>> for Error {
fn from(err: Box<dyn StdError + Send + Sync>) -> Self {
(StatusCode::INTERNAL_SERVER_ERROR, err).into()
}
}
impl From<(StatusCode, Box<dyn StdError + Send + Sync>)> for Error {
fn from((status, err): (StatusCode, Box<dyn StdError + Send + Sync>)) -> Self {
Error {
as_response: AsResponse::from_status(status),
source: ErrorSource::BoxedError(err),
}
}
}
#[cfg(feature = "anyhow")]
impl From<anyhow::Error> for Error {
fn from(err: anyhow::Error) -> Self {
Error {
as_response: AsResponse::from_status(StatusCode::INTERNAL_SERVER_ERROR),
source: ErrorSource::Anyhow(err),
}
}
}
#[cfg(feature = "anyhow")]
impl From<(StatusCode, anyhow::Error)> for Error {
fn from((status, err): (StatusCode, anyhow::Error)) -> Self {
Error {
as_response: AsResponse::from_status(status),
source: ErrorSource::Anyhow(err),
}
}
}
impl From<StatusCode> for Error {
fn from(status: StatusCode) -> Self {
Error::from_status(status)
}
}
impl Error {
#[inline]
pub fn new<T: StdError + Send + Sync + 'static>(err: T, status: StatusCode) -> Self {
Self {
as_response: AsResponse::from_status(status),
source: ErrorSource::BoxedError(Box::new(err)),
}
}
pub fn from_status(status: StatusCode) -> Self {
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
struct StatusError(StatusCode);
impl ResponseError for StatusError {
fn status(&self) -> StatusCode {
self.0
}
}
StatusError(status).into()
}
pub fn from_string(msg: impl Into<String>, status: StatusCode) -> Self {
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
struct StringError(String);
Self::new(StringError(msg.into()), status)
}
#[inline]
pub fn downcast_ref<T: StdError + Send + Sync + 'static>(&self) -> Option<&T> {
match &self.source {
ErrorSource::BoxedError(err) => err.downcast_ref::<T>(),
#[cfg(feature = "anyhow")]
ErrorSource::Anyhow(err) => err.downcast_ref::<T>(),
}
}
#[inline]
pub fn downcast<T: StdError + Send + Sync + 'static>(self) -> Result<T, Error> {
let as_response = self.as_response;
match self.source {
ErrorSource::BoxedError(err) => match err.downcast::<T>() {
Ok(err) => Ok(*err),
Err(err) => Err(Error {
as_response,
source: ErrorSource::BoxedError(err),
}),
},
#[cfg(feature = "anyhow")]
ErrorSource::Anyhow(err) => match err.downcast::<T>() {
Ok(err) => Ok(err),
Err(err) => Err(Error {
as_response,
source: ErrorSource::Anyhow(err),
}),
},
}
}
#[inline]
pub fn is<T: StdError + Debug + Send + Sync + 'static>(&self) -> bool {
match &self.source {
ErrorSource::BoxedError(err) => err.is::<T>(),
#[cfg(feature = "anyhow")]
ErrorSource::Anyhow(err) => err.is::<T>(),
}
}
pub fn as_response(&self) -> Response {
self.as_response.as_response(self)
}
}
define_http_error!(
(BadRequest, BAD_REQUEST);
(Unauthorized, UNAUTHORIZED);
(PaymentRequired, PAYMENT_REQUIRED);
(Forbidden, FORBIDDEN);
(NotFound, NOT_FOUND);
(MethodNotAllowed, METHOD_NOT_ALLOWED);
(NotAcceptable, NOT_ACCEPTABLE);
(ProxyAuthenticationRequired, PROXY_AUTHENTICATION_REQUIRED);
(RequestTimeout, REQUEST_TIMEOUT);
(Conflict, CONFLICT);
(Gone, GONE);
(LengthRequired, LENGTH_REQUIRED);
(PayloadTooLarge, PAYLOAD_TOO_LARGE);
(UriTooLong, URI_TOO_LONG);
(UnsupportedMediaType, UNSUPPORTED_MEDIA_TYPE);
(RangeNotSatisfiable, RANGE_NOT_SATISFIABLE);
(ImATeapot, IM_A_TEAPOT);
(MisdirectedRequest, MISDIRECTED_REQUEST);
(UnprocessableEntity, UNPROCESSABLE_ENTITY);
(Locked, LOCKED);
(FailedDependency, FAILED_DEPENDENCY);
(UpgradeRequired, UPGRADE_REQUIRED);
(PreconditionFailed, PRECONDITION_FAILED);
(PreconditionRequired, PRECONDITION_REQUIRED);
(TooManyRequests, TOO_MANY_REQUESTS);
(RequestHeaderFieldsTooLarge, REQUEST_HEADER_FIELDS_TOO_LARGE);
(UnavailableForLegalReasons, UNAVAILABLE_FOR_LEGAL_REASONS);
(ExpectationFailed, EXPECTATION_FAILED);
(InternalServerError, INTERNAL_SERVER_ERROR);
(NotImplemented, NOT_IMPLEMENTED);
(BadGateway, BAD_GATEWAY);
(ServiceUnavailable, SERVICE_UNAVAILABLE);
(GatewayTimeout, GATEWAY_TIMEOUT);
(HttpVersionNotSupported, HTTP_VERSION_NOT_SUPPORTED);
(VariantAlsoNegotiates, VARIANT_ALSO_NEGOTIATES);
(InsufficientStorage, INSUFFICIENT_STORAGE);
(LoopDetected, LOOP_DETECTED);
(NotExtended, NOT_EXTENDED);
(NetworkAuthenticationRequired, NETWORK_AUTHENTICATION_REQUIRED);
);
pub type Result<T, E = Error> = ::std::result::Result<T, E>;
pub trait IntoResult<T: IntoResponse> {
fn into_result(self) -> Result<T>;
}
impl<T, E> IntoResult<T> for Result<T, E>
where
T: IntoResponse,
E: Into<Error> + Debug + Send + Sync + 'static,
{
#[inline]
fn into_result(self) -> Result<T> {
self.map_err(Into::into)
}
}
impl<T: IntoResponse> IntoResult<T> for T {
#[inline]
fn into_result(self) -> Result<T> {
Ok(self)
}
}
macro_rules! define_simple_errors {
($($(#[$docs:meta])* ($name:ident, $status:ident, $err_msg:literal);)*) => {
$(
$(#[$docs])*
#[derive(Debug, thiserror::Error, Copy, Clone, Eq, PartialEq)]
#[error($err_msg)]
pub struct $name;
impl ResponseError for $name {
fn status(&self) -> StatusCode {
StatusCode::$status
}
}
)*
};
}
define_simple_errors!(
(ParsePathError, BAD_REQUEST, "invalid path params");
(NotFoundError, NOT_FOUND, "not found");
(MethodNotAllowedError, METHOD_NOT_ALLOWED, "method not allowed");
(CorsError, UNAUTHORIZED, "unauthorized");
);
#[derive(Debug, thiserror::Error)]
pub enum ReadBodyError {
#[error("the body has been taken")]
BodyHasBeenTaken,
#[error("parse utf8: {0}")]
Utf8(#[from] FromUtf8Error),
#[error("io: {0}")]
Io(#[from] std::io::Error),
}
impl ResponseError for ReadBodyError {
fn status(&self) -> StatusCode {
match self {
ReadBodyError::BodyHasBeenTaken => StatusCode::INTERNAL_SERVER_ERROR,
ReadBodyError::Utf8(_) => StatusCode::BAD_REQUEST,
ReadBodyError::Io(_) => StatusCode::BAD_REQUEST,
}
}
}
#[cfg(feature = "cookie")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookie")))]
#[derive(Debug, thiserror::Error)]
pub enum ParseCookieError {
#[error("cookie is illegal")]
CookieIllegal,
#[error("`Cookie` header is required")]
CookieHeaderRequired,
#[error("cookie is illegal: {0}")]
ParseJsonValue(#[from] serde_json::Error),
}
#[cfg(feature = "cookie")]
impl ResponseError for ParseCookieError {
fn status(&self) -> StatusCode {
StatusCode::BAD_REQUEST
}
}
#[derive(Debug, thiserror::Error, Eq, PartialEq)]
#[error("data of type `{0}` was not found.")]
pub struct GetDataError(pub &'static str);
impl ResponseError for GetDataError {
fn status(&self) -> StatusCode {
StatusCode::INTERNAL_SERVER_ERROR
}
}
#[derive(Debug, thiserror::Error)]
pub enum ParseFormError {
#[error("invalid content type `{0}`, expect: `application/x-www-form-urlencoded`")]
InvalidContentType(String),
#[error("expect content type `application/x-www-form-urlencoded`")]
ContentTypeRequired,
#[error("url decode: {0}")]
UrlDecode(#[from] serde_urlencoded::de::Error),
}
impl ResponseError for ParseFormError {
fn status(&self) -> StatusCode {
match self {
ParseFormError::InvalidContentType(_) => StatusCode::BAD_REQUEST,
ParseFormError::ContentTypeRequired => StatusCode::BAD_REQUEST,
ParseFormError::UrlDecode(_) => StatusCode::BAD_REQUEST,
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("parse: {0}")]
pub struct ParseJsonError(#[from] pub serde_json::Error);
impl ResponseError for ParseJsonError {
fn status(&self) -> StatusCode {
StatusCode::BAD_REQUEST
}
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ParseQueryError(#[from] pub serde_urlencoded::de::Error);
impl ResponseError for ParseQueryError {
fn status(&self) -> StatusCode {
StatusCode::BAD_REQUEST
}
}
#[cfg(feature = "multipart")]
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
#[derive(Debug, thiserror::Error)]
pub enum ParseMultipartError {
#[error("invalid content type `{0}`, expect: `multipart/form-data`")]
InvalidContentType(String),
#[error("expect content type `multipart/form-data`")]
ContentTypeRequired,
#[error("parse: {0}")]
Multipart(#[from] multer::Error),
}
#[cfg(feature = "multipart")]
impl ResponseError for ParseMultipartError {
fn status(&self) -> StatusCode {
match self {
ParseMultipartError::InvalidContentType(_) => StatusCode::BAD_REQUEST,
ParseMultipartError::ContentTypeRequired => StatusCode::BAD_REQUEST,
ParseMultipartError::Multipart(_) => StatusCode::BAD_REQUEST,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ParseTypedHeaderError {
#[error("header `{0}` is required")]
HeaderRequired(String),
#[error("parse: {0}")]
TypedHeader(#[from] headers::Error),
}
impl ResponseError for ParseTypedHeaderError {
fn status(&self) -> StatusCode {
StatusCode::BAD_REQUEST
}
}
#[cfg(feature = "websocket")]
#[cfg_attr(docsrs, doc(cfg(feature = "websocket")))]
#[derive(Debug, thiserror::Error)]
pub enum WebSocketError {
#[error("invalid protocol")]
InvalidProtocol,
#[error(transparent)]
UpgradeError(#[from] UpgradeError),
}
#[cfg(feature = "websocket")]
impl ResponseError for WebSocketError {
fn status(&self) -> StatusCode {
match self {
WebSocketError::InvalidProtocol => StatusCode::BAD_REQUEST,
WebSocketError::UpgradeError(err) => err.status(),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum UpgradeError {
#[error("no upgrade")]
NoUpgrade,
#[error("{0}")]
Other(String),
}
impl ResponseError for UpgradeError {
fn status(&self) -> StatusCode {
match self {
UpgradeError::NoUpgrade => StatusCode::INTERNAL_SERVER_ERROR,
UpgradeError::Other(_) => StatusCode::BAD_REQUEST,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum StaticFileError {
#[error("method not found")]
MethodNotAllowed(Method),
#[error("invalid path")]
InvalidPath,
#[error("forbidden: {0}")]
Forbidden(String),
#[error("not found: {0}")]
NotFound(String),
#[error("precondition failed")]
PreconditionFailed,
#[error("range not satisfiable")]
RangeNotSatisfiable {
size: u64,
},
#[error("io: {0}")]
Io(#[from] std::io::Error),
}
impl ResponseError for StaticFileError {
fn status(&self) -> StatusCode {
match self {
StaticFileError::MethodNotAllowed(_) => StatusCode::METHOD_NOT_ALLOWED,
StaticFileError::InvalidPath => StatusCode::BAD_REQUEST,
StaticFileError::Forbidden(_) => StatusCode::FORBIDDEN,
StaticFileError::NotFound(_) => StatusCode::NOT_FOUND,
StaticFileError::PreconditionFailed => StatusCode::PRECONDITION_FAILED,
StaticFileError::RangeNotSatisfiable { .. } => StatusCode::RANGE_NOT_SATISFIABLE,
StaticFileError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
fn as_response(&self) -> Response {
let mut resp = Response::builder()
.status(self.status())
.body(self.to_string());
if let StaticFileError::RangeNotSatisfiable { size } = self {
resp.headers_mut()
.typed_insert(ContentRange::unsatisfied_bytes(*size));
}
resp
}
}
#[derive(Debug, thiserror::Error, Eq, PartialEq)]
pub enum SizedLimitError {
#[error("missing `Content-Length` header")]
MissingContentLength,
#[error("payload too large")]
PayloadTooLarge,
}
impl ResponseError for SizedLimitError {
fn status(&self) -> StatusCode {
StatusCode::BAD_REQUEST
}
}
#[derive(Debug, thiserror::Error, Eq, PartialEq)]
pub enum RouteError {
#[error("invalid path: {0}")]
InvalidPath(String),
#[error("duplicate path: {0}")]
Duplicate(String),
#[error("invalid regex in path: {path}")]
InvalidRegex {
path: String,
regex: String,
},
}
impl ResponseError for RouteError {
fn status(&self) -> StatusCode {
StatusCode::INTERNAL_SERVER_ERROR
}
}
#[cfg(feature = "i18n")]
#[derive(Debug, thiserror::Error)]
pub enum I18NError {
#[error("fluent: {}", .0[0])]
Fluent(Vec<fluent::FluentError>),
#[error("fluent parser: {}", .0[0])]
FluentParser(Vec<fluent_syntax::parser::ParserError>),
#[error("no value")]
FluentNoValue,
#[error("msg not found: `{id}`")]
FluentMessageNotFound {
id: String,
},
#[error("invalid language id: {0}")]
LanguageIdentifier(#[from] unic_langid::LanguageIdentifierError),
#[error("io: {0}")]
Io(#[from] std::io::Error),
}
#[cfg(feature = "i18n")]
impl ResponseError for I18NError {
fn status(&self) -> StatusCode {
StatusCode::INTERNAL_SERVER_ERROR
}
}
#[cfg(test)]
mod tests {
use std::io::{Error as IoError, ErrorKind};
use super::*;
#[test]
fn test_into_result() {
assert!(matches!("hello".into_result(), Ok("hello")));
assert!(matches!(Ok::<_, Error>("hello").into_result(), Ok("hello")));
assert!(matches!(
Ok::<_, NotFoundError>("hello").into_result(),
Ok("hello")
));
assert!(Err::<String, _>(NotFoundError)
.into_result()
.unwrap_err()
.is::<NotFoundError>());
}
#[test]
fn test_error() {
let err = Error::new(
IoError::new(ErrorKind::AlreadyExists, "aaa"),
StatusCode::BAD_GATEWAY,
);
assert!(err.is::<IoError>());
assert_eq!(
err.downcast_ref::<IoError>().unwrap().kind(),
ErrorKind::AlreadyExists
);
assert_eq!(err.as_response().status(), StatusCode::BAD_GATEWAY);
}
#[test]
fn test_box_error() {
let boxed_err: Box<dyn StdError + Send + Sync> =
Box::new(IoError::new(ErrorKind::AlreadyExists, "aaa"));
let err: Error = Error::from((StatusCode::BAD_GATEWAY, boxed_err));
assert!(err.is::<IoError>());
assert_eq!(
err.downcast_ref::<IoError>().unwrap().kind(),
ErrorKind::AlreadyExists
);
assert_eq!(err.as_response().status(), StatusCode::BAD_GATEWAY);
}
#[cfg(feature = "anyhow")]
#[test]
fn test_anyhow_error() {
let anyhow_err: anyhow::Error = IoError::new(ErrorKind::AlreadyExists, "aaa").into();
let err: Error = Error::from((StatusCode::BAD_GATEWAY, anyhow_err));
assert!(err.is::<IoError>());
assert_eq!(
err.downcast_ref::<IoError>().unwrap().kind(),
ErrorKind::AlreadyExists
);
assert_eq!(err.as_response().status(), StatusCode::BAD_GATEWAY);
}
#[tokio::test]
async fn test_custom_as_response() {
#[derive(Debug, thiserror::Error)]
#[error("my error")]
struct MyError;
impl ResponseError for MyError {
fn status(&self) -> StatusCode {
StatusCode::BAD_GATEWAY
}
fn as_response(&self) -> Response {
Response::builder()
.status(self.status())
.body("my error message")
}
}
let err = Error::from(MyError);
let resp = err.as_response();
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
resp.into_body().into_string().await.unwrap(),
"my error message"
);
}
}