use std::future::{Ready, ready};
use std::rc::Rc;
use actix_web::{
Error, HttpMessage, HttpRequest,
dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
http::header::{HeaderName, HeaderValue},
};
use futures_util::future::LocalBoxFuture;
use tracing_actix_web::RequestId;
const HEADER_NAME: HeaderName = HeaderName::from_static("x-request-id");
pub fn request_id_from_request(req: &HttpRequest) -> String {
req.extensions()
.get::<RequestId>()
.map(|rid| rid.to_string())
.unwrap_or_default()
}
#[derive(Default, Clone)]
pub struct RequestIdHeader;
impl<S, B> Transform<S, ServiceRequest> for RequestIdHeader
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Transform = RequestIdHeaderMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(RequestIdHeaderMiddleware {
service: Rc::new(service),
}))
}
}
pub struct RequestIdHeaderMiddleware<S> {
service: Rc<S>,
}
impl<S, B> Service<ServiceRequest> for RequestIdHeaderMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let service = Rc::clone(&self.service);
Box::pin(async move {
let mut res = service.call(req).await?;
let request_id = res.request().extensions().get::<RequestId>().copied();
if let Some(rid) = request_id
&& let Ok(value) = HeaderValue::from_str(&rid.to_string())
{
res.headers_mut().insert(HEADER_NAME, value);
}
Ok(res)
})
}
}
#[cfg(test)]
mod tests {
use super::RequestIdHeader;
use actix_web::{
App, HttpResponse,
http::StatusCode,
test::{TestRequest, call_service, init_service},
web,
};
use tracing_actix_web::TracingLogger;
async fn ok_handler() -> HttpResponse {
HttpResponse::Ok().body("ok")
}
async fn fail_handler() -> HttpResponse {
HttpResponse::InternalServerError().json(serde_json::json!({"code": "BOOM"}))
}
#[actix_web::test]
async fn header_is_present_on_success_response() {
let app = init_service(
App::new()
.wrap(RequestIdHeader)
.wrap(TracingLogger::default())
.route("/ok", web::get().to(ok_handler)),
)
.await;
let res = call_service(&app, TestRequest::get().uri("/ok").to_request()).await;
assert_eq!(res.status(), StatusCode::OK);
let header = res
.headers()
.get("x-request-id")
.expect("X-Request-ID header should be present");
let value = header.to_str().expect("header value should be valid utf-8");
let uuid_re =
regex::Regex::new(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$")
.expect("valid uuid regex");
assert!(
uuid_re.is_match(value),
"X-Request-ID should be a canonical UUID, got: {value}"
);
}
#[actix_web::test]
async fn header_is_present_on_error_response() {
let app = init_service(
App::new()
.wrap(RequestIdHeader)
.wrap(TracingLogger::default())
.route("/boom", web::get().to(fail_handler)),
)
.await;
let res = call_service(&app, TestRequest::get().uri("/boom").to_request()).await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().contains_key("x-request-id"));
}
#[actix_web::test]
async fn header_value_matches_uuid_per_request_and_differs_across_requests() {
let app = init_service(
App::new()
.wrap(RequestIdHeader)
.wrap(TracingLogger::default())
.route("/ok", web::get().to(ok_handler)),
)
.await;
let res_a = call_service(&app, TestRequest::get().uri("/ok").to_request()).await;
let res_b = call_service(&app, TestRequest::get().uri("/ok").to_request()).await;
let value_a = res_a
.headers()
.get("x-request-id")
.expect("first response should carry header")
.to_str()
.unwrap()
.to_string();
let value_b = res_b
.headers()
.get("x-request-id")
.expect("second response should carry header")
.to_str()
.unwrap()
.to_string();
assert_ne!(
value_a, value_b,
"every request should receive a fresh request id"
);
}
#[actix_web::test]
async fn header_is_omitted_when_tracing_logger_is_absent() {
let app = init_service(
App::new()
.wrap(RequestIdHeader)
.route("/ok", web::get().to(ok_handler)),
)
.await;
let res = call_service(&app, TestRequest::get().uri("/ok").to_request()).await;
assert_eq!(res.status(), StatusCode::OK);
assert!(
!res.headers().contains_key("x-request-id"),
"header should be absent when TracingLogger is not registered"
);
}
}