Skip to main content

nidus_http/middleware/
request_context.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use axum::extract::Request;
8use tower::{Layer, Service};
9
10use crate::context::{RequestContext, header_to_string};
11
12/// Creates a Tower layer that enriches [`RequestContext`] request extensions.
13///
14/// Use this with [`crate::middleware::validated_request_id_layer`] so handlers
15/// can extract [`RequestContext`]. The request ID layer chooses and stores the
16/// final ID; this layer rebuilds the context from request parts so correlation,
17/// trace, route, and client-kind fields reflect the current request boundary.
18/// [`crate::middleware::ApiDefaults::production`] installs both layers.
19///
20/// If no prior context or `x-request-id` header exists, the context uses
21/// `"unknown"` as the request ID. Prefer validated request IDs for production
22/// APIs.
23pub fn request_context_layer() -> RequestContextLayer {
24    RequestContextLayer
25}
26
27/// Tower layer that inserts request/correlation context into request extensions.
28///
29/// The inserted context reads:
30/// - `x-request-id` from the existing [`RequestContext`] or request header
31/// - `x-correlation-id`, falling back to the request ID
32/// - `traceparent` trace ID
33/// - `x-api-key` / `Authorization` for client classification
34/// - Axum [`axum::extract::MatchedPath`] when available at this layer
35#[derive(Clone, Copy, Debug, Default)]
36pub struct RequestContextLayer;
37
38impl<S> Layer<S> for RequestContextLayer {
39    type Service = RequestContextService<S>;
40
41    fn layer(&self, inner: S) -> Self::Service {
42        RequestContextService { inner }
43    }
44}
45
46/// Service produced by [`RequestContextLayer`].
47#[derive(Clone, Debug)]
48pub struct RequestContextService<S> {
49    inner: S,
50}
51
52impl<S> Service<Request> for RequestContextService<S>
53where
54    S: Service<Request> + Send + 'static,
55    S::Future: Send + 'static,
56    S::Error: Send + 'static,
57{
58    type Response = S::Response;
59    type Error = S::Error;
60    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
61
62    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63        self.inner.poll_ready(cx)
64    }
65
66    fn call(&mut self, mut request: Request) -> Self::Future {
67        let (mut parts, body) = request.into_parts();
68        let request_id = parts
69            .extensions
70            .remove::<RequestContext>()
71            .map(RequestContext::into_request_id)
72            .or_else(|| header_to_string(&parts.headers, "x-request-id"))
73            .unwrap_or_else(|| "unknown".to_owned());
74        let context = RequestContext::from_parts(&parts, request_id);
75        parts.extensions.insert(context);
76        request = Request::from_parts(parts, body);
77        let future = self.inner.call(request);
78        Box::pin(future)
79    }
80}