use std::any::Any;
use axum::body::Body;
use axum::http::{Request, Response, StatusCode, header};
use axum::response::IntoResponse;
use minijinja::context;
pub const DEFAULT_404_HTML: &str = include_str!("templates/defaults/default_404.html");
pub const DEFAULT_500_HTML: &str = include_str!("templates/defaults/default_500.html");
pub const DEFAULT_404_TEMPLATE_NAME: &str = "__umbral__/default_404.html";
pub const DEFAULT_500_TEMPLATE_NAME: &str = "__umbral__/default_500.html";
pub type ServerErrorHook = std::sync::Arc<dyn Fn(&str, &str) + Send + Sync + 'static>;
use std::sync::OnceLock;
static DEFAULT_PAGES_ENABLED: OnceLock<bool> = OnceLock::new();
pub(crate) fn init_default_pages(enabled: bool) {
let _ = DEFAULT_PAGES_ENABLED.set(enabled);
}
pub(crate) fn default_pages_enabled() -> bool {
*DEFAULT_PAGES_ENABLED.get().unwrap_or(&true)
}
pub fn render_not_found(template: Option<&str>, path: &str) -> Response<Body> {
let effective_template = template.or_else(|| {
if default_pages_enabled() {
Some(DEFAULT_404_TEMPLATE_NAME)
} else {
None
}
});
let dev_mode = crate::settings::get_opt()
.map(|s| matches!(s.environment, crate::settings::Environment::Dev))
.unwrap_or(false);
let routes_ctx: Vec<minijinja::Value> = if dev_mode {
crate::routes::get()
.map(|reg| {
reg.by_plugin
.iter()
.filter(|(_, specs)| !specs.is_empty())
.map(|(plugin, specs)| {
let routes: Vec<minijinja::Value> = specs
.iter()
.map(|s| {
let method_label = if s.methods.is_empty() {
"ANY".to_string()
} else {
s.methods.join("·")
};
minijinja::context! {
path => s.path.as_str(),
methods => s.methods.clone(),
method_label => method_label,
}
})
.collect();
minijinja::context! {
plugin => plugin.as_str(),
routes => routes,
}
})
.collect()
})
.unwrap_or_default()
} else {
Vec::new()
};
let ctx = context! {
path => path,
dev_mode => dev_mode,
routes_by_plugin => routes_ctx,
};
let (body, content_type) = effective_template
.and_then(|name| match crate::templates::render(name, &ctx) {
Ok(html) => Some(html),
Err(e) => {
tracing::warn!(
"error-page template `{name}` failed to render ({e}); \
falling back to plain text"
);
None
}
})
.map(|html| (html, "text/html; charset=utf-8"))
.unwrap_or_else(|| ("Not Found".to_string(), "text/plain; charset=utf-8"));
let mut response = Response::new(Body::from(body));
*response.status_mut() = StatusCode::NOT_FOUND;
response.headers_mut().insert(
header::CONTENT_TYPE,
content_type.parse().expect("valid content-type"),
);
response
}
pub fn not_found_fallback(
template: Option<String>,
) -> impl Fn(
Request<Body>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response<Body>> + Send>>
+ Clone
+ Send
+ Sync
+ 'static {
move |req: Request<Body>| {
let template = template.clone();
Box::pin(async move {
let path = req.uri().path().to_owned();
render_not_found(template.as_deref(), &path)
})
}
}
pub fn collect_error_chain(top: &str, mut source: Option<&dyn std::error::Error>) -> Vec<String> {
let mut chain = vec![top.to_owned()];
while let Some(cause) = source {
chain.push(cause.to_string());
source = cause.source();
}
chain
}
fn is_dev_mode() -> bool {
crate::settings::SETTINGS
.get()
.map(|s| matches!(s.environment, crate::settings::Environment::Dev))
.unwrap_or(false)
}
fn build_500_context(
error_display: &str,
error_chain: &[String],
request_path: &str,
dev: bool,
) -> minijinja::Value {
if dev {
context! {
dev_mode => true,
error_display => error_display,
error_chain => error_chain,
request_path => request_path,
}
} else {
context! {
dev_mode => false,
error_display => "",
error_chain => Vec::<String>::new(),
request_path => "",
}
}
}
fn render_500(template: Option<&str>, ctx: &minijinja::Value) -> (String, &'static str) {
let effective = template.or_else(|| {
if default_pages_enabled() {
Some(DEFAULT_500_TEMPLATE_NAME)
} else {
None
}
});
let Some(name) = effective else {
return (
"Internal Server Error".to_string(),
"text/plain; charset=utf-8",
);
};
match crate::templates::render(name, ctx) {
Ok(html) => (html, "text/html; charset=utf-8"),
Err(secondary) => {
tracing::error!(
template = %name,
error = %secondary,
"render_500: secondary template render failed; the configured \
server-error template can't render itself. Likely a broken \
`{{% extends \"wrapper.html\" %}}` chain. Falling back to \
plain text.",
);
if is_dev_mode() {
let body = format!(
"Internal Server Error\n\n\
(dev) The configured 500 template `{name}` itself failed \
to render: {secondary}\n\n\
Check the original handler error in the server logs \
(line above this one) for the trigger."
);
(body, "text/plain; charset=utf-8")
} else {
(
"Internal Server Error".to_string(),
"text/plain; charset=utf-8",
)
}
}
}
}
pub fn server_error_panic_handler(
template: Option<String>,
hook: Option<ServerErrorHook>,
) -> impl Fn(Box<dyn Any + Send + 'static>) -> Response<Body> + Clone + Send + Sync + 'static {
move |err: Box<dyn Any + Send + 'static>| {
let panic_message = if let Some(s) = err.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
};
tracing::error!(
panic_message = %panic_message,
"handler panicked; serving 500 page",
);
if let Some(ref h) = hook {
h(&panic_message, "");
}
let dev = is_dev_mode();
let chain = vec![panic_message.clone()];
let ctx = build_500_context(&panic_message, &chain, "", dev);
let (body, content_type) = render_500(template.as_deref(), &ctx);
(
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, content_type)],
body,
)
.into_response()
}
}
pub fn fire_server_error_hook(hook: &Option<ServerErrorHook>, error_msg: &str, path: &str) {
if let Some(h) = hook {
h(error_msg, path);
}
}
#[derive(Clone)]
pub struct Render500State {
pub template: Option<String>,
pub hook: Option<ServerErrorHook>,
}
pub async fn render_500_middleware(
axum::extract::State(state): axum::extract::State<Render500State>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Response<Body> {
let path = req.uri().path().to_string();
let resp = next.run(req).await;
if resp.status() != StatusCode::INTERNAL_SERVER_ERROR {
return resp;
}
let ct = resp
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if ct.starts_with("text/html") {
return resp;
}
let (_parts, body) = resp.into_parts();
let bytes = axum::body::to_bytes(body, 64 * 1024)
.await
.unwrap_or_default();
let error_msg = String::from_utf8_lossy(&bytes).to_string();
tracing::error!(
error = %error_msg,
path = %path,
"handler returned 500; rendering server-error template",
);
fire_server_error_hook(&state.hook, &error_msg, &path);
let dev = is_dev_mode();
let chain = vec![error_msg.clone()];
let ctx = build_500_context(&error_msg, &chain, &path, dev);
let (body_str, content_type) = render_500(state.template.as_deref(), &ctx);
(
StatusCode::INTERNAL_SERVER_ERROR,
[(header::CONTENT_TYPE, content_type)],
body_str,
)
.into_response()
}
#[derive(Clone)]
pub struct RenderErrorState {
pub templates: std::sync::Arc<std::collections::HashMap<StatusCode, String>>,
}
pub async fn render_error_middleware(
axum::extract::State(state): axum::extract::State<RenderErrorState>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Response<Body> {
let path = req.uri().path().to_string();
let wants_json = req
.headers()
.get(header::ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|a| a.contains("application/json"))
.unwrap_or(false);
let resp = next.run(req).await;
let status = resp.status();
let Some(template) = state.templates.get(&status).cloned() else {
return resp;
};
if wants_json {
return resp;
}
let ct = resp
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if ct.starts_with("text/html") {
return resp;
}
let (_parts, body) = resp.into_parts();
let bytes = axum::body::to_bytes(body, 64 * 1024)
.await
.unwrap_or_default();
let message = String::from_utf8_lossy(&bytes).to_string();
let ctx = error_context(status, &message, &path, is_dev_mode());
let (body_str, content_type) = render_error_page(&template, status, &ctx);
(status, [(header::CONTENT_TYPE, content_type)], body_str).into_response()
}
fn error_context(status: StatusCode, message: &str, path: &str, dev: bool) -> minijinja::Value {
minijinja::context! {
status => status.as_u16(),
status_text => status.canonical_reason().unwrap_or(""),
message => message,
request_path => path,
dev_mode => dev,
}
}
fn render_error_page(
template: &str,
status: StatusCode,
ctx: &minijinja::Value,
) -> (String, &'static str) {
match crate::templates::render(template, ctx) {
Ok(html) => (html, "text/html; charset=utf-8"),
Err(secondary) => {
tracing::error!(
template = %template,
status = %status.as_u16(),
error = %secondary,
"render_error_page: the configured error template failed to render; \
falling back to plain text",
);
let reason = status.canonical_reason().unwrap_or("Error");
(reason.to_string(), "text/plain; charset=utf-8")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn error_context_carries_status_reason_message_and_path() {
let ctx = error_context(
StatusCode::TOO_MANY_REQUESTS,
"slow down",
"/p/notes",
false,
);
let mut env = minijinja::Environment::new();
env.add_template(
"t",
"{{ status }}|{{ status_text }}|{{ message }}|{{ request_path }}|{{ dev_mode }}",
)
.unwrap();
let out = env.get_template("t").unwrap().render(ctx).unwrap();
assert_eq!(out, "429|Too Many Requests|slow down|/p/notes|false");
}
#[test]
fn render_error_page_falls_back_to_plain_text_when_template_cant_render() {
let ctx = error_context(StatusCode::TOO_MANY_REQUESTS, "msg", "/x", false);
let (body, ct) = render_error_page("nonexistent.html", StatusCode::TOO_MANY_REQUESTS, &ctx);
assert!(ct.starts_with("text/plain"), "content-type: {ct}");
assert_eq!(body, "Too Many Requests");
}
#[test]
fn render_not_found_returns_plain_text_when_no_template() {
let resp = render_not_found(None, "/missing");
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let ct = resp.headers().get(header::CONTENT_TYPE).unwrap();
assert!(ct.to_str().unwrap().starts_with("text/plain"));
}
#[test]
fn render_not_found_falls_back_to_plain_text_when_template_missing() {
let resp = render_not_found(Some("nonexistent.html"), "/x");
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
#[test]
fn default_404_renders_route_panel_when_dev_mode_and_registry_populated() {
let mut env = minijinja::Environment::new();
env.set_auto_escape_callback(|_| minijinja::AutoEscape::Html);
env.add_template("default_404.html", DEFAULT_404_HTML)
.unwrap();
let ctx = minijinja::context! {
path => "/typo",
dev_mode => true,
routes_by_plugin => serde_json::json!([
{
"plugin": "app",
"routes": [
{ "path": "/", "methods": ["GET"], "method_label": "GET" },
{ "path": "/articles", "methods": ["GET","POST"], "method_label": "GET·POST" },
],
},
{
"plugin": "admin",
"routes": [
{ "path": "/admin/", "methods": ["GET"], "method_label": "GET" },
{ "path": "/admin/login", "methods": ["GET","POST"], "method_label": "GET·POST" },
],
},
]),
};
let out = env
.get_template("default_404.html")
.unwrap()
.render(&ctx)
.unwrap();
assert!(
out.contains("Dev only"),
"dev-mode panel header should be in the output"
);
assert!(
out.contains("/admin/login"),
"admin route should be listed: {out}"
);
assert!(
out.contains("/articles"),
"app route should be listed: {out}"
);
assert!(
out.contains("GET·POST"),
"composite-method badge label should render: {out}"
);
assert!(
out.contains("emerald"),
"GET badge should carry the emerald tint class"
);
}
#[test]
fn default_404_omits_route_panel_when_dev_mode_is_off() {
let mut env = minijinja::Environment::new();
env.set_auto_escape_callback(|_| minijinja::AutoEscape::Html);
env.add_template("default_404.html", DEFAULT_404_HTML)
.unwrap();
let ctx = minijinja::context! {
path => "/typo",
dev_mode => false,
routes_by_plugin => Vec::<minijinja::Value>::new(),
};
let out = env
.get_template("default_404.html")
.unwrap()
.render(&ctx)
.unwrap();
assert!(
!out.contains("Dev only"),
"production response must not surface the route registry"
);
}
#[test]
fn collect_error_chain_single_level() {
let chain = collect_error_chain("top error", None);
assert_eq!(chain, vec!["top error"]);
}
#[test]
fn build_500_context_prod_mode_has_empty_fields() {
let ctx = build_500_context("boom", &["boom".to_owned()], "/path", false);
let json = serde_json::to_value(&ctx).expect("context serialises");
assert_eq!(json["dev_mode"], serde_json::Value::Bool(false));
assert_eq!(
json["error_display"],
serde_json::Value::String("".to_string())
);
}
#[test]
fn build_500_context_dev_mode_has_error_info() {
let chain = vec!["cause one".to_owned(), "cause two".to_owned()];
let ctx = build_500_context("top error", &chain, "/api/items", true);
let json = serde_json::to_value(&ctx).expect("context serialises");
assert_eq!(json["dev_mode"], serde_json::Value::Bool(true));
assert_eq!(
json["error_display"],
serde_json::Value::String("top error".to_string())
);
let arr = json["error_chain"]
.as_array()
.expect("error_chain is array");
assert_eq!(arr.len(), 2);
}
}