rok-core 0.6.0

Core primitives for the rok ecosystem — errors, crypto, i18n, config, DI, and more
Documentation
use std::sync::Arc;

use axum::{
    body::Body,
    http::Request,
    response::{IntoResponse, Response},
};
use tower::{Layer, Service};

use crate::error::RokException;
use crate::Problem;

/// Global error handler middleware.
///
/// Catches `Err` from the inner service and delegates to the wrapped handler.
/// Default handler is [`default_error_handler`].
///
/// The inner service's error type must implement `Into<Box<dyn RokException>>`.
pub struct CatchLayer<F> {
    handler: Arc<F>,
}

impl<F> CatchLayer<F> {
    pub fn new(handler: F) -> Self {
        Self {
            handler: Arc::new(handler),
        }
    }
}

impl<F> Clone for CatchLayer<F> {
    fn clone(&self) -> Self {
        Self {
            handler: self.handler.clone(),
        }
    }
}

impl<S, F> Layer<S> for CatchLayer<F>
where
    F: Fn(Box<dyn RokException>) -> Response + Send + Sync + 'static,
    S: Service<Request<Body>, Response = Response> + Send + 'static,
    S::Future: Send,
{
    type Service = CatchService<S, F>;

    fn layer(&self, inner: S) -> Self::Service {
        CatchService {
            inner,
            handler: self.handler.clone(),
        }
    }
}

/// The per-request service produced by [`CatchLayer`].
pub struct CatchService<S, F> {
    inner: S,
    handler: Arc<F>,
}

impl<S: Clone, F> Clone for CatchService<S, F> {
    fn clone(&self) -> Self {
        Self {
            inner: self.inner.clone(),
            handler: self.handler.clone(),
        }
    }
}

impl<S, F> Service<Request<Body>> for CatchService<S, F>
where
    F: Fn(Box<dyn RokException>) -> Response + Send + Sync + 'static,
    S: Service<Request<Body>, Response = Response> + Send + 'static,
    S::Error: Into<Box<dyn RokException>>,
    S::Future: Send,
{
    type Response = Response;
    type Error = std::convert::Infallible;
    type Future = std::pin::Pin<
        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
    >;

    fn poll_ready(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx).map_err(|_| unreachable!())
    }

    fn call(&mut self, req: Request<Body>) -> Self::Future {
        let handler = self.handler.clone();
        let fut = self.inner.call(req);
        Box::pin(async move {
            match fut.await {
                Ok(resp) => Ok(resp),
                Err(err) => {
                    let rok_err: Box<dyn RokException> = err.into();
                    Ok(handler(rok_err))
                }
            }
        })
    }
}

/// Default error handler: logs the exception and returns Problem JSON.
///
/// * Self-handled exceptions return their own `IntoResponse`.
/// * Non-self-handled exceptions are logged and converted to a 500 `Problem`.
pub fn default_error_handler(err: Box<dyn RokException>) -> Response {
    if err.self_handled() {
        return err.into_response();
    }

    Problem::internal(err.to_string()).into_response()
}