use std::sync::OnceLock;
use axum::{
http::header::{HeaderName, HeaderValue},
Router as AXRouter,
};
use tower_http::set_header::SetResponseHeaderLayer;
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
static DEFAULT_IDENT_HEADER_VALUE: OnceLock<HeaderValue> = OnceLock::new();
fn get_default_ident_header_value() -> &'static HeaderValue {
DEFAULT_IDENT_HEADER_VALUE.get_or_init(|| HeaderValue::from_static("loco.rs"))
}
#[derive(Debug)]
pub struct Middleware {
ident: Option<HeaderValue>,
}
#[must_use]
pub fn new(ident: Option<&str>) -> Middleware {
let ident_value = ident.map_or_else(
|| Some(get_default_ident_header_value().clone()),
|ident| {
if ident.is_empty() {
None
} else {
match HeaderValue::from_str(ident) {
Ok(val) => Some(val),
Err(e) => {
tracing::info!(
error = format!("{}", e),
val = ident,
"could not set custom ident header"
);
Some(get_default_ident_header_value().clone())
}
}
}
},
);
Middleware { ident: ident_value }
}
impl MiddlewareLayer for Middleware {
fn name(&self) -> &'static str {
"powered_by"
}
fn is_enabled(&self) -> bool {
self.ident.is_some()
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
self.ident.as_ref().map_or_else(
|| Ok(serde_json::json!({})),
|ident| Ok(serde_json::json!({"ident": ident.to_str().unwrap_or_default()})),
)
}
fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
Ok(app.layer(SetResponseHeaderLayer::overriding(
HeaderName::from_static("x-powered-by"),
self.ident
.clone()
.unwrap_or_else(|| get_default_ident_header_value().clone()),
)))
}
}