1use std::fmt;
61
62use crate::{
63 Request, Response,
64 header::{HeaderName, HeaderValue},
65};
66use nanoid::nanoid;
67use rama_core::{Context, Layer, Service};
68use rama_utils::macros::define_inner_service_accessors;
69use uuid::Uuid;
70
71pub(crate) const REQUEST_ID: HeaderName = HeaderName::from_static("request-id");
73
74pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
75
76pub trait MakeRequestId: Send + Sync + 'static {
80 fn make_request_id<B>(&self, request: &Request<B>) -> Option<RequestId>;
82}
83
84#[derive(Debug, Clone)]
86pub struct RequestId(HeaderValue);
87
88impl RequestId {
89 pub const fn new(header_value: HeaderValue) -> Self {
91 Self(header_value)
92 }
93
94 pub fn header_value(&self) -> &HeaderValue {
96 &self.0
97 }
98
99 pub fn into_header_value(self) -> HeaderValue {
101 self.0
102 }
103}
104
105impl From<HeaderValue> for RequestId {
106 fn from(value: HeaderValue) -> Self {
107 Self::new(value)
108 }
109}
110
111pub struct SetRequestIdLayer<M> {
117 header_name: HeaderName,
118 make_request_id: M,
119}
120
121impl<M: fmt::Debug> fmt::Debug for SetRequestIdLayer<M> {
122 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
123 f.debug_struct("SetRequestIdLayer")
124 .field("header_name", &self.header_name)
125 .field("make_request_id", &self.make_request_id)
126 .finish()
127 }
128}
129
130impl<M: Clone> Clone for SetRequestIdLayer<M> {
131 fn clone(&self) -> Self {
132 Self {
133 header_name: self.header_name.clone(),
134 make_request_id: self.make_request_id.clone(),
135 }
136 }
137}
138
139impl<M> SetRequestIdLayer<M> {
140 pub const fn new(header_name: HeaderName, make_request_id: M) -> Self
142 where
143 M: MakeRequestId,
144 {
145 SetRequestIdLayer {
146 header_name,
147 make_request_id,
148 }
149 }
150
151 pub const fn request_id(make_request_id: M) -> Self
153 where
154 M: MakeRequestId,
155 {
156 SetRequestIdLayer::new(REQUEST_ID, make_request_id)
157 }
158
159 pub const fn x_request_id(make_request_id: M) -> Self
161 where
162 M: MakeRequestId,
163 {
164 SetRequestIdLayer::new(X_REQUEST_ID, make_request_id)
165 }
166}
167
168impl<S, M> Layer<S> for SetRequestIdLayer<M>
169where
170 M: Clone + MakeRequestId,
171{
172 type Service = SetRequestId<S, M>;
173
174 fn layer(&self, inner: S) -> Self::Service {
175 SetRequestId::new(
176 inner,
177 self.header_name.clone(),
178 self.make_request_id.clone(),
179 )
180 }
181
182 fn into_layer(self, inner: S) -> Self::Service {
183 SetRequestId::new(inner, self.header_name, self.make_request_id)
184 }
185}
186
187pub struct SetRequestId<S, M> {
197 inner: S,
198 header_name: HeaderName,
199 make_request_id: M,
200}
201
202impl<S: fmt::Debug, M: fmt::Debug> fmt::Debug for SetRequestId<S, M> {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 f.debug_struct("SetRequestId")
205 .field("inner", &self.inner)
206 .field("header_name", &self.header_name)
207 .field("make_request_id", &self.make_request_id)
208 .finish()
209 }
210}
211
212impl<S: Clone, M: Clone> Clone for SetRequestId<S, M> {
213 fn clone(&self) -> Self {
214 SetRequestId {
215 inner: self.inner.clone(),
216 header_name: self.header_name.clone(),
217 make_request_id: self.make_request_id.clone(),
218 }
219 }
220}
221
222impl<S, M> SetRequestId<S, M> {
223 pub const fn new(inner: S, header_name: HeaderName, make_request_id: M) -> Self
225 where
226 M: MakeRequestId,
227 {
228 Self {
229 inner,
230 header_name,
231 make_request_id,
232 }
233 }
234
235 pub const fn request_id(inner: S, make_request_id: M) -> Self
237 where
238 M: MakeRequestId,
239 {
240 Self::new(inner, REQUEST_ID, make_request_id)
241 }
242
243 pub const fn x_request_id(inner: S, make_request_id: M) -> Self
245 where
246 M: MakeRequestId,
247 {
248 Self::new(inner, X_REQUEST_ID, make_request_id)
249 }
250
251 define_inner_service_accessors!();
252}
253
254impl<State, S, M, ReqBody, ResBody> Service<State, Request<ReqBody>> for SetRequestId<S, M>
255where
256 State: Clone + Send + Sync + 'static,
257 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
258 M: MakeRequestId,
259 ReqBody: Send + 'static,
260 ResBody: Send + 'static,
261{
262 type Response = S::Response;
263 type Error = S::Error;
264
265 async fn serve(
266 &self,
267 ctx: Context<State>,
268 mut req: Request<ReqBody>,
269 ) -> Result<Self::Response, Self::Error> {
270 if let Some(request_id) = req.headers().get(&self.header_name) {
271 if req.extensions().get::<RequestId>().is_none() {
272 let request_id = request_id.clone();
273 req.extensions_mut().insert(RequestId::new(request_id));
274 }
275 } else if let Some(request_id) = self.make_request_id.make_request_id(&req) {
276 req.extensions_mut().insert(request_id.clone());
277 req.headers_mut()
278 .insert(self.header_name.clone(), request_id.0);
279 }
280
281 self.inner.serve(ctx, req).await
282 }
283}
284
285#[derive(Debug, Clone)]
291pub struct PropagateRequestIdLayer {
292 header_name: HeaderName,
293}
294
295impl PropagateRequestIdLayer {
296 pub const fn new(header_name: HeaderName) -> Self {
298 PropagateRequestIdLayer { header_name }
299 }
300
301 pub const fn request_id() -> Self {
303 Self::new(REQUEST_ID)
304 }
305
306 pub const fn x_request_id() -> Self {
308 Self::new(X_REQUEST_ID)
309 }
310}
311
312impl<S> Layer<S> for PropagateRequestIdLayer {
313 type Service = PropagateRequestId<S>;
314
315 fn layer(&self, inner: S) -> Self::Service {
316 PropagateRequestId::new(inner, self.header_name.clone())
317 }
318}
319
320pub struct PropagateRequestId<S> {
327 inner: S,
328 header_name: HeaderName,
329}
330
331impl<S> PropagateRequestId<S> {
332 pub const fn new(inner: S, header_name: HeaderName) -> Self {
334 Self { inner, header_name }
335 }
336
337 pub const fn request_id(inner: S) -> Self {
339 Self::new(inner, REQUEST_ID)
340 }
341
342 pub const fn x_request_id(inner: S) -> Self {
344 Self::new(inner, X_REQUEST_ID)
345 }
346
347 define_inner_service_accessors!();
348}
349
350impl<S: fmt::Debug> fmt::Debug for PropagateRequestId<S> {
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 f.debug_struct("PropagateRequestId")
353 .field("inner", &self.inner)
354 .field("header_name", &self.header_name)
355 .finish()
356 }
357}
358
359impl<S: Clone> Clone for PropagateRequestId<S> {
360 fn clone(&self) -> Self {
361 PropagateRequestId {
362 inner: self.inner.clone(),
363 header_name: self.header_name.clone(),
364 }
365 }
366}
367
368impl<State, S, ReqBody, ResBody> Service<State, Request<ReqBody>> for PropagateRequestId<S>
369where
370 State: Clone + Send + Sync + 'static,
371 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
372 ReqBody: Send + 'static,
373 ResBody: Send + 'static,
374{
375 type Response = S::Response;
376 type Error = S::Error;
377
378 async fn serve(
379 &self,
380 ctx: Context<State>,
381 req: Request<ReqBody>,
382 ) -> Result<Self::Response, Self::Error> {
383 let request_id = req
384 .headers()
385 .get(&self.header_name)
386 .cloned()
387 .map(RequestId::new);
388
389 let mut response = self.inner.serve(ctx, req).await?;
390
391 if let Some(current_id) = response.headers().get(&self.header_name) {
392 if response.extensions().get::<RequestId>().is_none() {
393 let current_id = current_id.clone();
394 response.extensions_mut().insert(RequestId::new(current_id));
395 }
396 } else if let Some(request_id) = request_id {
397 response
398 .headers_mut()
399 .insert(self.header_name.clone(), request_id.0.clone());
400 response.extensions_mut().insert(request_id);
401 }
402
403 Ok(response)
404 }
405}
406
407#[derive(Debug, Clone, Copy, Default)]
409pub struct MakeRequestUuid;
410
411impl MakeRequestId for MakeRequestUuid {
412 fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
413 let request_id = Uuid::new_v4().to_string().parse().unwrap();
414 Some(RequestId::new(request_id))
415 }
416}
417
418#[derive(Debug, Clone, Copy, Default)]
420pub struct MakeRequestNanoid;
421
422impl MakeRequestId for MakeRequestNanoid {
423 fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
424 let request_id = nanoid!().parse().unwrap();
425 Some(RequestId::new(request_id))
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use crate::layer::set_header;
432 use crate::{Body, Response};
433 use rama_core::Layer;
434 use rama_core::service::service_fn;
435 use std::{
436 convert::Infallible,
437 sync::{
438 Arc,
439 atomic::{AtomicU64, Ordering},
440 },
441 };
442
443 #[allow(unused_imports)]
444 use super::*;
445
446 #[tokio::test]
447 async fn basic() {
448 let svc = (
449 SetRequestIdLayer::x_request_id(Counter::default()),
450 PropagateRequestIdLayer::x_request_id(),
451 )
452 .into_layer(service_fn(handler));
453
454 let req = Request::builder().body(Body::empty()).unwrap();
456 let res = svc.serve(Context::default(), req).await.unwrap();
457 assert_eq!(res.headers()["x-request-id"], "0");
458
459 let req = Request::builder().body(Body::empty()).unwrap();
460 let res = svc.serve(Context::default(), req).await.unwrap();
461 assert_eq!(res.headers()["x-request-id"], "1");
462
463 let req = Request::builder()
465 .header("x-request-id", "foo")
466 .body(Body::empty())
467 .unwrap();
468 let res = svc.serve(Context::default(), req).await.unwrap();
469 assert_eq!(res.headers()["x-request-id"], "foo");
470
471 let req = Request::builder().body(Body::empty()).unwrap();
473 let res = svc.serve(Context::default(), req).await.unwrap();
474 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
475 }
476
477 #[tokio::test]
478 async fn basic_with_request_id() {
479 let svc = (
480 SetRequestIdLayer::request_id(Counter::default()),
481 PropagateRequestIdLayer::request_id(),
482 )
483 .into_layer(service_fn(handler));
484
485 let req = Request::builder().body(Body::empty()).unwrap();
487 let res = svc.serve(Context::default(), req).await.unwrap();
488 assert_eq!(res.headers()["request-id"], "0");
489
490 let req = Request::builder().body(Body::empty()).unwrap();
491 let res = svc.serve(Context::default(), req).await.unwrap();
492 assert_eq!(res.headers()["request-id"], "1");
493
494 let req = Request::builder()
496 .header("request-id", "foo")
497 .body(Body::empty())
498 .unwrap();
499 let res = svc.serve(Context::default(), req).await.unwrap();
500 assert_eq!(res.headers()["request-id"], "foo");
501
502 let req = Request::builder().body(Body::empty()).unwrap();
504 let res = svc.serve(Context::default(), req).await.unwrap();
505 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "2");
506 }
507
508 #[tokio::test]
509 async fn other_middleware_setting_request_id_on_response() {
510 let svc = (
511 SetRequestIdLayer::x_request_id(Counter::default()),
512 PropagateRequestIdLayer::x_request_id(),
513 set_header::SetResponseHeaderLayer::overriding(
514 HeaderName::from_static("x-request-id"),
515 HeaderValue::from_str("foo").unwrap(),
516 ),
517 )
518 .into_layer(service_fn(handler));
519
520 let req = Request::builder()
521 .header("x-request-id", "foo")
522 .body(Body::empty())
523 .unwrap();
524 let res = svc.serve(Context::default(), req).await.unwrap();
525 assert_eq!(res.headers()["x-request-id"], "foo");
526 assert_eq!(res.extensions().get::<RequestId>().unwrap().0, "foo");
527 }
528
529 #[derive(Clone, Default)]
530 struct Counter(Arc<AtomicU64>);
531
532 impl MakeRequestId for Counter {
533 fn make_request_id<B>(&self, _request: &Request<B>) -> Option<RequestId> {
534 let id =
535 HeaderValue::from_str(&self.0.fetch_add(1, Ordering::AcqRel).to_string()).unwrap();
536 Some(RequestId::new(id))
537 }
538 }
539
540 async fn handler(_: Request<Body>) -> Result<Response<Body>, Infallible> {
541 Ok(Response::new(Body::empty()))
542 }
543
544 #[tokio::test]
545 async fn uuid() {
546 let svc = (
547 SetRequestIdLayer::x_request_id(MakeRequestUuid),
548 PropagateRequestIdLayer::x_request_id(),
549 )
550 .into_layer(service_fn(handler));
551
552 let req = Request::builder().body(Body::empty()).unwrap();
554 let mut res = svc.serve(Context::default(), req).await.unwrap();
555 let id = res.headers_mut().remove("x-request-id").unwrap();
556 id.to_str().unwrap().parse::<Uuid>().unwrap();
557 }
558
559 #[tokio::test]
560 async fn nanoid() {
561 let svc = (
562 SetRequestIdLayer::x_request_id(MakeRequestNanoid),
563 PropagateRequestIdLayer::x_request_id(),
564 )
565 .into_layer(service_fn(handler));
566
567 let req = Request::builder().body(Body::empty()).unwrap();
569 let mut res = svc.serve(Context::default(), req).await.unwrap();
570 let id = res.headers_mut().remove("x-request-id").unwrap();
571 assert_eq!(id.to_str().unwrap().chars().count(), 21);
572 }
573}