use std::sync::{Arc, Mutex};
use axum::Router;
use axum::body::Body;
use axum::http::{Method, Request, StatusCode, header};
use axum::response::IntoResponse;
use axum::routing::get;
use tower::ServiceExt;
use umbral_core::errors::{
DEFAULT_404_TEMPLATE_NAME, DEFAULT_500_TEMPLATE_NAME, ServerErrorHook, collect_error_chain,
fire_server_error_hook, not_found_fallback, render_not_found, server_error_panic_handler,
};
async fn oneshot(router: Router, method: Method, path: &str) -> axum::http::Response<Body> {
let req = Request::builder()
.method(method)
.uri(path)
.body(Body::empty())
.unwrap();
router.oneshot(req).await.unwrap()
}
async fn read_body(resp: axum::http::Response<Body>) -> (StatusCode, String) {
let status = resp.status();
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
(status, String::from_utf8_lossy(&bytes).to_string())
}
#[tokio::test]
async fn handler_returning_err_produces_500() {
let router = Router::new().route(
"/fail",
get(|| async { (StatusCode::INTERNAL_SERVER_ERROR, "handler error body").into_response() }),
);
let resp = oneshot(router, Method::GET, "/fail").await;
let (status, body) = read_body(resp).await;
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(body, "handler error body");
}
#[tokio::test]
async fn panic_in_handler_produces_500_not_abort() {
let handler = server_error_panic_handler(None, None);
let router = Router::new()
.route(
"/panic",
get(|| async {
panic!("gap-35 panic test");
#[allow(unreachable_code)]
""
}),
)
.layer(tower_http::catch_panic::CatchPanicLayer::custom(handler));
let resp = oneshot(router, Method::GET, "/panic").await;
let (status, _body) = read_body(resp).await;
assert_eq!(
status,
StatusCode::INTERNAL_SERVER_ERROR,
"panicking handler must produce 500, not an abort"
);
}
#[tokio::test]
async fn on_server_error_hook_fires_on_panic() {
let fired: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(vec![]));
let fired_clone = Arc::clone(&fired);
let hook: ServerErrorHook = Arc::new(move |err, _path| {
fired_clone.lock().unwrap().push(err.to_string());
});
let handler = server_error_panic_handler(None, Some(hook));
let router = Router::new()
.route(
"/panic",
get(|| async {
panic!("hook-test panic");
#[allow(unreachable_code)]
""
}),
)
.layer(tower_http::catch_panic::CatchPanicLayer::custom(handler));
let resp = oneshot(router, Method::GET, "/panic").await;
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
let calls = fired.lock().unwrap();
assert_eq!(calls.len(), 1, "hook must fire exactly once");
assert!(
calls[0].contains("hook-test panic"),
"hook receives the panic message; got: {:?}",
calls[0]
);
}
#[test]
fn fire_server_error_hook_calls_hook_when_set() {
let fired: Arc<Mutex<Vec<(String, String)>>> = Arc::new(Mutex::new(vec![]));
let fired_clone = Arc::clone(&fired);
let hook: ServerErrorHook = Arc::new(move |err, path| {
fired_clone
.lock()
.unwrap()
.push((err.to_string(), path.to_string()));
});
fire_server_error_hook(&Some(hook), "boom", "/api/items");
let calls = fired.lock().unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].0, "boom");
assert_eq!(calls[0].1, "/api/items");
}
#[test]
fn fire_server_error_hook_is_silent_when_none() {
fire_server_error_hook(&None, "boom", "/");
}
#[tokio::test]
async fn panic_handler_returns_500_status_with_no_template() {
let handler = server_error_panic_handler(None, None);
let router = Router::new()
.route(
"/boom",
get(|| async {
panic!("test");
#[allow(unreachable_code)]
""
}),
)
.layer(tower_http::catch_panic::CatchPanicLayer::custom(handler));
let resp = oneshot(router, Method::GET, "/boom").await;
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn build_500_context_shows_chain_in_dev_mode() {
use umbral_core::errors::collect_error_chain;
let chain = collect_error_chain("outer error", None);
assert_eq!(chain, vec!["outer error"]);
}
#[test]
fn collect_error_chain_walks_source_chain() {
use std::error::Error;
#[derive(Debug)]
struct Inner;
impl std::fmt::Display for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "inner cause")
}
}
impl Error for Inner {}
#[derive(Debug)]
struct Outer(Inner);
impl std::fmt::Display for Outer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "outer error")
}
}
impl Error for Outer {
fn source(&self) -> Option<&(dyn Error + 'static)> {
Some(&self.0)
}
}
let err = Outer(Inner);
let chain = collect_error_chain(&err.to_string(), err.source());
assert_eq!(chain, vec!["outer error", "inner cause"]);
}
#[tokio::test]
async fn not_found_fallback_returns_404_status() {
let router = Router::new()
.route("/existing", get(|| async { "ok" }))
.fallback(not_found_fallback(None));
let resp = oneshot(router, Method::GET, "/does-not-exist").await;
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn not_found_fallback_passes_matched_routes_through() {
let router = Router::new()
.route("/existing", get(|| async { "found" }))
.fallback(not_found_fallback(None));
let resp = oneshot(router, Method::GET, "/existing").await;
let (status, body) = read_body(resp).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "found");
}
#[test]
fn render_not_found_with_none_returns_correct_status() {
let resp = render_not_found(None, "/missing-page");
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let ct = resp.headers().get(header::CONTENT_TYPE).unwrap();
let ct_str = ct.to_str().unwrap();
assert!(
ct_str.starts_with("text/plain") || ct_str.starts_with("text/html"),
"unexpected content-type: {ct_str}"
);
}
#[test]
fn default_template_name_constants_have_reserved_prefix() {
assert!(
DEFAULT_404_TEMPLATE_NAME.starts_with("__umbral__/"),
"404 template must use __umbral__/ prefix to avoid collisions"
);
assert!(
DEFAULT_500_TEMPLATE_NAME.starts_with("__umbral__/"),
"500 template must use __umbral__/ prefix to avoid collisions"
);
}
#[test]
fn server_error_hook_can_be_cloned() {
let hook: ServerErrorHook = Arc::new(|_err, _path| {});
let hook2 = Arc::clone(&hook);
hook("error", "/");
hook2("error2", "/path");
}
#[tokio::test]
async fn render_500_middleware_re_renders_plain_text_500_as_template() {
use umbral_core::errors::{Render500State, render_500_middleware};
let hook_fired: Arc<Mutex<Option<(String, String)>>> = Arc::new(Mutex::new(None));
let hook_fired_clone = Arc::clone(&hook_fired);
let hook: ServerErrorHook = Arc::new(move |err, path| {
*hook_fired_clone.lock().unwrap() = Some((err.to_string(), path.to_string()));
});
let state = Render500State {
template: None, hook: Some(hook),
};
let router = Router::new()
.route(
"/fail",
get(|| async {
(
StatusCode::INTERNAL_SERVER_ERROR,
"umbral templates: invalid operation",
)
.into_response()
}),
)
.layer(axum::middleware::from_fn_with_state(
state,
render_500_middleware,
));
let resp = oneshot(router, Method::GET, "/fail").await;
let status = resp.status();
let ct = resp
.headers()
.get(header::CONTENT_TYPE)
.map(|v| v.to_str().unwrap().to_string())
.unwrap_or_default();
let (_, body) = read_body(resp).await;
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
let _ = ct;
let _ = body;
let fired = hook_fired.lock().unwrap();
let (err, path) = fired.as_ref().expect("on_server_error hook should fire");
assert!(
err.contains("umbral templates: invalid operation"),
"hook got the handler-Err body as the error message; got: {err}"
);
assert_eq!(path, "/fail");
}
#[tokio::test]
async fn render_500_middleware_passes_html_500_through() {
use umbral_core::errors::{Render500State, render_500_middleware};
let state = Render500State {
template: None,
hook: None,
};
let already_html = "<!DOCTYPE html><h1>my own 500</h1>";
let router = Router::new()
.route(
"/fail",
get(move || async move {
axum::response::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(Body::from(already_html))
.unwrap()
}),
)
.layer(axum::middleware::from_fn_with_state(
state,
render_500_middleware,
));
let resp = oneshot(router, Method::GET, "/fail").await;
let (status, body) = read_body(resp).await;
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
body, already_html,
"HTML 500s must pass through the middleware unchanged"
);
}
#[tokio::test]
async fn render_500_middleware_leaves_non_500_responses_alone() {
use umbral_core::errors::{Render500State, render_500_middleware};
let state = Render500State {
template: None,
hook: None,
};
let router = Router::new()
.route("/ok", get(|| async { "ok body".to_string() }))
.layer(axum::middleware::from_fn_with_state(
state,
render_500_middleware,
));
let resp = oneshot(router, Method::GET, "/ok").await;
let (status, body) = read_body(resp).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "ok body");
}