use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use parking_lot::RwLock;
use std::sync::Arc;
use crate::cancel::CancelToken;
use crate::observability::logging::ContextLogger;
#[derive(Deserialize)]
struct IdentityRaw {
id: String,
#[serde(rename = "type")]
identity_type: String,
#[serde(default)]
roles: Vec<String>,
#[serde(default)]
attrs: HashMap<String, serde_json::Value>,
}
impl From<IdentityRaw> for Identity {
fn from(raw: IdentityRaw) -> Self {
Identity {
id: raw.id,
identity_type: raw.identity_type,
roles: raw.roles,
attrs: raw.attrs,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(from = "IdentityRaw")]
#[allow(clippy::struct_field_names)] pub struct Identity {
id: String,
#[serde(rename = "type")]
identity_type: String,
roles: Vec<String>,
attrs: HashMap<String, serde_json::Value>,
}
impl Identity {
pub fn new(
id: String,
identity_type: String,
roles: Vec<String>,
attrs: HashMap<String, serde_json::Value>,
) -> Self {
Self {
id,
identity_type,
roles,
attrs,
}
}
pub fn id(&self) -> &str {
&self.id
}
pub fn identity_type(&self) -> &str {
&self.identity_type
}
pub fn roles(&self) -> &[String] {
&self.roles
}
pub fn attrs(&self) -> &HashMap<String, serde_json::Value> {
&self.attrs
}
}
impl std::hash::Hash for Identity {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.identity_type.hash(state);
self.roles.hash(state);
let mut pairs: Vec<_> = self.attrs.iter().collect();
pairs.sort_by_key(|(k, _)| k.as_str());
for (k, v) in pairs {
k.hash(state);
v.to_string().hash(state);
}
}
}
pub type SharedData = Arc<RwLock<HashMap<String, serde_json::Value>>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Context<T> {
pub trace_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub identity: Option<Identity>,
pub services: T,
#[serde(skip_serializing_if = "Option::is_none")]
pub caller_id: Option<String>,
#[serde(
serialize_with = "serialize_shared_data",
deserialize_with = "deserialize_shared_data",
default = "default_shared_data"
)]
pub data: SharedData,
#[serde(default)]
pub call_chain: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redacted_inputs: Option<HashMap<String, serde_json::Value>>,
#[serde(skip)]
pub cancel_token: Option<CancelToken>,
#[serde(skip)]
pub global_deadline: Option<f64>,
#[serde(skip)]
pub executor: Option<Arc<dyn std::any::Any + Send + Sync>>,
}
fn default_shared_data() -> SharedData {
Arc::new(RwLock::new(HashMap::new()))
}
fn serialize_shared_data<S>(data: &SharedData, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let map = data.read();
map.serialize(serializer)
}
fn deserialize_shared_data<'de, D>(deserializer: D) -> Result<SharedData, D::Error>
where
D: serde::Deserializer<'de>,
{
let map = HashMap::<String, serde_json::Value>::deserialize(deserializer)?;
Ok(Arc::new(RwLock::new(map)))
}
impl<T: Default> Context<T> {
pub fn new(identity: Identity) -> Self {
Self {
trace_id: uuid::Uuid::new_v4().to_string(),
identity: Some(identity),
services: T::default(),
caller_id: None,
data: default_shared_data(),
call_chain: vec![],
redacted_inputs: None,
cancel_token: None,
global_deadline: None,
executor: None,
}
}
pub fn anonymous() -> Self {
Self {
trace_id: uuid::Uuid::new_v4().to_string(),
identity: None,
services: T::default(),
caller_id: None,
data: default_shared_data(),
call_chain: vec![],
redacted_inputs: None,
cancel_token: None,
global_deadline: None,
executor: None,
}
}
#[must_use]
pub fn child(&self, target_module_id: &str) -> Context<T>
where
T: Clone,
{
let caller_id = self.call_chain.last().cloned();
let mut call_chain = self.call_chain.clone();
call_chain.push(target_module_id.to_string());
Context {
trace_id: self.trace_id.clone(),
identity: self.identity.clone(),
services: self.services.clone(),
caller_id,
data: Arc::clone(&self.data),
call_chain,
redacted_inputs: None,
cancel_token: self.cancel_token.clone(),
global_deadline: self.global_deadline,
executor: self.executor.clone(),
}
}
pub fn to_json(&self) -> serde_json::Value
where
T: Serialize,
{
let mut value = serde_json::to_value(self).unwrap_or_else(|_| serde_json::json!({}));
if let Some(obj) = value.as_object_mut() {
if let Some(data_val) = obj.get_mut("data") {
if let Some(data_obj) = data_val.as_object_mut() {
let internal_keys: Vec<String> = data_obj
.keys()
.filter(|k| k.starts_with('_'))
.cloned()
.collect();
for key in internal_keys {
data_obj.remove(&key);
}
}
}
}
value
}
pub fn from_json(
data: serde_json::Value,
) -> Result<Context<serde_json::Value>, crate::errors::ModuleError> {
let ctx: Context<serde_json::Value> = serde_json::from_value(data)?;
Ok(ctx)
}
pub fn serialize(&self) -> serde_json::Value {
let mut result = serde_json::json!({
"_context_version": 1,
"trace_id": self.trace_id,
"caller_id": self.caller_id,
"call_chain": self.call_chain,
});
if let Some(ref identity) = self.identity {
result["identity"] = serde_json::json!({
"id": identity.id(),
"type": identity.identity_type(),
"roles": identity.roles(),
"attrs": identity.attrs(),
});
} else {
result["identity"] = serde_json::Value::Null;
}
if let Some(ref redacted) = self.redacted_inputs {
result["redacted_inputs"] = serde_json::to_value(redacted).unwrap_or_default();
}
let filtered: HashMap<String, serde_json::Value> = self
.data
.read()
.iter()
.filter(|(k, _)| !k.starts_with('_'))
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
result["data"] = serde_json::to_value(filtered).unwrap_or_default();
result
}
#[allow(clippy::needless_pass_by_value)] pub fn deserialize(value: serde_json::Value) -> Result<Self, serde_json::Error>
where
T: Default,
{
let obj = value
.as_object()
.ok_or_else(|| serde::de::Error::custom("expected JSON object"))?;
let version = obj
.get("_context_version")
.and_then(serde_json::Value::as_i64)
.unwrap_or(1);
if version > 1 {
tracing::warn!(
version = version,
"Unknown _context_version (expected 1). \
Proceeding with best-effort deserialization."
);
}
let identity: Option<Identity> = obj.get("identity").and_then(|v| {
if v.is_null() {
None
} else {
serde_json::from_value(v.clone()).ok()
}
});
let data_map: HashMap<String, serde_json::Value> = obj
.get("data")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
let call_chain: Vec<String> = obj
.get("call_chain")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
let redacted_inputs: Option<HashMap<String, serde_json::Value>> = obj
.get("redacted_inputs")
.and_then(|v| serde_json::from_value(v.clone()).ok());
Ok(Context {
trace_id: obj
.get("trace_id")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
caller_id: obj
.get("caller_id")
.and_then(|v| v.as_str())
.map(std::string::ToString::to_string),
call_chain,
identity,
redacted_inputs,
data: Arc::new(RwLock::new(data_map)),
services: T::default(),
cancel_token: None,
global_deadline: None,
executor: None,
})
}
pub fn logger(&self) -> ContextLogger {
let module_id = self.call_chain.last().cloned();
let caller_id = self.caller_id.clone();
ContextLogger {
name: "apcore".to_string(),
level: "info".to_string(),
format: crate::observability::logging::LogFormat::Json,
trace_id: Some(self.trace_id.clone()),
module_id,
caller_id,
}
}
pub fn create(
identity: Identity,
services: T,
caller_id: Option<String>,
data: Option<HashMap<String, serde_json::Value>>,
) -> Self {
Self {
trace_id: uuid::Uuid::new_v4().to_string(),
identity: Some(identity),
services,
caller_id,
data: Arc::new(RwLock::new(data.unwrap_or_default())),
call_chain: vec![],
redacted_inputs: None,
cancel_token: None,
global_deadline: None,
executor: None,
}
}
}
#[async_trait]
pub trait ContextFactory: Send + Sync {
async fn create(
&self,
identity: Option<Identity>,
services: serde_json::Value,
) -> Result<Context<serde_json::Value>, crate::errors::ModuleError>;
async fn create_context(
&self,
identity: Option<Identity>,
services: serde_json::Value,
) -> Result<Context<serde_json::Value>, crate::errors::ModuleError> {
self.create(identity, services).await
}
async fn create_child(
&self,
parent: &Context<serde_json::Value>,
module_name: &str,
) -> Result<Context<serde_json::Value>, crate::errors::ModuleError>;
}