use std::collections::HashMap;
use http::header::CONTENT_TYPE;
use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
use tracing::debug;
pub struct ErrorPagesHoop {
pages: HashMap<u16, ErrorPage>,
}
struct ErrorPage {
body: String,
content_type: String,
}
impl ErrorPagesHoop {
pub fn new(pages: HashMap<u16, String>) -> Self {
let pages = pages
.into_iter()
.map(|(code, body)| {
let content_type = if body.trim_start().starts_with('<') {
"text/html; charset=utf-8".to_string()
} else {
"text/plain; charset=utf-8".to_string()
};
(code, ErrorPage { body, content_type })
})
.collect();
Self { pages }
}
}
#[async_trait]
impl salvo::Handler for ErrorPagesHoop {
async fn handle(
&self,
req: &mut Request,
depot: &mut Depot,
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
ctrl.call_next(req, depot, res).await;
let status = match res.status_code {
Some(s) => s.as_u16(),
None => return,
};
if let Some(page) = self.pages.get(&status) {
debug!(status, "serving custom error page");
res.headers_mut().insert(
CONTENT_TYPE,
page.content_type.parse().unwrap_or_else(|_| {
http::HeaderValue::from_static("text/plain; charset=utf-8")
}),
);
res.headers_mut().insert(
http::header::CONTENT_LENGTH,
http::HeaderValue::from(page.body.len()),
);
res.body(page.body.clone());
}
}
}