actix_web_reqid/
lib.rs

1use actix_web::error::ErrorBadRequest;
2use actix_web::{
3    Error, FromRequest, HttpMessage, HttpRequest,
4    dev::{self, Service, Transform},
5};
6use actix_web::{dev::ServiceRequest, dev::ServiceResponse};
7use futures_util::future::LocalBoxFuture;
8use std::future::{Ready, ready};
9use std::task::{Context, Poll};
10use uuid::Uuid;
11
12/// Request ID middleware factory.
13pub struct RequestIDWrapper;
14
15impl<S, B> Transform<S, ServiceRequest> for RequestIDWrapper
16where
17    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
18    B: 'static,
19{
20    type Response = ServiceResponse<B>;
21    type Error = Error;
22    type Transform = RequestIDMiddleware<S>;
23    type InitError = ();
24    type Future = Ready<Result<Self::Transform, Self::InitError>>;
25
26    fn new_transform(&self, service: S) -> Self::Future {
27        ready(Ok(RequestIDMiddleware { service }))
28    }
29}
30
31/// Actual actix-web middleware
32pub struct RequestIDMiddleware<S> {
33    service: S,
34}
35
36impl<S, B> Service<ServiceRequest> for RequestIDMiddleware<S>
37where
38    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
39    B: 'static,
40{
41    type Response = ServiceResponse<B>;
42    type Error = Error;
43    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
44
45    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
46        self.service.poll_ready(cx)
47    }
48
49    fn call(&self, req: ServiceRequest) -> Self::Future {
50        let request_id = Uuid::new_v4();
51
52        req.extensions_mut().insert(RequestID(request_id));
53
54        let fut = self.service.call(req);
55        Box::pin(async move { fut.await })
56    }
57}
58
59/// Request ID extractor
60#[derive(Clone)]
61pub struct RequestID(pub Uuid);
62
63impl FromRequest for RequestID {
64    type Error = Error;
65    type Future = Ready<Result<Self, Self::Error>>;
66
67    fn from_request(req: &HttpRequest, _payload: &mut dev::Payload) -> Self::Future {
68        if let Some(req_id) = req.extensions().get::<RequestID>() {
69            ready(Ok(req_id.clone()))
70        } else {
71            ready(Err(ErrorBadRequest("request id is missing")))
72        }
73    }
74}
75
76impl std::fmt::Display for RequestID {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        write!(f, "{}", self.0.to_string())
79    }
80}
81
82impl std::fmt::Debug for RequestID {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        write!(f, "{}", self.0.to_string())
85    }
86}