pulseengine_mcp_logging/
correlation.rs

1//! Request correlation and distributed tracing for MCP servers
2//!
3//! This module provides:
4//! - Request correlation IDs
5//! - Distributed trace propagation
6//! - Request context tracking
7//! - Cross-service correlation
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, info, warn};
15use uuid::Uuid;
16
17/// Request correlation context
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CorrelationContext {
20    /// Primary correlation ID for the entire request chain
21    pub correlation_id: String,
22
23    /// Request ID for this specific request
24    pub request_id: String,
25
26    /// Parent request ID (if this is a sub-request)
27    pub parent_request_id: Option<String>,
28
29    /// Trace ID for OpenTelemetry compatibility
30    pub trace_id: Option<String>,
31
32    /// Span ID for OpenTelemetry compatibility
33    pub span_id: Option<String>,
34
35    /// User ID associated with the request
36    pub user_id: Option<String>,
37
38    /// Session ID
39    pub session_id: Option<String>,
40
41    /// Service name that initiated the request
42    pub originating_service: String,
43
44    /// Current service processing the request
45    pub current_service: String,
46
47    /// Request start time
48    pub start_time: DateTime<Utc>,
49
50    /// Request path/breadcrumb
51    pub request_path: Vec<String>,
52
53    /// Custom context fields
54    pub custom_fields: HashMap<String, String>,
55}
56
57/// Request tracking entry
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct RequestTraceEntry {
60    /// Request context
61    pub context: CorrelationContext,
62
63    /// Request details
64    pub method: String,
65    pub params: serde_json::Value,
66    pub response: Option<serde_json::Value>,
67    pub error: Option<String>,
68
69    /// Timing information
70    pub duration_ms: Option<u64>,
71    pub end_time: Option<DateTime<Utc>>,
72
73    /// Resource usage
74    pub memory_used_bytes: Option<u64>,
75    pub cpu_time_ms: Option<u64>,
76}
77
78/// Correlation manager
79pub struct CorrelationManager {
80    /// Active requests being tracked
81    active_requests: Arc<RwLock<HashMap<String, RequestTraceEntry>>>,
82
83    /// Completed request history (limited size)
84    completed_requests: Arc<RwLock<HashMap<String, RequestTraceEntry>>>,
85
86    /// Configuration
87    config: CorrelationConfig,
88}
89
90/// Configuration for correlation tracking
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct CorrelationConfig {
93    /// Enable correlation tracking
94    pub enabled: bool,
95
96    /// Maximum number of active requests to track
97    pub max_active_requests: usize,
98
99    /// Maximum number of completed requests to keep in history
100    pub max_completed_requests: usize,
101
102    /// Request timeout for cleanup (in seconds)
103    pub request_timeout_secs: u64,
104
105    /// Enable detailed resource tracking
106    pub track_resources: bool,
107
108    /// Enable cross-service correlation
109    pub cross_service_enabled: bool,
110
111    /// Header names for correlation propagation
112    pub correlation_headers: CorrelationHeaders,
113}
114
115/// HTTP headers used for correlation propagation
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CorrelationHeaders {
118    /// Correlation ID header
119    pub correlation_id: String,
120
121    /// Request ID header
122    pub request_id: String,
123
124    /// Parent request ID header
125    pub parent_request_id: String,
126
127    /// Trace ID header (OpenTelemetry)
128    pub trace_id: String,
129
130    /// Span ID header (OpenTelemetry)
131    pub span_id: String,
132
133    /// User ID header
134    pub user_id: String,
135
136    /// Session ID header
137    pub session_id: String,
138}
139
140impl CorrelationManager {
141    /// Create a new correlation manager
142    pub fn new(config: CorrelationConfig) -> Self {
143        Self {
144            active_requests: Arc::new(RwLock::new(HashMap::new())),
145            completed_requests: Arc::new(RwLock::new(HashMap::new())),
146            config,
147        }
148    }
149
150    /// Start correlation tracking services
151    pub async fn start(&self) {
152        if !self.config.enabled {
153            info!("Correlation tracking is disabled");
154            return;
155        }
156
157        info!("Starting correlation tracking");
158
159        // Start cleanup task
160        let active_requests = self.active_requests.clone();
161        let completed_requests = self.completed_requests.clone();
162        let config = self.config.clone();
163
164        tokio::spawn(async move {
165            Self::cleanup_expired_requests(active_requests, completed_requests, config).await;
166        });
167    }
168
169    /// Create a new correlation context
170    pub fn create_context(
171        &self,
172        service_name: &str,
173        parent_context: Option<&CorrelationContext>,
174    ) -> CorrelationContext {
175        let correlation_id = if let Some(parent) = parent_context {
176            parent.correlation_id.clone()
177        } else {
178            Uuid::new_v4().to_string()
179        };
180
181        let request_id = Uuid::new_v4().to_string();
182        let parent_request_id = parent_context.map(|ctx| ctx.request_id.clone());
183
184        let mut request_path = parent_context
185            .map(|ctx| ctx.request_path.clone())
186            .unwrap_or_default();
187        request_path.push(service_name.to_string());
188
189        CorrelationContext {
190            correlation_id,
191            request_id,
192            parent_request_id,
193            trace_id: parent_context.and_then(|ctx| ctx.trace_id.clone()),
194            span_id: parent_context.and_then(|ctx| ctx.span_id.clone()),
195            user_id: parent_context.and_then(|ctx| ctx.user_id.clone()),
196            session_id: parent_context.and_then(|ctx| ctx.session_id.clone()),
197            originating_service: parent_context
198                .map(|ctx| ctx.originating_service.clone())
199                .unwrap_or_else(|| service_name.to_string()),
200            current_service: service_name.to_string(),
201            start_time: Utc::now(),
202            request_path,
203            custom_fields: HashMap::new(),
204        }
205    }
206
207    /// Extract correlation context from HTTP headers
208    pub fn extract_from_headers(
209        &self,
210        headers: &HashMap<String, String>,
211    ) -> Option<CorrelationContext> {
212        let correlation_id = headers.get(&self.config.correlation_headers.correlation_id)?;
213        let parent_request_id = headers.get(&self.config.correlation_headers.request_id);
214
215        Some(CorrelationContext {
216            correlation_id: correlation_id.clone(),
217            request_id: Uuid::new_v4().to_string(),
218            parent_request_id: parent_request_id.cloned(),
219            trace_id: headers
220                .get(&self.config.correlation_headers.trace_id)
221                .cloned(),
222            span_id: headers
223                .get(&self.config.correlation_headers.span_id)
224                .cloned(),
225            user_id: headers
226                .get(&self.config.correlation_headers.user_id)
227                .cloned(),
228            session_id: headers
229                .get(&self.config.correlation_headers.session_id)
230                .cloned(),
231            originating_service: "unknown".to_string(),
232            current_service: "current".to_string(),
233            start_time: Utc::now(),
234            request_path: vec![],
235            custom_fields: HashMap::new(),
236        })
237    }
238
239    /// Inject correlation context into HTTP headers
240    pub fn inject_into_headers(
241        &self,
242        context: &CorrelationContext,
243        headers: &mut HashMap<String, String>,
244    ) {
245        headers.insert(
246            self.config.correlation_headers.correlation_id.clone(),
247            context.correlation_id.clone(),
248        );
249        headers.insert(
250            self.config.correlation_headers.request_id.clone(),
251            context.request_id.clone(),
252        );
253
254        if let Some(parent_id) = &context.parent_request_id {
255            headers.insert(
256                self.config.correlation_headers.parent_request_id.clone(),
257                parent_id.clone(),
258            );
259        }
260
261        if let Some(trace_id) = &context.trace_id {
262            headers.insert(
263                self.config.correlation_headers.trace_id.clone(),
264                trace_id.clone(),
265            );
266        }
267
268        if let Some(span_id) = &context.span_id {
269            headers.insert(
270                self.config.correlation_headers.span_id.clone(),
271                span_id.clone(),
272            );
273        }
274
275        if let Some(user_id) = &context.user_id {
276            headers.insert(
277                self.config.correlation_headers.user_id.clone(),
278                user_id.clone(),
279            );
280        }
281
282        if let Some(session_id) = &context.session_id {
283            headers.insert(
284                self.config.correlation_headers.session_id.clone(),
285                session_id.clone(),
286            );
287        }
288    }
289
290    /// Start tracking a request
291    pub async fn start_request_tracking(
292        &self,
293        context: CorrelationContext,
294        method: &str,
295        params: serde_json::Value,
296    ) -> Result<(), CorrelationError> {
297        if !self.config.enabled {
298            return Ok(());
299        }
300
301        let entry = RequestTraceEntry {
302            context: context.clone(),
303            method: method.to_string(),
304            params,
305            response: None,
306            error: None,
307            duration_ms: None,
308            end_time: None,
309            memory_used_bytes: None,
310            cpu_time_ms: None,
311        };
312
313        let mut active = self.active_requests.write().await;
314
315        // Check if we're at capacity
316        if active.len() >= self.config.max_active_requests {
317            warn!("Active request tracking at capacity, dropping oldest request");
318            if let Some(oldest_key) = active.keys().next().cloned() {
319                active.remove(&oldest_key);
320            }
321        }
322
323        active.insert(context.request_id.clone(), entry);
324        debug!("Started tracking request: {}", context.request_id);
325
326        Ok(())
327    }
328
329    /// Complete request tracking
330    pub async fn complete_request_tracking(
331        &self,
332        request_id: &str,
333        response: Option<serde_json::Value>,
334        error: Option<String>,
335    ) -> Result<(), CorrelationError> {
336        if !self.config.enabled {
337            return Ok(());
338        }
339
340        let mut active = self.active_requests.write().await;
341
342        if let Some(mut entry) = active.remove(request_id) {
343            let end_time = Utc::now();
344            let duration_ms = (end_time - entry.context.start_time).num_milliseconds() as u64;
345
346            entry.response = response;
347            entry.error = error;
348            entry.duration_ms = Some(duration_ms);
349            entry.end_time = Some(end_time);
350
351            // Add to completed requests
352            let mut completed = self.completed_requests.write().await;
353            if completed.len() >= self.config.max_completed_requests {
354                // Remove oldest completed request
355                if let Some(oldest_key) = completed.keys().next().cloned() {
356                    completed.remove(&oldest_key);
357                }
358            }
359            completed.insert(request_id.to_string(), entry);
360
361            debug!("Completed tracking request: {}", request_id);
362        }
363
364        Ok(())
365    }
366
367    /// Get request trace by ID
368    pub async fn get_request_trace(&self, request_id: &str) -> Option<RequestTraceEntry> {
369        // Check active requests first
370        {
371            let active = self.active_requests.read().await;
372            if let Some(entry) = active.get(request_id) {
373                return Some(entry.clone());
374            }
375        }
376
377        // Check completed requests
378        let completed = self.completed_requests.read().await;
379        completed.get(request_id).cloned()
380    }
381
382    /// Get all traces for a correlation ID
383    pub async fn get_correlation_traces(&self, correlation_id: &str) -> Vec<RequestTraceEntry> {
384        let mut traces = Vec::new();
385
386        // Check active requests
387        {
388            let active = self.active_requests.read().await;
389            for entry in active.values() {
390                if entry.context.correlation_id == correlation_id {
391                    traces.push(entry.clone());
392                }
393            }
394        }
395
396        // Check completed requests
397        {
398            let completed = self.completed_requests.read().await;
399            for entry in completed.values() {
400                if entry.context.correlation_id == correlation_id {
401                    traces.push(entry.clone());
402                }
403            }
404        }
405
406        traces.sort_by(|a, b| a.context.start_time.cmp(&b.context.start_time));
407        traces
408    }
409
410    /// Get statistics about correlation tracking
411    pub async fn get_stats(&self) -> CorrelationStats {
412        let active = self.active_requests.read().await;
413        let completed = self.completed_requests.read().await;
414
415        CorrelationStats {
416            active_requests: active.len(),
417            completed_requests: completed.len(),
418            unique_correlations: {
419                let mut correlations = std::collections::HashSet::new();
420                for entry in active.values() {
421                    correlations.insert(&entry.context.correlation_id);
422                }
423                for entry in completed.values() {
424                    correlations.insert(&entry.context.correlation_id);
425                }
426                correlations.len()
427            },
428        }
429    }
430
431    /// Cleanup expired requests
432    async fn cleanup_expired_requests(
433        active_requests: Arc<RwLock<HashMap<String, RequestTraceEntry>>>,
434        completed_requests: Arc<RwLock<HashMap<String, RequestTraceEntry>>>,
435        config: CorrelationConfig,
436    ) {
437        let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
438
439        loop {
440            interval.tick().await;
441
442            let cutoff = Utc::now() - chrono::Duration::seconds(config.request_timeout_secs as i64);
443
444            // Cleanup active requests
445            {
446                let mut active = active_requests.write().await;
447                let expired_keys: Vec<_> = active
448                    .iter()
449                    .filter(|(_, entry)| entry.context.start_time < cutoff)
450                    .map(|(key, _)| key.clone())
451                    .collect();
452
453                for key in expired_keys {
454                    if let Some(entry) = active.remove(&key) {
455                        warn!("Request {} expired without completion", key);
456
457                        // Move to completed with error
458                        let mut completed_entry = entry;
459                        completed_entry.error = Some("Request expired".to_string());
460                        completed_entry.end_time = Some(Utc::now());
461
462                        let mut completed = completed_requests.write().await;
463                        if completed.len() >= config.max_completed_requests {
464                            if let Some(oldest_key) = completed.keys().next().cloned() {
465                                completed.remove(&oldest_key);
466                            }
467                        }
468                        completed.insert(key, completed_entry);
469                    }
470                }
471            }
472
473            // Cleanup old completed requests
474            {
475                let mut completed = completed_requests.write().await;
476                let old_cutoff = Utc::now() - chrono::Duration::hours(24); // Keep for 24 hours
477
478                let expired_keys: Vec<_> = completed
479                    .iter()
480                    .filter(|(_, entry)| {
481                        entry.end_time.unwrap_or(entry.context.start_time) < old_cutoff
482                    })
483                    .map(|(key, _)| key.clone())
484                    .collect();
485
486                for key in expired_keys {
487                    completed.remove(&key);
488                }
489            }
490        }
491    }
492}
493
494/// Correlation statistics
495#[derive(Debug, Serialize, Deserialize)]
496pub struct CorrelationStats {
497    pub active_requests: usize,
498    pub completed_requests: usize,
499    pub unique_correlations: usize,
500}
501
502/// Correlation errors
503#[derive(Debug, thiserror::Error)]
504pub enum CorrelationError {
505    #[error("Correlation tracking is disabled")]
506    Disabled,
507
508    #[error("Request not found: {0}")]
509    RequestNotFound(String),
510
511    #[error("Capacity exceeded")]
512    CapacityExceeded,
513}
514
515impl Default for CorrelationConfig {
516    fn default() -> Self {
517        Self {
518            enabled: true,
519            max_active_requests: 10000,
520            max_completed_requests: 50000,
521            request_timeout_secs: 300, // 5 minutes
522            track_resources: true,
523            cross_service_enabled: true,
524            correlation_headers: CorrelationHeaders::default(),
525        }
526    }
527}
528
529impl Default for CorrelationHeaders {
530    fn default() -> Self {
531        Self {
532            correlation_id: "X-Correlation-ID".to_string(),
533            request_id: "X-Request-ID".to_string(),
534            parent_request_id: "X-Parent-Request-ID".to_string(),
535            trace_id: "X-Trace-ID".to_string(),
536            span_id: "X-Span-ID".to_string(),
537            user_id: "X-User-ID".to_string(),
538            session_id: "X-Session-ID".to_string(),
539        }
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    #[tokio::test]
548    async fn test_correlation_context_creation() {
549        let config = CorrelationConfig::default();
550        let manager = CorrelationManager::new(config);
551
552        let context = manager.create_context("test-service", None);
553
554        assert!(!context.correlation_id.is_empty());
555        assert!(!context.request_id.is_empty());
556        assert_eq!(context.originating_service, "test-service");
557        assert_eq!(context.current_service, "test-service");
558        assert_eq!(context.request_path, vec!["test-service"]);
559    }
560
561    #[tokio::test]
562    async fn test_request_tracking() {
563        let config = CorrelationConfig::default();
564        let manager = CorrelationManager::new(config);
565
566        let context = manager.create_context("test-service", None);
567        let request_id = context.request_id.clone();
568
569        // Start tracking
570        manager
571            .start_request_tracking(
572                context,
573                "test_method",
574                serde_json::json!({"param": "value"}),
575            )
576            .await
577            .unwrap();
578
579        // Verify it's being tracked
580        let trace = manager.get_request_trace(&request_id).await;
581        assert!(trace.is_some());
582        assert_eq!(trace.unwrap().method, "test_method");
583
584        // Complete tracking
585        manager
586            .complete_request_tracking(
587                &request_id,
588                Some(serde_json::json!({"result": "success"})),
589                None,
590            )
591            .await
592            .unwrap();
593
594        // Verify it's still accessible
595        let trace = manager.get_request_trace(&request_id).await;
596        assert!(trace.is_some());
597        let trace = trace.unwrap();
598        assert!(trace.response.is_some());
599        assert!(trace.duration_ms.is_some());
600    }
601
602    #[test]
603    fn test_header_injection_extraction() {
604        let config = CorrelationConfig::default();
605        let manager = CorrelationManager::new(config);
606
607        let context = manager.create_context("test-service", None);
608        let mut headers = HashMap::new();
609
610        // Inject context into headers
611        manager.inject_into_headers(&context, &mut headers);
612
613        // Verify headers are present
614        assert!(headers.contains_key("X-Correlation-ID"));
615        assert!(headers.contains_key("X-Request-ID"));
616
617        // Extract context from headers
618        let extracted = manager.extract_from_headers(&headers);
619        assert!(extracted.is_some());
620
621        let extracted = extracted.unwrap();
622        assert_eq!(extracted.correlation_id, context.correlation_id);
623    }
624}