use std::future::Future;
use std::sync::Arc;
use async_trait::async_trait;
use crate::{Handler, MiddleWareHandler, Next, Request, Response, Result, State};
#[derive(Default, Clone)]
pub struct ExceptionHandler<F> {
handler: Arc<F>,
}
impl<F, Fut, T> ExceptionHandler<F>
where
Fut: Future<Output = Result<T>> + Send + 'static,
F: Fn(Result<Response>, State) -> Fut + Send + Sync + 'static,
T: Into<Response>,
{
pub fn new(handler: F) -> Self {
Self {
handler: Arc::new(handler),
}
}
}
#[async_trait]
impl<F, Fut, T> MiddleWareHandler for ExceptionHandler<F>
where
Fut: Future<Output = Result<T>> + Send + 'static,
F: Fn(Result<Response>, State) -> Fut + Send + Sync + 'static,
T: Into<Response>,
{
async fn handle(&self, req: Request, next: &Next) -> Result<Response> {
let state = req.state();
self.handler.clone()(next.call(req).await, state)
.await
.map(|r| r.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_exception_handler_new() {
let handler = ExceptionHandler::new(|result: Result<Response>, _configs| async {
match result {
Ok(resp) => Ok(resp),
Err(_) => Ok(Response::text("error")),
}
});
let _ = handler;
}
#[test]
fn test_exception_handler_new_identity() {
let handler = ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let _ = handler;
}
#[test]
fn test_exception_handler_new_always_success() {
let handler = ExceptionHandler::new(|result: Result<Response>, _configs| async {
match result {
Ok(resp) => Ok(resp),
Err(_) => Ok(Response::text("caught error")),
}
});
let _ = handler;
}
#[test]
fn test_exception_handler_clone() {
let handler1 = ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let handler2 = handler1.clone();
let _ = handler1;
let _ = handler2;
}
#[test]
fn test_exception_handler_clone_independent() {
let handler1 = ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let handler2 = handler1.clone();
let _ = handler1;
let _ = handler2;
}
#[test]
fn test_exception_handler_type() {
let handler = ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let _handler: ExceptionHandler<_> = handler;
}
#[test]
fn test_exception_handler_size() {
use std::mem::size_of_val;
let handler = ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let size = size_of_val(&handler);
assert!(size > 0);
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_exception_handler_with_success_response() {
use crate::route::Route;
let exception_handler = ExceptionHandler::new(|result: Result<Response>, _configs| async {
result
});
let route = Route::new("/")
.hook(exception_handler)
.get(|_req: Request| async { Ok("success") });
let route = Route::new_root().append(route);
let req = Request::empty();
let result: Result<Response> = route.call(req).await;
assert!(result.is_ok());
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_exception_handler_catches_error() {
use crate::route::Route;
let exception_handler = ExceptionHandler::new(|result: Result<Response>, _configs| async {
match result {
Ok(resp) => Ok(resp),
Err(_) => Ok(Response::text("error was caught")),
}
});
let route = Route::new("/")
.hook(exception_handler)
.get(|_req: Request| async {
Err::<&str, _>(crate::SilentError::business_error(
http::StatusCode::INTERNAL_SERVER_ERROR,
"test error".to_string(),
))
});
let route = Route::new_root().append(route);
let req = Request::empty();
let result: Result<Response> = route.call(req).await;
assert!(result.is_ok());
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_exception_handler_modifies_error_response() {
use crate::route::Route;
let exception_handler =
ExceptionHandler::new(|result: Result<Response>, _configs| async move {
match result {
Ok(resp) => Ok(resp),
Err(e) => {
let mut resp = Response::text(&format!("Error: {}", e.message()));
resp.set_status(http::StatusCode::BAD_GATEWAY);
Ok(resp)
}
}
});
let route = Route::new("/")
.hook(exception_handler)
.get(|_req: Request| async {
Err::<&str, _>(crate::SilentError::business_error(
http::StatusCode::INTERNAL_SERVER_ERROR,
"original error".to_string(),
))
});
let route = Route::new_root().append(route);
let req = Request::empty();
let result: Result<Response> = route.call(req).await;
assert!(result.is_ok());
let resp = result.unwrap();
assert_eq!(resp.status, http::StatusCode::BAD_GATEWAY);
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_exception_handler_preserves_success() {
use crate::route::Route;
let exception_handler = ExceptionHandler::new(|result: Result<Response>, _configs| async {
result
});
let route = Route::new("/")
.hook(exception_handler)
.get(|_req: Request| async {
let mut resp = Response::text("success");
resp.set_status(http::StatusCode::OK);
Ok(resp)
});
let route = Route::new_root().append(route);
let req = Request::empty();
let result: Result<Response> = route.call(req).await;
assert!(result.is_ok());
let resp = result.unwrap();
assert_eq!(resp.status, http::StatusCode::OK);
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_exception_handler_with_into_response() {
use crate::route::Route;
let exception_handler =
ExceptionHandler::new(|result: Result<Response>, _configs| async move {
match result {
Ok(_) => Ok("converted to response"), Err(_) => Ok("error converted"),
}
});
let route = Route::new("/")
.hook(exception_handler)
.get(|_req: Request| async { Ok("original") });
let route = Route::new_root().append(route);
let req = Request::empty();
let result: Result<Response> = route.call(req).await;
assert!(result.is_ok());
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_exception_handler_concurrent() {
use crate::route::Route;
let exception_handler =
ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let route = Route::new("/")
.hook(exception_handler)
.get(|_req: Request| async { Ok("concurrent") });
let route: Arc<Route> = Arc::new(Route::new_root().append(route));
let tasks = (0..5)
.map(|_| {
let route = Arc::clone(&route);
tokio::spawn(async move {
let req = Request::empty();
let result: Result<Response> = route.call(req).await;
result
})
})
.collect::<Vec<_>>();
for task in tasks {
let result = task.await.unwrap();
assert!(result.is_ok());
}
}
#[test]
fn test_exception_handler_arc_shared() {
let handler = ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let handler1 = handler.clone();
let handler2 = handler.clone();
let _ = handler1;
let _ = handler2;
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_exception_handler_empty_response() {
use crate::route::Route;
let exception_handler =
ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let route = Route::new("/")
.hook(exception_handler)
.get(|_req: Request| async { Ok(Response::empty()) });
let route = Route::new_root().append(route);
let req = Request::empty();
let result: Result<Response> = route.call(req).await;
assert!(result.is_ok());
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_exception_handler_chain_multiple() {
use crate::route::Route;
let handler1 = ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let handler2 = ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let route = Route::new("/")
.hook(handler1)
.hook(handler2)
.get(|_req: Request| async { Ok("chained") });
let route = Route::new_root().append(route);
let req = Request::empty();
let result: Result<Response> = route.call(req).await;
assert!(result.is_ok());
}
#[cfg(feature = "server")]
#[tokio::test]
async fn test_exception_handler_different_http_methods() {
use crate::route::Route;
let exception_handler =
ExceptionHandler::new(|result: Result<Response>, _configs| async { result });
let route = Route::new("/")
.hook(exception_handler)
.get(|_req: Request| async { Ok("GET") })
.post(|_req: Request| async { Ok("POST") });
let route = Route::new_root().append(route);
let mut req = Request::empty();
*req.method_mut() = http::Method::GET;
let result: Result<Response> = route.call(req).await;
assert!(result.is_ok());
let mut req = Request::empty();
*req.method_mut() = http::Method::POST;
let result: Result<Response> = route.call(req).await;
assert!(result.is_ok());
}
}