Skip to main content

mabi_core/logging/
context.rs

1//! Structured context propagation for distributed tracing.
2//!
3//! This module provides utilities for propagating context information
4//! across async boundaries and between services. It supports:
5//!
6//! - Request/correlation IDs for tracing requests across components
7//! - Device context for device-specific operations
8//! - Protocol context for protocol-specific logging
9//! - Custom context fields for application-specific needs
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use mabi_core::logging::context::{TraceContext, RequestContext};
15//!
16//! // Create a trace context for a request
17//! let ctx = TraceContext::new()
18//!     .with_request_id("req-12345")
19//!     .with_device_id("device-001")
20//!     .with_protocol("modbus");
21//!
22//! // Use in a span
23//! let span = ctx.create_span("handle_request");
24//! let _guard = span.enter();
25//!
26//! // Or use the request_span! macro
27//! request_span!(ctx, "handle_request", {
28//!     // Your code here
29//! });
30//! ```
31
32use std::collections::HashMap;
33use std::sync::Arc;
34
35use serde::{Deserialize, Serialize};
36use tracing::{span, Level, Span};
37use uuid::Uuid;
38
39/// Trace context for distributed tracing.
40///
41/// This struct carries context information that should be propagated
42/// across async boundaries and potentially across service boundaries.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TraceContext {
45    /// Unique request/correlation ID.
46    pub request_id: String,
47
48    /// Trace ID (for distributed tracing systems).
49    #[serde(default)]
50    pub trace_id: Option<String>,
51
52    /// Span ID (for distributed tracing systems).
53    #[serde(default)]
54    pub span_id: Option<String>,
55
56    /// Parent span ID.
57    #[serde(default)]
58    pub parent_span_id: Option<String>,
59
60    /// Device ID (if applicable).
61    #[serde(default)]
62    pub device_id: Option<String>,
63
64    /// Protocol name (if applicable).
65    #[serde(default)]
66    pub protocol: Option<String>,
67
68    /// Operation name.
69    #[serde(default)]
70    pub operation: Option<String>,
71
72    /// Custom fields.
73    #[serde(default)]
74    pub fields: HashMap<String, String>,
75
76    /// Timestamp when context was created.
77    #[serde(default = "default_timestamp")]
78    pub created_at: u64,
79}
80
81fn default_timestamp() -> u64 {
82    std::time::SystemTime::now()
83        .duration_since(std::time::UNIX_EPOCH)
84        .map(|d| d.as_millis() as u64)
85        .unwrap_or(0)
86}
87
88impl Default for TraceContext {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl TraceContext {
95    /// Create a new trace context with a generated request ID.
96    pub fn new() -> Self {
97        Self {
98            request_id: Uuid::new_v4().to_string(),
99            trace_id: None,
100            span_id: None,
101            parent_span_id: None,
102            device_id: None,
103            protocol: None,
104            operation: None,
105            fields: HashMap::new(),
106            created_at: default_timestamp(),
107        }
108    }
109
110    /// Create a trace context with a specific request ID.
111    pub fn with_request_id(request_id: impl Into<String>) -> Self {
112        Self {
113            request_id: request_id.into(),
114            ..Self::new()
115        }
116    }
117
118    /// Create a child context (new span under same trace).
119    pub fn child(&self) -> Self {
120        Self {
121            request_id: self.request_id.clone(),
122            trace_id: self.trace_id.clone(),
123            span_id: Some(Uuid::new_v4().to_string()),
124            parent_span_id: self.span_id.clone(),
125            device_id: self.device_id.clone(),
126            protocol: self.protocol.clone(),
127            operation: None,
128            fields: self.fields.clone(),
129            created_at: default_timestamp(),
130        }
131    }
132
133    /// Set the device ID.
134    pub fn with_device_id(mut self, device_id: impl Into<String>) -> Self {
135        self.device_id = Some(device_id.into());
136        self
137    }
138
139    /// Set the protocol.
140    pub fn with_protocol(mut self, protocol: impl Into<String>) -> Self {
141        self.protocol = Some(protocol.into());
142        self
143    }
144
145    /// Set the operation name.
146    pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
147        self.operation = Some(operation.into());
148        self
149    }
150
151    /// Set the trace ID (for integration with distributed tracing).
152    pub fn with_trace_id(mut self, trace_id: impl Into<String>) -> Self {
153        self.trace_id = Some(trace_id.into());
154        self
155    }
156
157    /// Add a custom field.
158    pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
159        self.fields.insert(key.into(), value.into());
160        self
161    }
162
163    /// Add multiple custom fields.
164    pub fn with_fields(mut self, fields: impl IntoIterator<Item = (String, String)>) -> Self {
165        self.fields.extend(fields);
166        self
167    }
168
169    /// Create a tracing span with this context.
170    pub fn create_span(&self, name: &'static str) -> Span {
171        let span = span!(
172            Level::INFO,
173            "request",
174            request_id = %self.request_id,
175            operation = name,
176        );
177
178        // Add optional fields
179        if let Some(ref device_id) = self.device_id {
180            span.record("device_id", device_id.as_str());
181        }
182        if let Some(ref protocol) = self.protocol {
183            span.record("protocol", protocol.as_str());
184        }
185        if let Some(ref trace_id) = self.trace_id {
186            span.record("trace_id", trace_id.as_str());
187        }
188
189        span
190    }
191
192    /// Create a debug-level span (for less critical operations).
193    pub fn create_debug_span(&self, name: &'static str) -> Span {
194        span!(
195            Level::DEBUG,
196            "operation",
197            request_id = %self.request_id,
198            operation = name,
199            device_id = self.device_id.as_deref().unwrap_or(""),
200        )
201    }
202
203    /// Get the age of this context in milliseconds.
204    pub fn age_ms(&self) -> u64 {
205        default_timestamp().saturating_sub(self.created_at)
206    }
207
208    /// Check if this context is older than the given milliseconds.
209    pub fn is_older_than_ms(&self, ms: u64) -> bool {
210        self.age_ms() > ms
211    }
212
213    /// Convert to a map for logging or serialization.
214    pub fn to_map(&self) -> HashMap<String, String> {
215        let mut map = HashMap::new();
216        map.insert("request_id".to_string(), self.request_id.clone());
217
218        if let Some(ref trace_id) = self.trace_id {
219            map.insert("trace_id".to_string(), trace_id.clone());
220        }
221        if let Some(ref device_id) = self.device_id {
222            map.insert("device_id".to_string(), device_id.clone());
223        }
224        if let Some(ref protocol) = self.protocol {
225            map.insert("protocol".to_string(), protocol.clone());
226        }
227        if let Some(ref operation) = self.operation {
228            map.insert("operation".to_string(), operation.clone());
229        }
230
231        map.extend(self.fields.clone());
232        map
233    }
234
235    /// Parse from HTTP headers (for distributed tracing).
236    pub fn from_headers(headers: &HashMap<String, String>) -> Self {
237        let mut ctx = Self::new();
238
239        if let Some(request_id) = headers
240            .get("x-request-id")
241            .or(headers.get("x-correlation-id"))
242        {
243            ctx.request_id = request_id.clone();
244        }
245        if let Some(trace_id) = headers.get("x-trace-id").or(headers.get("traceparent")) {
246            ctx.trace_id = Some(trace_id.clone());
247        }
248        if let Some(span_id) = headers.get("x-span-id") {
249            ctx.span_id = Some(span_id.clone());
250        }
251        if let Some(device_id) = headers.get("x-device-id") {
252            ctx.device_id = Some(device_id.clone());
253        }
254
255        ctx
256    }
257
258    /// Convert to HTTP headers (for distributed tracing).
259    pub fn to_headers(&self) -> HashMap<String, String> {
260        let mut headers = HashMap::new();
261
262        headers.insert("x-request-id".to_string(), self.request_id.clone());
263
264        if let Some(ref trace_id) = self.trace_id {
265            headers.insert("x-trace-id".to_string(), trace_id.clone());
266        }
267        if let Some(ref span_id) = self.span_id {
268            headers.insert("x-span-id".to_string(), span_id.clone());
269        }
270        if let Some(ref device_id) = self.device_id {
271            headers.insert("x-device-id".to_string(), device_id.clone());
272        }
273
274        headers
275    }
276}
277
278/// Request context for protocol operations.
279///
280/// This is a specialized context for protocol-level requests.
281#[derive(Debug, Clone)]
282pub struct RequestContext {
283    /// Base trace context.
284    pub trace: TraceContext,
285
286    /// Request start time.
287    pub start_time: std::time::Instant,
288
289    /// Request timeout (if set).
290    pub timeout: Option<std::time::Duration>,
291
292    /// Whether this request should be logged at debug level.
293    pub debug_request: bool,
294}
295
296impl RequestContext {
297    /// Create a new request context.
298    pub fn new() -> Self {
299        Self {
300            trace: TraceContext::new(),
301            start_time: std::time::Instant::now(),
302            timeout: None,
303            debug_request: false,
304        }
305    }
306
307    /// Create with an existing trace context.
308    pub fn with_trace(trace: TraceContext) -> Self {
309        Self {
310            trace,
311            start_time: std::time::Instant::now(),
312            timeout: None,
313            debug_request: false,
314        }
315    }
316
317    /// Set the device ID.
318    pub fn device(mut self, device_id: impl Into<String>) -> Self {
319        self.trace = self.trace.with_device_id(device_id);
320        self
321    }
322
323    /// Set the protocol.
324    pub fn protocol(mut self, protocol: impl Into<String>) -> Self {
325        self.trace = self.trace.with_protocol(protocol);
326        self
327    }
328
329    /// Set the operation.
330    pub fn operation(mut self, operation: impl Into<String>) -> Self {
331        self.trace = self.trace.with_operation(operation);
332        self
333    }
334
335    /// Set a timeout.
336    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
337        self.timeout = Some(timeout);
338        self
339    }
340
341    /// Mark as a debug request (logged at debug level).
342    pub fn debug(mut self) -> Self {
343        self.debug_request = true;
344        self
345    }
346
347    /// Get elapsed time since request started.
348    pub fn elapsed(&self) -> std::time::Duration {
349        self.start_time.elapsed()
350    }
351
352    /// Check if the request has timed out.
353    pub fn is_timed_out(&self) -> bool {
354        self.timeout.map(|t| self.elapsed() > t).unwrap_or(false)
355    }
356
357    /// Get remaining time before timeout.
358    pub fn remaining_timeout(&self) -> Option<std::time::Duration> {
359        self.timeout.and_then(|t| t.checked_sub(self.elapsed()))
360    }
361
362    /// Get the request ID.
363    pub fn request_id(&self) -> &str {
364        &self.trace.request_id
365    }
366
367    /// Create a span for this request.
368    pub fn span(&self, name: &'static str) -> Span {
369        if self.debug_request {
370            self.trace.create_debug_span(name)
371        } else {
372            self.trace.create_span(name)
373        }
374    }
375}
376
377impl Default for RequestContext {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383/// Shared trace context for passing across threads.
384pub type SharedTraceContext = Arc<TraceContext>;
385
386/// Create a shared trace context.
387pub fn shared_context(ctx: TraceContext) -> SharedTraceContext {
388    Arc::new(ctx)
389}
390
391/// Device context for device-specific operations.
392#[derive(Debug, Clone)]
393pub struct DeviceContext {
394    /// Device ID.
395    pub device_id: String,
396
397    /// Protocol.
398    pub protocol: String,
399
400    /// Base trace context.
401    pub trace: TraceContext,
402}
403
404impl DeviceContext {
405    /// Create a new device context.
406    pub fn new(device_id: impl Into<String>, protocol: impl Into<String>) -> Self {
407        let device_id = device_id.into();
408        let protocol = protocol.into();
409
410        Self {
411            device_id: device_id.clone(),
412            protocol: protocol.clone(),
413            trace: TraceContext::new()
414                .with_device_id(device_id)
415                .with_protocol(protocol),
416        }
417    }
418
419    /// Create with an existing trace context.
420    pub fn with_trace(
421        device_id: impl Into<String>,
422        protocol: impl Into<String>,
423        trace: TraceContext,
424    ) -> Self {
425        let device_id = device_id.into();
426        let protocol = protocol.into();
427
428        Self {
429            device_id: device_id.clone(),
430            protocol: protocol.clone(),
431            trace: trace.with_device_id(device_id).with_protocol(protocol),
432        }
433    }
434
435    /// Create a span for a device operation.
436    pub fn span(&self, operation: &'static str) -> Span {
437        span!(
438            Level::DEBUG,
439            "device_operation",
440            device_id = %self.device_id,
441            protocol = %self.protocol,
442            operation = operation,
443            request_id = %self.trace.request_id,
444        )
445    }
446
447    /// Get the request ID.
448    pub fn request_id(&self) -> &str {
449        &self.trace.request_id
450    }
451
452    /// Create a child context for a sub-operation.
453    pub fn child(&self) -> Self {
454        Self {
455            device_id: self.device_id.clone(),
456            protocol: self.protocol.clone(),
457            trace: self.trace.child(),
458        }
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_trace_context_creation() {
468        let ctx = TraceContext::new();
469        assert!(!ctx.request_id.is_empty());
470        assert!(ctx.device_id.is_none());
471        assert!(ctx.protocol.is_none());
472    }
473
474    #[test]
475    fn test_trace_context_builder() {
476        let ctx = TraceContext::new()
477            .with_device_id("device-001")
478            .with_protocol("modbus")
479            .with_operation("read")
480            .with_field("unit_id", "1");
481
482        assert_eq!(ctx.device_id, Some("device-001".to_string()));
483        assert_eq!(ctx.protocol, Some("modbus".to_string()));
484        assert_eq!(ctx.operation, Some("read".to_string()));
485        assert_eq!(ctx.fields.get("unit_id"), Some(&"1".to_string()));
486    }
487
488    #[test]
489    fn test_trace_context_child() {
490        let parent = TraceContext::new()
491            .with_device_id("device-001")
492            .with_trace_id("trace-123");
493
494        let child = parent.child();
495
496        assert_eq!(child.request_id, parent.request_id);
497        assert_eq!(child.trace_id, parent.trace_id);
498        assert_eq!(child.device_id, parent.device_id);
499        assert_eq!(child.parent_span_id, parent.span_id);
500    }
501
502    #[test]
503    fn test_trace_context_to_map() {
504        let ctx = TraceContext::new()
505            .with_device_id("device-001")
506            .with_protocol("modbus");
507
508        let map = ctx.to_map();
509        assert!(map.contains_key("request_id"));
510        assert_eq!(map.get("device_id"), Some(&"device-001".to_string()));
511        assert_eq!(map.get("protocol"), Some(&"modbus".to_string()));
512    }
513
514    #[test]
515    fn test_trace_context_headers() {
516        let ctx = TraceContext::new()
517            .with_device_id("device-001")
518            .with_trace_id("trace-123");
519
520        let headers = ctx.to_headers();
521        assert!(headers.contains_key("x-request-id"));
522        assert_eq!(headers.get("x-trace-id"), Some(&"trace-123".to_string()));
523        assert_eq!(headers.get("x-device-id"), Some(&"device-001".to_string()));
524
525        // Round-trip
526        let parsed = TraceContext::from_headers(&headers);
527        assert_eq!(parsed.request_id, ctx.request_id);
528        assert_eq!(parsed.trace_id, ctx.trace_id);
529        assert_eq!(parsed.device_id, ctx.device_id);
530    }
531
532    #[test]
533    fn test_request_context() {
534        let ctx = RequestContext::new()
535            .device("device-001")
536            .protocol("modbus")
537            .operation("read")
538            .with_timeout(std::time::Duration::from_secs(5));
539
540        assert!(!ctx.request_id().is_empty());
541        assert!(!ctx.is_timed_out());
542        assert!(ctx.remaining_timeout().is_some());
543    }
544
545    #[test]
546    fn test_device_context() {
547        let ctx = DeviceContext::new("device-001", "modbus");
548
549        assert_eq!(ctx.device_id, "device-001");
550        assert_eq!(ctx.protocol, "modbus");
551        assert!(!ctx.request_id().is_empty());
552
553        let child = ctx.child();
554        assert_eq!(child.request_id(), ctx.request_id());
555    }
556
557    #[test]
558    fn test_trace_context_age() {
559        let ctx = TraceContext::new();
560        std::thread::sleep(std::time::Duration::from_millis(10));
561
562        assert!(ctx.age_ms() >= 10);
563        assert!(ctx.is_older_than_ms(5));
564        assert!(!ctx.is_older_than_ms(1000));
565    }
566
567    #[test]
568    fn test_shared_context() {
569        let ctx = TraceContext::new().with_device_id("device-001");
570        let shared = shared_context(ctx);
571
572        assert_eq!(shared.device_id, Some("device-001".to_string()));
573    }
574}