neocrates 0.1.52

A comprehensive Rust library for various utilities and helpers
Documentation
use axum::{
    extract::Request,
    http::{HeaderMap, HeaderName, HeaderValue},
    middleware::Next,
    response::Response,
};
use uuid::Uuid;

pub const TRACE_ID_HEADER: &str = "x-trace-id";
pub const REQUEST_ID_HEADER: &str = "x-request-id";

const REAL_IP_HEADER: &str = "x-real-ip";
const FORWARDED_FOR_HEADER: &str = "x-forwarded-for";
const USER_AGENT_HEADER: &str = "user-agent";

#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct RequestTraceContext {
    pub trace_id: Option<String>,
    pub request_id: Option<String>,
    pub tenant_id: Option<i64>,
    pub operator_id: Option<i64>,
    pub ip: Option<String>,
    pub user_agent: Option<String>,
}

impl RequestTraceContext {
    pub fn new(
        trace_id: Option<String>,
        request_id: Option<String>,
        tenant_id: Option<i64>,
        operator_id: Option<i64>,
        ip: Option<String>,
        user_agent: Option<String>,
    ) -> Self {
        Self {
            trace_id: normalize_optional_string(trace_id),
            request_id: normalize_optional_string(request_id),
            tenant_id,
            operator_id,
            ip: normalize_optional_string(ip),
            user_agent: normalize_optional_string(user_agent),
        }
    }

    pub fn with_identity(mut self, tenant_id: Option<i64>, operator_id: Option<i64>) -> Self {
        self.tenant_id = tenant_id;
        self.operator_id = operator_id;
        self
    }
}

pub fn build_request_trace_context(request: &Request) -> RequestTraceContext {
    RequestTraceContext::new(
        extract_header_value(request.headers(), TRACE_ID_HEADER)
            .or_else(|| Some(Uuid::new_v4().to_string())),
        extract_header_value(request.headers(), REQUEST_ID_HEADER)
            .or_else(|| Some(Uuid::new_v4().to_string())),
        None,
        None,
        extract_request_ip(request.headers()),
        extract_header_value(request.headers(), USER_AGENT_HEADER),
    )
}

pub fn extract_header_value(headers: &HeaderMap, name: &str) -> Option<String> {
    headers
        .get(name)
        .and_then(|value| value.to_str().ok())
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(str::to_string)
}

pub fn extract_forwarded_ip(headers: &HeaderMap) -> Option<String> {
    extract_header_value(headers, FORWARDED_FOR_HEADER).and_then(|value| {
        let value = value
            .split(',')
            .next()
            .unwrap_or_default()
            .trim()
            .to_string();
        if value.is_empty() { None } else { Some(value) }
    })
}

pub fn extract_request_ip(headers: &HeaderMap) -> Option<String> {
    extract_header_value(headers, REAL_IP_HEADER).or_else(|| extract_forwarded_ip(headers))
}

pub fn append_trace_headers(response: &mut Response, trace_context: &RequestTraceContext) {
    append_trace_header_values(
        response,
        trace_context.trace_id.as_deref(),
        trace_context.request_id.as_deref(),
    );
}

pub fn append_trace_header_values(
    response: &mut Response,
    trace_id: Option<&str>,
    request_id: Option<&str>,
) {
    if let Some(trace_id) = trace_id {
        insert_header(response, TRACE_ID_HEADER, trace_id);
    }

    if let Some(request_id) = request_id {
        insert_header(response, REQUEST_ID_HEADER, request_id);
    }
}

/// Builds, stores, and propagates request trace metadata around an Axum request.
pub async fn trace_request<F>(
    mut request: Request,
    next: Next,
    customize: F,
) -> (RequestTraceContext, Response)
where
    F: FnOnce(&Request, RequestTraceContext) -> RequestTraceContext,
{
    let trace_context = customize(&request, build_request_trace_context(&request));
    request.extensions_mut().insert(trace_context.clone());

    let mut response = next.run(request).await;
    append_trace_headers(&mut response, &trace_context);

    (trace_context, response)
}

fn insert_header(response: &mut Response, name: &'static str, value: &str) {
    let Ok(value) = HeaderValue::from_str(value) else {
        return;
    };

    response
        .headers_mut()
        .insert(HeaderName::from_static(name), value);
}

fn normalize_optional_string(value: Option<String>) -> Option<String> {
    value.and_then(|value| {
        let value = value.trim().to_string();
        if value.is_empty() { None } else { Some(value) }
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::{
        body::Body,
        http::{Request as HttpRequest, header::USER_AGENT},
        response::Response,
    };

    #[test]
    fn build_request_trace_context_uses_headers_and_generates_ids() {
        let request = HttpRequest::builder()
            .uri("/health")
            .header(TRACE_ID_HEADER, " trace-id ")
            .header(USER_AGENT, " stackloom-test ")
            .body(Body::empty())
            .unwrap();

        let trace_context = build_request_trace_context(&request);

        assert_eq!(trace_context.trace_id.as_deref(), Some("trace-id"));
        assert!(Uuid::parse_str(trace_context.request_id.as_deref().unwrap()).is_ok());
        assert_eq!(trace_context.user_agent.as_deref(), Some("stackloom-test"));
    }

    #[test]
    fn extract_request_ip_prefers_real_ip_then_forwarded_ip() {
        let real_ip_headers = {
            let mut headers = HeaderMap::new();
            headers.insert(
                HeaderName::from_static(REAL_IP_HEADER),
                HeaderValue::from_static("10.0.0.8"),
            );
            headers.insert(
                HeaderName::from_static(FORWARDED_FOR_HEADER),
                HeaderValue::from_static("192.168.1.10, 192.168.1.11"),
            );
            headers
        };
        assert_eq!(
            extract_request_ip(&real_ip_headers).as_deref(),
            Some("10.0.0.8")
        );

        let mut forwarded_only_headers = HeaderMap::new();
        forwarded_only_headers.insert(
            HeaderName::from_static(FORWARDED_FOR_HEADER),
            HeaderValue::from_static("192.168.1.10, 192.168.1.11"),
        );
        assert_eq!(
            extract_request_ip(&forwarded_only_headers).as_deref(),
            Some("192.168.1.10")
        );
    }

    #[test]
    fn append_trace_headers_sets_trace_headers() {
        let trace_context = RequestTraceContext::new(
            Some("trace-id".to_string()),
            Some("request-id".to_string()),
            None,
            None,
            None,
            None,
        );
        let mut response = Response::new(Body::empty());

        append_trace_headers(&mut response, &trace_context);

        assert_eq!(response.headers().get(TRACE_ID_HEADER).unwrap(), "trace-id");
        assert_eq!(
            response.headers().get(REQUEST_ID_HEADER).unwrap(),
            "request-id"
        );
    }
}