1use std::{
2 future::Future,
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll},
6};
7
8use axum::{body::Body, response::IntoResponse};
9use http::{HeaderValue, Request, Response, StatusCode, header::HeaderName};
10use serde::Serialize;
11use tower::{Layer, Service};
12use uuid::{Uuid, Version};
13
14use crate::context::RequestContext;
15
16#[derive(Clone, Copy, Debug, Default)]
27pub struct RequestIdLayer;
28
29impl<S> Layer<S> for RequestIdLayer {
30 type Service = RequestIdService<S>;
31
32 fn layer(&self, inner: S) -> Self::Service {
33 RequestIdService { inner }
34 }
35}
36
37#[derive(Clone, Debug)]
39pub struct RequestIdService<S> {
40 inner: S,
41}
42
43impl<S, RequestBody, ResponseBody> Service<Request<RequestBody>> for RequestIdService<S>
44where
45 S: Service<Request<RequestBody>, Response = Response<ResponseBody>> + Send + 'static,
46 S::Future: Send + 'static,
47 S::Error: Send + 'static,
48 RequestBody: Send + 'static,
49 ResponseBody: Send + 'static,
50{
51 type Response = Response<ResponseBody>;
52 type Error = S::Error;
53 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
54
55 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
56 self.inner.poll_ready(cx)
57 }
58
59 fn call(&mut self, request: Request<RequestBody>) -> Self::Future {
60 let request_id = request.headers().get(request_id_header()).cloned();
61 let future = self.inner.call(request);
62 Box::pin(async move {
63 let mut response = future.await?;
64 response
65 .headers_mut()
66 .entry(request_id_header())
67 .or_insert_with(|| request_id.unwrap_or_else(new_request_id));
68 Ok(response)
69 })
70 }
71}
72
73fn request_id_header() -> HeaderName {
74 HeaderName::from_static("x-request-id")
75}
76
77fn new_request_id() -> HeaderValue {
78 HeaderValue::from_str(&Uuid::new_v4().to_string())
79 .expect("generated request id contains only valid header characters")
80}
81
82#[derive(Clone, Copy, Debug, Eq, PartialEq)]
88pub enum RequestIdMode {
89 Permissive,
96 Strict,
102}
103
104pub type RequestIdPolicy = RequestIdMode;
106
107#[derive(Clone)]
118pub struct RequestIdConfig {
119 header_name: HeaderName,
120 mode: RequestIdMode,
121 generator: Arc<dyn Fn() -> String + Send + Sync>,
122}
123
124impl std::fmt::Debug for RequestIdConfig {
125 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 formatter
127 .debug_struct("RequestIdConfig")
128 .field("header_name", &self.header_name)
129 .field("mode", &self.mode)
130 .finish_non_exhaustive()
131 }
132}
133
134impl RequestIdConfig {
135 pub fn production() -> Self {
145 Self {
146 header_name: HeaderName::from_static("x-request-id"),
147 mode: RequestIdMode::Strict,
148 generator: Arc::new(|| Uuid::new_v4().to_string()),
149 }
150 }
151
152 pub fn development() -> Self {
159 Self::production().mode(RequestIdMode::Permissive)
160 }
161
162 pub fn header_name(mut self, header_name: HeaderName) -> Self {
164 self.header_name = header_name;
165 self
166 }
167
168 pub fn mode(mut self, mode: RequestIdMode) -> Self {
170 self.mode = mode;
171 self
172 }
173
174 pub fn generator(mut self, generator: impl Fn() -> String + Send + Sync + 'static) -> Self {
182 self.generator = Arc::new(generator);
183 self
184 }
185
186 pub fn header(&self) -> &HeaderName {
188 &self.header_name
189 }
190
191 pub const fn validation_mode(&self) -> RequestIdMode {
193 self.mode
194 }
195
196 fn generate(&self) -> String {
197 (self.generator)()
198 }
199}
200
201impl Default for RequestIdConfig {
202 fn default() -> Self {
203 Self::production()
204 }
205}
206
207pub fn validated_request_id_layer(config: RequestIdConfig) -> ValidatedRequestIdLayer {
229 ValidatedRequestIdLayer::new(config)
230}
231
232#[derive(Clone, Debug)]
245pub struct ValidatedRequestIdLayer {
246 config: RequestIdConfig,
247}
248
249impl ValidatedRequestIdLayer {
250 pub fn new(config: RequestIdConfig) -> Self {
252 Self { config }
253 }
254}
255
256impl<S> Layer<S> for ValidatedRequestIdLayer {
257 type Service = ValidatedRequestIdService<S>;
258
259 fn layer(&self, inner: S) -> Self::Service {
260 ValidatedRequestIdService {
261 inner,
262 config: self.config.clone(),
263 }
264 }
265}
266
267#[derive(Clone, Debug)]
269pub struct ValidatedRequestIdService<S> {
270 inner: S,
271 config: RequestIdConfig,
272}
273
274impl<S> Service<Request<Body>> for ValidatedRequestIdService<S>
275where
276 S: Service<Request<Body>, Response = Response<Body>> + Send + 'static,
277 S::Future: Send + 'static,
278 S::Error: Send + 'static,
279{
280 type Response = Response<Body>;
281 type Error = S::Error;
282 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
283
284 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
285 self.inner.poll_ready(cx)
286 }
287
288 fn call(&mut self, request: Request<Body>) -> Self::Future {
289 let config = self.config.clone();
290 let (mut parts, body) = request.into_parts();
291 let incoming = parts
292 .headers
293 .get(config.header())
294 .and_then(|value| value.to_str().ok())
295 .map(str::to_owned);
296 let request_id = match incoming {
297 Some(value) if is_valid_request_id(&value) => value,
298 Some(_) if config.validation_mode() == RequestIdMode::Strict => {
299 let (request_id, header_value) = match generated_request_id_header(&config) {
300 Some(generated) => generated,
301 None => {
302 return Box::pin(async move {
303 Ok(invalid_generated_request_id_response(parts.uri.path()))
304 });
305 }
306 };
307 let mut response = invalid_request_id_response(&request_id, parts.uri.path());
308 response
309 .headers_mut()
310 .insert(config.header().clone(), header_value);
311 return Box::pin(async move { Ok(response) });
312 }
313 Some(_) | None => {
314 let (request_id, header_value) = match generated_request_id_header(&config) {
315 Some(generated) => generated,
316 None => {
317 return Box::pin(async move {
318 Ok(invalid_generated_request_id_response(parts.uri.path()))
319 });
320 }
321 };
322 parts
323 .headers
324 .insert(config.header().clone(), header_value.clone());
325 let context = RequestContext::from_parts(&parts, request_id.clone());
326 parts.extensions.insert(context);
327 let future = self.inner.call(Request::from_parts(parts, body));
328
329 return Box::pin(async move {
330 let mut response = future.await?;
331 response
332 .headers_mut()
333 .entry(config.header().clone())
334 .or_insert(header_value);
335 Ok(response)
336 });
337 }
338 };
339
340 let header_value = HeaderValue::from_str(&request_id)
341 .unwrap_or_else(|_| unreachable!("accepted inbound request id came from a header"));
342 parts
343 .headers
344 .insert(config.header().clone(), header_value.clone());
345 let context = RequestContext::from_parts(&parts, request_id.clone());
346 parts.extensions.insert(context);
347 let future = self.inner.call(Request::from_parts(parts, body));
348
349 Box::pin(async move {
350 let mut response = future.await?;
351 response
352 .headers_mut()
353 .entry(config.header().clone())
354 .or_insert(header_value);
355 Ok(response)
356 })
357 }
358}
359
360fn generated_request_id_header(config: &RequestIdConfig) -> Option<(String, HeaderValue)> {
361 let request_id = config.generate();
362 let header_value = HeaderValue::from_str(&request_id).ok()?;
363 Some((request_id, header_value))
364}
365
366fn is_valid_request_id(value: &str) -> bool {
367 Uuid::parse_str(value)
368 .ok()
369 .and_then(|uuid| uuid.get_version())
370 == Some(Version::Random)
371}
372
373fn invalid_request_id_response(request_id: &str, path: &str) -> Response<Body> {
374 let timestamp = crate::error::timestamp_now();
375 (
376 StatusCode::BAD_REQUEST,
377 axum::Json(RequestIdErrorBody {
378 error: RequestIdErrorDetails {
379 status_code: StatusCode::BAD_REQUEST.as_u16(),
380 code: "invalid_request_id",
381 message: "invalid request id",
382 details: serde_json::Value::Null,
383 timestamp,
384 path: path.to_owned(),
385 request_id: Some(request_id.to_owned()),
386 },
387 }),
388 )
389 .into_response()
390}
391
392fn invalid_generated_request_id_response(path: &str) -> Response<Body> {
393 let timestamp = crate::error::timestamp_now();
394 (
395 StatusCode::INTERNAL_SERVER_ERROR,
396 axum::Json(RequestIdErrorBody {
397 error: RequestIdErrorDetails {
398 status_code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
399 code: "invalid_generated_request_id",
400 message: "generated request id was not a valid HTTP header value",
401 details: serde_json::Value::Null,
402 timestamp,
403 path: path.to_owned(),
404 request_id: None,
405 },
406 }),
407 )
408 .into_response()
409}
410
411#[derive(Debug, Serialize)]
412struct RequestIdErrorBody {
413 error: RequestIdErrorDetails,
414}
415
416#[derive(Debug, Serialize)]
417#[serde(rename_all = "camelCase")]
418struct RequestIdErrorDetails {
419 status_code: u16,
420 code: &'static str,
421 message: &'static str,
422 details: serde_json::Value,
423 timestamp: String,
424 path: String,
425 #[serde(skip_serializing_if = "Option::is_none")]
426 request_id: Option<String>,
427}