use car_eventlog::EventLog;
use car_state::StateStore;
use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex;
tokio::task_local! {
static AGENT_CTX: AgentContext;
}
#[derive(Clone)]
pub struct AgentContext {
pub agent_name: String,
pub local_state: Arc<StateStore>,
pub local_log: Arc<TokioMutex<EventLog>>,
parent_state: Arc<StateStore>,
}
impl AgentContext {
pub fn new(agent_name: &str, parent_state: Arc<StateStore>) -> Self {
Self {
agent_name: agent_name.to_string(),
local_state: Arc::new(StateStore::new()),
local_log: Arc::new(TokioMutex::new(EventLog::new())),
parent_state,
}
}
pub fn get(&self, key: &str) -> Option<serde_json::Value> {
self.local_state
.get(key)
.or_else(|| self.parent_state.get(key))
}
pub fn set(&self, key: &str, value: serde_json::Value) {
self.local_state.set(key, value, &self.agent_name);
}
pub fn merge_to_parent(&self) {
for key in self.local_state.keys() {
if let Some(value) = self.local_state.get(&key) {
self.parent_state.set(&key, value, &self.agent_name);
}
}
}
}
pub struct TaskScope;
impl TaskScope {
pub async fn run<F, T>(ctx: AgentContext, f: F) -> T
where
F: std::future::Future<Output = T>,
{
AGENT_CTX.scope(ctx, f).await
}
pub fn try_with<F, R>(f: F) -> Option<R>
where
F: FnOnce(&AgentContext) -> R,
{
AGENT_CTX.try_with(f).ok()
}
pub fn agent_name() -> Option<String> {
Self::try_with(|ctx| ctx.agent_name.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_agent_context_isolation() {
let parent = Arc::new(StateStore::new());
parent.set("shared_key", serde_json::json!("parent_value"), "test");
let ctx = AgentContext::new("agent_a", Arc::clone(&parent));
assert_eq!(ctx.get("shared_key"), Some(serde_json::json!("parent_value")));
ctx.set("local_key", serde_json::json!("local_value"));
assert_eq!(ctx.get("local_key"), Some(serde_json::json!("local_value")));
assert!(parent.get("local_key").is_none());
ctx.set("shared_key", serde_json::json!("overridden"));
assert_eq!(ctx.get("shared_key"), Some(serde_json::json!("overridden")));
assert_eq!(parent.get("shared_key"), Some(serde_json::json!("parent_value")));
ctx.merge_to_parent();
assert_eq!(parent.get("local_key"), Some(serde_json::json!("local_value")));
assert_eq!(parent.get("shared_key"), Some(serde_json::json!("overridden")));
}
#[tokio::test]
async fn test_task_scope() {
let parent = Arc::new(StateStore::new());
let ctx = AgentContext::new("scoped_agent", Arc::clone(&parent));
assert!(TaskScope::agent_name().is_none());
let name = TaskScope::run(ctx, async {
TaskScope::agent_name().unwrap()
})
.await;
assert_eq!(name, "scoped_agent");
}
#[tokio::test]
async fn test_parallel_isolation() {
let parent = Arc::new(StateStore::new());
parent.set("counter", serde_json::json!(0), "init");
let ctx_a = AgentContext::new("agent_a", Arc::clone(&parent));
let ctx_b = AgentContext::new("agent_b", Arc::clone(&parent));
let handle_a = tokio::spawn({
let ctx = ctx_a.clone();
async move {
TaskScope::run(ctx.clone(), async {
ctx.set("counter", serde_json::json!(1));
ctx.set("agent_a_only", serde_json::json!("a_data"));
})
.await;
ctx
}
});
let handle_b = tokio::spawn({
let ctx = ctx_b.clone();
async move {
TaskScope::run(ctx.clone(), async {
ctx.set("counter", serde_json::json!(2));
ctx.set("agent_b_only", serde_json::json!("b_data"));
})
.await;
ctx
}
});
let ctx_a = handle_a.await.unwrap();
let ctx_b = handle_b.await.unwrap();
assert_eq!(ctx_a.get("counter"), Some(serde_json::json!(1)));
assert_eq!(ctx_b.get("counter"), Some(serde_json::json!(2)));
assert_eq!(parent.get("counter"), Some(serde_json::json!(0)));
ctx_a.merge_to_parent();
ctx_b.merge_to_parent();
assert_eq!(parent.get("agent_a_only"), Some(serde_json::json!("a_data")));
assert_eq!(parent.get("agent_b_only"), Some(serde_json::json!("b_data")));
}
}