use serde::{Serialize, de::DeserializeOwned};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::{RwLock, mpsc};
#[derive(Clone)]
pub struct AgentContext {
pub execution_id: String,
pub session_id: Option<String>,
parent: Option<Arc<AgentContext>>,
state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
interrupt: Arc<InterruptSignal>,
event_bus: Arc<EventBus>,
config: Arc<ContextConfig>,
}
#[derive(Debug, Clone, Default)]
pub struct ContextConfig {
pub timeout_ms: Option<u64>,
pub max_retries: u32,
pub enable_tracing: bool,
pub custom: HashMap<String, serde_json::Value>,
}
impl AgentContext {
pub fn new(execution_id: impl Into<String>) -> Self {
Self {
execution_id: execution_id.into(),
session_id: None,
parent: None,
state: Arc::new(RwLock::new(HashMap::new())),
interrupt: Arc::new(InterruptSignal::new()),
event_bus: Arc::new(EventBus::new()),
config: Arc::new(ContextConfig::default()),
}
}
pub fn with_session(execution_id: impl Into<String>, session_id: impl Into<String>) -> Self {
let mut ctx = Self::new(execution_id);
ctx.session_id = Some(session_id.into());
ctx
}
pub fn child(&self, execution_id: impl Into<String>) -> Self {
Self {
execution_id: execution_id.into(),
session_id: self.session_id.clone(),
parent: Some(Arc::new(self.clone())),
state: Arc::new(RwLock::new(HashMap::new())),
interrupt: self.interrupt.clone(), event_bus: self.event_bus.clone(), config: self.config.clone(),
}
}
pub fn with_config(mut self, config: ContextConfig) -> Self {
self.config = Arc::new(config);
self
}
pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
let state = self.state.read().await;
state
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub async fn set<T: Serialize>(&self, key: &str, value: T) {
if let Ok(v) = serde_json::to_value(value) {
let mut state = self.state.write().await;
state.insert(key.to_string(), v);
}
}
pub async fn remove(&self, key: &str) -> Option<serde_json::Value> {
let mut state = self.state.write().await;
state.remove(key)
}
pub async fn contains(&self, key: &str) -> bool {
let state = self.state.read().await;
state.contains_key(key)
}
pub async fn keys(&self) -> Vec<String> {
let state = self.state.read().await;
state.keys().cloned().collect()
}
pub fn is_interrupted(&self) -> bool {
self.interrupt.is_triggered()
}
pub fn trigger_interrupt(&self) {
self.interrupt.trigger();
}
pub fn clear_interrupt(&self) {
self.interrupt.clear();
}
pub fn config(&self) -> &ContextConfig {
&self.config
}
pub fn parent(&self) -> Option<&Arc<AgentContext>> {
self.parent.as_ref()
}
pub async fn emit_event(&self, event: AgentEvent) {
self.event_bus.emit(event).await;
}
pub async fn subscribe(&self, event_type: &str) -> EventReceiver {
self.event_bus.subscribe(event_type).await
}
pub async fn find<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
if let Some(value) = self.get::<T>(key).await {
return Some(value);
}
if let Some(parent) = &self.parent {
return Box::pin(parent.find::<T>(key)).await;
}
None
}
}
pub struct InterruptSignal {
triggered: AtomicBool,
}
impl InterruptSignal {
pub fn new() -> Self {
Self {
triggered: AtomicBool::new(false),
}
}
pub fn is_triggered(&self) -> bool {
self.triggered.load(Ordering::SeqCst)
}
pub fn trigger(&self) {
self.triggered.store(true, Ordering::SeqCst);
}
pub fn clear(&self) {
self.triggered.store(false, Ordering::SeqCst);
}
}
impl Default for InterruptSignal {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct AgentEvent {
pub event_type: String,
pub data: serde_json::Value,
pub timestamp_ms: u64,
pub source: Option<String>,
}
impl AgentEvent {
pub fn new(event_type: impl Into<String>, data: serde_json::Value) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
Self {
event_type: event_type.into(),
data,
timestamp_ms: now,
source: None,
}
}
pub fn with_source(mut self, source: impl Into<String>) -> Self {
self.source = Some(source.into());
self
}
}
pub type EventReceiver = mpsc::Receiver<AgentEvent>;
pub struct EventBus {
subscribers: RwLock<HashMap<String, Vec<mpsc::Sender<AgentEvent>>>>,
}
impl EventBus {
pub fn new() -> Self {
Self {
subscribers: RwLock::new(HashMap::new()),
}
}
pub async fn emit(&self, event: AgentEvent) {
let subscribers = self.subscribers.read().await;
if let Some(senders) = subscribers.get(&event.event_type) {
for sender in senders {
let _ = sender.send(event.clone()).await;
}
}
if let Some(senders) = subscribers.get("*") {
for sender in senders {
let _ = sender.send(event.clone()).await;
}
}
}
pub async fn subscribe(&self, event_type: &str) -> EventReceiver {
let (tx, rx) = mpsc::channel(100);
let mut subscribers = self.subscribers.write().await;
subscribers
.entry(event_type.to_string())
.or_insert_with(Vec::new)
.push(tx);
rx
}
}
impl Default for EventBus {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_context_basic() {
let ctx = AgentContext::new("test-execution");
ctx.set("key1", "value1").await;
let value: Option<String> = ctx.get("key1").await;
assert_eq!(value, Some("value1".to_string()));
}
#[tokio::test]
async fn test_context_child() {
let parent = AgentContext::new("parent");
parent.set("parent_key", "parent_value").await;
let child = parent.child("child");
child.set("child_key", "child_value").await;
let child_value: Option<String> = child.get("child_key").await;
assert_eq!(child_value, Some("child_value".to_string()));
let parent_value: Option<String> = child.find("parent_key").await;
assert_eq!(parent_value, Some("parent_value".to_string()));
}
#[tokio::test]
async fn test_interrupt_signal() {
let ctx = AgentContext::new("test");
assert!(!ctx.is_interrupted());
ctx.trigger_interrupt();
assert!(ctx.is_interrupted());
ctx.clear_interrupt();
assert!(!ctx.is_interrupted());
}
#[tokio::test]
async fn test_event_bus() {
let ctx = AgentContext::new("test");
let mut rx = ctx.subscribe("test_event").await;
ctx.emit_event(AgentEvent::new(
"test_event",
serde_json::json!({"msg": "hello"}),
))
.await;
let event = rx.recv().await.unwrap();
assert_eq!(event.event_type, "test_event");
}
}