use crate::service::web::response::{ErrorResponse, IntoResponse};
use crate::{Request, Response};
use rama_core::{Layer, Service};
use rama_utils::macros::define_inner_service_accessors;
use std::convert::Infallible;
#[derive(Debug, Clone)]
pub struct ErrorHandlerLayer<F = ()> {
error_mapper: F,
}
impl Default for ErrorHandlerLayer {
fn default() -> Self {
Self::new()
}
}
impl ErrorHandlerLayer {
#[must_use]
pub const fn new() -> Self {
Self { error_mapper: () }
}
pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandlerLayer<F> {
ErrorHandlerLayer { error_mapper }
}
}
impl<S, F: Clone> Layer<S> for ErrorHandlerLayer<F> {
type Service = ErrorHandler<S, F>;
fn layer(&self, inner: S) -> Self::Service {
ErrorHandler::new(inner).error_mapper(self.error_mapper.clone())
}
fn into_layer(self, inner: S) -> Self::Service {
ErrorHandler::new(inner).error_mapper(self.error_mapper)
}
}
#[derive(Debug, Clone)]
pub struct ErrorHandler<S, F = ()> {
inner: S,
error_mapper: F,
}
impl<S> ErrorHandler<S> {
pub const fn new(inner: S) -> Self {
Self {
inner,
error_mapper: (),
}
}
define_inner_service_accessors!();
pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandler<S, F> {
ErrorHandler {
inner: self.inner,
error_mapper,
}
}
}
impl<S, Body> Service<Request<Body>> for ErrorHandler<S, ()>
where
S: Service<Request<Body>, Output: IntoResponse, Error: Into<ErrorResponse>>,
Body: Send + 'static,
{
type Output = Response;
type Error = Infallible;
async fn serve(&self, req: Request<Body>) -> Result<Self::Output, Self::Error> {
match self.inner.serve(req).await {
Ok(response) => Ok(response.into_response()),
Err(error) => Ok(error.into().into_response()),
}
}
}
impl<S, F, R, Body> Service<Request<Body>> for ErrorHandler<S, F>
where
S: Service<Request<Body>, Output: IntoResponse>,
F: Fn(S::Error) -> R + Clone + Send + Sync + 'static,
R: IntoResponse + 'static,
Body: Send + 'static,
{
type Output = Response;
type Error = Infallible;
async fn serve(&self, req: Request<Body>) -> Result<Self::Output, Self::Error> {
match self.inner.serve(req).await {
Ok(response) => Ok(response.into_response()),
Err(error) => Ok((self.error_mapper)(error).into_response()),
}
}
}