use axum::{
extract::Request, http::HeaderValue, middleware::Next, response::Response, Router as AXRouter,
};
use regex::Regex;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
const X_REQUEST_ID: &str = "x-request-id";
const MAX_LEN: usize = 255;
use std::sync::OnceLock;
static ID_CLEANUP: OnceLock<Regex> = OnceLock::new();
fn get_id_cleanup() -> &'static Regex {
ID_CLEANUP.get_or_init(|| Regex::new(r"[^\w\-@]").unwrap())
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RequestId {
#[serde(default)]
pub enable: bool,
}
impl MiddlewareLayer for RequestId {
fn name(&self) -> &'static str {
"request_id"
}
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
Ok(app.layer(axum::middleware::from_fn(request_id_middleware)))
}
}
#[derive(Debug, Clone)]
pub struct LocoRequestId(String);
impl LocoRequestId {
#[must_use]
pub fn get(&self) -> &str {
self.0.as_str()
}
}
pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
let header_request_id = request.headers().get(X_REQUEST_ID).cloned();
let request_id = make_request_id(header_request_id);
request
.extensions_mut()
.insert(LocoRequestId(request_id.clone()));
let mut res = next.run(request).await;
if let Ok(v) = HeaderValue::from_str(request_id.as_str()) {
res.headers_mut().insert(X_REQUEST_ID, v);
} else {
tracing::warn!("could not set request ID into response headers: `{request_id}`",);
}
res
}
fn make_request_id(maybe_request_id: Option<HeaderValue>) -> String {
maybe_request_id
.and_then(|hdr| {
let id: Option<String> = hdr.to_str().ok().map(|s| {
get_id_cleanup()
.replace_all(s, "")
.chars()
.take(MAX_LEN)
.collect()
});
id.filter(|s| !s.is_empty())
})
.unwrap_or_else(|| Uuid::new_v4().to_string())
}
#[cfg(test)]
mod tests {
use axum::http::HeaderValue;
use insta::assert_debug_snapshot;
use super::make_request_id;
#[test]
fn create_or_fetch_request_id() {
let id = make_request_id(Some(HeaderValue::from_static("foo-bar=baz")));
assert_debug_snapshot!(id);
let id = make_request_id(Some(HeaderValue::from_static("")));
assert_debug_snapshot!(id.len());
let id = make_request_id(Some(HeaderValue::from_static("==========")));
assert_debug_snapshot!(id.len());
let long_id = "x".repeat(1000);
let id = make_request_id(Some(HeaderValue::from_str(&long_id).unwrap()));
assert_debug_snapshot!(id.len());
let id = make_request_id(None);
assert_debug_snapshot!(id.len());
}
}