by_loco/controller/middleware/
fallback.rs1use axum::{http::StatusCode, response::Html, Router as AXRouter};
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9use serde_json::json;
10use tower_http::services::ServeFile;
11
12use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
13
14#[derive(Debug)]
15pub struct StatusCodeWrapper(pub StatusCode);
16
17#[derive(Debug, Clone, Deserialize, Serialize)]
18pub struct Fallback {
19 #[serde(default)]
22 pub enable: bool,
23 #[serde(
26 default = "default_status_code",
27 serialize_with = "serialize_status_code",
28 deserialize_with = "deserialize_status_code"
29 )]
30 pub code: StatusCode,
31 pub file: Option<String>,
34 pub not_found: Option<String>,
37}
38
39fn default_status_code() -> StatusCode {
40 StatusCode::OK
41}
42
43impl Default for Fallback {
44 fn default() -> Self {
45 serde_json::from_value(json!({})).unwrap()
46 }
47}
48
49fn deserialize_status_code<'de, D>(de: D) -> Result<StatusCode, D::Error>
50where
51 D: Deserializer<'de>,
52{
53 let code: u16 = Deserialize::deserialize(de)?;
54 StatusCode::from_u16(code).map_or_else(
55 |_| {
56 Err(serde::de::Error::invalid_value(
57 serde::de::Unexpected::Unsigned(u64::from(code)),
58 &"a value between 100 and 600",
59 ))
60 },
61 Ok,
62 )
63}
64
65#[allow(clippy::trivially_copy_pass_by_ref)]
66fn serialize_status_code<S>(status: &StatusCode, ser: S) -> Result<S::Ok, S::Error>
67where
68 S: Serializer,
69{
70 ser.serialize_u16(status.as_u16())
71}
72impl MiddlewareLayer for Fallback {
73 fn name(&self) -> &'static str {
75 "fallback"
76 }
77
78 fn is_enabled(&self) -> bool {
80 self.enable
81 }
82
83 fn config(&self) -> serde_json::Result<serde_json::Value> {
84 serde_json::to_value(self)
85 }
86
87 fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
89 let app = if let Some(path) = &self.file {
90 app.fallback_service(ServeFile::new(path))
91 } else if let Some(not_found) = &self.not_found {
92 let not_found = not_found.to_string();
93 let status_code = self.code;
94 app.fallback(move || async move { (status_code, not_found) })
95 } else {
96 let content = include_str!("fallback.html");
97 let status_code = self.code;
98 app.fallback(move || async move { (status_code, Html(content)) })
99 };
100 Ok(app)
101 }
102}