Skip to main content

gatel_core/hoops/
error_pages.rs

1use std::collections::HashMap;
2
3use http::header::CONTENT_TYPE;
4use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
5use tracing::debug;
6
7/// Custom error page middleware.
8///
9/// When the downstream handler produces a response with a status code that
10/// matches one of the configured error pages, the response body is replaced
11/// with the custom content. This is applied after the inner chain completes.
12pub struct ErrorPagesHoop {
13    /// Maps HTTP status code to the replacement body.
14    pages: HashMap<u16, ErrorPage>,
15}
16
17struct ErrorPage {
18    body: String,
19    content_type: String,
20}
21
22impl ErrorPagesHoop {
23    /// Create from a map of status code → body string.
24    ///
25    /// Content type is auto-detected: if the body starts with `<` it is
26    /// assumed to be HTML, otherwise plain text.
27    pub fn new(pages: HashMap<u16, String>) -> Self {
28        let pages = pages
29            .into_iter()
30            .map(|(code, body)| {
31                let content_type = if body.trim_start().starts_with('<') {
32                    "text/html; charset=utf-8".to_string()
33                } else {
34                    "text/plain; charset=utf-8".to_string()
35                };
36                (code, ErrorPage { body, content_type })
37            })
38            .collect();
39        Self { pages }
40    }
41}
42
43#[async_trait]
44impl salvo::Handler for ErrorPagesHoop {
45    async fn handle(
46        &self,
47        req: &mut Request,
48        depot: &mut Depot,
49        res: &mut Response,
50        ctrl: &mut FlowCtrl,
51    ) {
52        ctrl.call_next(req, depot, res).await;
53
54        let status = match res.status_code {
55            Some(s) => s.as_u16(),
56            None => return,
57        };
58
59        if let Some(page) = self.pages.get(&status) {
60            debug!(status, "serving custom error page");
61            res.headers_mut().insert(
62                CONTENT_TYPE,
63                page.content_type.parse().unwrap_or_else(|_| {
64                    http::HeaderValue::from_static("text/plain; charset=utf-8")
65                }),
66            );
67            res.headers_mut().insert(
68                http::header::CONTENT_LENGTH,
69                http::HeaderValue::from(page.body.len()),
70            );
71            res.body(page.body.clone());
72        }
73    }
74}