1use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, info, warn};
15use uuid::Uuid;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CorrelationContext {
20 pub correlation_id: String,
22
23 pub request_id: String,
25
26 pub parent_request_id: Option<String>,
28
29 pub trace_id: Option<String>,
31
32 pub span_id: Option<String>,
34
35 pub user_id: Option<String>,
37
38 pub session_id: Option<String>,
40
41 pub originating_service: String,
43
44 pub current_service: String,
46
47 pub start_time: DateTime<Utc>,
49
50 pub request_path: Vec<String>,
52
53 pub custom_fields: HashMap<String, String>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct RequestTraceEntry {
60 pub context: CorrelationContext,
62
63 pub method: String,
65 pub params: serde_json::Value,
66 pub response: Option<serde_json::Value>,
67 pub error: Option<String>,
68
69 pub duration_ms: Option<u64>,
71 pub end_time: Option<DateTime<Utc>>,
72
73 pub memory_used_bytes: Option<u64>,
75 pub cpu_time_ms: Option<u64>,
76}
77
78pub struct CorrelationManager {
80 active_requests: Arc<RwLock<HashMap<String, RequestTraceEntry>>>,
82
83 completed_requests: Arc<RwLock<HashMap<String, RequestTraceEntry>>>,
85
86 config: CorrelationConfig,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct CorrelationConfig {
93 pub enabled: bool,
95
96 pub max_active_requests: usize,
98
99 pub max_completed_requests: usize,
101
102 pub request_timeout_secs: u64,
104
105 pub track_resources: bool,
107
108 pub cross_service_enabled: bool,
110
111 pub correlation_headers: CorrelationHeaders,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CorrelationHeaders {
118 pub correlation_id: String,
120
121 pub request_id: String,
123
124 pub parent_request_id: String,
126
127 pub trace_id: String,
129
130 pub span_id: String,
132
133 pub user_id: String,
135
136 pub session_id: String,
138}
139
140impl CorrelationManager {
141 pub fn new(config: CorrelationConfig) -> Self {
143 Self {
144 active_requests: Arc::new(RwLock::new(HashMap::new())),
145 completed_requests: Arc::new(RwLock::new(HashMap::new())),
146 config,
147 }
148 }
149
150 pub async fn start(&self) {
152 if !self.config.enabled {
153 info!("Correlation tracking is disabled");
154 return;
155 }
156
157 info!("Starting correlation tracking");
158
159 let active_requests = self.active_requests.clone();
161 let completed_requests = self.completed_requests.clone();
162 let config = self.config.clone();
163
164 tokio::spawn(async move {
165 Self::cleanup_expired_requests(active_requests, completed_requests, config).await;
166 });
167 }
168
169 pub fn create_context(
171 &self,
172 service_name: &str,
173 parent_context: Option<&CorrelationContext>,
174 ) -> CorrelationContext {
175 let correlation_id = if let Some(parent) = parent_context {
176 parent.correlation_id.clone()
177 } else {
178 Uuid::new_v4().to_string()
179 };
180
181 let request_id = Uuid::new_v4().to_string();
182 let parent_request_id = parent_context.map(|ctx| ctx.request_id.clone());
183
184 let mut request_path = parent_context
185 .map(|ctx| ctx.request_path.clone())
186 .unwrap_or_default();
187 request_path.push(service_name.to_string());
188
189 CorrelationContext {
190 correlation_id,
191 request_id,
192 parent_request_id,
193 trace_id: parent_context.and_then(|ctx| ctx.trace_id.clone()),
194 span_id: parent_context.and_then(|ctx| ctx.span_id.clone()),
195 user_id: parent_context.and_then(|ctx| ctx.user_id.clone()),
196 session_id: parent_context.and_then(|ctx| ctx.session_id.clone()),
197 originating_service: parent_context
198 .map(|ctx| ctx.originating_service.clone())
199 .unwrap_or_else(|| service_name.to_string()),
200 current_service: service_name.to_string(),
201 start_time: Utc::now(),
202 request_path,
203 custom_fields: HashMap::new(),
204 }
205 }
206
207 pub fn extract_from_headers(
209 &self,
210 headers: &HashMap<String, String>,
211 ) -> Option<CorrelationContext> {
212 let correlation_id = headers.get(&self.config.correlation_headers.correlation_id)?;
213 let parent_request_id = headers.get(&self.config.correlation_headers.request_id);
214
215 Some(CorrelationContext {
216 correlation_id: correlation_id.clone(),
217 request_id: Uuid::new_v4().to_string(),
218 parent_request_id: parent_request_id.cloned(),
219 trace_id: headers
220 .get(&self.config.correlation_headers.trace_id)
221 .cloned(),
222 span_id: headers
223 .get(&self.config.correlation_headers.span_id)
224 .cloned(),
225 user_id: headers
226 .get(&self.config.correlation_headers.user_id)
227 .cloned(),
228 session_id: headers
229 .get(&self.config.correlation_headers.session_id)
230 .cloned(),
231 originating_service: "unknown".to_string(),
232 current_service: "current".to_string(),
233 start_time: Utc::now(),
234 request_path: vec![],
235 custom_fields: HashMap::new(),
236 })
237 }
238
239 pub fn inject_into_headers(
241 &self,
242 context: &CorrelationContext,
243 headers: &mut HashMap<String, String>,
244 ) {
245 headers.insert(
246 self.config.correlation_headers.correlation_id.clone(),
247 context.correlation_id.clone(),
248 );
249 headers.insert(
250 self.config.correlation_headers.request_id.clone(),
251 context.request_id.clone(),
252 );
253
254 if let Some(parent_id) = &context.parent_request_id {
255 headers.insert(
256 self.config.correlation_headers.parent_request_id.clone(),
257 parent_id.clone(),
258 );
259 }
260
261 if let Some(trace_id) = &context.trace_id {
262 headers.insert(
263 self.config.correlation_headers.trace_id.clone(),
264 trace_id.clone(),
265 );
266 }
267
268 if let Some(span_id) = &context.span_id {
269 headers.insert(
270 self.config.correlation_headers.span_id.clone(),
271 span_id.clone(),
272 );
273 }
274
275 if let Some(user_id) = &context.user_id {
276 headers.insert(
277 self.config.correlation_headers.user_id.clone(),
278 user_id.clone(),
279 );
280 }
281
282 if let Some(session_id) = &context.session_id {
283 headers.insert(
284 self.config.correlation_headers.session_id.clone(),
285 session_id.clone(),
286 );
287 }
288 }
289
290 pub async fn start_request_tracking(
292 &self,
293 context: CorrelationContext,
294 method: &str,
295 params: serde_json::Value,
296 ) -> Result<(), CorrelationError> {
297 if !self.config.enabled {
298 return Ok(());
299 }
300
301 let entry = RequestTraceEntry {
302 context: context.clone(),
303 method: method.to_string(),
304 params,
305 response: None,
306 error: None,
307 duration_ms: None,
308 end_time: None,
309 memory_used_bytes: None,
310 cpu_time_ms: None,
311 };
312
313 let mut active = self.active_requests.write().await;
314
315 if active.len() >= self.config.max_active_requests {
317 warn!("Active request tracking at capacity, dropping oldest request");
318 if let Some(oldest_key) = active.keys().next().cloned() {
319 active.remove(&oldest_key);
320 }
321 }
322
323 active.insert(context.request_id.clone(), entry);
324 debug!("Started tracking request: {}", context.request_id);
325
326 Ok(())
327 }
328
329 pub async fn complete_request_tracking(
331 &self,
332 request_id: &str,
333 response: Option<serde_json::Value>,
334 error: Option<String>,
335 ) -> Result<(), CorrelationError> {
336 if !self.config.enabled {
337 return Ok(());
338 }
339
340 let mut active = self.active_requests.write().await;
341
342 if let Some(mut entry) = active.remove(request_id) {
343 let end_time = Utc::now();
344 let duration_ms = (end_time - entry.context.start_time).num_milliseconds() as u64;
345
346 entry.response = response;
347 entry.error = error;
348 entry.duration_ms = Some(duration_ms);
349 entry.end_time = Some(end_time);
350
351 let mut completed = self.completed_requests.write().await;
353 if completed.len() >= self.config.max_completed_requests {
354 if let Some(oldest_key) = completed.keys().next().cloned() {
356 completed.remove(&oldest_key);
357 }
358 }
359 completed.insert(request_id.to_string(), entry);
360
361 debug!("Completed tracking request: {}", request_id);
362 }
363
364 Ok(())
365 }
366
367 pub async fn get_request_trace(&self, request_id: &str) -> Option<RequestTraceEntry> {
369 {
371 let active = self.active_requests.read().await;
372 if let Some(entry) = active.get(request_id) {
373 return Some(entry.clone());
374 }
375 }
376
377 let completed = self.completed_requests.read().await;
379 completed.get(request_id).cloned()
380 }
381
382 pub async fn get_correlation_traces(&self, correlation_id: &str) -> Vec<RequestTraceEntry> {
384 let mut traces = Vec::new();
385
386 {
388 let active = self.active_requests.read().await;
389 for entry in active.values() {
390 if entry.context.correlation_id == correlation_id {
391 traces.push(entry.clone());
392 }
393 }
394 }
395
396 {
398 let completed = self.completed_requests.read().await;
399 for entry in completed.values() {
400 if entry.context.correlation_id == correlation_id {
401 traces.push(entry.clone());
402 }
403 }
404 }
405
406 traces.sort_by(|a, b| a.context.start_time.cmp(&b.context.start_time));
407 traces
408 }
409
410 pub async fn get_stats(&self) -> CorrelationStats {
412 let active = self.active_requests.read().await;
413 let completed = self.completed_requests.read().await;
414
415 CorrelationStats {
416 active_requests: active.len(),
417 completed_requests: completed.len(),
418 unique_correlations: {
419 let mut correlations = std::collections::HashSet::new();
420 for entry in active.values() {
421 correlations.insert(&entry.context.correlation_id);
422 }
423 for entry in completed.values() {
424 correlations.insert(&entry.context.correlation_id);
425 }
426 correlations.len()
427 },
428 }
429 }
430
431 async fn cleanup_expired_requests(
433 active_requests: Arc<RwLock<HashMap<String, RequestTraceEntry>>>,
434 completed_requests: Arc<RwLock<HashMap<String, RequestTraceEntry>>>,
435 config: CorrelationConfig,
436 ) {
437 let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
438
439 loop {
440 interval.tick().await;
441
442 let cutoff = Utc::now() - chrono::Duration::seconds(config.request_timeout_secs as i64);
443
444 {
446 let mut active = active_requests.write().await;
447 let expired_keys: Vec<_> = active
448 .iter()
449 .filter(|(_, entry)| entry.context.start_time < cutoff)
450 .map(|(key, _)| key.clone())
451 .collect();
452
453 for key in expired_keys {
454 if let Some(entry) = active.remove(&key) {
455 warn!("Request {} expired without completion", key);
456
457 let mut completed_entry = entry;
459 completed_entry.error = Some("Request expired".to_string());
460 completed_entry.end_time = Some(Utc::now());
461
462 let mut completed = completed_requests.write().await;
463 if completed.len() >= config.max_completed_requests {
464 if let Some(oldest_key) = completed.keys().next().cloned() {
465 completed.remove(&oldest_key);
466 }
467 }
468 completed.insert(key, completed_entry);
469 }
470 }
471 }
472
473 {
475 let mut completed = completed_requests.write().await;
476 let old_cutoff = Utc::now() - chrono::Duration::hours(24); let expired_keys: Vec<_> = completed
479 .iter()
480 .filter(|(_, entry)| {
481 entry.end_time.unwrap_or(entry.context.start_time) < old_cutoff
482 })
483 .map(|(key, _)| key.clone())
484 .collect();
485
486 for key in expired_keys {
487 completed.remove(&key);
488 }
489 }
490 }
491 }
492}
493
494#[derive(Debug, Serialize, Deserialize)]
496pub struct CorrelationStats {
497 pub active_requests: usize,
498 pub completed_requests: usize,
499 pub unique_correlations: usize,
500}
501
502#[derive(Debug, thiserror::Error)]
504pub enum CorrelationError {
505 #[error("Correlation tracking is disabled")]
506 Disabled,
507
508 #[error("Request not found: {0}")]
509 RequestNotFound(String),
510
511 #[error("Capacity exceeded")]
512 CapacityExceeded,
513}
514
515impl Default for CorrelationConfig {
516 fn default() -> Self {
517 Self {
518 enabled: true,
519 max_active_requests: 10000,
520 max_completed_requests: 50000,
521 request_timeout_secs: 300, track_resources: true,
523 cross_service_enabled: true,
524 correlation_headers: CorrelationHeaders::default(),
525 }
526 }
527}
528
529impl Default for CorrelationHeaders {
530 fn default() -> Self {
531 Self {
532 correlation_id: "X-Correlation-ID".to_string(),
533 request_id: "X-Request-ID".to_string(),
534 parent_request_id: "X-Parent-Request-ID".to_string(),
535 trace_id: "X-Trace-ID".to_string(),
536 span_id: "X-Span-ID".to_string(),
537 user_id: "X-User-ID".to_string(),
538 session_id: "X-Session-ID".to_string(),
539 }
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[tokio::test]
548 async fn test_correlation_context_creation() {
549 let config = CorrelationConfig::default();
550 let manager = CorrelationManager::new(config);
551
552 let context = manager.create_context("test-service", None);
553
554 assert!(!context.correlation_id.is_empty());
555 assert!(!context.request_id.is_empty());
556 assert_eq!(context.originating_service, "test-service");
557 assert_eq!(context.current_service, "test-service");
558 assert_eq!(context.request_path, vec!["test-service"]);
559 }
560
561 #[tokio::test]
562 async fn test_request_tracking() {
563 let config = CorrelationConfig::default();
564 let manager = CorrelationManager::new(config);
565
566 let context = manager.create_context("test-service", None);
567 let request_id = context.request_id.clone();
568
569 manager
571 .start_request_tracking(
572 context,
573 "test_method",
574 serde_json::json!({"param": "value"}),
575 )
576 .await
577 .unwrap();
578
579 let trace = manager.get_request_trace(&request_id).await;
581 assert!(trace.is_some());
582 assert_eq!(trace.unwrap().method, "test_method");
583
584 manager
586 .complete_request_tracking(
587 &request_id,
588 Some(serde_json::json!({"result": "success"})),
589 None,
590 )
591 .await
592 .unwrap();
593
594 let trace = manager.get_request_trace(&request_id).await;
596 assert!(trace.is_some());
597 let trace = trace.unwrap();
598 assert!(trace.response.is_some());
599 assert!(trace.duration_ms.is_some());
600 }
601
602 #[test]
603 fn test_header_injection_extraction() {
604 let config = CorrelationConfig::default();
605 let manager = CorrelationManager::new(config);
606
607 let context = manager.create_context("test-service", None);
608 let mut headers = HashMap::new();
609
610 manager.inject_into_headers(&context, &mut headers);
612
613 assert!(headers.contains_key("X-Correlation-ID"));
615 assert!(headers.contains_key("X-Request-ID"));
616
617 let extracted = manager.extract_from_headers(&headers);
619 assert!(extracted.is_some());
620
621 let extracted = extracted.unwrap();
622 assert_eq!(extracted.correlation_id, context.correlation_id);
623 }
624}