Skip to main content

mabi_core/logging/
context.rs

1//! Structured context propagation for distributed tracing.
2//!
3//! This module provides utilities for propagating context information
4//! across async boundaries and between services. It supports:
5//!
6//! - Request/correlation IDs for tracing requests across components
7//! - Device context for device-specific operations
8//! - Protocol context for protocol-specific logging
9//! - Custom context fields for application-specific needs
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use mabi_core::logging::context::{TraceContext, RequestContext};
15//!
16//! // Create a trace context for a request
17//! let ctx = TraceContext::new()
18//!     .with_request_id("req-12345")
19//!     .with_device_id("device-001")
20//!     .with_protocol("modbus");
21//!
22//! // Use in a span
23//! let span = ctx.create_span("handle_request");
24//! let _guard = span.enter();
25//!
26//! // Or use the request_span! macro
27//! request_span!(ctx, "handle_request", {
28//!     // Your code here
29//! });
30//! ```
31
32use std::collections::HashMap;
33use std::sync::Arc;
34
35use serde::{Deserialize, Serialize};
36use tracing::{span, Level, Span};
37use uuid::Uuid;
38
39/// Trace context for distributed tracing.
40///
41/// This struct carries context information that should be propagated
42/// across async boundaries and potentially across service boundaries.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TraceContext {
45    /// Unique request/correlation ID.
46    pub request_id: String,
47
48    /// Trace ID (for distributed tracing systems).
49    #[serde(default)]
50    pub trace_id: Option<String>,
51
52    /// Span ID (for distributed tracing systems).
53    #[serde(default)]
54    pub span_id: Option<String>,
55
56    /// Parent span ID.
57    #[serde(default)]
58    pub parent_span_id: Option<String>,
59
60    /// Device ID (if applicable).
61    #[serde(default)]
62    pub device_id: Option<String>,
63
64    /// Protocol name (if applicable).
65    #[serde(default)]
66    pub protocol: Option<String>,
67
68    /// Operation name.
69    #[serde(default)]
70    pub operation: Option<String>,
71
72    /// Custom fields.
73    #[serde(default)]
74    pub fields: HashMap<String, String>,
75
76    /// Timestamp when context was created.
77    #[serde(default = "default_timestamp")]
78    pub created_at: u64,
79}
80
81fn default_timestamp() -> u64 {
82    std::time::SystemTime::now()
83        .duration_since(std::time::UNIX_EPOCH)
84        .map(|d| d.as_millis() as u64)
85        .unwrap_or(0)
86}
87
88impl Default for TraceContext {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl TraceContext {
95    /// Create a new trace context with a generated request ID.
96    pub fn new() -> Self {
97        Self {
98            request_id: Uuid::new_v4().to_string(),
99            trace_id: None,
100            span_id: None,
101            parent_span_id: None,
102            device_id: None,
103            protocol: None,
104            operation: None,
105            fields: HashMap::new(),
106            created_at: default_timestamp(),
107        }
108    }
109
110    /// Create a trace context with a specific request ID.
111    pub fn with_request_id(request_id: impl Into<String>) -> Self {
112        Self {
113            request_id: request_id.into(),
114            ..Self::new()
115        }
116    }
117
118    /// Create a child context (new span under same trace).
119    pub fn child(&self) -> Self {
120        Self {
121            request_id: self.request_id.clone(),
122            trace_id: self.trace_id.clone(),
123            span_id: Some(Uuid::new_v4().to_string()),
124            parent_span_id: self.span_id.clone(),
125            device_id: self.device_id.clone(),
126            protocol: self.protocol.clone(),
127            operation: None,
128            fields: self.fields.clone(),
129            created_at: default_timestamp(),
130        }
131    }
132
133    /// Set the device ID.
134    pub fn with_device_id(mut self, device_id: impl Into<String>) -> Self {
135        self.device_id = Some(device_id.into());
136        self
137    }
138
139    /// Set the protocol.
140    pub fn with_protocol(mut self, protocol: impl Into<String>) -> Self {
141        self.protocol = Some(protocol.into());
142        self
143    }
144
145    /// Set the operation name.
146    pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
147        self.operation = Some(operation.into());
148        self
149    }
150
151    /// Set the trace ID (for integration with distributed tracing).
152    pub fn with_trace_id(mut self, trace_id: impl Into<String>) -> Self {
153        self.trace_id = Some(trace_id.into());
154        self
155    }
156
157    /// Add a custom field.
158    pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
159        self.fields.insert(key.into(), value.into());
160        self
161    }
162
163    /// Add multiple custom fields.
164    pub fn with_fields(mut self, fields: impl IntoIterator<Item = (String, String)>) -> Self {
165        self.fields.extend(fields);
166        self
167    }
168
169    /// Create a tracing span with this context.
170    pub fn create_span(&self, name: &'static str) -> Span {
171        let span = span!(
172            Level::INFO,
173            "request",
174            request_id = %self.request_id,
175            operation = name,
176        );
177
178        // Add optional fields
179        if let Some(ref device_id) = self.device_id {
180            span.record("device_id", device_id.as_str());
181        }
182        if let Some(ref protocol) = self.protocol {
183            span.record("protocol", protocol.as_str());
184        }
185        if let Some(ref trace_id) = self.trace_id {
186            span.record("trace_id", trace_id.as_str());
187        }
188
189        span
190    }
191
192    /// Create a debug-level span (for less critical operations).
193    pub fn create_debug_span(&self, name: &'static str) -> Span {
194        span!(
195            Level::DEBUG,
196            "operation",
197            request_id = %self.request_id,
198            operation = name,
199            device_id = self.device_id.as_deref().unwrap_or(""),
200        )
201    }
202
203    /// Get the age of this context in milliseconds.
204    pub fn age_ms(&self) -> u64 {
205        default_timestamp().saturating_sub(self.created_at)
206    }
207
208    /// Check if this context is older than the given milliseconds.
209    pub fn is_older_than_ms(&self, ms: u64) -> bool {
210        self.age_ms() > ms
211    }
212
213    /// Convert to a map for logging or serialization.
214    pub fn to_map(&self) -> HashMap<String, String> {
215        let mut map = HashMap::new();
216        map.insert("request_id".to_string(), self.request_id.clone());
217
218        if let Some(ref trace_id) = self.trace_id {
219            map.insert("trace_id".to_string(), trace_id.clone());
220        }
221        if let Some(ref device_id) = self.device_id {
222            map.insert("device_id".to_string(), device_id.clone());
223        }
224        if let Some(ref protocol) = self.protocol {
225            map.insert("protocol".to_string(), protocol.clone());
226        }
227        if let Some(ref operation) = self.operation {
228            map.insert("operation".to_string(), operation.clone());
229        }
230
231        map.extend(self.fields.clone());
232        map
233    }
234
235    /// Parse from HTTP headers (for distributed tracing).
236    pub fn from_headers(headers: &HashMap<String, String>) -> Self {
237        let mut ctx = Self::new();
238
239        if let Some(request_id) = headers.get("x-request-id").or(headers.get("x-correlation-id")) {
240            ctx.request_id = request_id.clone();
241        }
242        if let Some(trace_id) = headers.get("x-trace-id").or(headers.get("traceparent")) {
243            ctx.trace_id = Some(trace_id.clone());
244        }
245        if let Some(span_id) = headers.get("x-span-id") {
246            ctx.span_id = Some(span_id.clone());
247        }
248        if let Some(device_id) = headers.get("x-device-id") {
249            ctx.device_id = Some(device_id.clone());
250        }
251
252        ctx
253    }
254
255    /// Convert to HTTP headers (for distributed tracing).
256    pub fn to_headers(&self) -> HashMap<String, String> {
257        let mut headers = HashMap::new();
258
259        headers.insert("x-request-id".to_string(), self.request_id.clone());
260
261        if let Some(ref trace_id) = self.trace_id {
262            headers.insert("x-trace-id".to_string(), trace_id.clone());
263        }
264        if let Some(ref span_id) = self.span_id {
265            headers.insert("x-span-id".to_string(), span_id.clone());
266        }
267        if let Some(ref device_id) = self.device_id {
268            headers.insert("x-device-id".to_string(), device_id.clone());
269        }
270
271        headers
272    }
273}
274
275/// Request context for protocol operations.
276///
277/// This is a specialized context for protocol-level requests.
278#[derive(Debug, Clone)]
279pub struct RequestContext {
280    /// Base trace context.
281    pub trace: TraceContext,
282
283    /// Request start time.
284    pub start_time: std::time::Instant,
285
286    /// Request timeout (if set).
287    pub timeout: Option<std::time::Duration>,
288
289    /// Whether this request should be logged at debug level.
290    pub debug_request: bool,
291}
292
293impl RequestContext {
294    /// Create a new request context.
295    pub fn new() -> Self {
296        Self {
297            trace: TraceContext::new(),
298            start_time: std::time::Instant::now(),
299            timeout: None,
300            debug_request: false,
301        }
302    }
303
304    /// Create with an existing trace context.
305    pub fn with_trace(trace: TraceContext) -> Self {
306        Self {
307            trace,
308            start_time: std::time::Instant::now(),
309            timeout: None,
310            debug_request: false,
311        }
312    }
313
314    /// Set the device ID.
315    pub fn device(mut self, device_id: impl Into<String>) -> Self {
316        self.trace = self.trace.with_device_id(device_id);
317        self
318    }
319
320    /// Set the protocol.
321    pub fn protocol(mut self, protocol: impl Into<String>) -> Self {
322        self.trace = self.trace.with_protocol(protocol);
323        self
324    }
325
326    /// Set the operation.
327    pub fn operation(mut self, operation: impl Into<String>) -> Self {
328        self.trace = self.trace.with_operation(operation);
329        self
330    }
331
332    /// Set a timeout.
333    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
334        self.timeout = Some(timeout);
335        self
336    }
337
338    /// Mark as a debug request (logged at debug level).
339    pub fn debug(mut self) -> Self {
340        self.debug_request = true;
341        self
342    }
343
344    /// Get elapsed time since request started.
345    pub fn elapsed(&self) -> std::time::Duration {
346        self.start_time.elapsed()
347    }
348
349    /// Check if the request has timed out.
350    pub fn is_timed_out(&self) -> bool {
351        self.timeout
352            .map(|t| self.elapsed() > t)
353            .unwrap_or(false)
354    }
355
356    /// Get remaining time before timeout.
357    pub fn remaining_timeout(&self) -> Option<std::time::Duration> {
358        self.timeout.and_then(|t| t.checked_sub(self.elapsed()))
359    }
360
361    /// Get the request ID.
362    pub fn request_id(&self) -> &str {
363        &self.trace.request_id
364    }
365
366    /// Create a span for this request.
367    pub fn span(&self, name: &'static str) -> Span {
368        if self.debug_request {
369            self.trace.create_debug_span(name)
370        } else {
371            self.trace.create_span(name)
372        }
373    }
374}
375
376impl Default for RequestContext {
377    fn default() -> Self {
378        Self::new()
379    }
380}
381
382/// Shared trace context for passing across threads.
383pub type SharedTraceContext = Arc<TraceContext>;
384
385/// Create a shared trace context.
386pub fn shared_context(ctx: TraceContext) -> SharedTraceContext {
387    Arc::new(ctx)
388}
389
390/// Device context for device-specific operations.
391#[derive(Debug, Clone)]
392pub struct DeviceContext {
393    /// Device ID.
394    pub device_id: String,
395
396    /// Protocol.
397    pub protocol: String,
398
399    /// Base trace context.
400    pub trace: TraceContext,
401}
402
403impl DeviceContext {
404    /// Create a new device context.
405    pub fn new(device_id: impl Into<String>, protocol: impl Into<String>) -> Self {
406        let device_id = device_id.into();
407        let protocol = protocol.into();
408
409        Self {
410            device_id: device_id.clone(),
411            protocol: protocol.clone(),
412            trace: TraceContext::new()
413                .with_device_id(device_id)
414                .with_protocol(protocol),
415        }
416    }
417
418    /// Create with an existing trace context.
419    pub fn with_trace(
420        device_id: impl Into<String>,
421        protocol: impl Into<String>,
422        trace: TraceContext,
423    ) -> Self {
424        let device_id = device_id.into();
425        let protocol = protocol.into();
426
427        Self {
428            device_id: device_id.clone(),
429            protocol: protocol.clone(),
430            trace: trace.with_device_id(device_id).with_protocol(protocol),
431        }
432    }
433
434    /// Create a span for a device operation.
435    pub fn span(&self, operation: &'static str) -> Span {
436        span!(
437            Level::DEBUG,
438            "device_operation",
439            device_id = %self.device_id,
440            protocol = %self.protocol,
441            operation = operation,
442            request_id = %self.trace.request_id,
443        )
444    }
445
446    /// Get the request ID.
447    pub fn request_id(&self) -> &str {
448        &self.trace.request_id
449    }
450
451    /// Create a child context for a sub-operation.
452    pub fn child(&self) -> Self {
453        Self {
454            device_id: self.device_id.clone(),
455            protocol: self.protocol.clone(),
456            trace: self.trace.child(),
457        }
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_trace_context_creation() {
467        let ctx = TraceContext::new();
468        assert!(!ctx.request_id.is_empty());
469        assert!(ctx.device_id.is_none());
470        assert!(ctx.protocol.is_none());
471    }
472
473    #[test]
474    fn test_trace_context_builder() {
475        let ctx = TraceContext::new()
476            .with_device_id("device-001")
477            .with_protocol("modbus")
478            .with_operation("read")
479            .with_field("unit_id", "1");
480
481        assert_eq!(ctx.device_id, Some("device-001".to_string()));
482        assert_eq!(ctx.protocol, Some("modbus".to_string()));
483        assert_eq!(ctx.operation, Some("read".to_string()));
484        assert_eq!(ctx.fields.get("unit_id"), Some(&"1".to_string()));
485    }
486
487    #[test]
488    fn test_trace_context_child() {
489        let parent = TraceContext::new()
490            .with_device_id("device-001")
491            .with_trace_id("trace-123");
492
493        let child = parent.child();
494
495        assert_eq!(child.request_id, parent.request_id);
496        assert_eq!(child.trace_id, parent.trace_id);
497        assert_eq!(child.device_id, parent.device_id);
498        assert_eq!(child.parent_span_id, parent.span_id);
499    }
500
501    #[test]
502    fn test_trace_context_to_map() {
503        let ctx = TraceContext::new()
504            .with_device_id("device-001")
505            .with_protocol("modbus");
506
507        let map = ctx.to_map();
508        assert!(map.contains_key("request_id"));
509        assert_eq!(map.get("device_id"), Some(&"device-001".to_string()));
510        assert_eq!(map.get("protocol"), Some(&"modbus".to_string()));
511    }
512
513    #[test]
514    fn test_trace_context_headers() {
515        let ctx = TraceContext::new()
516            .with_device_id("device-001")
517            .with_trace_id("trace-123");
518
519        let headers = ctx.to_headers();
520        assert!(headers.contains_key("x-request-id"));
521        assert_eq!(headers.get("x-trace-id"), Some(&"trace-123".to_string()));
522        assert_eq!(headers.get("x-device-id"), Some(&"device-001".to_string()));
523
524        // Round-trip
525        let parsed = TraceContext::from_headers(&headers);
526        assert_eq!(parsed.request_id, ctx.request_id);
527        assert_eq!(parsed.trace_id, ctx.trace_id);
528        assert_eq!(parsed.device_id, ctx.device_id);
529    }
530
531    #[test]
532    fn test_request_context() {
533        let ctx = RequestContext::new()
534            .device("device-001")
535            .protocol("modbus")
536            .operation("read")
537            .with_timeout(std::time::Duration::from_secs(5));
538
539        assert!(!ctx.request_id().is_empty());
540        assert!(!ctx.is_timed_out());
541        assert!(ctx.remaining_timeout().is_some());
542    }
543
544    #[test]
545    fn test_device_context() {
546        let ctx = DeviceContext::new("device-001", "modbus");
547
548        assert_eq!(ctx.device_id, "device-001");
549        assert_eq!(ctx.protocol, "modbus");
550        assert!(!ctx.request_id().is_empty());
551
552        let child = ctx.child();
553        assert_eq!(child.request_id(), ctx.request_id());
554    }
555
556    #[test]
557    fn test_trace_context_age() {
558        let ctx = TraceContext::new();
559        std::thread::sleep(std::time::Duration::from_millis(10));
560
561        assert!(ctx.age_ms() >= 10);
562        assert!(ctx.is_older_than_ms(5));
563        assert!(!ctx.is_older_than_ms(1000));
564    }
565
566    #[test]
567    fn test_shared_context() {
568        let ctx = TraceContext::new().with_device_id("device-001");
569        let shared = shared_context(ctx);
570
571        assert_eq!(shared.device_id, Some("device-001".to_string()));
572    }
573}