elif_http/middleware/utils/
request_id.rs1use crate::middleware::v2::{Middleware, Next, NextFuture};
7use crate::request::ElifRequest;
8use crate::response::{ElifHeaderName, ElifHeaderValue, ElifResponse};
9
10use std::sync::atomic::{AtomicU64, Ordering};
11use uuid::Uuid;
12
13#[derive(Debug)]
15pub enum RequestIdStrategy {
16 UuidV4,
18 UuidV1,
20 Counter(AtomicU64),
22 PrefixedUuid(String),
24 Custom(fn() -> String),
26}
27
28impl Default for RequestIdStrategy {
29 fn default() -> Self {
30 Self::UuidV4
31 }
32}
33
34impl Clone for RequestIdStrategy {
35 fn clone(&self) -> Self {
36 match self {
37 Self::UuidV4 => Self::UuidV4,
38 Self::UuidV1 => Self::UuidV1,
39 Self::Counter(counter) => {
40 Self::Counter(AtomicU64::new(counter.load(Ordering::Relaxed)))
42 }
43 Self::PrefixedUuid(prefix) => Self::PrefixedUuid(prefix.clone()),
44 Self::Custom(func) => Self::Custom(*func),
45 }
46 }
47}
48
49impl RequestIdStrategy {
50 pub fn generate(&self) -> String {
52 match self {
53 Self::UuidV4 => Uuid::new_v4().to_string(),
54 Self::UuidV1 => {
55 let timestamp = std::time::SystemTime::now()
58 .duration_since(std::time::UNIX_EPOCH)
59 .unwrap()
60 .as_millis();
61 format!("{}-{}", timestamp, Uuid::new_v4())
62 }
63 Self::Counter(counter) => {
64 let count = counter.fetch_add(1, Ordering::SeqCst);
65 format!("req-{:016x}", count)
66 }
67 Self::PrefixedUuid(prefix) => {
68 format!("{}-{}", prefix, Uuid::new_v4())
69 }
70 Self::Custom(generator) => generator(),
71 }
72 }
73}
74
75#[derive(Debug)]
77pub struct RequestIdConfig {
78 pub header_name: String,
80 pub strategy: RequestIdStrategy,
82 pub override_existing: bool,
84 pub add_to_response: bool,
86 pub log_request_id: bool,
88}
89
90impl Clone for RequestIdConfig {
91 fn clone(&self) -> Self {
92 Self {
93 header_name: self.header_name.clone(),
94 strategy: self.strategy.clone(),
95 override_existing: self.override_existing,
96 add_to_response: self.add_to_response,
97 log_request_id: self.log_request_id,
98 }
99 }
100}
101
102impl Default for RequestIdConfig {
103 fn default() -> Self {
104 Self {
105 header_name: "x-request-id".to_string(),
106 strategy: RequestIdStrategy::default(),
107 override_existing: false,
108 add_to_response: true,
109 log_request_id: true,
110 }
111 }
112}
113
114#[derive(Debug)]
116pub struct RequestIdMiddleware {
117 config: RequestIdConfig,
118}
119
120impl RequestIdMiddleware {
121 pub fn new() -> Self {
123 Self {
124 config: RequestIdConfig::default(),
125 }
126 }
127
128 pub fn with_config(config: RequestIdConfig) -> Self {
130 Self { config }
131 }
132
133 pub fn header_name(mut self, name: impl Into<String>) -> Self {
135 self.config.header_name = name.into();
136 self
137 }
138
139 pub fn strategy(mut self, strategy: RequestIdStrategy) -> Self {
141 self.config.strategy = strategy;
142 self
143 }
144
145 pub fn uuid_v4(mut self) -> Self {
147 self.config.strategy = RequestIdStrategy::UuidV4;
148 self
149 }
150
151 pub fn uuid_v1(mut self) -> Self {
153 self.config.strategy = RequestIdStrategy::UuidV1;
154 self
155 }
156
157 pub fn counter(mut self) -> Self {
159 self.config.strategy = RequestIdStrategy::Counter(AtomicU64::new(0));
160 self
161 }
162
163 pub fn prefixed(mut self, prefix: impl Into<String>) -> Self {
165 self.config.strategy = RequestIdStrategy::PrefixedUuid(prefix.into());
166 self
167 }
168
169 pub fn custom_generator(mut self, generator: fn() -> String) -> Self {
171 self.config.strategy = RequestIdStrategy::Custom(generator);
172 self
173 }
174
175 pub fn override_existing(mut self) -> Self {
177 self.config.override_existing = true;
178 self
179 }
180
181 pub fn no_response_header(mut self) -> Self {
183 self.config.add_to_response = false;
184 self
185 }
186
187 pub fn no_logging(mut self) -> Self {
189 self.config.log_request_id = false;
190 self
191 }
192
193 fn get_or_generate_request_id(&self, request: &ElifRequest) -> String {
195 if !self.config.override_existing {
197 if let Some(existing_id) = request.header(&self.config.header_name) {
198 if let Ok(id_str) = existing_id.to_str() {
199 if !id_str.trim().is_empty() {
200 return id_str.to_string();
201 }
202 }
203 }
204 }
205
206 self.config.strategy.generate()
208 }
209
210 fn add_request_id_to_request(&self, mut request: ElifRequest, request_id: &str) -> ElifRequest {
212 let header_name = match ElifHeaderName::from_str(&self.config.header_name) {
213 Ok(name) => name,
214 Err(_) => return request, };
216
217 let header_value = match ElifHeaderValue::from_str(request_id) {
218 Ok(value) => value,
219 Err(_) => return request, };
221
222 request.headers.insert(header_name, header_value);
223 request
224 }
225
226 fn add_request_id_to_response(&self, response: ElifResponse, request_id: &str) -> ElifResponse {
228 if !self.config.add_to_response {
229 return response;
230 }
231
232 let header_name = match self.config.header_name.as_str() {
233 "x-request-id" => "x-request-id",
234 "request-id" => "request-id",
235 "x-trace-id" => "x-trace-id",
236 _ => &self.config.header_name,
237 };
238
239 response
240 .header(header_name, request_id)
241 .unwrap_or_else(|_| {
242 ElifResponse::internal_server_error().json_value(serde_json::json!({
244 "error": {
245 "code": "internal_error",
246 "message": "Failed to add request ID to response"
247 }
248 }))
249 })
250 }
251
252 fn log_request_id(&self, request_id: &str, method: &axum::http::Method, path: &str) {
254 if self.config.log_request_id {
255 tracing::info!(
256 request_id = request_id,
257 method = %method,
258 path = path,
259 "Request started"
260 );
261 }
262 }
263}
264
265impl Default for RequestIdMiddleware {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271impl Middleware for RequestIdMiddleware {
272 fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
273 let request_id = self.get_or_generate_request_id(&request);
275 let method = request.method.clone();
276 let path = request.path().to_string();
277
278 self.log_request_id(&request_id, method.to_axum(), &path);
280
281 let updated_request = self.add_request_id_to_request(request, &request_id);
283
284 let config = self.config.clone();
285 let request_id_clone = request_id.clone();
286
287 Box::pin(async move {
288 let response = next.run(updated_request).await;
290
291 let middleware = RequestIdMiddleware { config };
293 middleware.add_request_id_to_response(response, &request_id_clone)
294 })
295 }
296
297 fn name(&self) -> &'static str {
298 "RequestIdMiddleware"
299 }
300}
301
302pub trait RequestIdExt {
304 fn request_id(&self) -> Option<String>;
306
307 fn request_id_with_fallbacks(&self) -> Option<String>;
309}
310
311impl RequestIdExt for ElifRequest {
312 fn request_id(&self) -> Option<String> {
313 self.header("x-request-id")
314 .and_then(|h| h.to_str().ok())
315 .map(|s| s.to_string())
316 }
317
318 fn request_id_with_fallbacks(&self) -> Option<String> {
319 let header_names = [
321 "x-request-id",
322 "request-id",
323 "x-trace-id",
324 "x-correlation-id",
325 "x-session-id",
326 ];
327
328 for header_name in &header_names {
329 if let Some(value) = self.header(header_name) {
330 if let Ok(id_str) = value.to_str() {
331 if !id_str.trim().is_empty() {
332 return Some(id_str.to_string());
333 }
334 }
335 }
336 }
337
338 None
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::request::ElifRequest;
346 use crate::response::{ElifHeaderMap, ElifResponse};
347
348 #[test]
349 fn test_request_id_strategies() {
350 let uuid_strategy = RequestIdStrategy::UuidV4;
352 let id1 = uuid_strategy.generate();
353 let id2 = uuid_strategy.generate();
354 assert_ne!(id1, id2);
355 assert_eq!(id1.len(), 36); let counter_strategy = RequestIdStrategy::Counter(AtomicU64::new(0));
359 let id1 = counter_strategy.generate();
360 let id2 = counter_strategy.generate();
361 assert_ne!(id1, id2);
362 assert!(id1.starts_with("req-"));
363 assert!(id2.starts_with("req-"));
364
365 let prefixed_strategy = RequestIdStrategy::PrefixedUuid("api".to_string());
367 let id = prefixed_strategy.generate();
368 assert!(id.starts_with("api-"));
369 assert_eq!(id.len(), 40); let custom_strategy = RequestIdStrategy::Custom(|| "custom-123".to_string());
373 let id = custom_strategy.generate();
374 assert_eq!(id, "custom-123");
375 }
376
377 #[test]
378 fn test_request_id_config() {
379 let config = RequestIdConfig::default();
380 assert_eq!(config.header_name, "x-request-id");
381 assert!(!config.override_existing);
382 assert!(config.add_to_response);
383 assert!(config.log_request_id);
384 }
385
386 #[tokio::test]
387 async fn test_request_id_middleware_basic() {
388 let middleware = RequestIdMiddleware::new();
389
390 let request = ElifRequest::new(
391 crate::request::ElifMethod::GET,
392 "/api/test".parse().unwrap(),
393 ElifHeaderMap::new(),
394 );
395
396 let next = Next::new(|req| {
397 Box::pin(async move {
398 assert!(req.request_id().is_some());
400 ElifResponse::ok().text("Success")
401 })
402 });
403
404 let response = middleware.handle(request, next).await;
405 assert_eq!(
406 response.status_code(),
407 crate::response::status::ElifStatusCode::OK
408 );
409
410 let axum_response = response.into_axum_response();
412 let (parts, _) = axum_response.into_parts();
413 assert!(parts.headers.contains_key("x-request-id"));
414 }
415
416 #[tokio::test]
417 async fn test_request_id_middleware_existing_id() {
418 let middleware = RequestIdMiddleware::new();
419
420 let mut headers = crate::response::headers::ElifHeaderMap::new();
421 headers.insert(
422 crate::response::headers::ElifHeaderName::from_str("x-request-id").unwrap(),
423 "existing-123".parse().unwrap(),
424 );
425 let request = ElifRequest::new(
426 crate::request::ElifMethod::GET,
427 "/api/test".parse().unwrap(),
428 headers,
429 );
430
431 let next = Next::new(|req| {
432 Box::pin(async move {
433 assert_eq!(req.request_id(), Some("existing-123".to_string()));
435 ElifResponse::ok().text("Success")
436 })
437 });
438
439 let response = middleware.handle(request, next).await;
440
441 let axum_response = response.into_axum_response();
443 let (parts, _) = axum_response.into_parts();
444 assert_eq!(parts.headers.get("x-request-id").unwrap(), "existing-123");
445 }
446
447 #[tokio::test]
448 async fn test_request_id_middleware_override() {
449 let middleware = RequestIdMiddleware::new().override_existing();
450
451 let mut headers = ElifHeaderMap::new();
452 headers.insert(
453 crate::response::headers::ElifHeaderName::from_str("x-request-id").unwrap(),
454 "existing-123".parse().unwrap(),
455 );
456 let request = ElifRequest::new(
457 crate::request::ElifMethod::GET,
458 "/api/test".parse().unwrap(),
459 headers,
460 );
461
462 let next = Next::new(|req| {
463 Box::pin(async move {
464 let request_id = req.request_id().unwrap();
466 assert_ne!(request_id, "existing-123");
467 ElifResponse::ok().text("Success")
468 })
469 });
470
471 let response = middleware.handle(request, next).await;
472
473 let axum_response = response.into_axum_response();
475 let (parts, _) = axum_response.into_parts();
476 let response_id = parts.headers.get("x-request-id").unwrap().to_str().unwrap();
477 assert_ne!(response_id, "existing-123");
478 }
479
480 #[tokio::test]
481 async fn test_request_id_custom_header() {
482 let middleware = RequestIdMiddleware::new().header_name("x-trace-id");
483
484 let request = ElifRequest::new(
485 crate::request::ElifMethod::GET,
486 "/api/test".parse().unwrap(),
487 ElifHeaderMap::new(),
488 );
489
490 let next = Next::new(|req| {
491 Box::pin(async move {
492 assert!(req.header("x-trace-id").is_some());
494 ElifResponse::ok().text("Success")
495 })
496 });
497
498 let response = middleware.handle(request, next).await;
499
500 let axum_response = response.into_axum_response();
501 let (parts, _) = axum_response.into_parts();
502 assert!(parts.headers.contains_key("x-trace-id"));
503 }
504
505 #[tokio::test]
506 async fn test_request_id_prefixed() {
507 let middleware = RequestIdMiddleware::new().prefixed("api");
508
509 let request = ElifRequest::new(
510 crate::request::ElifMethod::GET,
511 "/api/test".parse().unwrap(),
512 ElifHeaderMap::new(),
513 );
514
515 let next = Next::new(|req| {
516 Box::pin(async move {
517 let request_id = req.request_id().unwrap();
518 assert!(request_id.starts_with("api-"));
519 ElifResponse::ok().text("Success")
520 })
521 });
522
523 let response = middleware.handle(request, next).await;
524
525 let axum_response = response.into_axum_response();
526 let (parts, _) = axum_response.into_parts();
527 let response_id = parts.headers.get("x-request-id").unwrap().to_str().unwrap();
528 assert!(response_id.starts_with("api-"));
529 }
530
531 #[tokio::test]
532 async fn test_request_id_counter() {
533 let middleware = RequestIdMiddleware::new().counter();
534
535 let request = ElifRequest::new(
536 crate::request::ElifMethod::GET,
537 "/api/test".parse().unwrap(),
538 ElifHeaderMap::new(),
539 );
540
541 let next = Next::new(|req| {
542 Box::pin(async move {
543 let request_id = req.request_id().unwrap();
544 assert!(request_id.starts_with("req-"));
545 ElifResponse::ok().text("Success")
546 })
547 });
548
549 let response = middleware.handle(request, next).await;
550 assert_eq!(
551 response.status_code(),
552 crate::response::status::ElifStatusCode::OK
553 );
554 }
555
556 #[tokio::test]
557 async fn test_request_id_no_response_header() {
558 let middleware = RequestIdMiddleware::new().no_response_header();
559
560 let request = ElifRequest::new(
561 crate::request::ElifMethod::GET,
562 "/api/test".parse().unwrap(),
563 ElifHeaderMap::new(),
564 );
565
566 let next = Next::new(|_req| Box::pin(async move { ElifResponse::ok().text("Success") }));
567
568 let response = middleware.handle(request, next).await;
569
570 let axum_response = response.into_axum_response();
571 let (parts, _) = axum_response.into_parts();
572 assert!(!parts.headers.contains_key("x-request-id"));
573 }
574
575 #[test]
576 fn test_request_id_extension_trait() {
577 let mut headers = ElifHeaderMap::new();
578 headers.insert(
579 crate::response::headers::ElifHeaderName::from_str("x-request-id").unwrap(),
580 "test-123".parse().unwrap(),
581 );
582 let request = ElifRequest::new(
583 crate::request::ElifMethod::GET,
584 "/test".parse().unwrap(),
585 headers,
586 );
587
588 assert_eq!(request.request_id(), Some("test-123".to_string()));
589
590 let mut headers = ElifHeaderMap::new();
592 headers.insert(
593 crate::response::headers::ElifHeaderName::from_str("x-trace-id").unwrap(),
594 "trace-456".parse().unwrap(),
595 );
596 let request = ElifRequest::new(
597 crate::request::ElifMethod::GET,
598 "/test".parse().unwrap(),
599 headers,
600 );
601
602 assert_eq!(
603 request.request_id_with_fallbacks(),
604 Some("trace-456".to_string())
605 );
606 }
607
608 #[tokio::test]
609 async fn test_builder_pattern() {
610 let middleware = RequestIdMiddleware::new()
611 .header_name("x-custom-id")
612 .prefixed("test")
613 .override_existing()
614 .no_response_header()
615 .no_logging();
616
617 assert_eq!(middleware.config.header_name, "x-custom-id");
618 assert!(middleware.config.override_existing);
619 assert!(!middleware.config.add_to_response);
620 assert!(!middleware.config.log_request_id);
621 assert!(matches!(
622 middleware.config.strategy,
623 RequestIdStrategy::PrefixedUuid(_)
624 ));
625 }
626}