actix_web_requestid/
lib.rs

1//! actix-web-requestid.
2//!
3//! [`RequestID`] provides a "request-id" to a http request. This can be
4//! used for tracing, debuging, user error reporting.
5//!
6//! Insert the request id middleware to provide the request-id to the
7//! `request-id` http header. To access requestID data, [`RequestID`] actix
8//!  extractor must be used.
9//!
10//! It is still useable without the middleware. The first time you try to
11//! extract the id, it will be generated. Then reused along the request.
12//! You can for exemple use that in a Logging or tracing middleware.
13use std::convert::Infallible;
14use std::future::{ready, Future, Ready};
15use std::pin::Pin;
16
17use actix_web::dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform};
18use actix_web::http::header::{HeaderName, HeaderValue};
19use actix_web::{Error, FromRequest, HttpMessage, HttpRequest};
20use rand::distributions::Alphanumeric;
21use rand::Rng;
22
23pub const REQUEST_ID_HEADER: &str = "request-id";
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct RequestID {
27    inner: String,
28}
29
30impl From<RequestID> for String {
31    fn from(r: RequestID) -> Self {
32        r.inner
33    }
34}
35
36impl std::fmt::Display for RequestID {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        self.inner.fmt(f)
39    }
40}
41
42/// Extractor implementation for [`RequestID`] type.
43///
44/// ```noexec
45/// # use actix_web::*;
46/// # use actix_web_requestid::{RequestID};
47///
48/// async fn index(id: RequestID) -> String {
49///     format!("Welcome! {}", id)
50/// }
51/// ```
52impl FromRequest for RequestID {
53    type Error = Infallible;
54    type Future = Ready<Result<RequestID, Infallible>>;
55
56    #[inline]
57    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
58        ready(Ok(req.request_id()))
59    }
60}
61
62/// Request id middleware
63///
64/// ```
65/// use actix_web::*;
66/// use actix_web_requestid::{RequestIDMiddleware};
67///
68/// let app = App::new()
69///     .wrap(RequestIDMiddleware::default());
70/// ```
71pub struct RequestIDMiddleware {}
72
73impl Default for RequestIDMiddleware {
74    fn default() -> Self {
75        Self {}
76    }
77}
78
79impl<S, B> Transform<S, ServiceRequest> for RequestIDMiddleware
80where
81    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
82    S::Future: 'static,
83{
84    type Response = ServiceResponse<B>;
85    type Error = Error;
86    type InitError = ();
87    type Transform = RequestIDService<S>;
88    type Future = Ready<Result<Self::Transform, Self::InitError>>;
89
90    fn new_transform(&self, service: S) -> Self::Future {
91        ready(Ok(RequestIDService {
92            wrapped_service: service,
93        }))
94    }
95}
96
97pub struct RequestIDService<S> {
98    wrapped_service: S,
99}
100
101impl<S, Req> Service<ServiceRequest> for RequestIDService<S>
102where
103    S: Service<ServiceRequest, Response = ServiceResponse<Req>, Error = Error>,
104    S::Future: 'static,
105{
106    type Response = ServiceResponse<Req>;
107    type Error = S::Error;
108    #[allow(clippy::type_complexity)]
109    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
110
111    fn poll_ready(
112        &self,
113        ctx: &mut core::task::Context<'_>,
114    ) -> std::task::Poll<Result<(), Self::Error>> {
115        self.wrapped_service.poll_ready(ctx)
116    }
117
118    fn call(&self, req: actix_web::dev::ServiceRequest) -> Self::Future {
119        let id = req.request_id().inner;
120        let fut = self.wrapped_service.call(req);
121
122        Box::pin(async move {
123            let mut res = fut.await?;
124
125            res.headers_mut().append(
126                HeaderName::from_static(REQUEST_ID_HEADER),
127                HeaderValue::from_str(&id).unwrap(),
128            );
129
130            Ok(res)
131        })
132    }
133}
134
135pub trait RequestIDMessage {
136    fn request_id(&self) -> RequestID;
137}
138
139impl<T> RequestIDMessage for T
140where
141    T: HttpMessage,
142{
143    fn request_id(&self) -> RequestID {
144        if let Some(id) = self.extensions().get::<RequestID>() {
145            return id.clone();
146        }
147
148        let new_id = RequestID {
149            inner: rand::thread_rng()
150                .sample_iter(&Alphanumeric)
151                .map(char::from)
152                .take(10)
153                .collect::<_>(),
154        };
155
156        self.extensions_mut().insert(new_id.clone());
157
158        new_id
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use actix_web::test::TestRequest;
166    use actix_web::{http::StatusCode, test, web, App, HttpResponse};
167
168    #[actix_rt::test]
169    async fn request_id_is_consistent_for_same_request() {
170        let req = TestRequest::default().to_http_request();
171        let id_1 = RequestID::extract(&req).await.unwrap();
172        let id_2 = RequestID::extract(&req).await.unwrap();
173
174        assert_eq!(id_1, id_2);
175    }
176
177    #[actix_rt::test]
178    async fn request_id_is_new_between_different_requests() {
179        let req1 = TestRequest::default().to_http_request();
180        let req2 = TestRequest::default().to_http_request();
181
182        let req1_id = RequestID::extract(&req1).await.unwrap();
183        let req2_id = RequestID::extract(&req2).await.unwrap();
184
185        assert!(req1_id != req2_id);
186    }
187
188    #[actix_rt::test]
189    async fn middleware_adds_request_id_in_headers() {
190        let app = test::init_service(
191            App::new()
192                .wrap(RequestIDMiddleware::default())
193                .service(web::resource("/").to(|| async { HttpResponse::Ok().await })),
194        )
195        .await;
196
197        // Create request object
198        let req = test::TestRequest::with_uri("/").to_request();
199
200        // Execute application
201        let resp = test::call_service(&app, req).await;
202        assert_eq!(resp.status(), StatusCode::OK);
203
204        assert!(!resp.headers().get("request-id").unwrap().is_empty());
205    }
206}