by_loco/controller/middleware/
fallback.rs

1//! Fallback Middleware
2//!
3//! This middleware handles fallback logic for the application when routes do
4//! not match. It serves a file, a custom not-found message, or a default HTML
5//! fallback page based on the configuration.
6
7use axum::{http::StatusCode, response::Html, Router as AXRouter};
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9use serde_json::json;
10use tower_http::services::ServeFile;
11
12use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
13
14#[derive(Debug)]
15pub struct StatusCodeWrapper(pub StatusCode);
16
17#[derive(Debug, Clone, Deserialize, Serialize)]
18pub struct Fallback {
19    /// By default when enabled, returns a prebaked 404 not found page optimized
20    /// for development. For production set something else (see fields below)
21    #[serde(default)]
22    pub enable: bool,
23    /// For the unlikely reason to return something different than `404`, you
24    /// can set it here
25    #[serde(
26        default = "default_status_code",
27        serialize_with = "serialize_status_code",
28        deserialize_with = "deserialize_status_code"
29    )]
30    pub code: StatusCode,
31    /// Returns content from a file pointed to by this field with a `404` status
32    /// code.
33    pub file: Option<String>,
34    /// Returns a "404 not found" with a single message string. This sets the
35    /// message.
36    pub not_found: Option<String>,
37}
38
39fn default_status_code() -> StatusCode {
40    StatusCode::OK
41}
42
43impl Default for Fallback {
44    fn default() -> Self {
45        serde_json::from_value(json!({})).unwrap()
46    }
47}
48
49fn deserialize_status_code<'de, D>(de: D) -> Result<StatusCode, D::Error>
50where
51    D: Deserializer<'de>,
52{
53    let code: u16 = Deserialize::deserialize(de)?;
54    StatusCode::from_u16(code).map_or_else(
55        |_| {
56            Err(serde::de::Error::invalid_value(
57                serde::de::Unexpected::Unsigned(u64::from(code)),
58                &"a value between 100 and 600",
59            ))
60        },
61        Ok,
62    )
63}
64
65#[allow(clippy::trivially_copy_pass_by_ref)]
66fn serialize_status_code<S>(status: &StatusCode, ser: S) -> Result<S::Ok, S::Error>
67where
68    S: Serializer,
69{
70    ser.serialize_u16(status.as_u16())
71}
72impl MiddlewareLayer for Fallback {
73    /// Returns the name of the middleware
74    fn name(&self) -> &'static str {
75        "fallback"
76    }
77
78    /// Returns whether the middleware is enabled or not
79    fn is_enabled(&self) -> bool {
80        self.enable
81    }
82
83    fn config(&self) -> serde_json::Result<serde_json::Value> {
84        serde_json::to_value(self)
85    }
86
87    /// Applies the fallback middleware to the application router.
88    fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
89        let app = if let Some(path) = &self.file {
90            app.fallback_service(ServeFile::new(path))
91        } else if let Some(not_found) = &self.not_found {
92            let not_found = not_found.to_string();
93            let status_code = self.code;
94            app.fallback(move || async move { (status_code, not_found) })
95        } else {
96            let content = include_str!("fallback.html");
97            let status_code = self.code;
98            app.fallback(move || async move { (status_code, Html(content)) })
99        };
100        Ok(app)
101    }
102}