1use std::collections::HashMap;
33use std::sync::Arc;
34
35use serde::{Deserialize, Serialize};
36use tracing::{span, Level, Span};
37use uuid::Uuid;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TraceContext {
45 pub request_id: String,
47
48 #[serde(default)]
50 pub trace_id: Option<String>,
51
52 #[serde(default)]
54 pub span_id: Option<String>,
55
56 #[serde(default)]
58 pub parent_span_id: Option<String>,
59
60 #[serde(default)]
62 pub device_id: Option<String>,
63
64 #[serde(default)]
66 pub protocol: Option<String>,
67
68 #[serde(default)]
70 pub operation: Option<String>,
71
72 #[serde(default)]
74 pub fields: HashMap<String, String>,
75
76 #[serde(default = "default_timestamp")]
78 pub created_at: u64,
79}
80
81fn default_timestamp() -> u64 {
82 std::time::SystemTime::now()
83 .duration_since(std::time::UNIX_EPOCH)
84 .map(|d| d.as_millis() as u64)
85 .unwrap_or(0)
86}
87
88impl Default for TraceContext {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl TraceContext {
95 pub fn new() -> Self {
97 Self {
98 request_id: Uuid::new_v4().to_string(),
99 trace_id: None,
100 span_id: None,
101 parent_span_id: None,
102 device_id: None,
103 protocol: None,
104 operation: None,
105 fields: HashMap::new(),
106 created_at: default_timestamp(),
107 }
108 }
109
110 pub fn with_request_id(request_id: impl Into<String>) -> Self {
112 Self {
113 request_id: request_id.into(),
114 ..Self::new()
115 }
116 }
117
118 pub fn child(&self) -> Self {
120 Self {
121 request_id: self.request_id.clone(),
122 trace_id: self.trace_id.clone(),
123 span_id: Some(Uuid::new_v4().to_string()),
124 parent_span_id: self.span_id.clone(),
125 device_id: self.device_id.clone(),
126 protocol: self.protocol.clone(),
127 operation: None,
128 fields: self.fields.clone(),
129 created_at: default_timestamp(),
130 }
131 }
132
133 pub fn with_device_id(mut self, device_id: impl Into<String>) -> Self {
135 self.device_id = Some(device_id.into());
136 self
137 }
138
139 pub fn with_protocol(mut self, protocol: impl Into<String>) -> Self {
141 self.protocol = Some(protocol.into());
142 self
143 }
144
145 pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
147 self.operation = Some(operation.into());
148 self
149 }
150
151 pub fn with_trace_id(mut self, trace_id: impl Into<String>) -> Self {
153 self.trace_id = Some(trace_id.into());
154 self
155 }
156
157 pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
159 self.fields.insert(key.into(), value.into());
160 self
161 }
162
163 pub fn with_fields(mut self, fields: impl IntoIterator<Item = (String, String)>) -> Self {
165 self.fields.extend(fields);
166 self
167 }
168
169 pub fn create_span(&self, name: &'static str) -> Span {
171 let span = span!(
172 Level::INFO,
173 "request",
174 request_id = %self.request_id,
175 operation = name,
176 );
177
178 if let Some(ref device_id) = self.device_id {
180 span.record("device_id", device_id.as_str());
181 }
182 if let Some(ref protocol) = self.protocol {
183 span.record("protocol", protocol.as_str());
184 }
185 if let Some(ref trace_id) = self.trace_id {
186 span.record("trace_id", trace_id.as_str());
187 }
188
189 span
190 }
191
192 pub fn create_debug_span(&self, name: &'static str) -> Span {
194 span!(
195 Level::DEBUG,
196 "operation",
197 request_id = %self.request_id,
198 operation = name,
199 device_id = self.device_id.as_deref().unwrap_or(""),
200 )
201 }
202
203 pub fn age_ms(&self) -> u64 {
205 default_timestamp().saturating_sub(self.created_at)
206 }
207
208 pub fn is_older_than_ms(&self, ms: u64) -> bool {
210 self.age_ms() > ms
211 }
212
213 pub fn to_map(&self) -> HashMap<String, String> {
215 let mut map = HashMap::new();
216 map.insert("request_id".to_string(), self.request_id.clone());
217
218 if let Some(ref trace_id) = self.trace_id {
219 map.insert("trace_id".to_string(), trace_id.clone());
220 }
221 if let Some(ref device_id) = self.device_id {
222 map.insert("device_id".to_string(), device_id.clone());
223 }
224 if let Some(ref protocol) = self.protocol {
225 map.insert("protocol".to_string(), protocol.clone());
226 }
227 if let Some(ref operation) = self.operation {
228 map.insert("operation".to_string(), operation.clone());
229 }
230
231 map.extend(self.fields.clone());
232 map
233 }
234
235 pub fn from_headers(headers: &HashMap<String, String>) -> Self {
237 let mut ctx = Self::new();
238
239 if let Some(request_id) = headers
240 .get("x-request-id")
241 .or(headers.get("x-correlation-id"))
242 {
243 ctx.request_id = request_id.clone();
244 }
245 if let Some(trace_id) = headers.get("x-trace-id").or(headers.get("traceparent")) {
246 ctx.trace_id = Some(trace_id.clone());
247 }
248 if let Some(span_id) = headers.get("x-span-id") {
249 ctx.span_id = Some(span_id.clone());
250 }
251 if let Some(device_id) = headers.get("x-device-id") {
252 ctx.device_id = Some(device_id.clone());
253 }
254
255 ctx
256 }
257
258 pub fn to_headers(&self) -> HashMap<String, String> {
260 let mut headers = HashMap::new();
261
262 headers.insert("x-request-id".to_string(), self.request_id.clone());
263
264 if let Some(ref trace_id) = self.trace_id {
265 headers.insert("x-trace-id".to_string(), trace_id.clone());
266 }
267 if let Some(ref span_id) = self.span_id {
268 headers.insert("x-span-id".to_string(), span_id.clone());
269 }
270 if let Some(ref device_id) = self.device_id {
271 headers.insert("x-device-id".to_string(), device_id.clone());
272 }
273
274 headers
275 }
276}
277
278#[derive(Debug, Clone)]
282pub struct RequestContext {
283 pub trace: TraceContext,
285
286 pub start_time: std::time::Instant,
288
289 pub timeout: Option<std::time::Duration>,
291
292 pub debug_request: bool,
294}
295
296impl RequestContext {
297 pub fn new() -> Self {
299 Self {
300 trace: TraceContext::new(),
301 start_time: std::time::Instant::now(),
302 timeout: None,
303 debug_request: false,
304 }
305 }
306
307 pub fn with_trace(trace: TraceContext) -> Self {
309 Self {
310 trace,
311 start_time: std::time::Instant::now(),
312 timeout: None,
313 debug_request: false,
314 }
315 }
316
317 pub fn device(mut self, device_id: impl Into<String>) -> Self {
319 self.trace = self.trace.with_device_id(device_id);
320 self
321 }
322
323 pub fn protocol(mut self, protocol: impl Into<String>) -> Self {
325 self.trace = self.trace.with_protocol(protocol);
326 self
327 }
328
329 pub fn operation(mut self, operation: impl Into<String>) -> Self {
331 self.trace = self.trace.with_operation(operation);
332 self
333 }
334
335 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
337 self.timeout = Some(timeout);
338 self
339 }
340
341 pub fn debug(mut self) -> Self {
343 self.debug_request = true;
344 self
345 }
346
347 pub fn elapsed(&self) -> std::time::Duration {
349 self.start_time.elapsed()
350 }
351
352 pub fn is_timed_out(&self) -> bool {
354 self.timeout.map(|t| self.elapsed() > t).unwrap_or(false)
355 }
356
357 pub fn remaining_timeout(&self) -> Option<std::time::Duration> {
359 self.timeout.and_then(|t| t.checked_sub(self.elapsed()))
360 }
361
362 pub fn request_id(&self) -> &str {
364 &self.trace.request_id
365 }
366
367 pub fn span(&self, name: &'static str) -> Span {
369 if self.debug_request {
370 self.trace.create_debug_span(name)
371 } else {
372 self.trace.create_span(name)
373 }
374 }
375}
376
377impl Default for RequestContext {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383pub type SharedTraceContext = Arc<TraceContext>;
385
386pub fn shared_context(ctx: TraceContext) -> SharedTraceContext {
388 Arc::new(ctx)
389}
390
391#[derive(Debug, Clone)]
393pub struct DeviceContext {
394 pub device_id: String,
396
397 pub protocol: String,
399
400 pub trace: TraceContext,
402}
403
404impl DeviceContext {
405 pub fn new(device_id: impl Into<String>, protocol: impl Into<String>) -> Self {
407 let device_id = device_id.into();
408 let protocol = protocol.into();
409
410 Self {
411 device_id: device_id.clone(),
412 protocol: protocol.clone(),
413 trace: TraceContext::new()
414 .with_device_id(device_id)
415 .with_protocol(protocol),
416 }
417 }
418
419 pub fn with_trace(
421 device_id: impl Into<String>,
422 protocol: impl Into<String>,
423 trace: TraceContext,
424 ) -> Self {
425 let device_id = device_id.into();
426 let protocol = protocol.into();
427
428 Self {
429 device_id: device_id.clone(),
430 protocol: protocol.clone(),
431 trace: trace.with_device_id(device_id).with_protocol(protocol),
432 }
433 }
434
435 pub fn span(&self, operation: &'static str) -> Span {
437 span!(
438 Level::DEBUG,
439 "device_operation",
440 device_id = %self.device_id,
441 protocol = %self.protocol,
442 operation = operation,
443 request_id = %self.trace.request_id,
444 )
445 }
446
447 pub fn request_id(&self) -> &str {
449 &self.trace.request_id
450 }
451
452 pub fn child(&self) -> Self {
454 Self {
455 device_id: self.device_id.clone(),
456 protocol: self.protocol.clone(),
457 trace: self.trace.child(),
458 }
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_trace_context_creation() {
468 let ctx = TraceContext::new();
469 assert!(!ctx.request_id.is_empty());
470 assert!(ctx.device_id.is_none());
471 assert!(ctx.protocol.is_none());
472 }
473
474 #[test]
475 fn test_trace_context_builder() {
476 let ctx = TraceContext::new()
477 .with_device_id("device-001")
478 .with_protocol("modbus")
479 .with_operation("read")
480 .with_field("unit_id", "1");
481
482 assert_eq!(ctx.device_id, Some("device-001".to_string()));
483 assert_eq!(ctx.protocol, Some("modbus".to_string()));
484 assert_eq!(ctx.operation, Some("read".to_string()));
485 assert_eq!(ctx.fields.get("unit_id"), Some(&"1".to_string()));
486 }
487
488 #[test]
489 fn test_trace_context_child() {
490 let parent = TraceContext::new()
491 .with_device_id("device-001")
492 .with_trace_id("trace-123");
493
494 let child = parent.child();
495
496 assert_eq!(child.request_id, parent.request_id);
497 assert_eq!(child.trace_id, parent.trace_id);
498 assert_eq!(child.device_id, parent.device_id);
499 assert_eq!(child.parent_span_id, parent.span_id);
500 }
501
502 #[test]
503 fn test_trace_context_to_map() {
504 let ctx = TraceContext::new()
505 .with_device_id("device-001")
506 .with_protocol("modbus");
507
508 let map = ctx.to_map();
509 assert!(map.contains_key("request_id"));
510 assert_eq!(map.get("device_id"), Some(&"device-001".to_string()));
511 assert_eq!(map.get("protocol"), Some(&"modbus".to_string()));
512 }
513
514 #[test]
515 fn test_trace_context_headers() {
516 let ctx = TraceContext::new()
517 .with_device_id("device-001")
518 .with_trace_id("trace-123");
519
520 let headers = ctx.to_headers();
521 assert!(headers.contains_key("x-request-id"));
522 assert_eq!(headers.get("x-trace-id"), Some(&"trace-123".to_string()));
523 assert_eq!(headers.get("x-device-id"), Some(&"device-001".to_string()));
524
525 let parsed = TraceContext::from_headers(&headers);
527 assert_eq!(parsed.request_id, ctx.request_id);
528 assert_eq!(parsed.trace_id, ctx.trace_id);
529 assert_eq!(parsed.device_id, ctx.device_id);
530 }
531
532 #[test]
533 fn test_request_context() {
534 let ctx = RequestContext::new()
535 .device("device-001")
536 .protocol("modbus")
537 .operation("read")
538 .with_timeout(std::time::Duration::from_secs(5));
539
540 assert!(!ctx.request_id().is_empty());
541 assert!(!ctx.is_timed_out());
542 assert!(ctx.remaining_timeout().is_some());
543 }
544
545 #[test]
546 fn test_device_context() {
547 let ctx = DeviceContext::new("device-001", "modbus");
548
549 assert_eq!(ctx.device_id, "device-001");
550 assert_eq!(ctx.protocol, "modbus");
551 assert!(!ctx.request_id().is_empty());
552
553 let child = ctx.child();
554 assert_eq!(child.request_id(), ctx.request_id());
555 }
556
557 #[test]
558 fn test_trace_context_age() {
559 let ctx = TraceContext::new();
560 std::thread::sleep(std::time::Duration::from_millis(10));
561
562 assert!(ctx.age_ms() >= 10);
563 assert!(ctx.is_older_than_ms(5));
564 assert!(!ctx.is_older_than_ms(1000));
565 }
566
567 #[test]
568 fn test_shared_context() {
569 let ctx = TraceContext::new().with_device_id("device-001");
570 let shared = shared_context(ctx);
571
572 assert_eq!(shared.device_id, Some("device-001".to_string()));
573 }
574}