1use std::collections::HashMap;
33use std::sync::Arc;
34
35use serde::{Deserialize, Serialize};
36use tracing::{span, Level, Span};
37use uuid::Uuid;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TraceContext {
45 pub request_id: String,
47
48 #[serde(default)]
50 pub trace_id: Option<String>,
51
52 #[serde(default)]
54 pub span_id: Option<String>,
55
56 #[serde(default)]
58 pub parent_span_id: Option<String>,
59
60 #[serde(default)]
62 pub device_id: Option<String>,
63
64 #[serde(default)]
66 pub protocol: Option<String>,
67
68 #[serde(default)]
70 pub operation: Option<String>,
71
72 #[serde(default)]
74 pub fields: HashMap<String, String>,
75
76 #[serde(default = "default_timestamp")]
78 pub created_at: u64,
79}
80
81fn default_timestamp() -> u64 {
82 std::time::SystemTime::now()
83 .duration_since(std::time::UNIX_EPOCH)
84 .map(|d| d.as_millis() as u64)
85 .unwrap_or(0)
86}
87
88impl Default for TraceContext {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl TraceContext {
95 pub fn new() -> Self {
97 Self {
98 request_id: Uuid::new_v4().to_string(),
99 trace_id: None,
100 span_id: None,
101 parent_span_id: None,
102 device_id: None,
103 protocol: None,
104 operation: None,
105 fields: HashMap::new(),
106 created_at: default_timestamp(),
107 }
108 }
109
110 pub fn with_request_id(request_id: impl Into<String>) -> Self {
112 Self {
113 request_id: request_id.into(),
114 ..Self::new()
115 }
116 }
117
118 pub fn child(&self) -> Self {
120 Self {
121 request_id: self.request_id.clone(),
122 trace_id: self.trace_id.clone(),
123 span_id: Some(Uuid::new_v4().to_string()),
124 parent_span_id: self.span_id.clone(),
125 device_id: self.device_id.clone(),
126 protocol: self.protocol.clone(),
127 operation: None,
128 fields: self.fields.clone(),
129 created_at: default_timestamp(),
130 }
131 }
132
133 pub fn with_device_id(mut self, device_id: impl Into<String>) -> Self {
135 self.device_id = Some(device_id.into());
136 self
137 }
138
139 pub fn with_protocol(mut self, protocol: impl Into<String>) -> Self {
141 self.protocol = Some(protocol.into());
142 self
143 }
144
145 pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
147 self.operation = Some(operation.into());
148 self
149 }
150
151 pub fn with_trace_id(mut self, trace_id: impl Into<String>) -> Self {
153 self.trace_id = Some(trace_id.into());
154 self
155 }
156
157 pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
159 self.fields.insert(key.into(), value.into());
160 self
161 }
162
163 pub fn with_fields(mut self, fields: impl IntoIterator<Item = (String, String)>) -> Self {
165 self.fields.extend(fields);
166 self
167 }
168
169 pub fn create_span(&self, name: &'static str) -> Span {
171 let span = span!(
172 Level::INFO,
173 "request",
174 request_id = %self.request_id,
175 operation = name,
176 );
177
178 if let Some(ref device_id) = self.device_id {
180 span.record("device_id", device_id.as_str());
181 }
182 if let Some(ref protocol) = self.protocol {
183 span.record("protocol", protocol.as_str());
184 }
185 if let Some(ref trace_id) = self.trace_id {
186 span.record("trace_id", trace_id.as_str());
187 }
188
189 span
190 }
191
192 pub fn create_debug_span(&self, name: &'static str) -> Span {
194 span!(
195 Level::DEBUG,
196 "operation",
197 request_id = %self.request_id,
198 operation = name,
199 device_id = self.device_id.as_deref().unwrap_or(""),
200 )
201 }
202
203 pub fn age_ms(&self) -> u64 {
205 default_timestamp().saturating_sub(self.created_at)
206 }
207
208 pub fn is_older_than_ms(&self, ms: u64) -> bool {
210 self.age_ms() > ms
211 }
212
213 pub fn to_map(&self) -> HashMap<String, String> {
215 let mut map = HashMap::new();
216 map.insert("request_id".to_string(), self.request_id.clone());
217
218 if let Some(ref trace_id) = self.trace_id {
219 map.insert("trace_id".to_string(), trace_id.clone());
220 }
221 if let Some(ref device_id) = self.device_id {
222 map.insert("device_id".to_string(), device_id.clone());
223 }
224 if let Some(ref protocol) = self.protocol {
225 map.insert("protocol".to_string(), protocol.clone());
226 }
227 if let Some(ref operation) = self.operation {
228 map.insert("operation".to_string(), operation.clone());
229 }
230
231 map.extend(self.fields.clone());
232 map
233 }
234
235 pub fn from_headers(headers: &HashMap<String, String>) -> Self {
237 let mut ctx = Self::new();
238
239 if let Some(request_id) = headers.get("x-request-id").or(headers.get("x-correlation-id")) {
240 ctx.request_id = request_id.clone();
241 }
242 if let Some(trace_id) = headers.get("x-trace-id").or(headers.get("traceparent")) {
243 ctx.trace_id = Some(trace_id.clone());
244 }
245 if let Some(span_id) = headers.get("x-span-id") {
246 ctx.span_id = Some(span_id.clone());
247 }
248 if let Some(device_id) = headers.get("x-device-id") {
249 ctx.device_id = Some(device_id.clone());
250 }
251
252 ctx
253 }
254
255 pub fn to_headers(&self) -> HashMap<String, String> {
257 let mut headers = HashMap::new();
258
259 headers.insert("x-request-id".to_string(), self.request_id.clone());
260
261 if let Some(ref trace_id) = self.trace_id {
262 headers.insert("x-trace-id".to_string(), trace_id.clone());
263 }
264 if let Some(ref span_id) = self.span_id {
265 headers.insert("x-span-id".to_string(), span_id.clone());
266 }
267 if let Some(ref device_id) = self.device_id {
268 headers.insert("x-device-id".to_string(), device_id.clone());
269 }
270
271 headers
272 }
273}
274
275#[derive(Debug, Clone)]
279pub struct RequestContext {
280 pub trace: TraceContext,
282
283 pub start_time: std::time::Instant,
285
286 pub timeout: Option<std::time::Duration>,
288
289 pub debug_request: bool,
291}
292
293impl RequestContext {
294 pub fn new() -> Self {
296 Self {
297 trace: TraceContext::new(),
298 start_time: std::time::Instant::now(),
299 timeout: None,
300 debug_request: false,
301 }
302 }
303
304 pub fn with_trace(trace: TraceContext) -> Self {
306 Self {
307 trace,
308 start_time: std::time::Instant::now(),
309 timeout: None,
310 debug_request: false,
311 }
312 }
313
314 pub fn device(mut self, device_id: impl Into<String>) -> Self {
316 self.trace = self.trace.with_device_id(device_id);
317 self
318 }
319
320 pub fn protocol(mut self, protocol: impl Into<String>) -> Self {
322 self.trace = self.trace.with_protocol(protocol);
323 self
324 }
325
326 pub fn operation(mut self, operation: impl Into<String>) -> Self {
328 self.trace = self.trace.with_operation(operation);
329 self
330 }
331
332 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
334 self.timeout = Some(timeout);
335 self
336 }
337
338 pub fn debug(mut self) -> Self {
340 self.debug_request = true;
341 self
342 }
343
344 pub fn elapsed(&self) -> std::time::Duration {
346 self.start_time.elapsed()
347 }
348
349 pub fn is_timed_out(&self) -> bool {
351 self.timeout
352 .map(|t| self.elapsed() > t)
353 .unwrap_or(false)
354 }
355
356 pub fn remaining_timeout(&self) -> Option<std::time::Duration> {
358 self.timeout.and_then(|t| t.checked_sub(self.elapsed()))
359 }
360
361 pub fn request_id(&self) -> &str {
363 &self.trace.request_id
364 }
365
366 pub fn span(&self, name: &'static str) -> Span {
368 if self.debug_request {
369 self.trace.create_debug_span(name)
370 } else {
371 self.trace.create_span(name)
372 }
373 }
374}
375
376impl Default for RequestContext {
377 fn default() -> Self {
378 Self::new()
379 }
380}
381
382pub type SharedTraceContext = Arc<TraceContext>;
384
385pub fn shared_context(ctx: TraceContext) -> SharedTraceContext {
387 Arc::new(ctx)
388}
389
390#[derive(Debug, Clone)]
392pub struct DeviceContext {
393 pub device_id: String,
395
396 pub protocol: String,
398
399 pub trace: TraceContext,
401}
402
403impl DeviceContext {
404 pub fn new(device_id: impl Into<String>, protocol: impl Into<String>) -> Self {
406 let device_id = device_id.into();
407 let protocol = protocol.into();
408
409 Self {
410 device_id: device_id.clone(),
411 protocol: protocol.clone(),
412 trace: TraceContext::new()
413 .with_device_id(device_id)
414 .with_protocol(protocol),
415 }
416 }
417
418 pub fn with_trace(
420 device_id: impl Into<String>,
421 protocol: impl Into<String>,
422 trace: TraceContext,
423 ) -> Self {
424 let device_id = device_id.into();
425 let protocol = protocol.into();
426
427 Self {
428 device_id: device_id.clone(),
429 protocol: protocol.clone(),
430 trace: trace.with_device_id(device_id).with_protocol(protocol),
431 }
432 }
433
434 pub fn span(&self, operation: &'static str) -> Span {
436 span!(
437 Level::DEBUG,
438 "device_operation",
439 device_id = %self.device_id,
440 protocol = %self.protocol,
441 operation = operation,
442 request_id = %self.trace.request_id,
443 )
444 }
445
446 pub fn request_id(&self) -> &str {
448 &self.trace.request_id
449 }
450
451 pub fn child(&self) -> Self {
453 Self {
454 device_id: self.device_id.clone(),
455 protocol: self.protocol.clone(),
456 trace: self.trace.child(),
457 }
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_trace_context_creation() {
467 let ctx = TraceContext::new();
468 assert!(!ctx.request_id.is_empty());
469 assert!(ctx.device_id.is_none());
470 assert!(ctx.protocol.is_none());
471 }
472
473 #[test]
474 fn test_trace_context_builder() {
475 let ctx = TraceContext::new()
476 .with_device_id("device-001")
477 .with_protocol("modbus")
478 .with_operation("read")
479 .with_field("unit_id", "1");
480
481 assert_eq!(ctx.device_id, Some("device-001".to_string()));
482 assert_eq!(ctx.protocol, Some("modbus".to_string()));
483 assert_eq!(ctx.operation, Some("read".to_string()));
484 assert_eq!(ctx.fields.get("unit_id"), Some(&"1".to_string()));
485 }
486
487 #[test]
488 fn test_trace_context_child() {
489 let parent = TraceContext::new()
490 .with_device_id("device-001")
491 .with_trace_id("trace-123");
492
493 let child = parent.child();
494
495 assert_eq!(child.request_id, parent.request_id);
496 assert_eq!(child.trace_id, parent.trace_id);
497 assert_eq!(child.device_id, parent.device_id);
498 assert_eq!(child.parent_span_id, parent.span_id);
499 }
500
501 #[test]
502 fn test_trace_context_to_map() {
503 let ctx = TraceContext::new()
504 .with_device_id("device-001")
505 .with_protocol("modbus");
506
507 let map = ctx.to_map();
508 assert!(map.contains_key("request_id"));
509 assert_eq!(map.get("device_id"), Some(&"device-001".to_string()));
510 assert_eq!(map.get("protocol"), Some(&"modbus".to_string()));
511 }
512
513 #[test]
514 fn test_trace_context_headers() {
515 let ctx = TraceContext::new()
516 .with_device_id("device-001")
517 .with_trace_id("trace-123");
518
519 let headers = ctx.to_headers();
520 assert!(headers.contains_key("x-request-id"));
521 assert_eq!(headers.get("x-trace-id"), Some(&"trace-123".to_string()));
522 assert_eq!(headers.get("x-device-id"), Some(&"device-001".to_string()));
523
524 let parsed = TraceContext::from_headers(&headers);
526 assert_eq!(parsed.request_id, ctx.request_id);
527 assert_eq!(parsed.trace_id, ctx.trace_id);
528 assert_eq!(parsed.device_id, ctx.device_id);
529 }
530
531 #[test]
532 fn test_request_context() {
533 let ctx = RequestContext::new()
534 .device("device-001")
535 .protocol("modbus")
536 .operation("read")
537 .with_timeout(std::time::Duration::from_secs(5));
538
539 assert!(!ctx.request_id().is_empty());
540 assert!(!ctx.is_timed_out());
541 assert!(ctx.remaining_timeout().is_some());
542 }
543
544 #[test]
545 fn test_device_context() {
546 let ctx = DeviceContext::new("device-001", "modbus");
547
548 assert_eq!(ctx.device_id, "device-001");
549 assert_eq!(ctx.protocol, "modbus");
550 assert!(!ctx.request_id().is_empty());
551
552 let child = ctx.child();
553 assert_eq!(child.request_id(), ctx.request_id());
554 }
555
556 #[test]
557 fn test_trace_context_age() {
558 let ctx = TraceContext::new();
559 std::thread::sleep(std::time::Duration::from_millis(10));
560
561 assert!(ctx.age_ms() >= 10);
562 assert!(ctx.is_older_than_ms(5));
563 assert!(!ctx.is_older_than_ms(1000));
564 }
565
566 #[test]
567 fn test_shared_context() {
568 let ctx = TraceContext::new().with_device_id("device-001");
569 let shared = shared_context(ctx);
570
571 assert_eq!(shared.device_id, Some("device-001".to_string()));
572 }
573}