use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use axum::http::{HeaderValue, Request, Response};
use http::header::HeaderName;
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use uuid::Uuid;
static X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone, Debug)]
pub struct RequestId(Uuid);
impl RequestId {
#[must_use]
pub const fn as_uuid(&self) -> Uuid {
self.0
}
}
impl fmt::Display for RequestId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Clone, Debug)]
pub struct RequestIdLayer;
impl<S> Layer<S> for RequestIdLayer {
type Service = RequestIdService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestIdService { inner }
}
}
#[derive(Clone, Debug)]
pub struct RequestIdService<S> {
inner: S,
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RequestIdService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = RequestIdFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let id = RequestId(Uuid::new_v4());
req.extensions_mut().insert(id.clone());
RequestIdFuture {
inner: self.inner.call(req),
request_id: Some(id),
}
}
}
pin_project! {
pub struct RequestIdFuture<F> {
#[pin]
inner: F,
request_id: Option<RequestId>,
}
}
impl<F, ResBody, E> Future for RequestIdFuture<F>
where
F: Future<Output = Result<Response<ResBody>, E>>,
{
type Output = Result<Response<ResBody>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Ready(Ok(mut response)) => {
if let Some(id) = this.request_id.take() {
let mut buf = [0u8; uuid::fmt::Hyphenated::LENGTH];
let s = id.0.as_hyphenated().encode_lower(&mut buf);
if let Ok(value) = HeaderValue::from_bytes(s.as_bytes()) {
response.headers_mut().insert(X_REQUEST_ID.clone(), value);
}
}
Poll::Ready(Ok(response))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::Router;
use axum::body::Body;
use axum::extract::Extension;
use axum::routing::get;
use tower::ServiceExt;
#[tokio::test]
async fn response_has_request_id_header() {
let app = Router::new()
.route("/", get(|| async { "ok" }))
.layer(RequestIdLayer);
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert!(response.headers().contains_key("x-request-id"));
let id_str = response.headers()["x-request-id"].to_str().unwrap();
assert!(Uuid::parse_str(id_str).is_ok());
}
#[tokio::test]
async fn each_request_gets_unique_id() {
let app = Router::new()
.route("/", get(|| async { "ok" }))
.layer(RequestIdLayer);
let r1 = app
.clone()
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let r2 = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
let id1 = r1.headers()["x-request-id"].to_str().unwrap();
let id2 = r2.headers()["x-request-id"].to_str().unwrap();
assert_ne!(id1, id2);
}
#[tokio::test]
async fn request_id_available_in_extensions() {
async fn handler(Extension(id): Extension<RequestId>) -> String {
id.to_string()
}
let app = Router::new().route("/", get(handler)).layer(RequestIdLayer);
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(Uuid::parse_str(&body_str).is_ok());
}
#[test]
fn request_id_display() {
let id = RequestId(Uuid::nil());
assert_eq!(id.to_string(), "00000000-0000-0000-0000-000000000000");
}
}