use crate::types::RequestId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use tokio::task_local;
#[cfg(not(target_arch = "wasm32"))]
task_local! {
static REQUEST_CONTEXT: Arc<RequestContext>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestContext {
pub request_id: RequestId,
pub trace_id: String,
pub parent_span_id: Option<String>,
pub span_id: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub user_id: Option<String>,
pub session_id: Option<String>,
pub client_info: Option<ClientInfo>,
pub metadata: HashMap<String, serde_json::Value>,
pub baggage: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientInfo {
pub client_id: String,
pub version: Option<String>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
}
impl RequestContext {
pub fn new(request_id: RequestId) -> Self {
Self {
request_id,
trace_id: uuid::Uuid::new_v4().to_string(),
parent_span_id: None,
span_id: uuid::Uuid::new_v4().to_string(),
timestamp: chrono::Utc::now(),
user_id: None,
session_id: None,
client_info: None,
metadata: HashMap::new(),
baggage: HashMap::new(),
}
}
pub fn child(&self) -> Self {
Self {
request_id: self.request_id.clone(),
trace_id: self.trace_id.clone(),
parent_span_id: Some(self.span_id.clone()),
span_id: uuid::Uuid::new_v4().to_string(),
timestamp: chrono::Utc::now(),
user_id: self.user_id.clone(),
session_id: self.session_id.clone(),
client_info: self.client_info.clone(),
metadata: self.metadata.clone(),
baggage: self.baggage.clone(),
}
}
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
pub fn with_baggage(mut self, key: String, value: String) -> Self {
self.baggage.insert(key, value);
self
}
pub fn with_user_id(mut self, user_id: String) -> Self {
self.user_id = Some(user_id);
self
}
pub fn with_session_id(mut self, session_id: String) -> Self {
self.session_id = Some(session_id);
self
}
pub fn with_client_info(mut self, client_info: ClientInfo) -> Self {
self.client_info = Some(client_info);
self
}
pub fn current() -> Option<Arc<Self>> {
#[cfg(not(target_arch = "wasm32"))]
{
REQUEST_CONTEXT.try_with(|ctx| ctx.clone()).ok()
}
#[cfg(target_arch = "wasm32")]
{
None
}
}
pub async fn run<F, R>(self, f: F) -> R
where
F: std::future::Future<Output = R>,
{
#[cfg(not(target_arch = "wasm32"))]
{
REQUEST_CONTEXT.scope(Arc::new(self), f).await
}
#[cfg(target_arch = "wasm32")]
{
f.await
}
}
pub fn to_headers(&self) -> HashMap<String, String> {
let mut headers = HashMap::new();
let trace_id_hex = self.trace_id.replace('-', "");
let span_id_hex = self.span_id.replace('-', "")[..16].to_string();
headers.insert(
"traceparent".to_string(),
format!("00-{}-{}-01", trace_id_hex, span_id_hex),
);
headers.insert("x-request-id".to_string(), self.request_id.to_string());
if let Some(user_id) = &self.user_id {
headers.insert("x-user-id".to_string(), user_id.clone());
}
if let Some(session_id) = &self.session_id {
headers.insert("x-session-id".to_string(), session_id.clone());
}
for (key, value) in &self.baggage {
headers.insert(format!("baggage-{}", key), value.clone());
}
headers
}
pub fn from_headers(headers: &HashMap<String, String>) -> Option<Self> {
if let Some(traceparent) = headers.get("traceparent") {
let parts: Vec<&str> = traceparent.split('-').collect();
if parts.len() >= 4 {
let trace_id_hex = parts[1];
let trace_id = if trace_id_hex.len() == 32 {
format!(
"{}-{}-{}-{}-{}",
&trace_id_hex[0..8],
&trace_id_hex[8..12],
&trace_id_hex[12..16],
&trace_id_hex[16..20],
&trace_id_hex[20..32]
)
} else {
trace_id_hex.to_string()
};
let span_id_hex = parts[2];
let span_id = if span_id_hex.len() == 16 {
format!(
"{}-{}-{}-{}-{}",
&span_id_hex[0..8],
&span_id_hex[8..12],
&span_id_hex[12..16],
"0000",
"000000000000"
)
} else {
span_id_hex.to_string()
};
let request_id = headers
.get("x-request-id")
.and_then(|id| id.parse::<i64>().ok())
.map_or_else(
|| RequestId::from(uuid::Uuid::new_v4().as_u128() as i64),
RequestId::from,
);
let mut context = Self::new(request_id);
context.trace_id = trace_id;
context.parent_span_id = Some(span_id);
context.span_id = uuid::Uuid::new_v4().to_string();
if let Some(user_id) = headers.get("x-user-id") {
context.user_id = Some(user_id.clone());
}
if let Some(session_id) = headers.get("x-session-id") {
context.session_id = Some(session_id.clone());
}
for (key, value) in headers {
if key.starts_with("baggage-") {
let baggage_key = key.strip_prefix("baggage-").unwrap();
context
.baggage
.insert(baggage_key.to_string(), value.clone());
}
}
return Some(context);
}
}
None
}
}
#[derive(Debug)]
pub struct ContextPropagator;
impl ContextPropagator {
pub fn extract(headers: &HashMap<String, String>) -> Option<RequestContext> {
RequestContext::from_headers(headers)
}
pub fn inject(context: &RequestContext) -> HashMap<String, String> {
context.to_headers()
}
}
#[macro_export]
macro_rules! with_context {
($ctx:expr, $body:expr) => {
$ctx.run(async move { $body }).await
};
}
#[macro_export]
macro_rules! context_or_new {
($request_id:expr) => {
RequestContext::current()
.map(|ctx| (*ctx).clone())
.unwrap_or_else(|| RequestContext::new($request_id))
};
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_context_propagation() {
let context = RequestContext::new(RequestId::from(123i64))
.with_user_id("user123".to_string())
.with_baggage("key1".to_string(), "value1".to_string());
let result = context
.clone()
.run(async {
let current = RequestContext::current().unwrap();
assert_eq!(current.user_id, Some("user123".to_string()));
assert_eq!(current.baggage.get("key1"), Some(&"value1".to_string()));
42
})
.await;
assert_eq!(result, 42);
}
#[tokio::test]
async fn test_child_context() {
let parent = RequestContext::new(RequestId::from(123i64));
let child = parent.child();
assert_eq!(parent.request_id, child.request_id);
assert_eq!(parent.trace_id, child.trace_id);
assert_eq!(child.parent_span_id, Some(parent.span_id.clone()));
assert_ne!(parent.span_id, child.span_id);
}
#[tokio::test]
async fn test_headers_conversion() {
let context = RequestContext::new(RequestId::from(123i64))
.with_user_id("user123".to_string())
.with_session_id("session456".to_string())
.with_baggage("env".to_string(), "prod".to_string());
let headers = context.to_headers();
assert!(headers.contains_key("traceparent"));
assert_eq!(headers.get("x-request-id"), Some(&"123".to_string()));
assert_eq!(headers.get("x-user-id"), Some(&"user123".to_string()));
assert_eq!(headers.get("x-session-id"), Some(&"session456".to_string()));
assert_eq!(headers.get("baggage-env"), Some(&"prod".to_string()));
let restored = RequestContext::from_headers(&headers).unwrap();
assert_eq!(restored.trace_id, context.trace_id);
assert_eq!(restored.user_id, context.user_id);
assert_eq!(restored.session_id, context.session_id);
assert_eq!(restored.baggage.get("env"), Some(&"prod".to_string()));
}
}