elif_http/middleware/utils/
request_id.rs

1//! # Request ID Middleware
2//!
3//! Provides unique request ID generation and tracking for distributed systems.
4//! Supports X-Request-ID header forwarding and custom ID generation strategies.
5
6use crate::middleware::v2::{Middleware, Next, NextFuture};
7use crate::request::ElifRequest;
8use crate::response::{ElifHeaderName, ElifHeaderValue, ElifResponse};
9
10use std::sync::atomic::{AtomicU64, Ordering};
11use uuid::Uuid;
12
13/// Request ID generation strategy
14#[derive(Debug)]
15pub enum RequestIdStrategy {
16    /// Generate UUID v4 (random)
17    UuidV4,
18    /// Generate UUID v1 (timestamp-based)
19    UuidV1,
20    /// Use incrementing counter (not suitable for distributed systems)
21    Counter(AtomicU64),
22    /// Use custom prefix with UUID
23    PrefixedUuid(String),
24    /// Use custom function to generate request ID
25    Custom(fn() -> String),
26}
27
28impl Default for RequestIdStrategy {
29    fn default() -> Self {
30        Self::UuidV4
31    }
32}
33
34impl Clone for RequestIdStrategy {
35    fn clone(&self) -> Self {
36        match self {
37            Self::UuidV4 => Self::UuidV4,
38            Self::UuidV1 => Self::UuidV1,
39            Self::Counter(counter) => {
40                // Create new counter starting from current value
41                Self::Counter(AtomicU64::new(counter.load(Ordering::Relaxed)))
42            }
43            Self::PrefixedUuid(prefix) => Self::PrefixedUuid(prefix.clone()),
44            Self::Custom(func) => Self::Custom(*func),
45        }
46    }
47}
48
49impl RequestIdStrategy {
50    /// Generate a new request ID using this strategy
51    pub fn generate(&self) -> String {
52        match self {
53            Self::UuidV4 => Uuid::new_v4().to_string(),
54            Self::UuidV1 => {
55                // UUID v1 requires MAC address and timestamp
56                // For simplicity, we'll use v4 with timestamp prefix
57                let timestamp = std::time::SystemTime::now()
58                    .duration_since(std::time::UNIX_EPOCH)
59                    .unwrap()
60                    .as_millis();
61                format!("{}-{}", timestamp, Uuid::new_v4())
62            }
63            Self::Counter(counter) => {
64                let count = counter.fetch_add(1, Ordering::SeqCst);
65                format!("req-{:016x}", count)
66            }
67            Self::PrefixedUuid(prefix) => {
68                format!("{}-{}", prefix, Uuid::new_v4())
69            }
70            Self::Custom(generator) => generator(),
71        }
72    }
73}
74
75/// Configuration for request ID middleware
76#[derive(Debug)]
77pub struct RequestIdConfig {
78    /// Header name for request ID (default: "x-request-id")
79    pub header_name: String,
80    /// Request ID generation strategy
81    pub strategy: RequestIdStrategy,
82    /// Whether to generate new ID if one already exists
83    pub override_existing: bool,
84    /// Whether to add request ID to response headers
85    pub add_to_response: bool,
86    /// Whether to log request ID
87    pub log_request_id: bool,
88}
89
90impl Clone for RequestIdConfig {
91    fn clone(&self) -> Self {
92        Self {
93            header_name: self.header_name.clone(),
94            strategy: self.strategy.clone(),
95            override_existing: self.override_existing,
96            add_to_response: self.add_to_response,
97            log_request_id: self.log_request_id,
98        }
99    }
100}
101
102impl Default for RequestIdConfig {
103    fn default() -> Self {
104        Self {
105            header_name: "x-request-id".to_string(),
106            strategy: RequestIdStrategy::default(),
107            override_existing: false,
108            add_to_response: true,
109            log_request_id: true,
110        }
111    }
112}
113
114/// Middleware for request ID generation and tracking
115#[derive(Debug)]
116pub struct RequestIdMiddleware {
117    config: RequestIdConfig,
118}
119
120impl RequestIdMiddleware {
121    /// Create new request ID middleware with default configuration
122    pub fn new() -> Self {
123        Self {
124            config: RequestIdConfig::default(),
125        }
126    }
127
128    /// Create request ID middleware with custom configuration
129    pub fn with_config(config: RequestIdConfig) -> Self {
130        Self { config }
131    }
132
133    /// Set custom header name for request ID
134    pub fn header_name(mut self, name: impl Into<String>) -> Self {
135        self.config.header_name = name.into();
136        self
137    }
138
139    /// Set request ID generation strategy
140    pub fn strategy(mut self, strategy: RequestIdStrategy) -> Self {
141        self.config.strategy = strategy;
142        self
143    }
144
145    /// Use UUID v4 strategy (default)
146    pub fn uuid_v4(mut self) -> Self {
147        self.config.strategy = RequestIdStrategy::UuidV4;
148        self
149    }
150
151    /// Use UUID v1 strategy (timestamp-based)
152    pub fn uuid_v1(mut self) -> Self {
153        self.config.strategy = RequestIdStrategy::UuidV1;
154        self
155    }
156
157    /// Use counter strategy (not recommended for distributed systems)
158    pub fn counter(mut self) -> Self {
159        self.config.strategy = RequestIdStrategy::Counter(AtomicU64::new(0));
160        self
161    }
162
163    /// Use prefixed UUID strategy
164    pub fn prefixed(mut self, prefix: impl Into<String>) -> Self {
165        self.config.strategy = RequestIdStrategy::PrefixedUuid(prefix.into());
166        self
167    }
168
169    /// Use custom ID generation function
170    pub fn custom_generator(mut self, generator: fn() -> String) -> Self {
171        self.config.strategy = RequestIdStrategy::Custom(generator);
172        self
173    }
174
175    /// Override existing request ID if present
176    pub fn override_existing(mut self) -> Self {
177        self.config.override_existing = true;
178        self
179    }
180
181    /// Don't add request ID to response headers
182    pub fn no_response_header(mut self) -> Self {
183        self.config.add_to_response = false;
184        self
185    }
186
187    /// Disable request ID logging
188    pub fn no_logging(mut self) -> Self {
189        self.config.log_request_id = false;
190        self
191    }
192
193    /// Extract or generate request ID from request
194    fn get_or_generate_request_id(&self, request: &ElifRequest) -> String {
195        // Check if request already has a request ID
196        if !self.config.override_existing {
197            if let Some(existing_id) = request.header(&self.config.header_name) {
198                if let Ok(id_str) = existing_id.to_str() {
199                    if !id_str.trim().is_empty() {
200                        return id_str.to_string();
201                    }
202                }
203            }
204        }
205
206        // Generate new request ID
207        self.config.strategy.generate()
208    }
209
210    /// Add request ID to request headers
211    fn add_request_id_to_request(&self, mut request: ElifRequest, request_id: &str) -> ElifRequest {
212        let header_name = match ElifHeaderName::from_str(&self.config.header_name) {
213            Ok(name) => name,
214            Err(_) => return request, // Invalid header name, skip
215        };
216
217        let header_value = match ElifHeaderValue::from_str(request_id) {
218            Ok(value) => value,
219            Err(_) => return request, // Invalid header value, skip
220        };
221
222        request.headers.insert(header_name, header_value);
223        request
224    }
225
226    /// Add request ID to response headers
227    fn add_request_id_to_response(&self, response: ElifResponse, request_id: &str) -> ElifResponse {
228        if !self.config.add_to_response {
229            return response;
230        }
231
232        let header_name = match self.config.header_name.as_str() {
233            "x-request-id" => "x-request-id",
234            "request-id" => "request-id",
235            "x-trace-id" => "x-trace-id",
236            _ => &self.config.header_name,
237        };
238
239        response
240            .header(header_name, request_id)
241            .unwrap_or_else(|_| {
242                // If we can't add the header for some reason, return a new response with error
243                ElifResponse::internal_server_error().json_value(serde_json::json!({
244                    "error": {
245                        "code": "internal_error",
246                        "message": "Failed to add request ID to response"
247                    }
248                }))
249            })
250    }
251
252    /// Log request ID if enabled
253    fn log_request_id(&self, request_id: &str, method: &axum::http::Method, path: &str) {
254        if self.config.log_request_id {
255            tracing::info!(
256                request_id = request_id,
257                method = %method,
258                path = path,
259                "Request started"
260            );
261        }
262    }
263}
264
265impl Default for RequestIdMiddleware {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271impl Middleware for RequestIdMiddleware {
272    fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
273        // Generate or extract request ID
274        let request_id = self.get_or_generate_request_id(&request);
275        let method = request.method.clone();
276        let path = request.path().to_string();
277
278        // Log request ID
279        self.log_request_id(&request_id, method.to_axum(), &path);
280
281        // Add request ID to request headers
282        let updated_request = self.add_request_id_to_request(request, &request_id);
283
284        let config = self.config.clone();
285        let request_id_clone = request_id.clone();
286
287        Box::pin(async move {
288            // Execute next middleware/handler
289            let response = next.run(updated_request).await;
290
291            // Add request ID to response headers
292            let middleware = RequestIdMiddleware { config };
293            middleware.add_request_id_to_response(response, &request_id_clone)
294        })
295    }
296
297    fn name(&self) -> &'static str {
298        "RequestIdMiddleware"
299    }
300}
301
302/// Extension trait to easily get request ID from ElifRequest
303pub trait RequestIdExt {
304    /// Get the request ID from the request headers
305    fn request_id(&self) -> Option<String>;
306
307    /// Get the request ID with fallback header names
308    fn request_id_with_fallbacks(&self) -> Option<String>;
309}
310
311impl RequestIdExt for ElifRequest {
312    fn request_id(&self) -> Option<String> {
313        self.header("x-request-id")
314            .and_then(|h| h.to_str().ok())
315            .map(|s| s.to_string())
316    }
317
318    fn request_id_with_fallbacks(&self) -> Option<String> {
319        // Try common request ID header names
320        let header_names = [
321            "x-request-id",
322            "request-id",
323            "x-trace-id",
324            "x-correlation-id",
325            "x-session-id",
326        ];
327
328        for header_name in &header_names {
329            if let Some(value) = self.header(header_name) {
330                if let Ok(id_str) = value.to_str() {
331                    if !id_str.trim().is_empty() {
332                        return Some(id_str.to_string());
333                    }
334                }
335            }
336        }
337
338        None
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::request::ElifRequest;
346    use crate::response::{ElifHeaderMap, ElifResponse};
347
348    #[test]
349    fn test_request_id_strategies() {
350        // UUID v4
351        let uuid_strategy = RequestIdStrategy::UuidV4;
352        let id1 = uuid_strategy.generate();
353        let id2 = uuid_strategy.generate();
354        assert_ne!(id1, id2);
355        assert_eq!(id1.len(), 36); // Standard UUID length
356
357        // Counter
358        let counter_strategy = RequestIdStrategy::Counter(AtomicU64::new(0));
359        let id1 = counter_strategy.generate();
360        let id2 = counter_strategy.generate();
361        assert_ne!(id1, id2);
362        assert!(id1.starts_with("req-"));
363        assert!(id2.starts_with("req-"));
364
365        // Prefixed UUID
366        let prefixed_strategy = RequestIdStrategy::PrefixedUuid("api".to_string());
367        let id = prefixed_strategy.generate();
368        assert!(id.starts_with("api-"));
369        assert_eq!(id.len(), 40); // "api-" + 36-char UUID
370
371        // Custom
372        let custom_strategy = RequestIdStrategy::Custom(|| "custom-123".to_string());
373        let id = custom_strategy.generate();
374        assert_eq!(id, "custom-123");
375    }
376
377    #[test]
378    fn test_request_id_config() {
379        let config = RequestIdConfig::default();
380        assert_eq!(config.header_name, "x-request-id");
381        assert!(!config.override_existing);
382        assert!(config.add_to_response);
383        assert!(config.log_request_id);
384    }
385
386    #[tokio::test]
387    async fn test_request_id_middleware_basic() {
388        let middleware = RequestIdMiddleware::new();
389
390        let request = ElifRequest::new(
391            crate::request::ElifMethod::GET,
392            "/api/test".parse().unwrap(),
393            ElifHeaderMap::new(),
394        );
395
396        let next = Next::new(|req| {
397            Box::pin(async move {
398                // Verify request has request ID
399                assert!(req.request_id().is_some());
400                ElifResponse::ok().text("Success")
401            })
402        });
403
404        let response = middleware.handle(request, next).await;
405        assert_eq!(
406            response.status_code(),
407            crate::response::status::ElifStatusCode::OK
408        );
409
410        // Check response has request ID header
411        let axum_response = response.into_axum_response();
412        let (parts, _) = axum_response.into_parts();
413        assert!(parts.headers.contains_key("x-request-id"));
414    }
415
416    #[tokio::test]
417    async fn test_request_id_middleware_existing_id() {
418        let middleware = RequestIdMiddleware::new();
419
420        let mut headers = crate::response::headers::ElifHeaderMap::new();
421        headers.insert(
422            crate::response::headers::ElifHeaderName::from_str("x-request-id").unwrap(),
423            "existing-123".parse().unwrap(),
424        );
425        let request = ElifRequest::new(
426            crate::request::ElifMethod::GET,
427            "/api/test".parse().unwrap(),
428            headers,
429        );
430
431        let next = Next::new(|req| {
432            Box::pin(async move {
433                // Should preserve existing request ID
434                assert_eq!(req.request_id(), Some("existing-123".to_string()));
435                ElifResponse::ok().text("Success")
436            })
437        });
438
439        let response = middleware.handle(request, next).await;
440
441        // Response should have the same request ID
442        let axum_response = response.into_axum_response();
443        let (parts, _) = axum_response.into_parts();
444        assert_eq!(parts.headers.get("x-request-id").unwrap(), "existing-123");
445    }
446
447    #[tokio::test]
448    async fn test_request_id_middleware_override() {
449        let middleware = RequestIdMiddleware::new().override_existing();
450
451        let mut headers = ElifHeaderMap::new();
452        headers.insert(
453            crate::response::headers::ElifHeaderName::from_str("x-request-id").unwrap(),
454            "existing-123".parse().unwrap(),
455        );
456        let request = ElifRequest::new(
457            crate::request::ElifMethod::GET,
458            "/api/test".parse().unwrap(),
459            headers,
460        );
461
462        let next = Next::new(|req| {
463            Box::pin(async move {
464                // Should have new request ID, not the existing one
465                let request_id = req.request_id().unwrap();
466                assert_ne!(request_id, "existing-123");
467                ElifResponse::ok().text("Success")
468            })
469        });
470
471        let response = middleware.handle(request, next).await;
472
473        // Response should have new request ID
474        let axum_response = response.into_axum_response();
475        let (parts, _) = axum_response.into_parts();
476        let response_id = parts.headers.get("x-request-id").unwrap().to_str().unwrap();
477        assert_ne!(response_id, "existing-123");
478    }
479
480    #[tokio::test]
481    async fn test_request_id_custom_header() {
482        let middleware = RequestIdMiddleware::new().header_name("x-trace-id");
483
484        let request = ElifRequest::new(
485            crate::request::ElifMethod::GET,
486            "/api/test".parse().unwrap(),
487            ElifHeaderMap::new(),
488        );
489
490        let next = Next::new(|req| {
491            Box::pin(async move {
492                // Check custom header name
493                assert!(req.header("x-trace-id").is_some());
494                ElifResponse::ok().text("Success")
495            })
496        });
497
498        let response = middleware.handle(request, next).await;
499
500        let axum_response = response.into_axum_response();
501        let (parts, _) = axum_response.into_parts();
502        assert!(parts.headers.contains_key("x-trace-id"));
503    }
504
505    #[tokio::test]
506    async fn test_request_id_prefixed() {
507        let middleware = RequestIdMiddleware::new().prefixed("api");
508
509        let request = ElifRequest::new(
510            crate::request::ElifMethod::GET,
511            "/api/test".parse().unwrap(),
512            ElifHeaderMap::new(),
513        );
514
515        let next = Next::new(|req| {
516            Box::pin(async move {
517                let request_id = req.request_id().unwrap();
518                assert!(request_id.starts_with("api-"));
519                ElifResponse::ok().text("Success")
520            })
521        });
522
523        let response = middleware.handle(request, next).await;
524
525        let axum_response = response.into_axum_response();
526        let (parts, _) = axum_response.into_parts();
527        let response_id = parts.headers.get("x-request-id").unwrap().to_str().unwrap();
528        assert!(response_id.starts_with("api-"));
529    }
530
531    #[tokio::test]
532    async fn test_request_id_counter() {
533        let middleware = RequestIdMiddleware::new().counter();
534
535        let request = ElifRequest::new(
536            crate::request::ElifMethod::GET,
537            "/api/test".parse().unwrap(),
538            ElifHeaderMap::new(),
539        );
540
541        let next = Next::new(|req| {
542            Box::pin(async move {
543                let request_id = req.request_id().unwrap();
544                assert!(request_id.starts_with("req-"));
545                ElifResponse::ok().text("Success")
546            })
547        });
548
549        let response = middleware.handle(request, next).await;
550        assert_eq!(
551            response.status_code(),
552            crate::response::status::ElifStatusCode::OK
553        );
554    }
555
556    #[tokio::test]
557    async fn test_request_id_no_response_header() {
558        let middleware = RequestIdMiddleware::new().no_response_header();
559
560        let request = ElifRequest::new(
561            crate::request::ElifMethod::GET,
562            "/api/test".parse().unwrap(),
563            ElifHeaderMap::new(),
564        );
565
566        let next = Next::new(|_req| Box::pin(async move { ElifResponse::ok().text("Success") }));
567
568        let response = middleware.handle(request, next).await;
569
570        let axum_response = response.into_axum_response();
571        let (parts, _) = axum_response.into_parts();
572        assert!(!parts.headers.contains_key("x-request-id"));
573    }
574
575    #[test]
576    fn test_request_id_extension_trait() {
577        let mut headers = ElifHeaderMap::new();
578        headers.insert(
579            crate::response::headers::ElifHeaderName::from_str("x-request-id").unwrap(),
580            "test-123".parse().unwrap(),
581        );
582        let request = ElifRequest::new(
583            crate::request::ElifMethod::GET,
584            "/test".parse().unwrap(),
585            headers,
586        );
587
588        assert_eq!(request.request_id(), Some("test-123".to_string()));
589
590        // Test with fallbacks
591        let mut headers = ElifHeaderMap::new();
592        headers.insert(
593            crate::response::headers::ElifHeaderName::from_str("x-trace-id").unwrap(),
594            "trace-456".parse().unwrap(),
595        );
596        let request = ElifRequest::new(
597            crate::request::ElifMethod::GET,
598            "/test".parse().unwrap(),
599            headers,
600        );
601
602        assert_eq!(
603            request.request_id_with_fallbacks(),
604            Some("trace-456".to_string())
605        );
606    }
607
608    #[tokio::test]
609    async fn test_builder_pattern() {
610        let middleware = RequestIdMiddleware::new()
611            .header_name("x-custom-id")
612            .prefixed("test")
613            .override_existing()
614            .no_response_header()
615            .no_logging();
616
617        assert_eq!(middleware.config.header_name, "x-custom-id");
618        assert!(middleware.config.override_existing);
619        assert!(!middleware.config.add_to_response);
620        assert!(!middleware.config.log_request_id);
621        assert!(matches!(
622            middleware.config.strategy,
623            RequestIdStrategy::PrefixedUuid(_)
624        ));
625    }
626}