1use std::{
2 future::Future,
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll},
6};
7
8use axum::{body::Body, response::IntoResponse};
9use http::{HeaderValue, Request, Response, StatusCode, header::HeaderName};
10use serde::Serialize;
11use tower::{Layer, Service};
12use uuid::{Uuid, Version};
13
14use crate::context::RequestContext;
15
16#[derive(Clone, Copy, Debug, Default)]
27pub struct RequestIdLayer;
28
29impl<S> Layer<S> for RequestIdLayer {
30 type Service = RequestIdService<S>;
31
32 fn layer(&self, inner: S) -> Self::Service {
33 RequestIdService { inner }
34 }
35}
36
37#[derive(Clone, Debug)]
39pub struct RequestIdService<S> {
40 inner: S,
41}
42
43impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for RequestIdService<S>
44where
45 S: Service<Request<RequestBody>, Response = Response<ResponseBody>> + Send + 'static,
46 S::Future: Send + 'static,
47 S::Error: Send + 'static,
48 RequestBody: Send + 'static,
49 ResponseBody: Send + 'static,
50{
51 type Response = Response<ResponseBody>;
52 type Error = S::Error;
53 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
54
55 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56 self.inner.poll_ready(cx)
57 }
58
59 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
60 let request_id = request.headers().get(request_id_header()).cloned();
61 let future = self.inner.call(request);
62 Box::pin(async move {
63 let mut response = future.await?;
64 response
65 .headers_mut()
66 .entry(request_id_header())
67 .or_insert_with(|| request_id.unwrap_or_else(new_request_id));
68 Ok(response)
69 })
70 }
71}
72
73fn request_id_header() -> HeaderName {
74 HeaderName::from_static("x-request-id")
75}
76
77fn new_request_id() -> HeaderValue {
78 HeaderValue::from_str(&Uuid::new_v4().to_string())
79 .expect("generated request id contains only valid header characters")
80}
81
82#[derive(Clone, Copy, Debug, Eq, PartialEq)]
88pub enum RequestIdMode {
89 Permissive,
96 Strict,
102}
103
104pub type RequestIdPolicy = RequestIdMode;
106
107#[derive(Clone)]
118pub struct RequestIdConfig {
119 header_name: HeaderName,
120 mode: RequestIdMode,
121 generator: Arc<dyn Fn() -> String + Send + Sync>,
122}
123
124impl std::fmt::Debug for RequestIdConfig {
125 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 formatter
127 .debug_struct("RequestIdConfig")
128 .field("header_name", &self.header_name)
129 .field("mode", &self.mode)
130 .finish_non_exhaustive()
131 }
132}
133
134impl RequestIdConfig {
135 pub fn production() -> Self {
145 Self {
146 header_name: HeaderName::from_static("x-request-id"),
147 mode: RequestIdMode::Strict,
148 generator: Arc::new(|| Uuid::new_v4().to_string()),
149 }
150 }
151
152 pub fn development() -> Self {
159 Self::production().mode(RequestIdMode::Permissive)
160 }
161
162 pub fn header_name(mut self, header_name: HeaderName) -> Self {
164 self.header_name = header_name;
165 self
166 }
167
168 pub fn mode(mut self, mode: RequestIdMode) -> Self {
170 self.mode = mode;
171 self
172 }
173
174 pub fn generator(mut self, generator: impl Fn() -> String + Send + Sync + 'static) -> Self {
182 self.generator = Arc::new(generator);
183 self
184 }
185
186 pub fn header(&self) -> &HeaderName {
188 &self.header_name
189 }
190
191 pub const fn validation_mode(&self) -> RequestIdMode {
193 self.mode
194 }
195
196 fn generate(&self) -> String {
197 (self.generator)()
198 }
199}
200
201impl Default for RequestIdConfig {
202 fn default() -> Self {
203 Self::production()
204 }
205}
206
207pub fn validated_request_id_layer(config: RequestIdConfig) -> ValidatedRequestIdLayer {
227 ValidatedRequestIdLayer::new(config)
228}
229
230#[derive(Clone, Debug)]
243pub struct ValidatedRequestIdLayer {
244 config: RequestIdConfig,
245}
246
247impl ValidatedRequestIdLayer {
248 pub fn new(config: RequestIdConfig) -> Self {
250 Self { config }
251 }
252}
253
254impl<S> Layer<S> for ValidatedRequestIdLayer {
255 type Service = ValidatedRequestIdService<S>;
256
257 fn layer(&self, inner: S) -> Self::Service {
258 ValidatedRequestIdService {
259 inner,
260 config: self.config.clone(),
261 }
262 }
263}
264
265#[derive(Clone, Debug)]
267pub struct ValidatedRequestIdService<S> {
268 inner: S,
269 config: RequestIdConfig,
270}
271
272impl<S> Service<Request<Body>> for ValidatedRequestIdService<S>
273where
274 S: Service<Request<Body>, Response = Response<Body>> + Send + 'static,
275 S::Future: Send + 'static,
276 S::Error: Send + 'static,
277{
278 type Response = Response<Body>;
279 type Error = S::Error;
280 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
281
282 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
283 self.inner.poll_ready(cx)
284 }
285
286 fn call(&mut self, request: Request<Body>) -> Self::Future {
287 let config = self.config.clone();
288 let (mut parts, body) = request.into_parts();
289 let incoming = parts
290 .headers
291 .get(config.header())
292 .and_then(|value| value.to_str().ok())
293 .map(str::to_owned);
294 let request_id = match incoming {
295 Some(value) if is_valid_request_id(&value) => value,
296 Some(_) if config.validation_mode() == RequestIdMode::Strict => {
297 let (request_id, header_value) = match generated_request_id_header(&config) {
298 Some(generated) => generated,
299 None => {
300 return Box::pin(async move {
301 Ok(invalid_generated_request_id_response(parts.uri.path()))
302 });
303 }
304 };
305 let mut response = invalid_request_id_response(&request_id, parts.uri.path());
306 response
307 .headers_mut()
308 .insert(config.header().clone(), header_value);
309 return Box::pin(async move { Ok(response) });
310 }
311 Some(_) | None => {
312 let (request_id, header_value) = match generated_request_id_header(&config) {
313 Some(generated) => generated,
314 None => {
315 return Box::pin(async move {
316 Ok(invalid_generated_request_id_response(parts.uri.path()))
317 });
318 }
319 };
320 parts
321 .headers
322 .insert(config.header().clone(), header_value.clone());
323 let context = RequestContext::from_parts(&parts, request_id.clone());
324 parts.extensions.insert(context);
325 let future = self.inner.call(Request::from_parts(parts, body));
326
327 return Box::pin(async move {
328 let mut response = future.await?;
329 response
330 .headers_mut()
331 .entry(config.header().clone())
332 .or_insert(header_value);
333 Ok(response)
334 });
335 }
336 };
337
338 let header_value = HeaderValue::from_str(&request_id)
339 .unwrap_or_else(|_| unreachable!("accepted inbound request id came from a header"));
340 parts
341 .headers
342 .insert(config.header().clone(), header_value.clone());
343 let context = RequestContext::from_parts(&parts, request_id.clone());
344 parts.extensions.insert(context);
345 let future = self.inner.call(Request::from_parts(parts, body));
346
347 Box::pin(async move {
348 let mut response = future.await?;
349 response
350 .headers_mut()
351 .entry(config.header().clone())
352 .or_insert(header_value);
353 Ok(response)
354 })
355 }
356}
357
358fn generated_request_id_header(config: &RequestIdConfig) -> Option<(String, HeaderValue)> {
359 let request_id = config.generate();
360 let header_value = HeaderValue::from_str(&request_id).ok()?;
361 Some((request_id, header_value))
362}
363
364fn is_valid_request_id(value: &str) -> bool {
365 Uuid::parse_str(value)
366 .ok()
367 .and_then(|uuid| uuid.get_version())
368 == Some(Version::Random)
369}
370
371fn invalid_request_id_response(request_id: &str, path: &str) -> Response<Body> {
372 let timestamp = crate::error::timestamp_now();
373 (
374 StatusCode::BAD_REQUEST,
375 axum::Json(RequestIdErrorBody {
376 error: RequestIdErrorDetails {
377 status_code: StatusCode::BAD_REQUEST.as_u16(),
378 code: "invalid_request_id",
379 message: "invalid request id",
380 details: serde_json::Value::Null,
381 timestamp,
382 path: path.to_owned(),
383 request_id: Some(request_id.to_owned()),
384 },
385 }),
386 )
387 .into_response()
388}
389
390fn invalid_generated_request_id_response(path: &str) -> Response<Body> {
391 let timestamp = crate::error::timestamp_now();
392 (
393 StatusCode::INTERNAL_SERVER_ERROR,
394 axum::Json(RequestIdErrorBody {
395 error: RequestIdErrorDetails {
396 status_code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
397 code: "invalid_generated_request_id",
398 message: "generated request id was not a valid HTTP header value",
399 details: serde_json::Value::Null,
400 timestamp,
401 path: path.to_owned(),
402 request_id: None,
403 },
404 }),
405 )
406 .into_response()
407}
408
409#[derive(Debug, Serialize)]
410struct RequestIdErrorBody {
411 error: RequestIdErrorDetails,
412}
413
414#[derive(Debug, Serialize)]
415#[serde(rename_all = "camelCase")]
416struct RequestIdErrorDetails {
417 status_code: u16,
418 code: &'static str,
419 message: &'static str,
420 details: serde_json::Value,
421 timestamp: String,
422 path: String,
423 #[serde(skip_serializing_if = "Option::is_none")]
424 request_id: Option<String>,
425}