Skip to main content

nidus_http/middleware/
request_id.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    sync::Arc,
5    task::{Context, Poll},
6};
7
8use axum::{body::Body, response::IntoResponse};
9use http::{HeaderValue, Request, Response, StatusCode, header::HeaderName};
10use serde::Serialize;
11use tower::{Layer, Service};
12use uuid::{Uuid, Version};
13
14use crate::context::RequestContext;
15
16/// Legacy Tower layer that adds an `x-request-id` response header when absent.
17///
18/// Incoming request IDs are propagated to the response unless the inner service
19/// already set a response ID. Requests without an ID receive a generated
20/// UUID v4 value.
21///
22/// This layer does not validate inbound IDs and does not populate
23/// [`RequestContext`]. Prefer [`validated_request_id_layer`] for production API
24/// defaults, UUID v4 generation, strict/permissive validation, request
25/// extension insertion, and consistent error responses.
26#[derive(Clone, Copy, Debug, Default)]
27pub struct RequestIdLayer;
28
29impl<S> Layer<S> for RequestIdLayer {
30    type Service = RequestIdService<S>;
31
32    fn layer(&self, inner: S) -> Self::Service {
33        RequestIdService { inner }
34    }
35}
36
37/// Service produced by [`RequestIdLayer`].
38#[derive(Clone, Debug)]
39pub struct RequestIdService<S> {
40    inner: S,
41}
42
43impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for RequestIdService<S>
44where
45    S: Service<Request<RequestBody>, Response = Response<ResponseBody>> + Send + 'static,
46    S::Future: Send + 'static,
47    S::Error: Send + 'static,
48    RequestBody: Send + 'static,
49    ResponseBody: Send + 'static,
50{
51    type Response = Response<ResponseBody>;
52    type Error = S::Error;
53    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
54
55    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56        self.inner.poll_ready(cx)
57    }
58
59    fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
60        let request_id = request.headers().get(request_id_header()).cloned();
61        let future = self.inner.call(request);
62        Box::pin(async move {
63            let mut response = future.await?;
64            response
65                .headers_mut()
66                .entry(request_id_header())
67                .or_insert_with(|| request_id.unwrap_or_else(new_request_id));
68            Ok(response)
69        })
70    }
71}
72
73fn request_id_header() -> HeaderName {
74    HeaderName::from_static("x-request-id")
75}
76
77fn new_request_id() -> HeaderValue {
78    HeaderValue::from_str(&Uuid::new_v4().to_string())
79        .expect("generated request id contains only valid header characters")
80}
81
82/// Request ID validation behavior for inbound `x-request-id` values.
83///
84/// Valid inbound IDs must parse as UUID v4 values. Invalid header syntax,
85/// non-UUID strings, and UUIDs from other versions are treated as malformed.
86/// Missing IDs are never rejected; the configured generator is used instead.
87#[derive(Clone, Copy, Debug, Eq, PartialEq)]
88pub enum RequestIdMode {
89    /// Propagate valid UUID v4 IDs and replace malformed incoming IDs.
90    ///
91    /// This mode is useful in development and at integration boundaries where
92    /// clients may still send legacy IDs. A malformed inbound value is not
93    /// exposed to handlers; it is replaced with a generated ID before the
94    /// request reaches the inner service.
95    Permissive,
96    /// Propagate valid UUID v4 IDs and reject malformed incoming IDs.
97    ///
98    /// A malformed inbound value returns `400 Bad Request` with an
99    /// `invalid_request_id` JSON error body. The rejection response still
100    /// receives a generated request ID in the configured response header.
101    Strict,
102}
103
104/// Compatibility alias for naming request ID validation policy.
105pub type RequestIdPolicy = RequestIdMode;
106
107/// Typed configuration for validated request ID propagation.
108///
109/// The default production config uses the `x-request-id` header, strict inbound
110/// validation, and UUID v4 generation. Custom generators are accepted, but their
111/// output must be a valid HTTP header value because generated IDs are inserted
112/// into request headers, request extensions, and response headers. If a custom
113/// generator returns an invalid header value, the middleware returns a stable
114/// `500 Internal Server Error` response with code
115/// `invalid_generated_request_id` instead of panicking or calling the inner
116/// service.
117#[derive(Clone)]
118pub struct RequestIdConfig {
119    header_name: HeaderName,
120    mode: RequestIdMode,
121    generator: Arc<dyn Fn() -> String + Send + Sync>,
122}
123
124impl std::fmt::Debug for RequestIdConfig {
125    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        formatter
127            .debug_struct("RequestIdConfig")
128            .field("header_name", &self.header_name)
129            .field("mode", &self.mode)
130            .finish_non_exhaustive()
131    }
132}
133
134impl RequestIdConfig {
135    /// Creates a production request ID policy.
136    ///
137    /// Defaults:
138    /// - header: `x-request-id`
139    /// - inbound validation: [`RequestIdMode::Strict`]
140    /// - generated IDs: UUID v4 strings
141    ///
142    /// In strict mode, a present but malformed inbound ID is rejected with
143    /// `400 Bad Request`. Missing IDs are generated and accepted.
144    pub fn production() -> Self {
145        Self {
146            header_name: HeaderName::from_static("x-request-id"),
147            mode: RequestIdMode::Strict,
148            generator: Arc::new(|| Uuid::new_v4().to_string()),
149        }
150    }
151
152    /// Creates a development request ID policy.
153    ///
154    /// Development uses the same `x-request-id` header and UUID v4 generator as
155    /// production, but switches to [`RequestIdMode::Permissive`]. Valid inbound
156    /// UUID v4 IDs are propagated; malformed inbound IDs are replaced with
157    /// generated UUID v4 IDs instead of returning `400`.
158    pub fn development() -> Self {
159        Self::production().mode(RequestIdMode::Permissive)
160    }
161
162    /// Sets the request ID header name.
163    pub fn header_name(mut self, header_name: HeaderName) -> Self {
164        self.header_name = header_name;
165        self
166    }
167
168    /// Sets request ID validation behavior for present inbound IDs.
169    pub fn mode(mut self, mode: RequestIdMode) -> Self {
170        self.mode = mode;
171        self
172    }
173
174    /// Replaces the request ID generator.
175    ///
176    /// [`RequestIdConfig::production`] and [`RequestIdConfig::development`] use
177    /// UUID v4 strings. If you provide a custom generator, keep it deterministic
178    /// enough for your tests and ensure it returns values that can be stored in
179    /// an HTTP header. Invalid generated header values return a structured
180    /// framework error response before the request reaches the inner service.
181    pub fn generator(mut self, generator: impl Fn() -> String + Send + Sync + 'static) -> Self {
182        self.generator = Arc::new(generator);
183        self
184    }
185
186    /// Returns the configured request ID header name.
187    pub fn header(&self) -> &HeaderName {
188        &self.header_name
189    }
190
191    /// Returns the configured validation mode.
192    pub const fn validation_mode(&self) -> RequestIdMode {
193        self.mode
194    }
195
196    fn generate(&self) -> String {
197        (self.generator)()
198    }
199}
200
201impl Default for RequestIdConfig {
202    fn default() -> Self {
203        Self::production()
204    }
205}
206
207/// Creates a validated request ID layer.
208///
209/// The layer validates or generates a request ID before the inner service runs,
210/// inserts that value into the configured request header, stores a
211/// [`RequestContext`] in request extensions, and mirrors the same header onto
212/// the response when the inner service has not already set it.
213///
214/// ```
215/// use axum::{Router, routing::get};
216/// use nidus_http::middleware::{
217///     RequestIdConfig, RequestIdMode, validated_request_id_layer,
218/// };
219/// # async fn handler() -> &'static str { "user" }
220///
221/// let app = Router::new()
222///     .route("/users/{id}", get(handler))
223///     .layer(validated_request_id_layer(
224///         RequestIdConfig::production().mode(RequestIdMode::Strict),
225///     ));
226/// # let _: Router = app;
227/// ```
228pub fn validated_request_id_layer(config: RequestIdConfig) -> ValidatedRequestIdLayer {
229    ValidatedRequestIdLayer::new(config)
230}
231
232/// Tower layer that validates, generates, stores, and propagates request IDs.
233///
234/// Valid inbound request IDs are UUID v4 strings. With
235/// [`RequestIdMode::Strict`], malformed inbound IDs receive `400 Bad Request`.
236/// With [`RequestIdMode::Permissive`], malformed inbound IDs are replaced with a
237/// generated ID. Generated IDs are UUID v4 by default.
238///
239/// On accepted requests, the final ID is inserted into the configured request
240/// header, added to request extensions through [`RequestContext`], and copied to
241/// the response header if the inner service did not set one. Use
242/// [`crate::middleware::request_context_layer`] after this layer when you want
243/// the context enriched with route and correlation fields before handlers run.
244#[derive(Clone, Debug)]
245pub struct ValidatedRequestIdLayer {
246    config: RequestIdConfig,
247}
248
249impl ValidatedRequestIdLayer {
250    /// Creates a validated request ID layer from typed config.
251    pub fn new(config: RequestIdConfig) -> Self {
252        Self { config }
253    }
254}
255
256impl<S> Layer<S> for ValidatedRequestIdLayer {
257    type Service = ValidatedRequestIdService<S>;
258
259    fn layer(&self, inner: S) -> Self::Service {
260        ValidatedRequestIdService {
261            inner,
262            config: self.config.clone(),
263        }
264    }
265}
266
267/// Service produced by [`ValidatedRequestIdLayer`].
268#[derive(Clone, Debug)]
269pub struct ValidatedRequestIdService<S> {
270    inner: S,
271    config: RequestIdConfig,
272}
273
274impl<S> Service<Request<Body>> for ValidatedRequestIdService<S>
275where
276    S: Service<Request<Body>, Response = Response<Body>> + Send + 'static,
277    S::Future: Send + 'static,
278    S::Error: Send + 'static,
279{
280    type Response = Response<Body>;
281    type Error = S::Error;
282    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
283
284    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
285        self.inner.poll_ready(cx)
286    }
287
288    fn call(&mut self, request: Request<Body>) -> Self::Future {
289        let config = self.config.clone();
290        let (mut parts, body) = request.into_parts();
291        let incoming = parts
292            .headers
293            .get(config.header())
294            .and_then(|value| value.to_str().ok())
295            .map(str::to_owned);
296        let request_id = match incoming {
297            Some(value) if is_valid_request_id(&value) => value,
298            Some(_) if config.validation_mode() == RequestIdMode::Strict => {
299                let (request_id, header_value) = match generated_request_id_header(&config) {
300                    Some(generated) => generated,
301                    None => {
302                        return Box::pin(async move {
303                            Ok(invalid_generated_request_id_response(parts.uri.path()))
304                        });
305                    }
306                };
307                let mut response = invalid_request_id_response(&request_id, parts.uri.path());
308                response
309                    .headers_mut()
310                    .insert(config.header().clone(), header_value);
311                return Box::pin(async move { Ok(response) });
312            }
313            Some(_) | None => {
314                let (request_id, header_value) = match generated_request_id_header(&config) {
315                    Some(generated) => generated,
316                    None => {
317                        return Box::pin(async move {
318                            Ok(invalid_generated_request_id_response(parts.uri.path()))
319                        });
320                    }
321                };
322                parts
323                    .headers
324                    .insert(config.header().clone(), header_value.clone());
325                let context = RequestContext::from_parts(&parts, request_id.clone());
326                parts.extensions.insert(context);
327                let future = self.inner.call(Request::from_parts(parts, body));
328
329                return Box::pin(async move {
330                    let mut response = future.await?;
331                    response
332                        .headers_mut()
333                        .entry(config.header().clone())
334                        .or_insert(header_value);
335                    Ok(response)
336                });
337            }
338        };
339
340        let header_value = HeaderValue::from_str(&request_id)
341            .unwrap_or_else(|_| unreachable!("accepted inbound request id came from a header"));
342        parts
343            .headers
344            .insert(config.header().clone(), header_value.clone());
345        let context = RequestContext::from_parts(&parts, request_id.clone());
346        parts.extensions.insert(context);
347        let future = self.inner.call(Request::from_parts(parts, body));
348
349        Box::pin(async move {
350            let mut response = future.await?;
351            response
352                .headers_mut()
353                .entry(config.header().clone())
354                .or_insert(header_value);
355            Ok(response)
356        })
357    }
358}
359
360fn generated_request_id_header(config: &RequestIdConfig) -> Option<(String, HeaderValue)> {
361    let request_id = config.generate();
362    let header_value = HeaderValue::from_str(&request_id).ok()?;
363    Some((request_id, header_value))
364}
365
366fn is_valid_request_id(value: &str) -> bool {
367    Uuid::parse_str(value)
368        .ok()
369        .and_then(|uuid| uuid.get_version())
370        == Some(Version::Random)
371}
372
373fn invalid_request_id_response(request_id: &str, path: &str) -> Response<Body> {
374    let timestamp = crate::error::timestamp_now();
375    (
376        StatusCode::BAD_REQUEST,
377        axum::Json(RequestIdErrorBody {
378            error: RequestIdErrorDetails {
379                status_code: StatusCode::BAD_REQUEST.as_u16(),
380                code: "invalid_request_id",
381                message: "invalid request id",
382                details: serde_json::Value::Null,
383                timestamp,
384                path: path.to_owned(),
385                request_id: Some(request_id.to_owned()),
386            },
387        }),
388    )
389        .into_response()
390}
391
392fn invalid_generated_request_id_response(path: &str) -> Response<Body> {
393    let timestamp = crate::error::timestamp_now();
394    (
395        StatusCode::INTERNAL_SERVER_ERROR,
396        axum::Json(RequestIdErrorBody {
397            error: RequestIdErrorDetails {
398                status_code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
399                code: "invalid_generated_request_id",
400                message: "generated request id was not a valid HTTP header value",
401                details: serde_json::Value::Null,
402                timestamp,
403                path: path.to_owned(),
404                request_id: None,
405            },
406        }),
407    )
408        .into_response()
409}
410
411#[derive(Debug, Serialize)]
412struct RequestIdErrorBody {
413    error: RequestIdErrorDetails,
414}
415
416#[derive(Debug, Serialize)]
417#[serde(rename_all = "camelCase")]
418struct RequestIdErrorDetails {
419    status_code: u16,
420    code: &'static str,
421    message: &'static str,
422    details: serde_json::Value,
423    timestamp: String,
424    path: String,
425    #[serde(skip_serializing_if = "Option::is_none")]
426    request_id: Option<String>,
427}