use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tracing::{span, Level, Span};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraceContext {
pub request_id: String,
#[serde(default)]
pub trace_id: Option<String>,
#[serde(default)]
pub span_id: Option<String>,
#[serde(default)]
pub parent_span_id: Option<String>,
#[serde(default)]
pub device_id: Option<String>,
#[serde(default)]
pub protocol: Option<String>,
#[serde(default)]
pub operation: Option<String>,
#[serde(default)]
pub fields: HashMap<String, String>,
#[serde(default = "default_timestamp")]
pub created_at: u64,
}
fn default_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
impl Default for TraceContext {
fn default() -> Self {
Self::new()
}
}
impl TraceContext {
pub fn new() -> Self {
Self {
request_id: Uuid::new_v4().to_string(),
trace_id: None,
span_id: None,
parent_span_id: None,
device_id: None,
protocol: None,
operation: None,
fields: HashMap::new(),
created_at: default_timestamp(),
}
}
pub fn with_request_id(request_id: impl Into<String>) -> Self {
Self {
request_id: request_id.into(),
..Self::new()
}
}
pub fn child(&self) -> Self {
Self {
request_id: self.request_id.clone(),
trace_id: self.trace_id.clone(),
span_id: Some(Uuid::new_v4().to_string()),
parent_span_id: self.span_id.clone(),
device_id: self.device_id.clone(),
protocol: self.protocol.clone(),
operation: None,
fields: self.fields.clone(),
created_at: default_timestamp(),
}
}
pub fn with_device_id(mut self, device_id: impl Into<String>) -> Self {
self.device_id = Some(device_id.into());
self
}
pub fn with_protocol(mut self, protocol: impl Into<String>) -> Self {
self.protocol = Some(protocol.into());
self
}
pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
self.operation = Some(operation.into());
self
}
pub fn with_trace_id(mut self, trace_id: impl Into<String>) -> Self {
self.trace_id = Some(trace_id.into());
self
}
pub fn with_field(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.fields.insert(key.into(), value.into());
self
}
pub fn with_fields(mut self, fields: impl IntoIterator<Item = (String, String)>) -> Self {
self.fields.extend(fields);
self
}
pub fn create_span(&self, name: &'static str) -> Span {
let span = span!(
Level::INFO,
"request",
request_id = %self.request_id,
operation = name,
);
if let Some(ref device_id) = self.device_id {
span.record("device_id", device_id.as_str());
}
if let Some(ref protocol) = self.protocol {
span.record("protocol", protocol.as_str());
}
if let Some(ref trace_id) = self.trace_id {
span.record("trace_id", trace_id.as_str());
}
span
}
pub fn create_debug_span(&self, name: &'static str) -> Span {
span!(
Level::DEBUG,
"operation",
request_id = %self.request_id,
operation = name,
device_id = self.device_id.as_deref().unwrap_or(""),
)
}
pub fn age_ms(&self) -> u64 {
default_timestamp().saturating_sub(self.created_at)
}
pub fn is_older_than_ms(&self, ms: u64) -> bool {
self.age_ms() > ms
}
pub fn to_map(&self) -> HashMap<String, String> {
let mut map = HashMap::new();
map.insert("request_id".to_string(), self.request_id.clone());
if let Some(ref trace_id) = self.trace_id {
map.insert("trace_id".to_string(), trace_id.clone());
}
if let Some(ref device_id) = self.device_id {
map.insert("device_id".to_string(), device_id.clone());
}
if let Some(ref protocol) = self.protocol {
map.insert("protocol".to_string(), protocol.clone());
}
if let Some(ref operation) = self.operation {
map.insert("operation".to_string(), operation.clone());
}
map.extend(self.fields.clone());
map
}
pub fn from_headers(headers: &HashMap<String, String>) -> Self {
let mut ctx = Self::new();
if let Some(request_id) = headers
.get("x-request-id")
.or(headers.get("x-correlation-id"))
{
ctx.request_id = request_id.clone();
}
if let Some(trace_id) = headers.get("x-trace-id").or(headers.get("traceparent")) {
ctx.trace_id = Some(trace_id.clone());
}
if let Some(span_id) = headers.get("x-span-id") {
ctx.span_id = Some(span_id.clone());
}
if let Some(device_id) = headers.get("x-device-id") {
ctx.device_id = Some(device_id.clone());
}
ctx
}
pub fn to_headers(&self) -> HashMap<String, String> {
let mut headers = HashMap::new();
headers.insert("x-request-id".to_string(), self.request_id.clone());
if let Some(ref trace_id) = self.trace_id {
headers.insert("x-trace-id".to_string(), trace_id.clone());
}
if let Some(ref span_id) = self.span_id {
headers.insert("x-span-id".to_string(), span_id.clone());
}
if let Some(ref device_id) = self.device_id {
headers.insert("x-device-id".to_string(), device_id.clone());
}
headers
}
}
#[derive(Debug, Clone)]
pub struct RequestContext {
pub trace: TraceContext,
pub start_time: std::time::Instant,
pub timeout: Option<std::time::Duration>,
pub debug_request: bool,
}
impl RequestContext {
pub fn new() -> Self {
Self {
trace: TraceContext::new(),
start_time: std::time::Instant::now(),
timeout: None,
debug_request: false,
}
}
pub fn with_trace(trace: TraceContext) -> Self {
Self {
trace,
start_time: std::time::Instant::now(),
timeout: None,
debug_request: false,
}
}
pub fn device(mut self, device_id: impl Into<String>) -> Self {
self.trace = self.trace.with_device_id(device_id);
self
}
pub fn protocol(mut self, protocol: impl Into<String>) -> Self {
self.trace = self.trace.with_protocol(protocol);
self
}
pub fn operation(mut self, operation: impl Into<String>) -> Self {
self.trace = self.trace.with_operation(operation);
self
}
pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn debug(mut self) -> Self {
self.debug_request = true;
self
}
pub fn elapsed(&self) -> std::time::Duration {
self.start_time.elapsed()
}
pub fn is_timed_out(&self) -> bool {
self.timeout.map(|t| self.elapsed() > t).unwrap_or(false)
}
pub fn remaining_timeout(&self) -> Option<std::time::Duration> {
self.timeout.and_then(|t| t.checked_sub(self.elapsed()))
}
pub fn request_id(&self) -> &str {
&self.trace.request_id
}
pub fn span(&self, name: &'static str) -> Span {
if self.debug_request {
self.trace.create_debug_span(name)
} else {
self.trace.create_span(name)
}
}
}
impl Default for RequestContext {
fn default() -> Self {
Self::new()
}
}
pub type SharedTraceContext = Arc<TraceContext>;
pub fn shared_context(ctx: TraceContext) -> SharedTraceContext {
Arc::new(ctx)
}
#[derive(Debug, Clone)]
pub struct DeviceContext {
pub device_id: String,
pub protocol: String,
pub trace: TraceContext,
}
impl DeviceContext {
pub fn new(device_id: impl Into<String>, protocol: impl Into<String>) -> Self {
let device_id = device_id.into();
let protocol = protocol.into();
Self {
device_id: device_id.clone(),
protocol: protocol.clone(),
trace: TraceContext::new()
.with_device_id(device_id)
.with_protocol(protocol),
}
}
pub fn with_trace(
device_id: impl Into<String>,
protocol: impl Into<String>,
trace: TraceContext,
) -> Self {
let device_id = device_id.into();
let protocol = protocol.into();
Self {
device_id: device_id.clone(),
protocol: protocol.clone(),
trace: trace.with_device_id(device_id).with_protocol(protocol),
}
}
pub fn span(&self, operation: &'static str) -> Span {
span!(
Level::DEBUG,
"device_operation",
device_id = %self.device_id,
protocol = %self.protocol,
operation = operation,
request_id = %self.trace.request_id,
)
}
pub fn request_id(&self) -> &str {
&self.trace.request_id
}
pub fn child(&self) -> Self {
Self {
device_id: self.device_id.clone(),
protocol: self.protocol.clone(),
trace: self.trace.child(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trace_context_creation() {
let ctx = TraceContext::new();
assert!(!ctx.request_id.is_empty());
assert!(ctx.device_id.is_none());
assert!(ctx.protocol.is_none());
}
#[test]
fn test_trace_context_builder() {
let ctx = TraceContext::new()
.with_device_id("device-001")
.with_protocol("modbus")
.with_operation("read")
.with_field("unit_id", "1");
assert_eq!(ctx.device_id, Some("device-001".to_string()));
assert_eq!(ctx.protocol, Some("modbus".to_string()));
assert_eq!(ctx.operation, Some("read".to_string()));
assert_eq!(ctx.fields.get("unit_id"), Some(&"1".to_string()));
}
#[test]
fn test_trace_context_child() {
let parent = TraceContext::new()
.with_device_id("device-001")
.with_trace_id("trace-123");
let child = parent.child();
assert_eq!(child.request_id, parent.request_id);
assert_eq!(child.trace_id, parent.trace_id);
assert_eq!(child.device_id, parent.device_id);
assert_eq!(child.parent_span_id, parent.span_id);
}
#[test]
fn test_trace_context_to_map() {
let ctx = TraceContext::new()
.with_device_id("device-001")
.with_protocol("modbus");
let map = ctx.to_map();
assert!(map.contains_key("request_id"));
assert_eq!(map.get("device_id"), Some(&"device-001".to_string()));
assert_eq!(map.get("protocol"), Some(&"modbus".to_string()));
}
#[test]
fn test_trace_context_headers() {
let ctx = TraceContext::new()
.with_device_id("device-001")
.with_trace_id("trace-123");
let headers = ctx.to_headers();
assert!(headers.contains_key("x-request-id"));
assert_eq!(headers.get("x-trace-id"), Some(&"trace-123".to_string()));
assert_eq!(headers.get("x-device-id"), Some(&"device-001".to_string()));
let parsed = TraceContext::from_headers(&headers);
assert_eq!(parsed.request_id, ctx.request_id);
assert_eq!(parsed.trace_id, ctx.trace_id);
assert_eq!(parsed.device_id, ctx.device_id);
}
#[test]
fn test_request_context() {
let ctx = RequestContext::new()
.device("device-001")
.protocol("modbus")
.operation("read")
.with_timeout(std::time::Duration::from_secs(5));
assert!(!ctx.request_id().is_empty());
assert!(!ctx.is_timed_out());
assert!(ctx.remaining_timeout().is_some());
}
#[test]
fn test_device_context() {
let ctx = DeviceContext::new("device-001", "modbus");
assert_eq!(ctx.device_id, "device-001");
assert_eq!(ctx.protocol, "modbus");
assert!(!ctx.request_id().is_empty());
let child = ctx.child();
assert_eq!(child.request_id(), ctx.request_id());
}
#[test]
fn test_trace_context_age() {
let ctx = TraceContext::new();
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(ctx.age_ms() >= 10);
assert!(ctx.is_older_than_ms(5));
assert!(!ctx.is_older_than_ms(1000));
}
#[test]
fn test_shared_context() {
let ctx = TraceContext::new().with_device_id("device-001");
let shared = shared_context(ctx);
assert_eq!(shared.device_id, Some("device-001".to_string()));
}
}