use std::borrow::Cow;
use std::collections::HashSet;
use std::env;
use std::fmt::{self, Debug, Formatter};
use std::sync::{Arc, LazyLock};
use async_trait::async_trait;
use bytes::Bytes;
use mime::Mime;
use serde::Serialize;
use crate::handler::{Handler, WhenHoop};
use crate::http::mime::guess_accept_mime;
use crate::http::{Request, ResBody, Response, StatusCode, StatusError, header};
use crate::{Depot, FlowCtrl};
static SUPPORTED_FORMATS: LazyLock<Vec<mime::Name>> =
LazyLock::new(|| vec![mime::JSON, mime::HTML, mime::XML, mime::PLAIN]);
static STATUS_ERROR_SETS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
HashSet::from([
"force_detail",
"debug_detail",
"never_detail",
"force_cause",
"debug_cause",
"never_cause",
])
});
static PARSED_ENV_SETS: LazyLock<HashSet<String>> = LazyLock::new(|| {
env::var("SALVO_STATUS_ERROR")
.unwrap_or_default()
.split(',')
.filter_map(|s| {
let s = s.trim().to_lowercase();
if STATUS_ERROR_SETS.contains(s.as_str()) {
Some(s)
} else if s.is_empty() {
None
} else {
tracing::warn!("unknown SALVO_STATUS_ERROR option: {}", s);
None
}
})
.collect::<HashSet<_>>()
});
const SALVO_LINK: &str = r#"<a href="https://salvo.rs" target="_blank">salvo</a>"#;
pub struct Catcher {
goal: Arc<dyn Handler>,
hoops: Vec<Arc<dyn Handler>>,
}
impl Default for Catcher {
fn default() -> Self {
Self {
goal: Arc::new(DefaultGoal::new()),
hoops: vec![],
}
}
}
impl Debug for Catcher {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Catcher").finish()
}
}
impl Catcher {
pub fn new<H: Handler>(goal: H) -> Self {
Self {
goal: Arc::new(goal),
hoops: vec![],
}
}
#[inline]
#[must_use]
pub fn hoops(&self) -> &Vec<Arc<dyn Handler>> {
&self.hoops
}
#[inline]
pub fn hoops_mut(&mut self) -> &mut Vec<Arc<dyn Handler>> {
&mut self.hoops
}
#[inline]
#[must_use]
pub fn hoop<H: Handler>(mut self, hoop: H) -> Self {
self.hoops.push(Arc::new(hoop));
self
}
#[inline]
#[must_use]
pub fn hoop_when<H, F>(mut self, hoop: H, filter: F) -> Self
where
H: Handler,
F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static,
{
self.hoops.push(Arc::new(WhenHoop {
inner: hoop,
filter,
}));
self
}
pub async fn catch(&self, req: &mut Request, depot: &mut Depot, res: &mut Response) {
let mut ctrl = FlowCtrl::new(self.hoops.iter().chain([&self.goal]).cloned().collect());
ctrl.call_next(req, depot, res).await;
}
}
#[derive(Default, Debug)]
pub struct DefaultGoal {
footer: Option<Cow<'static, str>>,
}
impl DefaultGoal {
#[must_use]
pub fn new() -> Self {
Self { footer: None }
}
#[inline]
#[must_use]
pub fn with_footer(footer: impl Into<Cow<'static, str>>) -> Self {
Self::new().footer(footer)
}
#[must_use]
pub fn footer(mut self, footer: impl Into<Cow<'static, str>>) -> Self {
self.footer = Some(footer.into());
self
}
}
#[async_trait]
impl Handler for DefaultGoal {
async fn handle(
&self,
req: &mut Request,
_depot: &mut Depot,
res: &mut Response,
_ctrl: &mut FlowCtrl,
) {
let status = res.status_code.unwrap_or(StatusCode::NOT_FOUND);
if (status.is_server_error() || status.is_client_error())
&& (res.body.is_none() || res.body.is_error())
{
write_error_default(req, res, self.footer.as_deref());
}
}
}
fn status_error_html(
code: StatusCode,
name: &str,
brief: &str,
detail: Option<&str>,
cause: Option<&str>,
footer: Option<&str>,
) -> String {
format!(
r#"<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width">
<title>{0}: {1}</title>
<style>
:root {{
--bg-color: #fff;
--text-color: #222;
}}
body {{
background: var(--bg-color);
color: var(--text-color);
text-align: center;
}}
pre {{ text-align: left; padding: 0 1rem; }}
footer{{text-align:center;}}
@media (prefers-color-scheme: dark) {{
:root {{
--bg-color: #222;
--text-color: #ddd;
}}
a:link {{ color: red; }}
a:visited {{ color: #a8aeff; }}
a:hover {{color: #a8aeff;}}
a:active {{color: #a8aeff;}}
}}
</style>
</head>
<body>
<div><h1>{}: {}</h1><h3>{}</h3>{}{}<hr><footer>{}</footer></div>
</body>
</html>"#,
code.as_u16(),
name,
brief,
detail
.map(|detail| format!("<pre>{detail}</pre>"))
.unwrap_or_default(),
cause
.map(|cause| format!("<pre>{cause:#?}</pre>"))
.unwrap_or_default(),
footer.unwrap_or(SALVO_LINK)
)
}
#[inline]
fn status_error_json(
code: StatusCode,
name: &str,
brief: &str,
detail: Option<&str>,
cause: Option<&str>,
) -> String {
#[derive(Serialize)]
struct Data<'a> {
error: Error<'a>,
}
#[derive(Serialize)]
struct Error<'a> {
code: u16,
name: &'a str,
brief: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
detail: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
cause: Option<&'a str>,
}
let data = Data {
error: Error {
code: code.as_u16(),
name,
brief,
detail,
cause,
},
};
serde_json::to_string(&data).unwrap_or_default()
}
fn status_error_plain(
code: StatusCode,
name: &str,
brief: &str,
detail: Option<&str>,
cause: Option<&str>,
) -> String {
format!(
"code: {}\n\nname: {}\n\nbrief: {}{}{}",
code.as_u16(),
name,
brief,
detail
.map(|detail| format!("\n\ndetail: {detail}"))
.unwrap_or_default(),
cause
.map(|cause| format!("\n\ncause: {cause:#?}"))
.unwrap_or_default(),
)
}
fn status_error_xml(
code: StatusCode,
name: &str,
brief: &str,
detail: Option<&str>,
cause: Option<&str>,
) -> String {
#[derive(Serialize)]
struct Data<'a> {
code: u16,
name: &'a str,
brief: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
detail: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
cause: Option<&'a str>,
}
let data = Data {
code: code.as_u16(),
name,
brief,
detail,
cause,
};
serde_xml_rs::to_string(&data).unwrap_or_default()
}
#[doc(hidden)]
#[inline]
pub fn status_error_bytes(
err: &StatusError,
prefer_format: &Mime,
footer: Option<&str>,
) -> (Mime, Bytes) {
let format = if !SUPPORTED_FORMATS.contains(&prefer_format.subtype()) {
mime::TEXT_HTML
} else {
prefer_format.clone()
};
let env_sets = &*PARSED_ENV_SETS;
let detail = if !env_sets.contains("never_detail")
&& (env_sets.contains("force_detail")
|| (env_sets.contains("debug_detail") && cfg!(debug_assertions)))
{
err.detail.as_deref()
} else {
None
};
let cause = if !env_sets.contains("never_cause")
&& (env_sets.contains("force_cause")
|| (env_sets.contains("debug_cause") && cfg!(debug_assertions)))
{
err.cause.as_ref().map(|e| format!("{e:#?}"))
} else {
None
};
let content = match format.subtype().as_ref() {
"plain" => status_error_plain(err.code, &err.name, &err.brief, detail, cause.as_deref()),
"json" => status_error_json(err.code, &err.name, &err.brief, detail, cause.as_deref()),
"xml" => status_error_xml(err.code, &err.name, &err.brief, detail, cause.as_deref()),
_ => status_error_html(
err.code,
&err.name,
&err.brief,
detail,
cause.as_deref(),
footer,
),
};
(format, Bytes::from(content))
}
#[doc(hidden)]
pub fn write_error_default(req: &Request, res: &mut Response, footer: Option<&str>) {
let format = guess_accept_mime(req, None);
let (format, data) = if let ResBody::Error(body) = &res.body {
status_error_bytes(body, &format, footer)
} else {
let status = res.status_code.unwrap_or(StatusCode::NOT_FOUND);
status_error_bytes(
&StatusError::from_code(status).unwrap_or_else(StatusError::internal_server_error),
&format,
footer,
)
};
res.headers_mut().insert(
header::CONTENT_TYPE,
format.to_string().parse().expect("invalid `Content-Type`"),
);
let _ = res.write_body(data);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::*;
use crate::test::{ResponseExt, TestClient};
struct CustomError;
#[async_trait]
impl Writer for CustomError {
async fn write(self, _req: &mut Request, _depot: &mut Depot, res: &mut Response) {
res.status_code = Some(StatusCode::INTERNAL_SERVER_ERROR);
res.render("custom error");
}
}
#[handler]
async fn handle404(
&self,
_req: &Request,
_depot: &Depot,
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
if res.status_code.is_none() || Some(StatusCode::NOT_FOUND) == res.status_code {
res.render("Custom 404 Error Page");
ctrl.skip_rest();
}
}
#[tokio::test]
async fn test_handle_error() {
#[handler]
async fn handle_custom() -> Result<(), CustomError> {
Err(CustomError)
}
let router = Router::new().push(Router::with_path("custom").get(handle_custom));
let service = Service::new(router);
async fn access(service: &Service, name: &str) -> String {
TestClient::get(format!("http://127.0.0.1:8698/{name}"))
.send(service)
.await
.take_string()
.await
.unwrap()
}
assert_eq!(access(&service, "custom").await, "custom error");
}
#[tokio::test]
async fn test_custom_catcher() {
#[handler]
async fn hello() -> &'static str {
"Hello World"
}
let router = Router::new().get(hello);
let service = Service::new(router).catcher(Catcher::default().hoop(handle404));
async fn access(service: &Service, name: &str) -> String {
TestClient::get(format!("http://127.0.0.1:8698/{name}"))
.send(service)
.await
.take_string()
.await
.unwrap()
}
assert_eq!(access(&service, "notfound").await, "Custom 404 Error Page");
}
}