mofa_kernel/agent/
context.rs1use serde::{Serialize, de::DeserializeOwned};
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::sync::atomic::{AtomicBool, Ordering};
20use tokio::sync::{RwLock, mpsc};
21
22#[derive(Clone)]
46pub struct AgentContext {
47 pub execution_id: String,
49 pub session_id: Option<String>,
51 parent: Option<Arc<AgentContext>>,
53 state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
55 interrupt: Arc<InterruptSignal>,
57 event_bus: Arc<EventBus>,
59 config: Arc<ContextConfig>,
61}
62
63#[derive(Debug, Clone, Default)]
65pub struct ContextConfig {
66 pub timeout_ms: Option<u64>,
68 pub max_retries: u32,
70 pub enable_tracing: bool,
72 pub custom: HashMap<String, serde_json::Value>,
74}
75
76impl AgentContext {
77 pub fn new(execution_id: impl Into<String>) -> Self {
79 Self {
80 execution_id: execution_id.into(),
81 session_id: None,
82 parent: None,
83 state: Arc::new(RwLock::new(HashMap::new())),
84 interrupt: Arc::new(InterruptSignal::new()),
85 event_bus: Arc::new(EventBus::new()),
86 config: Arc::new(ContextConfig::default()),
87 }
88 }
89
90 pub fn with_session(execution_id: impl Into<String>, session_id: impl Into<String>) -> Self {
92 let mut ctx = Self::new(execution_id);
93 ctx.session_id = Some(session_id.into());
94 ctx
95 }
96
97 pub fn child(&self, execution_id: impl Into<String>) -> Self {
99 Self {
100 execution_id: execution_id.into(),
101 session_id: self.session_id.clone(),
102 parent: Some(Arc::new(self.clone())),
103 state: Arc::new(RwLock::new(HashMap::new())),
104 interrupt: self.interrupt.clone(), event_bus: self.event_bus.clone(), config: self.config.clone(),
107 }
108 }
109
110 pub fn with_config(mut self, config: ContextConfig) -> Self {
112 self.config = Arc::new(config);
113 self
114 }
115
116 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
118 let state = self.state.read().await;
119 state
120 .get(key)
121 .and_then(|v| serde_json::from_value(v.clone()).ok())
122 }
123
124 pub async fn set<T: Serialize>(&self, key: &str, value: T) {
126 if let Ok(v) = serde_json::to_value(value) {
127 let mut state = self.state.write().await;
128 state.insert(key.to_string(), v);
129 }
130 }
131
132 pub async fn remove(&self, key: &str) -> Option<serde_json::Value> {
134 let mut state = self.state.write().await;
135 state.remove(key)
136 }
137
138 pub async fn contains(&self, key: &str) -> bool {
140 let state = self.state.read().await;
141 state.contains_key(key)
142 }
143
144 pub async fn keys(&self) -> Vec<String> {
146 let state = self.state.read().await;
147 state.keys().cloned().collect()
148 }
149
150 pub fn is_interrupted(&self) -> bool {
152 self.interrupt.is_triggered()
153 }
154
155 pub fn trigger_interrupt(&self) {
157 self.interrupt.trigger();
158 }
159
160 pub fn clear_interrupt(&self) {
162 self.interrupt.clear();
163 }
164
165 pub fn config(&self) -> &ContextConfig {
167 &self.config
168 }
169
170 pub fn parent(&self) -> Option<&Arc<AgentContext>> {
172 self.parent.as_ref()
173 }
174
175 pub async fn emit_event(&self, event: AgentEvent) {
177 self.event_bus.emit(event).await;
178 }
179
180 pub async fn subscribe(&self, event_type: &str) -> EventReceiver {
182 self.event_bus.subscribe(event_type).await
183 }
184
185 pub async fn find<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
187 if let Some(value) = self.get::<T>(key).await {
189 return Some(value);
190 }
191
192 if let Some(parent) = &self.parent {
194 return Box::pin(parent.find::<T>(key)).await;
195 }
196
197 None
198 }
199}
200
201pub struct InterruptSignal {
207 triggered: AtomicBool,
208}
209
210impl InterruptSignal {
211 pub fn new() -> Self {
213 Self {
214 triggered: AtomicBool::new(false),
215 }
216 }
217
218 pub fn is_triggered(&self) -> bool {
220 self.triggered.load(Ordering::SeqCst)
221 }
222
223 pub fn trigger(&self) {
225 self.triggered.store(true, Ordering::SeqCst);
226 }
227
228 pub fn clear(&self) {
230 self.triggered.store(false, Ordering::SeqCst);
231 }
232}
233
234impl Default for InterruptSignal {
235 fn default() -> Self {
236 Self::new()
237 }
238}
239
240#[derive(Debug, Clone)]
246pub struct AgentEvent {
247 pub event_type: String,
249 pub data: serde_json::Value,
251 pub timestamp_ms: u64,
253 pub source: Option<String>,
255}
256
257impl AgentEvent {
258 pub fn new(event_type: impl Into<String>, data: serde_json::Value) -> Self {
260 let now = std::time::SystemTime::now()
261 .duration_since(std::time::UNIX_EPOCH)
262 .unwrap_or_default()
263 .as_millis() as u64;
264
265 Self {
266 event_type: event_type.into(),
267 data,
268 timestamp_ms: now,
269 source: None,
270 }
271 }
272
273 pub fn with_source(mut self, source: impl Into<String>) -> Self {
275 self.source = Some(source.into());
276 self
277 }
278}
279
280pub type EventReceiver = mpsc::Receiver<AgentEvent>;
282
283pub struct EventBus {
285 subscribers: RwLock<HashMap<String, Vec<mpsc::Sender<AgentEvent>>>>,
286}
287
288impl EventBus {
289 pub fn new() -> Self {
291 Self {
292 subscribers: RwLock::new(HashMap::new()),
293 }
294 }
295
296 pub async fn emit(&self, event: AgentEvent) {
298 let subscribers = self.subscribers.read().await;
299
300 if let Some(senders) = subscribers.get(&event.event_type) {
302 for sender in senders {
303 let _ = sender.send(event.clone()).await;
304 }
305 }
306
307 if let Some(senders) = subscribers.get("*") {
309 for sender in senders {
310 let _ = sender.send(event.clone()).await;
311 }
312 }
313 }
314
315 pub async fn subscribe(&self, event_type: &str) -> EventReceiver {
317 let (tx, rx) = mpsc::channel(100);
318 let mut subscribers = self.subscribers.write().await;
319 subscribers
320 .entry(event_type.to_string())
321 .or_insert_with(Vec::new)
322 .push(tx);
323 rx
324 }
325}
326
327impl Default for EventBus {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[tokio::test]
338 async fn test_context_basic() {
339 let ctx = AgentContext::new("test-execution");
340
341 ctx.set("key1", "value1").await;
342 let value: Option<String> = ctx.get("key1").await;
343 assert_eq!(value, Some("value1".to_string()));
344 }
345
346 #[tokio::test]
347 async fn test_context_child() {
348 let parent = AgentContext::new("parent");
349 parent.set("parent_key", "parent_value").await;
350
351 let child = parent.child("child");
352 child.set("child_key", "child_value").await;
353
354 let child_value: Option<String> = child.get("child_key").await;
356 assert_eq!(child_value, Some("child_value".to_string()));
357
358 let parent_value: Option<String> = child.find("parent_key").await;
360 assert_eq!(parent_value, Some("parent_value".to_string()));
361 }
362
363 #[tokio::test]
364 async fn test_interrupt_signal() {
365 let ctx = AgentContext::new("test");
366
367 assert!(!ctx.is_interrupted());
368 ctx.trigger_interrupt();
369 assert!(ctx.is_interrupted());
370 ctx.clear_interrupt();
371 assert!(!ctx.is_interrupted());
372 }
373
374 #[tokio::test]
375 async fn test_event_bus() {
376 let ctx = AgentContext::new("test");
377
378 let mut rx = ctx.subscribe("test_event").await;
379
380 ctx.emit_event(AgentEvent::new(
381 "test_event",
382 serde_json::json!({"msg": "hello"}),
383 ))
384 .await;
385
386 let event = rx.recv().await.unwrap();
387 assert_eq!(event.event_type, "test_event");
388 }
389}