use {axum::http::{HeaderName,
HeaderValue,
Request},
std::{future::Future,
pin::Pin,
task::{Context,
Poll}},
tower::{Layer,
Service},
tracing::Span,
uuid::Uuid};
pub static X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone, Debug)]
pub struct RequestId(String);
impl RequestId {
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone, Default)]
pub struct RequestIdLayer;
impl RequestIdLayer {
pub fn new() -> Self {
Self
}
}
impl<S> Layer<S> for RequestIdLayer {
type Service = RequestIdMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestIdMiddleware { inner }
}
}
#[derive(Clone)]
pub struct RequestIdMiddleware<S> {
inner: S,
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RequestIdMiddleware<S>
where
S: Service<Request<ReqBody>, Response = axum::response::Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
type Response = S::Response;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let id = req
.headers()
.get(&X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.filter(|v| !v.is_empty())
.map(String::from)
.unwrap_or_else(|| Uuid::new_v4().to_string());
Span::current().record("request_id", id.as_str());
let header_val = HeaderValue::from_str(&id).ok();
req.extensions_mut().insert(RequestId(id));
let mut inner = self.inner.clone();
Box::pin(async move {
let mut response = inner.call(req).await?;
if let Some(val) = header_val {
response.headers_mut().insert(&X_REQUEST_ID, val);
}
Ok(response)
})
}
}