use crate::service::web::response::IntoResponse;
use crate::{Request, Response};
use rama_core::{Context, Layer, Service};
use rama_utils::macros::define_inner_service_accessors;
use std::{convert::Infallible, fmt};
pub struct ErrorHandlerLayer<F = ()> {
error_mapper: F,
}
impl Default for ErrorHandlerLayer {
fn default() -> Self {
Self::new()
}
}
impl<F: fmt::Debug> fmt::Debug for ErrorHandlerLayer<F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ErrorHandlerLayer")
.field("error_mapper", &self.error_mapper)
.finish()
}
}
impl<F: Clone> Clone for ErrorHandlerLayer<F> {
fn clone(&self) -> Self {
Self {
error_mapper: self.error_mapper.clone(),
}
}
}
impl ErrorHandlerLayer {
pub const fn new() -> Self {
Self { error_mapper: () }
}
pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandlerLayer<F> {
ErrorHandlerLayer { error_mapper }
}
}
impl<S, F: Clone> Layer<S> for ErrorHandlerLayer<F> {
type Service = ErrorHandler<S, F>;
fn layer(&self, inner: S) -> Self::Service {
ErrorHandler::new(inner).error_mapper(self.error_mapper.clone())
}
fn into_layer(self, inner: S) -> Self::Service {
ErrorHandler::new(inner).error_mapper(self.error_mapper)
}
}
pub struct ErrorHandler<S, F = ()> {
inner: S,
error_mapper: F,
}
impl<S: fmt::Debug, F: fmt::Debug> fmt::Debug for ErrorHandler<S, F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ErrorHandler")
.field("inner", &self.inner)
.field("error_mapper", &self.error_mapper)
.finish()
}
}
impl<S: Clone, F: Clone> Clone for ErrorHandler<S, F> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
error_mapper: self.error_mapper.clone(),
}
}
}
impl<S> ErrorHandler<S> {
pub const fn new(inner: S) -> Self {
Self {
inner,
error_mapper: (),
}
}
define_inner_service_accessors!();
pub fn error_mapper<F>(self, error_mapper: F) -> ErrorHandler<S, F> {
ErrorHandler {
inner: self.inner,
error_mapper,
}
}
}
impl<S, State, Body> Service<State, Request<Body>> for ErrorHandler<S, ()>
where
S: Service<State, Request<Body>, Response: IntoResponse, Error: IntoResponse>,
State: Clone + Send + Sync + 'static,
Body: Send + 'static,
{
type Response = Response;
type Error = Infallible;
async fn serve(
&self,
ctx: Context<State>,
req: Request<Body>,
) -> Result<Self::Response, Self::Error> {
match self.inner.serve(ctx, req).await {
Ok(response) => Ok(response.into_response()),
Err(error) => Ok(error.into_response()),
}
}
}
impl<S, F, R, State, Body> Service<State, Request<Body>> for ErrorHandler<S, F>
where
S: Service<State, Request<Body>, Response: IntoResponse>,
F: Fn(S::Error) -> R + Clone + Send + Sync + 'static,
R: IntoResponse + 'static,
State: Clone + Send + Sync + 'static,
Body: Send + 'static,
{
type Response = Response;
type Error = Infallible;
async fn serve(
&self,
ctx: Context<State>,
req: Request<Body>,
) -> Result<Self::Response, Self::Error> {
match self.inner.serve(ctx, req).await {
Ok(response) => Ok(response.into_response()),
Err(error) => Ok((self.error_mapper)(error).into_response()),
}
}
}