Skip to main content

mofa_kernel/agent/
context.rs

1//! Agent 上下文定义
2//!
3//! 统一的执行上下文,用于在 Agent 及其组件间传递状态
4//!
5//! # 核心原则
6//!
7//! CoreAgentContext 只包含内核原语(kernel primitives):
8//! - 基本的状态存储(K/V store)
9//! - 中断信号
10//! - 事件总线
11//! - 配置
12//! - 父子上下文关系
13//!
14//! 业务逻辑(如指标收集、输出记录)应该在 foundation 层的 RichAgentContext 中实现。
15
16use serde::{Serialize, de::DeserializeOwned};
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::sync::atomic::{AtomicBool, Ordering};
20use tokio::sync::{RwLock, mpsc};
21
22// ============================================================================
23// Agent 上下文
24// ============================================================================
25
26/// 核心执行上下文 (Core Agent Context)
27///
28/// 提供最小的内核原语用于 Agent 执行:
29/// - 执行 ID 和会话 ID
30/// - 父子上下文关系(用于嵌套执行)
31/// - 通用键值存储
32/// - 中断信号
33/// - 事件总线
34/// - 配置
35///
36/// # 示例
37///
38/// ```rust,ignore
39/// use mofa_kernel::agent::context::CoreAgentContext;
40///
41/// let ctx = CoreAgentContext::new("execution-123");
42/// ctx.set("user_id", "user-456").await;
43/// let value: Option<String> = ctx.get("user_id").await;
44/// ```
45#[derive(Clone)]
46pub struct AgentContext {
47    /// 执行 ID (唯一标识本次执行)
48    pub execution_id: String,
49    /// 会话 ID (用于多轮对话)
50    pub session_id: Option<String>,
51    /// 父上下文 (用于层级执行)
52    parent: Option<Arc<AgentContext>>,
53    /// 共享状态 (通用键值存储)
54    state: Arc<RwLock<HashMap<String, serde_json::Value>>>,
55    /// 中断信号
56    interrupt: Arc<InterruptSignal>,
57    /// 事件总线
58    event_bus: Arc<EventBus>,
59    /// 配置
60    config: Arc<ContextConfig>,
61}
62
63/// 上下文配置
64#[derive(Debug, Clone, Default)]
65pub struct ContextConfig {
66    /// 超时时间 (毫秒)
67    pub timeout_ms: Option<u64>,
68    /// 最大重试次数
69    pub max_retries: u32,
70    /// 是否启用追踪
71    pub enable_tracing: bool,
72    /// 自定义配置
73    pub custom: HashMap<String, serde_json::Value>,
74}
75
76impl AgentContext {
77    /// 创建新的上下文
78    pub fn new(execution_id: impl Into<String>) -> Self {
79        Self {
80            execution_id: execution_id.into(),
81            session_id: None,
82            parent: None,
83            state: Arc::new(RwLock::new(HashMap::new())),
84            interrupt: Arc::new(InterruptSignal::new()),
85            event_bus: Arc::new(EventBus::new()),
86            config: Arc::new(ContextConfig::default()),
87        }
88    }
89
90    /// 创建带会话 ID 的上下文
91    pub fn with_session(execution_id: impl Into<String>, session_id: impl Into<String>) -> Self {
92        let mut ctx = Self::new(execution_id);
93        ctx.session_id = Some(session_id.into());
94        ctx
95    }
96
97    /// 创建子上下文 (用于子任务执行)
98    pub fn child(&self, execution_id: impl Into<String>) -> Self {
99        Self {
100            execution_id: execution_id.into(),
101            session_id: self.session_id.clone(),
102            parent: Some(Arc::new(self.clone())),
103            state: Arc::new(RwLock::new(HashMap::new())),
104            interrupt: self.interrupt.clone(), // 共享中断信号
105            event_bus: self.event_bus.clone(), // 共享事件总线
106            config: self.config.clone(),
107        }
108    }
109
110    /// 设置配置
111    pub fn with_config(mut self, config: ContextConfig) -> Self {
112        self.config = Arc::new(config);
113        self
114    }
115
116    /// 获取值
117    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
118        let state = self.state.read().await;
119        state
120            .get(key)
121            .and_then(|v| serde_json::from_value(v.clone()).ok())
122    }
123
124    /// 设置值
125    pub async fn set<T: Serialize>(&self, key: &str, value: T) {
126        if let Ok(v) = serde_json::to_value(value) {
127            let mut state = self.state.write().await;
128            state.insert(key.to_string(), v);
129        }
130    }
131
132    /// 删除值
133    pub async fn remove(&self, key: &str) -> Option<serde_json::Value> {
134        let mut state = self.state.write().await;
135        state.remove(key)
136    }
137
138    /// 检查是否存在值
139    pub async fn contains(&self, key: &str) -> bool {
140        let state = self.state.read().await;
141        state.contains_key(key)
142    }
143
144    /// 获取所有键
145    pub async fn keys(&self) -> Vec<String> {
146        let state = self.state.read().await;
147        state.keys().cloned().collect()
148    }
149
150    /// 检查是否被中断
151    pub fn is_interrupted(&self) -> bool {
152        self.interrupt.is_triggered()
153    }
154
155    /// 触发中断
156    pub fn trigger_interrupt(&self) {
157        self.interrupt.trigger();
158    }
159
160    /// 清除中断状态
161    pub fn clear_interrupt(&self) {
162        self.interrupt.clear();
163    }
164
165    /// 获取配置
166    pub fn config(&self) -> &ContextConfig {
167        &self.config
168    }
169
170    /// 获取父上下文
171    pub fn parent(&self) -> Option<&Arc<AgentContext>> {
172        self.parent.as_ref()
173    }
174
175    /// 发送事件
176    pub async fn emit_event(&self, event: AgentEvent) {
177        self.event_bus.emit(event).await;
178    }
179
180    /// 订阅事件
181    pub async fn subscribe(&self, event_type: &str) -> EventReceiver {
182        self.event_bus.subscribe(event_type).await
183    }
184
185    /// 从父上下文查找值 (递归向上查找)
186    pub async fn find<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
187        // 先在当前上下文查找
188        if let Some(value) = self.get::<T>(key).await {
189            return Some(value);
190        }
191
192        // 递归查找父上下文
193        if let Some(parent) = &self.parent {
194            return Box::pin(parent.find::<T>(key)).await;
195        }
196
197        None
198    }
199}
200
201// ============================================================================
202// 中断信号
203// ============================================================================
204
205/// 中断信号
206pub struct InterruptSignal {
207    triggered: AtomicBool,
208}
209
210impl InterruptSignal {
211    /// 创建新的中断信号
212    pub fn new() -> Self {
213        Self {
214            triggered: AtomicBool::new(false),
215        }
216    }
217
218    /// 检查是否已触发
219    pub fn is_triggered(&self) -> bool {
220        self.triggered.load(Ordering::SeqCst)
221    }
222
223    /// 触发中断
224    pub fn trigger(&self) {
225        self.triggered.store(true, Ordering::SeqCst);
226    }
227
228    /// 清除中断状态
229    pub fn clear(&self) {
230        self.triggered.store(false, Ordering::SeqCst);
231    }
232}
233
234impl Default for InterruptSignal {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240// ============================================================================
241// 事件总线
242// ============================================================================
243
244/// Agent 事件
245#[derive(Debug, Clone)]
246pub struct AgentEvent {
247    /// 事件类型
248    pub event_type: String,
249    /// 事件数据
250    pub data: serde_json::Value,
251    /// 时间戳
252    pub timestamp_ms: u64,
253    /// 来源
254    pub source: Option<String>,
255}
256
257impl AgentEvent {
258    /// 创建新事件
259    pub fn new(event_type: impl Into<String>, data: serde_json::Value) -> Self {
260        let now = std::time::SystemTime::now()
261            .duration_since(std::time::UNIX_EPOCH)
262            .unwrap_or_default()
263            .as_millis() as u64;
264
265        Self {
266            event_type: event_type.into(),
267            data,
268            timestamp_ms: now,
269            source: None,
270        }
271    }
272
273    /// 设置来源
274    pub fn with_source(mut self, source: impl Into<String>) -> Self {
275        self.source = Some(source.into());
276        self
277    }
278}
279
280/// 事件接收器
281pub type EventReceiver = mpsc::Receiver<AgentEvent>;
282
283/// 事件总线
284pub struct EventBus {
285    subscribers: RwLock<HashMap<String, Vec<mpsc::Sender<AgentEvent>>>>,
286}
287
288impl EventBus {
289    /// 创建新的事件总线
290    pub fn new() -> Self {
291        Self {
292            subscribers: RwLock::new(HashMap::new()),
293        }
294    }
295
296    /// 发送事件
297    pub async fn emit(&self, event: AgentEvent) {
298        let subscribers = self.subscribers.read().await;
299
300        // 发送给类型特定订阅者
301        if let Some(senders) = subscribers.get(&event.event_type) {
302            for sender in senders {
303                let _ = sender.send(event.clone()).await;
304            }
305        }
306
307        // 发送给通配订阅者
308        if let Some(senders) = subscribers.get("*") {
309            for sender in senders {
310                let _ = sender.send(event.clone()).await;
311            }
312        }
313    }
314
315    /// 订阅事件
316    pub async fn subscribe(&self, event_type: &str) -> EventReceiver {
317        let (tx, rx) = mpsc::channel(100);
318        let mut subscribers = self.subscribers.write().await;
319        subscribers
320            .entry(event_type.to_string())
321            .or_insert_with(Vec::new)
322            .push(tx);
323        rx
324    }
325}
326
327impl Default for EventBus {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[tokio::test]
338    async fn test_context_basic() {
339        let ctx = AgentContext::new("test-execution");
340
341        ctx.set("key1", "value1").await;
342        let value: Option<String> = ctx.get("key1").await;
343        assert_eq!(value, Some("value1".to_string()));
344    }
345
346    #[tokio::test]
347    async fn test_context_child() {
348        let parent = AgentContext::new("parent");
349        parent.set("parent_key", "parent_value").await;
350
351        let child = parent.child("child");
352        child.set("child_key", "child_value").await;
353
354        // 子上下文可以访问自己的值
355        let child_value: Option<String> = child.get("child_key").await;
356        assert_eq!(child_value, Some("child_value".to_string()));
357
358        // 子上下文不能直接访问父上下文的值 (需要用 find)
359        let parent_value: Option<String> = child.find("parent_key").await;
360        assert_eq!(parent_value, Some("parent_value".to_string()));
361    }
362
363    #[tokio::test]
364    async fn test_interrupt_signal() {
365        let ctx = AgentContext::new("test");
366
367        assert!(!ctx.is_interrupted());
368        ctx.trigger_interrupt();
369        assert!(ctx.is_interrupted());
370        ctx.clear_interrupt();
371        assert!(!ctx.is_interrupted());
372    }
373
374    #[tokio::test]
375    async fn test_event_bus() {
376        let ctx = AgentContext::new("test");
377
378        let mut rx = ctx.subscribe("test_event").await;
379
380        ctx.emit_event(AgentEvent::new(
381            "test_event",
382            serde_json::json!({"msg": "hello"}),
383        ))
384        .await;
385
386        let event = rx.recv().await.unwrap();
387        assert_eq!(event.event_type, "test_event");
388    }
389}