use super::Error;
use crate::{
HttpResult,
http::{FromRequestParts, IntoResponse, MapErr},
status,
};
use futures_util::future::BoxFuture;
use hyper::Uri;
use hyper::http::request::Parts;
use std::sync::Arc;
pub trait ErrorHandler: Send + Sync {
fn extract(&self, parts: &Parts) -> Box<dyn ErasedErrorArgs + Send>;
#[inline]
fn needs_parts_extraction(&self) -> bool {
true
}
}
pub trait ErasedErrorArgs: Send {
fn call(self: Box<Self>, err: Error) -> BoxFuture<'static, HttpResult>;
}
pub(crate) enum ErrorArgsSlot {
Uri(Uri),
Custom(Box<dyn ErasedErrorArgs + Send>),
}
impl ErrorArgsSlot {
#[inline]
pub(crate) async fn call(self, err: Error) -> HttpResult {
match self {
Self::Uri(uri) => {
let mut err = err;
if err.instance.is_none() {
err.instance = Some(uri.to_string());
}
default_error_handler(err).await
}
Self::Custom(args) => args.call(err).await,
}
}
}
#[derive(Debug)]
pub(crate) struct DefaultErrorHandler;
impl ErrorHandler for DefaultErrorHandler {
#[inline]
fn extract(&self, parts: &Parts) -> Box<dyn ErasedErrorArgs + Send> {
Box::new(DefaultErrorArgs {
uri: parts.uri.clone(),
})
}
#[inline]
fn needs_parts_extraction(&self) -> bool {
false
}
}
#[derive(Debug)]
pub struct ErrorFunc<F, R, Args>
where
F: MapErr<Args, Output = R>,
R: IntoResponse,
Args: FromRequestParts + Send,
{
func: F,
_marker: std::marker::PhantomData<fn(Args) -> R>,
}
impl<F, R, Args> ErrorFunc<F, R, Args>
where
F: MapErr<Args, Output = R>,
R: IntoResponse,
Args: FromRequestParts + Send,
{
pub(crate) fn new(func: F) -> Self {
Self {
func,
_marker: std::marker::PhantomData,
}
}
}
struct BoundErrorArgs<F, Args> {
func: F,
args: Args,
uri: Uri,
}
impl<F, Args> ErasedErrorArgs for BoundErrorArgs<F, Args>
where
F: MapErr<Args> + Send + 'static,
F::Output: IntoResponse + 'static,
Args: Send + 'static,
{
fn call(self: Box<Self>, mut err: Error) -> BoxFuture<'static, HttpResult> {
Box::pin(async move {
if err.instance.is_none() {
err.instance = Some(self.uri.to_string());
}
match self.func.map_err(err, self.args).await.into_response() {
Ok(resp) => Ok(resp),
Err(err) => default_error_handler(err).await,
}
})
}
}
struct DefaultErrorArgs {
uri: Uri,
}
impl ErasedErrorArgs for DefaultErrorArgs {
fn call(self: Box<Self>, mut err: Error) -> BoxFuture<'static, HttpResult> {
Box::pin(async move {
if err.instance.is_none() {
err.instance = Some(self.uri.to_string());
}
default_error_handler(err).await
})
}
}
impl<F, R, Args> ErrorHandler for ErrorFunc<F, R, Args>
where
F: MapErr<Args, Output = R> + Clone + 'static,
R: IntoResponse + 'static,
Args: FromRequestParts + Send + 'static,
{
#[inline]
fn extract(&self, parts: &Parts) -> Box<dyn ErasedErrorArgs + Send> {
let uri = parts.uri.clone();
match Args::from_parts(parts) {
Ok(args) => Box::new(BoundErrorArgs {
func: self.func.clone(),
args,
uri,
}),
Err(_) => Box::new(DefaultErrorArgs { uri }),
}
}
}
impl<F, R, Args> From<ErrorFunc<F, R, Args>> for PipelineErrorHandler
where
F: MapErr<Args, Output = R>,
R: IntoResponse + 'static,
Args: FromRequestParts + Send + 'static,
{
#[inline]
fn from(func: ErrorFunc<F, R, Args>) -> Self {
Arc::new(func)
}
}
pub(crate) type PipelineErrorHandler = Arc<dyn ErrorHandler + Send + Sync>;
#[inline]
pub(crate) async fn default_error_handler(err: Error) -> HttpResult {
status!(err.status.as_u16(), "{err}")
}
#[inline]
pub(crate) fn extract_error_args(handler: &PipelineErrorHandler, parts: &Parts) -> ErrorArgsSlot {
if handler.needs_parts_extraction() {
ErrorArgsSlot::Custom(handler.extract(parts))
} else {
ErrorArgsSlot::Uri(parts.uri.clone())
}
}
#[cfg(test)]
mod tests {
use super::{
DefaultErrorHandler, Error, ErrorArgsSlot, ErrorFunc, PipelineErrorHandler,
default_error_handler, extract_error_args,
};
use crate::{error::ErrorHandler, status};
use http_body_util::BodyExt;
use hyper::Request;
use std::sync::Arc;
#[tokio::test]
async fn default_error_handler_returns_server_error_status_code() {
let error = Error::server_error("Some error");
let response = default_error_handler(error).await;
assert!(response.is_ok());
let mut response = response.unwrap();
let body = &response.body_mut().collect().await.unwrap().to_bytes();
assert_eq!(response.status(), 500);
assert_eq!(String::from_utf8_lossy(body), "Some error");
}
#[tokio::test]
async fn default_error_handler_returns_client_error_status_code() {
let error = Error::client_error("Some error");
let response = default_error_handler(error).await;
assert!(response.is_ok());
let mut response = response.unwrap();
let body = &response.body_mut().collect().await.unwrap().to_bytes();
assert_eq!(response.status(), 400);
assert_eq!(String::from_utf8_lossy(body), "Some error");
}
#[tokio::test]
async fn it_create_new_error_handler() {
let fallback = |_: Error| async { status!(403) };
let handler = ErrorFunc::new(fallback);
let error = Error::server_error("Some error");
let req = Request::get("/foo/bar?baz").body(()).unwrap();
let (parts, _) = req.into_parts();
let extracted = handler.extract(&parts);
let response = extracted.call(error).await;
assert!(response.is_ok());
let mut response = response.unwrap();
let body = &response.body_mut().collect().await.unwrap().to_bytes();
assert_eq!(response.status(), 403);
assert_eq!(body.len(), 0);
}
#[tokio::test]
async fn it_calls_error_handler_via_extract_error_args() {
let fallback = |_: Error| async { status!(403) };
let handler = PipelineErrorHandler::from(ErrorFunc::new(fallback));
let error = Error::server_error("Some error");
let req = Request::get("/foo/bar?baz").body(()).unwrap();
let (parts, _) = req.into_parts();
let slot = extract_error_args(&handler, &parts);
assert!(matches!(slot, ErrorArgsSlot::Custom(_)));
let response = slot.call(error).await;
assert!(response.is_ok());
let mut response = response.unwrap();
let body = &response.body_mut().collect().await.unwrap().to_bytes();
assert_eq!(response.status(), 403);
assert_eq!(body.len(), 0);
}
#[tokio::test]
async fn it_returns_uri_slot_for_default_handler() {
let handler: PipelineErrorHandler = Arc::new(DefaultErrorHandler);
let req = Request::get("/foo/bar").body(()).unwrap();
let (parts, _) = req.into_parts();
let slot = extract_error_args(&handler, &parts);
assert!(matches!(slot, ErrorArgsSlot::Uri(_)));
}
#[tokio::test]
async fn it_calls_default_handler_via_uri_slot() {
let handler: PipelineErrorHandler = Arc::new(DefaultErrorHandler);
let error = Error::server_error("Some error");
let req = Request::get("/foo/bar").body(()).unwrap();
let (parts, _) = req.into_parts();
let slot = extract_error_args(&handler, &parts);
let response = slot.call(error).await;
assert!(response.is_ok());
assert_eq!(response.unwrap().status(), 500);
}
}