use axum::extract::Request;
use axum::http::HeaderName;
use axum::http::header::{HeaderValue, InvalidHeaderValue};
use axum::middleware::Next;
use axum::response::Response;
use tracing::Instrument;
use uuid::Uuid;
pub(crate) const CORRELATION_HEADER: &str = "x-request-id";
const MIN_CALLER_ID_LEN: usize = 8;
const MAX_CALLER_ID_LEN: usize = 128;
pub(crate) async fn correlation_id(req: Request, next: Next) -> Response {
let id = extract_or_mint(&req);
let span = tracing::info_span!(
"http_request",
correlation_id = %id,
method = %req.method(),
uri = %req.uri(),
);
let mut response = next.run(req).instrument(span).await;
if let Ok(val) = HeaderValue::from_str(&id) {
response
.headers_mut()
.insert(HeaderName::from_static(CORRELATION_HEADER), val);
}
response
}
pub(crate) fn extract_or_mint(req: &Request) -> String {
if let Some(raw) = req.headers().get(CORRELATION_HEADER)
&& let Ok(s) = raw.to_str()
{
let trimmed = s.trim();
if is_acceptable_caller_id(trimmed) {
return trimmed.to_string();
}
}
mint_uuid_v7()
}
fn mint_uuid_v7() -> String {
Uuid::now_v7().to_string()
}
fn is_acceptable_caller_id(s: &str) -> bool {
let len = s.len();
if !(MIN_CALLER_ID_LEN..=MAX_CALLER_ID_LEN).contains(&len) {
return false;
}
s.chars().all(|c| c.is_ascii_graphic() || c == ' ')
}
#[allow(dead_code)]
pub(crate) fn _header_value(id: &str) -> Result<HeaderValue, InvalidHeaderValue> {
HeaderValue::from_str(id)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request;
#[test]
fn reuses_valid_caller_id() {
let req = Request::builder()
.uri("/v1/healthz")
.header(CORRELATION_HEADER, "req-abc12345")
.body(Body::empty())
.unwrap();
assert_eq!(extract_or_mint(&req), "req-abc12345");
}
#[test]
fn mints_when_header_missing() {
let req = Request::builder()
.uri("/v1/healthz")
.body(Body::empty())
.unwrap();
let id = extract_or_mint(&req);
assert_eq!(id.len(), 36);
assert_eq!(id.matches('-').count(), 4);
}
#[test]
fn rejects_too_short_caller_id() {
let req = Request::builder()
.uri("/v1/healthz")
.header(CORRELATION_HEADER, "short-1")
.body(Body::empty())
.unwrap();
let id = extract_or_mint(&req);
assert_ne!(id, "short-1");
assert_eq!(id.len(), 36, "expected UUIDv7, got {id}");
}
#[test]
fn rejects_too_long_caller_id() {
let too_long = "x".repeat(MAX_CALLER_ID_LEN + 1);
let req = Request::builder()
.uri("/v1/healthz")
.header(CORRELATION_HEADER, &too_long)
.body(Body::empty())
.unwrap();
let id = extract_or_mint(&req);
assert_ne!(id, too_long);
}
#[test]
fn two_minted_ids_differ() {
let a = mint_uuid_v7();
let b = mint_uuid_v7();
assert_ne!(a, b);
}
#[test]
fn accepts_opentelemetry_style_id() {
let caller = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
let req = Request::builder()
.uri("/v1/healthz")
.header(CORRELATION_HEADER, caller)
.body(Body::empty())
.unwrap();
assert_eq!(extract_or_mint(&req), caller);
}
#[test]
fn rejects_control_chars() {
assert!(!is_acceptable_caller_id("abcdef\x01gh"));
assert!(!is_acceptable_caller_id("abc\tdefg"));
}
}