use axum::{http::StatusCode, response::Html, Router as AXRouter};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::json;
use tower_http::services::ServeFile;
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
#[derive(Debug)]
pub struct StatusCodeWrapper(pub StatusCode);
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Fallback {
#[serde(default)]
pub enable: bool,
#[serde(
default = "default_status_code",
serialize_with = "serialize_status_code",
deserialize_with = "deserialize_status_code"
)]
pub code: StatusCode,
pub file: Option<String>,
pub not_found: Option<String>,
}
fn default_status_code() -> StatusCode {
StatusCode::OK
}
impl Default for Fallback {
fn default() -> Self {
serde_json::from_value(json!({})).unwrap()
}
}
fn deserialize_status_code<'de, D>(de: D) -> Result<StatusCode, D::Error>
where
D: Deserializer<'de>,
{
let code: u16 = Deserialize::deserialize(de)?;
StatusCode::from_u16(code).map_or_else(
|_| {
Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Unsigned(u64::from(code)),
&"a value between 100 and 600",
))
},
Ok,
)
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn serialize_status_code<S>(status: &StatusCode, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
ser.serialize_u16(status.as_u16())
}
impl MiddlewareLayer for Fallback {
fn name(&self) -> &'static str {
"fallback"
}
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
let app = if let Some(path) = &self.file {
app.fallback_service(ServeFile::new(path))
} else if let Some(not_found) = &self.not_found {
let not_found = not_found.to_string();
let status_code = self.code;
app.fallback(move || async move { (status_code, not_found) })
} else {
let content = include_str!("fallback.html");
let status_code = self.code;
app.fallback(move || async move { (status_code, Html(content)) })
};
Ok(app)
}
}