use std::sync::Arc;
use axum::{
body::Body,
http::Request,
response::{IntoResponse, Response},
};
use tower::{Layer, Service};
use crate::error::RokException;
use crate::Problem;
pub struct CatchLayer<F> {
handler: Arc<F>,
}
impl<F> CatchLayer<F> {
pub fn new(handler: F) -> Self {
Self {
handler: Arc::new(handler),
}
}
}
impl<F> Clone for CatchLayer<F> {
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
}
}
}
impl<S, F> Layer<S> for CatchLayer<F>
where
F: Fn(Box<dyn RokException>) -> Response + Send + Sync + 'static,
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send,
{
type Service = CatchService<S, F>;
fn layer(&self, inner: S) -> Self::Service {
CatchService {
inner,
handler: self.handler.clone(),
}
}
}
pub struct CatchService<S, F> {
inner: S,
handler: Arc<F>,
}
impl<S: Clone, F> Clone for CatchService<S, F> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
handler: self.handler.clone(),
}
}
}
impl<S, F> Service<Request<Body>> for CatchService<S, F>
where
F: Fn(Box<dyn RokException>) -> Response + Send + Sync + 'static,
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Error: Into<Box<dyn RokException>>,
S::Future: Send,
{
type Response = Response;
type Error = std::convert::Infallible;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|_| unreachable!())
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let handler = self.handler.clone();
let fut = self.inner.call(req);
Box::pin(async move {
match fut.await {
Ok(resp) => Ok(resp),
Err(err) => {
let rok_err: Box<dyn RokException> = err.into();
Ok(handler(rok_err))
}
}
})
}
}
pub fn default_error_handler(err: Box<dyn RokException>) -> Response {
if err.self_handled() {
return err.into_response();
}
Problem::internal(err.to_string()).into_response()
}