Skip to main content

nidus_http/middleware/
request_id.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::Arc,
5    task::{Context, Poll},
6};
7
8use axum::{body::Body, response::IntoResponse};
9use http::{HeaderValue, Request, Response, StatusCode, header::HeaderName};
10use serde::Serialize;
11use tower::{Layer, Service};
12use uuid::{Uuid, Version};
13
14use crate::context::RequestContext;
15
16/// Legacy Tower layer that adds an `x-request-id` response header when absent.
17///
18/// Incoming request IDs are propagated to the response unless the inner service
19/// already set a response ID. Requests without an ID receive a generated
20/// UUID v4 value.
21///
22/// This layer does not validate inbound IDs and does not populate
23/// [`RequestContext`]. Prefer [`validated_request_id_layer`] for production API
24/// defaults, UUID v4 generation, strict/permissive validation, request
25/// extension insertion, and consistent error responses.
26#[derive(Clone, Copy, Debug, Default)]
27pub struct RequestIdLayer;
28
29impl<S> Layer<S> for RequestIdLayer {
30    type Service = RequestIdService<S>;
31
32    fn layer(&self, inner: S) -> Self::Service {
33        RequestIdService { inner }
34    }
35}
36
37/// Service produced by [`RequestIdLayer`].
38#[derive(Clone, Debug)]
39pub struct RequestIdService<S> {
40    inner: S,
41}
42
43impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for RequestIdService<S>
44where
45    S: Service<Request<RequestBody>, Response = Response<ResponseBody>> + Send + 'static,
46    S::Future: Send + 'static,
47    S::Error: Send + 'static,
48    RequestBody: Send + 'static,
49    ResponseBody: Send + 'static,
50{
51    type Response = Response<ResponseBody>;
52    type Error = S::Error;
53    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
54
55    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56        self.inner.poll_ready(cx)
57    }
58
59    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
60        let request_id = request.headers().get(request_id_header()).cloned();
61        let future = self.inner.call(request);
62        Box::pin(async move {
63            let mut response = future.await?;
64            response
65                .headers_mut()
66                .entry(request_id_header())
67                .or_insert_with(|| request_id.unwrap_or_else(new_request_id));
68            Ok(response)
69        })
70    }
71}
72
73fn request_id_header() -> HeaderName {
74    HeaderName::from_static("x-request-id")
75}
76
77fn new_request_id() -> HeaderValue {
78    HeaderValue::from_str(&Uuid::new_v4().to_string())
79        .expect("generated request id contains only valid header characters")
80}
81
82/// Request ID validation behavior for inbound `x-request-id` values.
83///
84/// Valid inbound IDs must parse as UUID v4 values. Invalid header syntax,
85/// non-UUID strings, and UUIDs from other versions are treated as malformed.
86/// Missing IDs are never rejected; the configured generator is used instead.
87#[derive(Clone, Copy, Debug, Eq, PartialEq)]
88pub enum RequestIdMode {
89    /// Propagate valid UUID v4 IDs and replace malformed incoming IDs.
90    ///
91    /// This mode is useful in development and at integration boundaries where
92    /// clients may still send legacy IDs. A malformed inbound value is not
93    /// exposed to handlers; it is replaced with a generated ID before the
94    /// request reaches the inner service.
95    Permissive,
96    /// Propagate valid UUID v4 IDs and reject malformed incoming IDs.
97    ///
98    /// A malformed inbound value returns `400 Bad Request` with an
99    /// `invalid_request_id` JSON error body. The rejection response still
100    /// receives a generated request ID in the configured response header.
101    Strict,
102}
103
104/// Compatibility alias for naming request ID validation policy.
105pub type RequestIdPolicy = RequestIdMode;
106
107/// Typed configuration for validated request ID propagation.
108///
109/// The default production config uses the `x-request-id` header, strict inbound
110/// validation, and UUID v4 generation. Custom generators are accepted, but their
111/// output must be a valid HTTP header value because generated IDs are inserted
112/// into request headers, request extensions, and response headers. If a custom
113/// generator returns an invalid header value, the middleware returns a stable
114/// `500 Internal Server Error` response with code
115/// `invalid_generated_request_id` instead of panicking or calling the inner
116/// service.
117#[derive(Clone)]
118pub struct RequestIdConfig {
119    header_name: HeaderName,
120    mode: RequestIdMode,
121    generator: Arc<dyn Fn() -> String + Send + Sync>,
122}
123
124impl std::fmt::Debug for RequestIdConfig {
125    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        formatter
127            .debug_struct("RequestIdConfig")
128            .field("header_name", &self.header_name)
129            .field("mode", &self.mode)
130            .finish_non_exhaustive()
131    }
132}
133
134impl RequestIdConfig {
135    /// Creates a production request ID policy.
136    ///
137    /// Defaults:
138    /// - header: `x-request-id`
139    /// - inbound validation: [`RequestIdMode::Strict`]
140    /// - generated IDs: UUID v4 strings
141    ///
142    /// In strict mode, a present but malformed inbound ID is rejected with
143    /// `400 Bad Request`. Missing IDs are generated and accepted.
144    pub fn production() -> Self {
145        Self {
146            header_name: HeaderName::from_static("x-request-id"),
147            mode: RequestIdMode::Strict,
148            generator: Arc::new(|| Uuid::new_v4().to_string()),
149        }
150    }
151
152    /// Creates a development request ID policy.
153    ///
154    /// Development uses the same `x-request-id` header and UUID v4 generator as
155    /// production, but switches to [`RequestIdMode::Permissive`]. Valid inbound
156    /// UUID v4 IDs are propagated; malformed inbound IDs are replaced with
157    /// generated UUID v4 IDs instead of returning `400`.
158    pub fn development() -> Self {
159        Self::production().mode(RequestIdMode::Permissive)
160    }
161
162    /// Sets the request ID header name.
163    pub fn header_name(mut self, header_name: HeaderName) -> Self {
164        self.header_name = header_name;
165        self
166    }
167
168    /// Sets request ID validation behavior for present inbound IDs.
169    pub fn mode(mut self, mode: RequestIdMode) -> Self {
170        self.mode = mode;
171        self
172    }
173
174    /// Replaces the request ID generator.
175    ///
176    /// [`RequestIdConfig::production`] and [`RequestIdConfig::development`] use
177    /// UUID v4 strings. If you provide a custom generator, keep it deterministic
178    /// enough for your tests and ensure it returns values that can be stored in
179    /// an HTTP header. Invalid generated header values return a structured
180    /// framework error response before the request reaches the inner service.
181    pub fn generator(mut self, generator: impl Fn() -> String + Send + Sync + 'static) -> Self {
182        self.generator = Arc::new(generator);
183        self
184    }
185
186    /// Returns the configured request ID header name.
187    pub fn header(&self) -> &HeaderName {
188        &self.header_name
189    }
190
191    /// Returns the configured validation mode.
192    pub const fn validation_mode(&self) -> RequestIdMode {
193        self.mode
194    }
195
196    fn generate(&self) -> String {
197        (self.generator)()
198    }
199}
200
201impl Default for RequestIdConfig {
202    fn default() -> Self {
203        Self::production()
204    }
205}
206
207/// Creates a validated request ID layer.
208///
209/// The layer validates or generates a request ID before the inner service runs,
210/// inserts that value into the configured request header, stores a
211/// [`RequestContext`] in request extensions, and mirrors the same header onto
212/// the response when the inner service has not already set it.
213///
214/// ```ignore
215/// use axum::{Router, routing::get};
216/// use nidus_http::middleware::{
217///     RequestIdConfig, RequestIdMode, validated_request_id_layer,
218/// };
219///
220/// let app = Router::new()
221///     .route("/users/:id", get(handler))
222///     .layer(validated_request_id_layer(
223///         RequestIdConfig::production().mode(RequestIdMode::Strict),
224///     ));
225/// ```
226pub fn validated_request_id_layer(config: RequestIdConfig) -> ValidatedRequestIdLayer {
227    ValidatedRequestIdLayer::new(config)
228}
229
230/// Tower layer that validates, generates, stores, and propagates request IDs.
231///
232/// Valid inbound request IDs are UUID v4 strings. With
233/// [`RequestIdMode::Strict`], malformed inbound IDs receive `400 Bad Request`.
234/// With [`RequestIdMode::Permissive`], malformed inbound IDs are replaced with a
235/// generated ID. Generated IDs are UUID v4 by default.
236///
237/// On accepted requests, the final ID is inserted into the configured request
238/// header, added to request extensions through [`RequestContext`], and copied to
239/// the response header if the inner service did not set one. Use
240/// [`crate::middleware::request_context_layer`] after this layer when you want
241/// the context enriched with route and correlation fields before handlers run.
242#[derive(Clone, Debug)]
243pub struct ValidatedRequestIdLayer {
244    config: RequestIdConfig,
245}
246
247impl ValidatedRequestIdLayer {
248    /// Creates a validated request ID layer from typed config.
249    pub fn new(config: RequestIdConfig) -> Self {
250        Self { config }
251    }
252}
253
254impl<S> Layer<S> for ValidatedRequestIdLayer {
255    type Service = ValidatedRequestIdService<S>;
256
257    fn layer(&self, inner: S) -> Self::Service {
258        ValidatedRequestIdService {
259            inner,
260            config: self.config.clone(),
261        }
262    }
263}
264
265/// Service produced by [`ValidatedRequestIdLayer`].
266#[derive(Clone, Debug)]
267pub struct ValidatedRequestIdService<S> {
268    inner: S,
269    config: RequestIdConfig,
270}
271
272impl<S> Service<Request<Body>> for ValidatedRequestIdService<S>
273where
274    S: Service<Request<Body>, Response = Response<Body>> + Send + 'static,
275    S::Future: Send + 'static,
276    S::Error: Send + 'static,
277{
278    type Response = Response<Body>;
279    type Error = S::Error;
280    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
281
282    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
283        self.inner.poll_ready(cx)
284    }
285
286    fn call(&mut self, request: Request<Body>) -> Self::Future {
287        let config = self.config.clone();
288        let (mut parts, body) = request.into_parts();
289        let incoming = parts
290            .headers
291            .get(config.header())
292            .and_then(|value| value.to_str().ok())
293            .map(str::to_owned);
294        let request_id = match incoming {
295            Some(value) if is_valid_request_id(&value) => value,
296            Some(_) if config.validation_mode() == RequestIdMode::Strict => {
297                let (request_id, header_value) = match generated_request_id_header(&config) {
298                    Some(generated) => generated,
299                    None => {
300                        return Box::pin(async move {
301                            Ok(invalid_generated_request_id_response(parts.uri.path()))
302                        });
303                    }
304                };
305                let mut response = invalid_request_id_response(&request_id, parts.uri.path());
306                response
307                    .headers_mut()
308                    .insert(config.header().clone(), header_value);
309                return Box::pin(async move { Ok(response) });
310            }
311            Some(_) | None => {
312                let (request_id, header_value) = match generated_request_id_header(&config) {
313                    Some(generated) => generated,
314                    None => {
315                        return Box::pin(async move {
316                            Ok(invalid_generated_request_id_response(parts.uri.path()))
317                        });
318                    }
319                };
320                parts
321                    .headers
322                    .insert(config.header().clone(), header_value.clone());
323                let context = RequestContext::from_parts(&parts, request_id.clone());
324                parts.extensions.insert(context);
325                let future = self.inner.call(Request::from_parts(parts, body));
326
327                return Box::pin(async move {
328                    let mut response = future.await?;
329                    response
330                        .headers_mut()
331                        .entry(config.header().clone())
332                        .or_insert(header_value);
333                    Ok(response)
334                });
335            }
336        };
337
338        let header_value = HeaderValue::from_str(&request_id)
339            .unwrap_or_else(|_| unreachable!("accepted inbound request id came from a header"));
340        parts
341            .headers
342            .insert(config.header().clone(), header_value.clone());
343        let context = RequestContext::from_parts(&parts, request_id.clone());
344        parts.extensions.insert(context);
345        let future = self.inner.call(Request::from_parts(parts, body));
346
347        Box::pin(async move {
348            let mut response = future.await?;
349            response
350                .headers_mut()
351                .entry(config.header().clone())
352                .or_insert(header_value);
353            Ok(response)
354        })
355    }
356}
357
358fn generated_request_id_header(config: &RequestIdConfig) -> Option<(String, HeaderValue)> {
359    let request_id = config.generate();
360    let header_value = HeaderValue::from_str(&request_id).ok()?;
361    Some((request_id, header_value))
362}
363
364fn is_valid_request_id(value: &str) -> bool {
365    Uuid::parse_str(value)
366        .ok()
367        .and_then(|uuid| uuid.get_version())
368        == Some(Version::Random)
369}
370
371fn invalid_request_id_response(request_id: &str, path: &str) -> Response<Body> {
372    let timestamp = crate::error::timestamp_now();
373    (
374        StatusCode::BAD_REQUEST,
375        axum::Json(RequestIdErrorBody {
376            error: RequestIdErrorDetails {
377                status_code: StatusCode::BAD_REQUEST.as_u16(),
378                code: "invalid_request_id",
379                message: "invalid request id",
380                details: serde_json::Value::Null,
381                timestamp,
382                path: path.to_owned(),
383                request_id: Some(request_id.to_owned()),
384            },
385        }),
386    )
387        .into_response()
388}
389
390fn invalid_generated_request_id_response(path: &str) -> Response<Body> {
391    let timestamp = crate::error::timestamp_now();
392    (
393        StatusCode::INTERNAL_SERVER_ERROR,
394        axum::Json(RequestIdErrorBody {
395            error: RequestIdErrorDetails {
396                status_code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
397                code: "invalid_generated_request_id",
398                message: "generated request id was not a valid HTTP header value",
399                details: serde_json::Value::Null,
400                timestamp,
401                path: path.to_owned(),
402                request_id: None,
403            },
404        }),
405    )
406        .into_response()
407}
408
409#[derive(Debug, Serialize)]
410struct RequestIdErrorBody {
411    error: RequestIdErrorDetails,
412}
413
414#[derive(Debug, Serialize)]
415#[serde(rename_all = "camelCase")]
416struct RequestIdErrorDetails {
417    status_code: u16,
418    code: &'static str,
419    message: &'static str,
420    details: serde_json::Value,
421    timestamp: String,
422    path: String,
423    #[serde(skip_serializing_if = "Option::is_none")]
424    request_id: Option<String>,
425}