actix_web_requestid/
lib.rs1use std::convert::Infallible;
14use std::future::{ready, Future, Ready};
15use std::pin::Pin;
16
17use actix_web::dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform};
18use actix_web::http::header::{HeaderName, HeaderValue};
19use actix_web::{Error, FromRequest, HttpMessage, HttpRequest};
20use rand::distributions::Alphanumeric;
21use rand::Rng;
22
23pub const REQUEST_ID_HEADER: &str = "request-id";
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct RequestID {
27 inner: String,
28}
29
30impl From<RequestID> for String {
31 fn from(r: RequestID) -> Self {
32 r.inner
33 }
34}
35
36impl std::fmt::Display for RequestID {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 self.inner.fmt(f)
39 }
40}
41
42impl FromRequest for RequestID {
53 type Error = Infallible;
54 type Future = Ready<Result<RequestID, Infallible>>;
55
56 #[inline]
57 fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
58 ready(Ok(req.request_id()))
59 }
60}
61
62pub struct RequestIDMiddleware {}
72
73impl Default for RequestIDMiddleware {
74 fn default() -> Self {
75 Self {}
76 }
77}
78
79impl<S, B> Transform<S, ServiceRequest> for RequestIDMiddleware
80where
81 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
82 S::Future: 'static,
83{
84 type Response = ServiceResponse<B>;
85 type Error = Error;
86 type InitError = ();
87 type Transform = RequestIDService<S>;
88 type Future = Ready<Result<Self::Transform, Self::InitError>>;
89
90 fn new_transform(&self, service: S) -> Self::Future {
91 ready(Ok(RequestIDService {
92 wrapped_service: service,
93 }))
94 }
95}
96
97pub struct RequestIDService<S> {
98 wrapped_service: S,
99}
100
101impl<S, Req> Service<ServiceRequest> for RequestIDService<S>
102where
103 S: Service<ServiceRequest, Response = ServiceResponse<Req>, Error = Error>,
104 S::Future: 'static,
105{
106 type Response = ServiceResponse<Req>;
107 type Error = S::Error;
108 #[allow(clippy::type_complexity)]
109 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
110
111 fn poll_ready(
112 &self,
113 ctx: &mut core::task::Context<'_>,
114 ) -> std::task::Poll<Result<(), Self::Error>> {
115 self.wrapped_service.poll_ready(ctx)
116 }
117
118 fn call(&self, req: actix_web::dev::ServiceRequest) -> Self::Future {
119 let id = req.request_id().inner;
120 let fut = self.wrapped_service.call(req);
121
122 Box::pin(async move {
123 let mut res = fut.await?;
124
125 res.headers_mut().append(
126 HeaderName::from_static(REQUEST_ID_HEADER),
127 HeaderValue::from_str(&id).unwrap(),
128 );
129
130 Ok(res)
131 })
132 }
133}
134
135pub trait RequestIDMessage {
136 fn request_id(&self) -> RequestID;
137}
138
139impl<T> RequestIDMessage for T
140where
141 T: HttpMessage,
142{
143 fn request_id(&self) -> RequestID {
144 if let Some(id) = self.extensions().get::<RequestID>() {
145 return id.clone();
146 }
147
148 let new_id = RequestID {
149 inner: rand::thread_rng()
150 .sample_iter(&Alphanumeric)
151 .map(char::from)
152 .take(10)
153 .collect::<_>(),
154 };
155
156 self.extensions_mut().insert(new_id.clone());
157
158 new_id
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use actix_web::test::TestRequest;
166 use actix_web::{http::StatusCode, test, web, App, HttpResponse};
167
168 #[actix_rt::test]
169 async fn request_id_is_consistent_for_same_request() {
170 let req = TestRequest::default().to_http_request();
171 let id_1 = RequestID::extract(&req).await.unwrap();
172 let id_2 = RequestID::extract(&req).await.unwrap();
173
174 assert_eq!(id_1, id_2);
175 }
176
177 #[actix_rt::test]
178 async fn request_id_is_new_between_different_requests() {
179 let req1 = TestRequest::default().to_http_request();
180 let req2 = TestRequest::default().to_http_request();
181
182 let req1_id = RequestID::extract(&req1).await.unwrap();
183 let req2_id = RequestID::extract(&req2).await.unwrap();
184
185 assert!(req1_id != req2_id);
186 }
187
188 #[actix_rt::test]
189 async fn middleware_adds_request_id_in_headers() {
190 let app = test::init_service(
191 App::new()
192 .wrap(RequestIDMiddleware::default())
193 .service(web::resource("/").to(|| async { HttpResponse::Ok().await })),
194 )
195 .await;
196
197 let req = test::TestRequest::with_uri("/").to_request();
199
200 let resp = test::call_service(&app, req).await;
202 assert_eq!(resp.status(), StatusCode::OK);
203
204 assert!(!resp.headers().get("request-id").unwrap().is_empty());
205 }
206}