Skip to main content

nidus_http/
context.rs

1//! Request context primitives shared by middleware, handlers, and observers.
2
3use std::{
4    future::Future,
5    net::{IpAddr, SocketAddr},
6};
7
8use axum::extract::FromRequestParts;
9use http::{HeaderMap, Method, request::Parts};
10use serde::Serialize;
11
12/// Client classification inferred from request boundary headers.
13#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize)]
14#[serde(rename_all = "snake_case")]
15pub enum ClientKind {
16    /// Request uses an API key style credential.
17    ApiKey,
18    /// Request carries a bearer token or other authorization header.
19    Authenticated,
20    /// Request has no recognized application credential.
21    Anonymous,
22}
23
24impl ClientKind {
25    /// Returns the stable string label for this client kind.
26    pub const fn as_str(self) -> &'static str {
27        match self {
28            Self::ApiKey => "api_key",
29            Self::Authenticated => "authenticated",
30            Self::Anonymous => "anonymous",
31        }
32    }
33}
34
35/// Request/correlation context attached to request extensions.
36///
37/// `RequestContext` is available to handlers when the router uses
38/// [`crate::middleware::validated_request_id_layer`] plus
39/// [`crate::middleware::request_context_layer`], or when it is wrapped by
40/// [`crate::middleware::ApiDefaults::production`]. Extracting it without those
41/// extensions rejects the request with `500 Internal Server Error`.
42///
43/// Fields are inferred from request headers and Axum extensions:
44/// - `request_id`: the final validated/generated `x-request-id`
45/// - `correlation_id`: `x-correlation-id`, falling back to the request ID
46/// - `trace_id`: the trace-id segment from `traceparent`
47/// - `client_kind`: `x-api-key` means API key, otherwise `Authorization` means
48///   authenticated, otherwise anonymous
49/// - `route`: Axum's [`axum::extract::MatchedPath`] when it is available at the
50///   point the context layer runs
51///
52/// ```ignore
53/// use nidus_http::{Json, middleware::RequestContext};
54///
55/// async fn handler(context: RequestContext) -> Json<serde_json::Value> {
56///     Json(serde_json::json!({
57///         "requestId": context.request_id(),
58///         "correlationId": context.correlation_id(),
59///     }))
60/// }
61/// ```
62#[derive(Clone, Debug, Eq, PartialEq)]
63pub struct RequestContext {
64    request_id: String,
65    correlation_id: Option<String>,
66    method: Method,
67    route: Option<String>,
68    path: String,
69    trace_id: Option<String>,
70    span_id: Option<String>,
71    client_kind: ClientKind,
72    user_id: Option<String>,
73    tenant_id: Option<String>,
74    session_id: Option<String>,
75}
76
77impl RequestContext {
78    /// Creates a context for the current request boundary.
79    ///
80    /// This constructor is useful in tests or custom middleware. It does not
81    /// inspect headers, so optional correlation, trace, route, and client fields
82    /// remain empty/default until set explicitly or built via [`Self::from_parts`].
83    pub fn new(request_id: impl Into<String>, method: Method, path: impl Into<String>) -> Self {
84        Self {
85            request_id: request_id.into(),
86            correlation_id: None,
87            method,
88            route: None,
89            path: path.into(),
90            trace_id: None,
91            span_id: None,
92            client_kind: ClientKind::Anonymous,
93            user_id: None,
94            tenant_id: None,
95            session_id: None,
96        }
97    }
98
99    /// Creates a context from request parts.
100    ///
101    /// This reads `x-correlation-id`, `traceparent`, `x-api-key`,
102    /// `Authorization`, and [`axum::extract::MatchedPath`] from the request
103    /// boundary. The supplied `request_id` is expected to be the final ID chosen
104    /// by request ID middleware.
105    pub fn from_parts(parts: &Parts, request_id: impl Into<String>) -> Self {
106        let request_id = request_id.into();
107        let mut context = Self::new(request_id.clone(), parts.method.clone(), parts.uri.path());
108        context.correlation_id = header_to_string(&parts.headers, "x-correlation-id")
109            .or_else(|| Some(request_id).filter(|value| !value.is_empty()));
110        context.route = parts
111            .extensions
112            .get::<axum::extract::MatchedPath>()
113            .map(|path| path.as_str().to_owned());
114        context.client_kind = infer_client_kind(&parts.headers);
115        context.trace_id = header_to_string(&parts.headers, "traceparent")
116            .and_then(|value| value.split('-').nth(1).map(str::to_owned));
117        context
118    }
119
120    /// Returns the final request id.
121    ///
122    /// With [`crate::middleware::validated_request_id_layer`], this is either a
123    /// valid inbound UUID v4 or a generated ID.
124    pub fn request_id(&self) -> &str {
125        &self.request_id
126    }
127
128    pub(crate) fn into_request_id(self) -> String {
129        self.request_id
130    }
131
132    /// Returns the correlation id when available.
133    ///
134    /// [`Self::from_parts`] prefers `x-correlation-id` and falls back to the
135    /// request ID when no correlation header is present.
136    pub fn correlation_id(&self) -> Option<&str> {
137        self.correlation_id.as_deref()
138    }
139
140    /// Returns the request method.
141    pub const fn method(&self) -> &Method {
142        &self.method
143    }
144
145    /// Returns the stable matched route pattern when available.
146    ///
147    /// This depends on Axum's [`axum::extract::MatchedPath`] extension being
148    /// present before the context is built. Layer placement can affect whether
149    /// this is available for a given router shape.
150    pub fn route(&self) -> Option<&str> {
151        self.route.as_deref()
152    }
153
154    /// Returns the raw request path.
155    pub fn path(&self) -> &str {
156        &self.path
157    }
158
159    /// Returns the trace id when available.
160    ///
161    /// The value is extracted from the second segment of the W3C `traceparent`
162    /// header. Use [`crate::otel::extract_trace_context`] when the `otel`
163    /// feature is enabled and you need full trace/span validation.
164    pub fn trace_id(&self) -> Option<&str> {
165        self.trace_id.as_deref()
166    }
167
168    /// Returns the span id when available.
169    pub fn span_id(&self) -> Option<&str> {
170        self.span_id.as_deref()
171    }
172
173    /// Returns the inferred client kind.
174    ///
175    /// `x-api-key` takes precedence over `Authorization`; otherwise the request
176    /// is classified as anonymous.
177    pub const fn client_kind(&self) -> ClientKind {
178        self.client_kind
179    }
180
181    /// Returns the optional application user id.
182    pub fn user_id(&self) -> Option<&str> {
183        self.user_id.as_deref()
184    }
185
186    /// Returns the optional application tenant id.
187    pub fn tenant_id(&self) -> Option<&str> {
188        self.tenant_id.as_deref()
189    }
190
191    /// Returns the optional application session id.
192    pub fn session_id(&self) -> Option<&str> {
193        self.session_id.as_deref()
194    }
195
196    /// Sets the stable matched route pattern.
197    pub fn with_route(mut self, route: impl Into<String>) -> Self {
198        self.route = Some(route.into());
199        self
200    }
201
202    /// Sets an application user id.
203    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
204        self.user_id = Some(user_id.into());
205        self
206    }
207
208    /// Sets an application tenant id.
209    pub fn with_tenant_id(mut self, tenant_id: impl Into<String>) -> Self {
210        self.tenant_id = Some(tenant_id.into());
211        self
212    }
213
214    /// Sets an application session id.
215    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
216        self.session_id = Some(session_id.into());
217        self
218    }
219}
220
221impl<S> FromRequestParts<S> for RequestContext
222where
223    S: Send + Sync,
224{
225    type Rejection = axum::http::StatusCode;
226
227    fn from_request_parts(
228        parts: &mut Parts,
229        _state: &S,
230    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
231        let context = parts.extensions.get::<Self>().cloned();
232        async move { context.ok_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR) }
233    }
234}
235
236/// Request identity used by rate limiters and observers.
237#[derive(Clone, Debug, Eq, Hash, PartialEq)]
238pub struct RequestIdentity(String);
239
240impl RequestIdentity {
241    /// Creates a request identity from a stable label.
242    pub fn new(value: impl Into<String>) -> Self {
243        Self(value.into())
244    }
245
246    /// Returns the identity label.
247    pub fn as_str(&self) -> &str {
248        &self.0
249    }
250}
251
252/// Extracts a rate-limit identity from request parts.
253pub trait IdentityExtractor: Clone + Send + Sync + 'static {
254    /// Returns the identity for this request.
255    fn extract(&self, parts: &Parts) -> Option<RequestIdentity>;
256}
257
258impl<F> IdentityExtractor for F
259where
260    F: Fn(&Parts) -> Option<RequestIdentity> + Clone + Send + Sync + 'static,
261{
262    fn extract(&self, parts: &Parts) -> Option<RequestIdentity> {
263        self(parts)
264    }
265}
266
267/// Builds an identity extractor that prefers user/tenant/API key context fields.
268pub fn context_identity() -> impl IdentityExtractor {
269    |parts: &Parts| {
270        if let Some(context) = parts.extensions.get::<RequestContext>()
271            && let Some(value) = context.user_id().or_else(|| context.tenant_id())
272        {
273            return Some(RequestIdentity::new(value.to_owned()));
274        }
275        header_to_string(&parts.headers, "x-api-key").map(RequestIdentity::new)
276    }
277}
278
279/// Builds an identity extractor from API key headers.
280pub fn api_key_identity() -> impl IdentityExtractor {
281    |parts: &Parts| header_to_string(&parts.headers, "x-api-key").map(RequestIdentity::new)
282}
283
284/// Builds an identity extractor from the connected client IP address.
285///
286/// This extractor uses Axum's [`axum::extract::ConnectInfo<SocketAddr>`]
287/// extension and ignores `X-Forwarded-For`. Nidus serving helpers populate
288/// `ConnectInfo` on the normal `listen`/`serve` path. If a router is exercised
289/// without peer information, the identity falls back to `"anonymous"`.
290pub fn client_ip_identity() -> impl IdentityExtractor {
291    |parts: &Parts| {
292        peer_ip(parts)
293            .map(|ip| RequestIdentity::new(ip.to_string()))
294            .or_else(|| Some(RequestIdentity::new("anonymous")))
295    }
296}
297
298/// Builds an identity extractor that trusts `X-Forwarded-For` only from known proxies.
299///
300/// Use this when Nidus runs behind a reverse proxy that rewrites or appends
301/// `X-Forwarded-For` and the direct peer address is one of the configured
302/// trusted proxy IPs. Requests from untrusted peers ignore `X-Forwarded-For`
303/// and use the direct peer IP. Requests without peer information fall back to
304/// `"anonymous"`.
305pub fn trusted_proxy_client_ip_identity(
306    trusted_proxies: impl IntoIterator<Item = IpAddr>,
307) -> impl IdentityExtractor {
308    let trusted_proxies = trusted_proxies.into_iter().collect::<Vec<_>>();
309    move |parts: &Parts| {
310        peer_ip(parts)
311            .map(|peer| {
312                if trusted_proxies.contains(&peer)
313                    && let Some(forwarded_ip) = forwarded_for_ip(&parts.headers)
314                {
315                    RequestIdentity::new(forwarded_ip.to_string())
316                } else {
317                    RequestIdentity::new(peer.to_string())
318                }
319            })
320            .or_else(|| Some(RequestIdentity::new("anonymous")))
321    }
322}
323
324fn peer_ip(parts: &Parts) -> Option<IpAddr> {
325    parts
326        .extensions
327        .get::<axum::extract::ConnectInfo<SocketAddr>>()
328        .map(|connect| connect.0.ip())
329}
330
331fn forwarded_for_ip(headers: &HeaderMap) -> Option<IpAddr> {
332    header_to_string(headers, "x-forwarded-for").and_then(|value| {
333        value
334            .split(',')
335            .next()
336            .map(str::trim)
337            .filter(|value| !value.is_empty())
338            .and_then(|value| value.parse().ok())
339    })
340}
341
342pub(crate) fn header_to_string(headers: &HeaderMap, name: &'static str) -> Option<String> {
343    headers
344        .get(name)
345        .and_then(|value| value.to_str().ok())
346        .filter(|value| !value.is_empty())
347        .map(str::to_owned)
348}
349
350fn infer_client_kind(headers: &HeaderMap) -> ClientKind {
351    if headers.contains_key("x-api-key") {
352        ClientKind::ApiKey
353    } else if headers.contains_key(http::header::AUTHORIZATION) {
354        ClientKind::Authenticated
355    } else {
356        ClientKind::Anonymous
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use http::Request;
364
365    #[test]
366    fn request_context_can_consume_request_id() {
367        let context = RequestContext::new("req-123", Method::GET, "/users");
368
369        assert_eq!(context.into_request_id(), "req-123");
370    }
371
372    #[test]
373    fn client_ip_identity_ignores_forwarded_headers_without_peer_info() {
374        let parts = request_parts(None, Some("203.0.113.10"));
375        let identity = client_ip_identity().extract(&parts).unwrap();
376
377        assert_eq!(identity.as_str(), "anonymous");
378    }
379
380    #[test]
381    fn trusted_proxy_client_ip_identity_uses_forwarded_header_from_trusted_peer() {
382        let parts = request_parts(Some("127.0.0.1:5000"), Some("203.0.113.10, 10.0.0.5"));
383        let trusted_proxy = "127.0.0.1".parse::<IpAddr>().unwrap();
384        let identity = trusted_proxy_client_ip_identity([trusted_proxy])
385            .extract(&parts)
386            .unwrap();
387
388        assert_eq!(identity.as_str(), "203.0.113.10");
389    }
390
391    #[test]
392    fn trusted_proxy_client_ip_identity_ignores_forwarded_header_from_untrusted_peer() {
393        let parts = request_parts(Some("127.0.0.1:5000"), Some("203.0.113.10"));
394        let trusted_proxy = "10.0.0.1".parse::<IpAddr>().unwrap();
395        let identity = trusted_proxy_client_ip_identity([trusted_proxy])
396            .extract(&parts)
397            .unwrap();
398
399        assert_eq!(identity.as_str(), "127.0.0.1");
400    }
401
402    fn request_parts(peer: Option<&str>, forwarded_for: Option<&str>) -> Parts {
403        let mut builder = Request::builder().uri("/");
404        if let Some(forwarded_for) = forwarded_for {
405            builder = builder.header("x-forwarded-for", forwarded_for);
406        }
407        let (mut parts, ()) = builder.body(()).unwrap().into_parts();
408        if let Some(peer) = peer {
409            parts.extensions.insert(axum::extract::ConnectInfo(
410                peer.parse::<SocketAddr>().unwrap(),
411            ));
412        }
413        parts
414    }
415}