use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use parking_lot::RwLock;
use std::sync::Arc;
use crate::cancel::CancelToken;
use crate::observability::logging::ContextLogger;
use crate::trace_context::TraceParent;
const TRACE_ID_ZEROS: &str = "00000000000000000000000000000000";
const TRACE_ID_FFFF: &str = "ffffffffffffffffffffffffffffffff";
#[must_use]
fn generate_trace_id() -> String {
uuid::Uuid::new_v4().simple().to_string()
}
fn accept_or_regenerate_trace_id(incoming: &str) -> String {
let is_valid_hex = incoming.len() == 32
&& incoming
.bytes()
.all(|b| b.is_ascii_digit() || (b'a'..=b'f').contains(&b));
let is_w3c_valid = incoming != TRACE_ID_ZEROS && incoming != TRACE_ID_FFFF;
if is_valid_hex && is_w3c_valid {
incoming.to_string()
} else {
tracing::warn!(
"Invalid trace_id format in trace_parent: {:?}. Restarting trace.",
incoming
);
generate_trace_id()
}
}
#[derive(Deserialize)]
struct IdentityRaw {
id: String,
#[serde(rename = "type")]
identity_type: String,
#[serde(default)]
roles: Vec<String>,
#[serde(default)]
attrs: HashMap<String, serde_json::Value>,
}
impl From<IdentityRaw> for Identity {
fn from(raw: IdentityRaw) -> Self {
Identity {
id: raw.id,
identity_type: raw.identity_type,
roles: raw.roles,
attrs: raw.attrs,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(from = "IdentityRaw")]
#[allow(clippy::struct_field_names)] pub struct Identity {
id: String,
#[serde(rename = "type")]
identity_type: String,
roles: Vec<String>,
attrs: HashMap<String, serde_json::Value>,
}
impl Identity {
#[must_use]
pub fn new(
id: String,
identity_type: String,
roles: Vec<String>,
attrs: HashMap<String, serde_json::Value>,
) -> Self {
Self {
id,
identity_type,
roles,
attrs,
}
}
#[must_use]
pub fn id(&self) -> &str {
&self.id
}
#[must_use]
pub fn identity_type(&self) -> &str {
&self.identity_type
}
#[must_use]
pub fn roles(&self) -> &[String] {
&self.roles
}
#[must_use]
pub fn attrs(&self) -> &HashMap<String, serde_json::Value> {
&self.attrs
}
#[must_use]
pub fn get_attr(&self, key: &str) -> Option<&serde_json::Value> {
self.attrs.get(key)
}
}
impl std::hash::Hash for Identity {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.identity_type.hash(state);
self.roles.hash(state);
let mut pairs: Vec<_> = self.attrs.iter().collect();
pairs.sort_by_key(|(k, _)| k.as_str());
for (k, v) in pairs {
k.hash(state);
v.to_string().hash(state);
}
}
}
pub type SharedData = Arc<RwLock<HashMap<String, serde_json::Value>>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Context<T> {
pub trace_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub identity: Option<Identity>,
pub services: T,
#[serde(skip_serializing_if = "Option::is_none")]
pub caller_id: Option<String>,
#[serde(
serialize_with = "serialize_shared_data",
deserialize_with = "deserialize_shared_data",
default = "default_shared_data"
)]
pub data: SharedData,
#[serde(default)]
pub call_chain: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redacted_inputs: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redacted_output: Option<HashMap<String, serde_json::Value>>,
#[serde(skip)]
pub cancel_token: Option<CancelToken>,
#[serde(skip)]
pub global_deadline: Option<f64>,
#[serde(skip)]
pub executor: Option<Arc<dyn std::any::Any + Send + Sync>>,
}
fn default_shared_data() -> SharedData {
Arc::new(RwLock::new(HashMap::new()))
}
fn serialize_shared_data<S>(data: &SharedData, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let map = data.read();
map.serialize(serializer)
}
fn deserialize_shared_data<'de, D>(deserializer: D) -> Result<SharedData, D::Error>
where
D: serde::Deserializer<'de>,
{
let map = HashMap::<String, serde_json::Value>::deserialize(deserializer)?;
Ok(Arc::new(RwLock::new(map)))
}
impl<T: Default> Context<T> {
#[must_use]
pub fn new(identity: Identity) -> Self {
Self {
trace_id: generate_trace_id(),
identity: Some(identity),
services: T::default(),
caller_id: None,
data: default_shared_data(),
call_chain: vec![],
redacted_inputs: None,
redacted_output: None,
cancel_token: None,
global_deadline: None,
executor: None,
}
}
#[must_use]
pub fn anonymous() -> Self {
Self {
trace_id: generate_trace_id(),
identity: None,
services: T::default(),
caller_id: None,
data: default_shared_data(),
call_chain: vec![],
redacted_inputs: None,
redacted_output: None,
cancel_token: None,
global_deadline: None,
executor: None,
}
}
#[must_use]
pub fn child(&self, target_module_id: &str) -> Context<T>
where
T: Clone,
{
let caller_id = self.call_chain.last().cloned();
let mut call_chain = self.call_chain.clone();
call_chain.push(target_module_id.to_string());
Context {
trace_id: self.trace_id.clone(),
identity: self.identity.clone(),
services: self.services.clone(),
caller_id,
data: Arc::clone(&self.data),
call_chain,
redacted_inputs: None,
redacted_output: None,
cancel_token: self.cancel_token.clone(),
global_deadline: self.global_deadline,
executor: self.executor.clone(),
}
}
pub fn to_json(&self) -> serde_json::Value {
self.serialize()
}
pub fn to_json_checked(&self) -> Result<serde_json::Value, crate::errors::ModuleError> {
Ok(self.serialize())
}
pub fn from_json(
data: serde_json::Value,
) -> Result<Context<serde_json::Value>, crate::errors::ModuleError> {
let ctx: Context<serde_json::Value> = serde_json::from_value(data)?;
Ok(ctx)
}
pub fn serialize(&self) -> serde_json::Value {
let mut result = serde_json::json!({
"_context_version": 1,
"trace_id": self.trace_id,
"caller_id": self.caller_id,
"call_chain": self.call_chain,
});
if let Some(ref identity) = self.identity {
result["identity"] = serde_json::json!({
"id": identity.id(),
"type": identity.identity_type(),
"roles": identity.roles(),
"attrs": identity.attrs(),
});
} else {
result["identity"] = serde_json::Value::Null;
}
if let Some(ref redacted) = self.redacted_inputs {
result["redacted_inputs"] = serde_json::to_value(redacted).unwrap_or_else(|e| {
tracing::error!(
trace_id = %self.trace_id,
error = %e,
"Context::serialize failed to serialize redacted_inputs"
);
serde_json::Value::Null
});
}
if let Some(ref redacted) = self.redacted_output {
result["redacted_output"] = serde_json::to_value(redacted).unwrap_or_else(|e| {
tracing::error!(
trace_id = %self.trace_id,
error = %e,
"Context::serialize failed to serialize redacted_output"
);
serde_json::Value::Null
});
}
let filtered: HashMap<String, serde_json::Value> = self
.data
.read()
.iter()
.filter(|(k, _)| !k.starts_with('_'))
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
result["data"] = serde_json::to_value(filtered).unwrap_or_else(|e| {
tracing::error!(
trace_id = %self.trace_id,
error = %e,
"Context::serialize failed to serialize data; returning empty object"
);
serde_json::json!({})
});
result
}
#[allow(clippy::needless_pass_by_value)] pub fn deserialize(value: serde_json::Value) -> Result<Self, serde_json::Error>
where
T: Default,
{
let obj = value
.as_object()
.ok_or_else(|| serde::de::Error::custom("expected JSON object"))?;
let version = obj
.get("_context_version")
.and_then(serde_json::Value::as_i64)
.unwrap_or(1);
if version > 1 {
tracing::warn!(
version = version,
"Unknown _context_version (expected 1). \
Proceeding with best-effort deserialization."
);
}
let identity: Option<Identity> = obj.get("identity").and_then(|v| {
if v.is_null() {
None
} else {
serde_json::from_value(v.clone()).ok()
}
});
let data_map: HashMap<String, serde_json::Value> = obj
.get("data")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
let call_chain: Vec<String> = obj
.get("call_chain")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
let redacted_inputs: Option<HashMap<String, serde_json::Value>> = obj
.get("redacted_inputs")
.and_then(|v| serde_json::from_value(v.clone()).ok());
let redacted_output: Option<HashMap<String, serde_json::Value>> = obj
.get("redacted_output")
.and_then(|v| serde_json::from_value(v.clone()).ok());
Ok(Context {
trace_id: obj
.get("trace_id")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
caller_id: obj
.get("caller_id")
.and_then(|v| v.as_str())
.map(std::string::ToString::to_string),
call_chain,
identity,
redacted_inputs,
redacted_output,
data: Arc::new(RwLock::new(data_map)),
services: T::default(),
cancel_token: None,
global_deadline: None,
executor: None,
})
}
pub fn logger(&self) -> ContextLogger {
let module_id = self.call_chain.last().cloned();
let caller_id = self.caller_id.clone();
let mut logger = ContextLogger::new("apcore");
logger.trace_id = Some(self.trace_id.clone());
logger.module_id = module_id;
logger.caller_id = caller_id;
logger
}
pub fn create(
identity: Option<Identity>,
trace_parent: Option<TraceParent>,
cancel_token: Option<CancelToken>,
data: Option<HashMap<String, serde_json::Value>>,
services: T,
global_deadline: Option<f64>,
) -> Self {
let mut initial_data = data.unwrap_or_default();
let trace_id = if let Some(tp) = trace_parent {
let tid = accept_or_regenerate_trace_id(&tp.trace_id);
initial_data
.entry(crate::trace_context::TRACE_FLAGS_KEY.to_string())
.or_insert_with(|| serde_json::Value::String(format!("{:02x}", tp.trace_flags)));
if !tp.tracestate.is_empty() {
initial_data
.entry(crate::trace_context::TRACE_STATE_KEY.to_string())
.or_insert_with(|| {
serde_json::Value::Array(
tp.tracestate
.into_iter()
.map(|(k, v)| {
serde_json::Value::Array(vec![
serde_json::Value::String(k),
serde_json::Value::String(v),
])
})
.collect(),
)
});
}
tid
} else {
generate_trace_id()
};
Self {
trace_id,
identity,
services,
caller_id: None,
data: Arc::new(RwLock::new(initial_data)),
call_chain: vec![],
redacted_inputs: None,
redacted_output: None,
cancel_token,
global_deadline,
executor: None,
}
}
#[must_use]
pub fn builder() -> ContextBuilder<T> {
ContextBuilder::new()
}
}
pub struct ContextBuilder<T> {
trace_parent: Option<TraceParent>,
identity: Option<Identity>,
services: Option<T>,
caller_id: Option<String>,
data: Option<HashMap<String, serde_json::Value>>,
global_deadline: Option<f64>,
cancel_token: Option<CancelToken>,
}
impl<T> Default for ContextBuilder<T> {
fn default() -> Self {
Self {
trace_parent: None,
identity: None,
services: None,
caller_id: None,
data: None,
global_deadline: None,
cancel_token: None,
}
}
}
impl<T> ContextBuilder<T> {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn trace_parent(mut self, trace_parent: Option<TraceParent>) -> Self {
self.trace_parent = trace_parent;
self
}
#[must_use]
pub fn identity(mut self, identity: Option<Identity>) -> Self {
self.identity = identity;
self
}
#[must_use]
pub fn services(mut self, services: T) -> Self {
self.services = Some(services);
self
}
#[must_use]
pub fn caller_id(mut self, caller_id: Option<String>) -> Self {
self.caller_id = caller_id;
self
}
#[must_use]
pub fn data(mut self, data: HashMap<String, serde_json::Value>) -> Self {
self.data = Some(data);
self
}
#[must_use]
pub fn global_deadline(mut self, deadline: Option<f64>) -> Self {
self.global_deadline = deadline;
self
}
#[must_use]
pub fn cancel_token(mut self, token: Option<CancelToken>) -> Self {
self.cancel_token = token;
self
}
}
impl<T: Default> ContextBuilder<T> {
#[must_use]
pub fn build(self) -> Context<T> {
let trace_id = match self.trace_parent.as_ref() {
Some(tp) => accept_or_regenerate_trace_id(&tp.trace_id),
None => generate_trace_id(),
};
let mut initial_data = self.data.unwrap_or_default();
if let Some(tp) = self.trace_parent.as_ref() {
initial_data
.entry(crate::trace_context::TRACE_FLAGS_KEY.to_string())
.or_insert_with(|| serde_json::Value::String(format!("{:02x}", tp.trace_flags)));
}
if let Some(tp) = self.trace_parent.as_ref() {
if !tp.tracestate.is_empty() {
initial_data
.entry(crate::trace_context::TRACE_STATE_KEY.to_string())
.or_insert_with(|| {
serde_json::Value::Array(
tp.tracestate
.iter()
.map(|(k, v)| {
serde_json::Value::Array(vec![
serde_json::Value::String(k.clone()),
serde_json::Value::String(v.clone()),
])
})
.collect(),
)
});
}
}
Context {
trace_id,
identity: self.identity,
services: self.services.unwrap_or_default(),
caller_id: self.caller_id,
data: Arc::new(RwLock::new(initial_data)),
call_chain: vec![],
redacted_inputs: None,
redacted_output: None,
cancel_token: self.cancel_token,
global_deadline: self.global_deadline,
executor: None,
}
}
}
impl<T> Context<T> {
#[doc(hidden)]
pub fn bind_executor(
&mut self,
executor: Arc<dyn std::any::Any + Send + Sync>,
) -> Result<(), crate::errors::ModuleError> {
match &self.executor {
None => {
self.executor = Some(executor);
Ok(())
}
Some(existing) if Arc::ptr_eq(existing, &executor) => Ok(()),
Some(_) => Err(crate::errors::ModuleError::new(
crate::errors::ErrorCode::ContextBindingError,
"Context already bound to a different Executor instance",
)),
}
}
}
#[async_trait]
pub trait ContextFactory: Send + Sync {
async fn create(
&self,
identity: Option<Identity>,
services: serde_json::Value,
) -> Result<Context<serde_json::Value>, crate::errors::ModuleError>;
async fn create_context(
&self,
identity: Option<Identity>,
services: serde_json::Value,
) -> Result<Context<serde_json::Value>, crate::errors::ModuleError> {
self.create(identity, services).await
}
async fn create_child(
&self,
parent: &Context<serde_json::Value>,
module_name: &str,
) -> Result<Context<serde_json::Value>, crate::errors::ModuleError>;
}