use axum::{
extract::Request,
http::{HeaderMap, HeaderName, HeaderValue},
middleware::Next,
response::Response,
};
use uuid::Uuid;
pub const TRACE_ID_HEADER: &str = "x-trace-id";
pub const REQUEST_ID_HEADER: &str = "x-request-id";
const REAL_IP_HEADER: &str = "x-real-ip";
const FORWARDED_FOR_HEADER: &str = "x-forwarded-for";
const USER_AGENT_HEADER: &str = "user-agent";
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct RequestTraceContext {
pub trace_id: Option<String>,
pub request_id: Option<String>,
pub tenant_id: Option<i64>,
pub operator_id: Option<i64>,
pub ip: Option<String>,
pub user_agent: Option<String>,
}
impl RequestTraceContext {
pub fn new(
trace_id: Option<String>,
request_id: Option<String>,
tenant_id: Option<i64>,
operator_id: Option<i64>,
ip: Option<String>,
user_agent: Option<String>,
) -> Self {
Self {
trace_id: normalize_optional_string(trace_id),
request_id: normalize_optional_string(request_id),
tenant_id,
operator_id,
ip: normalize_optional_string(ip),
user_agent: normalize_optional_string(user_agent),
}
}
pub fn with_identity(mut self, tenant_id: Option<i64>, operator_id: Option<i64>) -> Self {
self.tenant_id = tenant_id;
self.operator_id = operator_id;
self
}
}
pub fn build_request_trace_context(request: &Request) -> RequestTraceContext {
RequestTraceContext::new(
extract_header_value(request.headers(), TRACE_ID_HEADER)
.or_else(|| Some(Uuid::new_v4().to_string())),
extract_header_value(request.headers(), REQUEST_ID_HEADER)
.or_else(|| Some(Uuid::new_v4().to_string())),
None,
None,
extract_request_ip(request.headers()),
extract_header_value(request.headers(), USER_AGENT_HEADER),
)
}
pub fn extract_header_value(headers: &HeaderMap, name: &str) -> Option<String> {
headers
.get(name)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string)
}
pub fn extract_forwarded_ip(headers: &HeaderMap) -> Option<String> {
extract_header_value(headers, FORWARDED_FOR_HEADER).and_then(|value| {
let value = value
.split(',')
.next()
.unwrap_or_default()
.trim()
.to_string();
if value.is_empty() { None } else { Some(value) }
})
}
pub fn extract_request_ip(headers: &HeaderMap) -> Option<String> {
extract_header_value(headers, REAL_IP_HEADER).or_else(|| extract_forwarded_ip(headers))
}
pub fn append_trace_headers(response: &mut Response, trace_context: &RequestTraceContext) {
append_trace_header_values(
response,
trace_context.trace_id.as_deref(),
trace_context.request_id.as_deref(),
);
}
pub fn append_trace_header_values(
response: &mut Response,
trace_id: Option<&str>,
request_id: Option<&str>,
) {
if let Some(trace_id) = trace_id {
insert_header(response, TRACE_ID_HEADER, trace_id);
}
if let Some(request_id) = request_id {
insert_header(response, REQUEST_ID_HEADER, request_id);
}
}
pub async fn trace_request<F>(
mut request: Request,
next: Next,
customize: F,
) -> (RequestTraceContext, Response)
where
F: FnOnce(&Request, RequestTraceContext) -> RequestTraceContext,
{
let trace_context = customize(&request, build_request_trace_context(&request));
request.extensions_mut().insert(trace_context.clone());
let mut response = next.run(request).await;
append_trace_headers(&mut response, &trace_context);
(trace_context, response)
}
fn insert_header(response: &mut Response, name: &'static str, value: &str) {
let Ok(value) = HeaderValue::from_str(value) else {
return;
};
response
.headers_mut()
.insert(HeaderName::from_static(name), value);
}
fn normalize_optional_string(value: Option<String>) -> Option<String> {
value.and_then(|value| {
let value = value.trim().to_string();
if value.is_empty() { None } else { Some(value) }
})
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request as HttpRequest, header::USER_AGENT},
response::Response,
};
#[test]
fn build_request_trace_context_uses_headers_and_generates_ids() {
let request = HttpRequest::builder()
.uri("/health")
.header(TRACE_ID_HEADER, " trace-id ")
.header(USER_AGENT, " stackloom-test ")
.body(Body::empty())
.unwrap();
let trace_context = build_request_trace_context(&request);
assert_eq!(trace_context.trace_id.as_deref(), Some("trace-id"));
assert!(Uuid::parse_str(trace_context.request_id.as_deref().unwrap()).is_ok());
assert_eq!(trace_context.user_agent.as_deref(), Some("stackloom-test"));
}
#[test]
fn extract_request_ip_prefers_real_ip_then_forwarded_ip() {
let real_ip_headers = {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static(REAL_IP_HEADER),
HeaderValue::from_static("10.0.0.8"),
);
headers.insert(
HeaderName::from_static(FORWARDED_FOR_HEADER),
HeaderValue::from_static("192.168.1.10, 192.168.1.11"),
);
headers
};
assert_eq!(
extract_request_ip(&real_ip_headers).as_deref(),
Some("10.0.0.8")
);
let mut forwarded_only_headers = HeaderMap::new();
forwarded_only_headers.insert(
HeaderName::from_static(FORWARDED_FOR_HEADER),
HeaderValue::from_static("192.168.1.10, 192.168.1.11"),
);
assert_eq!(
extract_request_ip(&forwarded_only_headers).as_deref(),
Some("192.168.1.10")
);
}
#[test]
fn append_trace_headers_sets_trace_headers() {
let trace_context = RequestTraceContext::new(
Some("trace-id".to_string()),
Some("request-id".to_string()),
None,
None,
None,
None,
);
let mut response = Response::new(Body::empty());
append_trace_headers(&mut response, &trace_context);
assert_eq!(response.headers().get(TRACE_ID_HEADER).unwrap(), "trace-id");
assert_eq!(
response.headers().get(REQUEST_ID_HEADER).unwrap(),
"request-id"
);
}
}