Skip to main content

mofa_foundation/agent/context/
rich.rs

1//! Rich Agent Context - 扩展上下文
2//!
3//! 提供业务特定的功能,扩展内核的 CoreAgentContext
4
5use mofa_kernel::agent::context::AgentContext;
6use serde::{Serialize, de::DeserializeOwned};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11/// 组件输出记录
12#[derive(Debug, Clone)]
13pub struct ComponentOutput {
14    /// 组件名称
15    pub component: String,
16    /// 输出内容
17    pub output: serde_json::Value,
18    /// 时间戳
19    pub timestamp_ms: u64,
20}
21
22/// 执行指标
23#[derive(Debug, Clone, Default)]
24pub struct ExecutionMetrics {
25    /// 开始时间
26    pub start_time_ms: u64,
27    /// 结束时间
28    pub end_time_ms: Option<u64>,
29    /// 组件执行次数
30    pub component_calls: HashMap<String, u64>,
31    /// Token 使用
32    pub total_tokens: u64,
33    /// 工具调用次数
34    pub tool_calls: u64,
35}
36
37impl ExecutionMetrics {
38    /// 创建新的指标
39    pub fn new() -> Self {
40        let now = std::time::SystemTime::now()
41            .duration_since(std::time::UNIX_EPOCH)
42            .unwrap_or_default()
43            .as_millis() as u64;
44
45        Self {
46            start_time_ms: now,
47            ..Default::default()
48        }
49    }
50
51    /// 获取执行时长 (毫秒)
52    pub fn duration_ms(&self) -> u64 {
53        let now = std::time::SystemTime::now()
54            .duration_since(std::time::UNIX_EPOCH)
55            .unwrap_or_default()
56            .as_millis() as u64;
57
58        self.end_time_ms.unwrap_or(now) - self.start_time_ms
59    }
60}
61
62/// 扩展的 Agent 上下文
63///
64/// 提供业务特定的功能:
65/// - 组件输出记录
66/// - 执行指标跟踪
67/// - 委托所有核心功能到 CoreAgentContext
68///
69/// # 示例
70///
71/// ```rust,ignore
72/// use mofa_foundation::agent::context::RichAgentContext;
73/// use mofa_kernel::agent::context::CoreAgentContext;
74///
75/// let core_ctx = CoreAgentContext::new("execution-123");
76/// let rich_ctx = RichAgentContext::from(core_ctx);
77///
78/// // 业务特定功能
79/// rich_ctx.record_output("llm", serde_json::json!("response")).await;
80/// rich_ctx.increment_component_calls("llm").await;
81///
82/// // 核心功能委托
83/// rich_ctx.set("key", "value").await;
84/// ```
85#[derive(Clone)]
86pub struct RichAgentContext {
87    /// 内核上下文 (委托核心功能)
88    inner: Arc<AgentContext>,
89    /// 累积输出
90    outputs: Arc<RwLock<Vec<ComponentOutput>>>,
91    /// 执行指标
92    metrics: Arc<RwLock<ExecutionMetrics>>,
93}
94
95impl RichAgentContext {
96    /// 从 CoreAgentContext 创建 RichAgentContext
97    pub fn new(inner: AgentContext) -> Self {
98        Self {
99            inner: Arc::new(inner),
100            outputs: Arc::new(RwLock::new(Vec::new())),
101            metrics: Arc::new(RwLock::new(ExecutionMetrics::new())),
102        }
103    }
104
105    /// 记录组件输出
106    pub async fn record_output(&self, component: impl Into<String>, output: serde_json::Value) {
107        let now = std::time::SystemTime::now()
108            .duration_since(std::time::UNIX_EPOCH)
109            .unwrap_or_default()
110            .as_millis() as u64;
111
112        let mut outputs = self.outputs.write().await;
113        outputs.push(ComponentOutput {
114            component: component.into(),
115            output,
116            timestamp_ms: now,
117        });
118    }
119
120    /// 获取所有组件输出
121    pub async fn get_outputs(&self) -> Vec<ComponentOutput> {
122        let outputs = self.outputs.read().await;
123        outputs.clone()
124    }
125
126    /// 增加组件调用计数
127    pub async fn increment_component_calls(&self, component: &str) {
128        let mut metrics = self.metrics.write().await;
129        *metrics
130            .component_calls
131            .entry(component.to_string())
132            .or_insert(0) += 1;
133    }
134
135    /// 增加 Token 使用
136    pub async fn add_tokens(&self, tokens: u64) {
137        let mut metrics = self.metrics.write().await;
138        metrics.total_tokens += tokens;
139    }
140
141    /// 增加工具调用计数
142    pub async fn increment_tool_calls(&self) {
143        let mut metrics = self.metrics.write().await;
144        metrics.tool_calls += 1;
145    }
146
147    /// 获取执行指标
148    pub async fn get_metrics(&self) -> ExecutionMetrics {
149        let metrics = self.metrics.read().await;
150        metrics.clone()
151    }
152
153    /// 结束执行 (记录结束时间)
154    pub async fn finish(&self) {
155        let now = std::time::SystemTime::now()
156            .duration_since(std::time::UNIX_EPOCH)
157            .unwrap_or_default()
158            .as_millis() as u64;
159
160        let mut metrics = self.metrics.write().await;
161        metrics.end_time_ms = Some(now);
162    }
163
164    /// 获取执行时长 (毫秒)
165    pub async fn duration_ms(&self) -> u64 {
166        let metrics = self.metrics.read().await;
167        metrics.duration_ms()
168    }
169
170    // ===== 核心功能委托 =====
171
172    /// 获取值
173    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
174        self.inner.get(key).await
175    }
176
177    /// 设置值
178    pub async fn set<T: Serialize>(&self, key: &str, value: T) {
179        self.inner.set(key, value).await
180    }
181
182    /// 删除值
183    pub async fn remove(&self, key: &str) -> Option<serde_json::Value> {
184        self.inner.remove(key).await
185    }
186
187    /// 检查是否存在值
188    pub async fn contains(&self, key: &str) -> bool {
189        self.inner.contains(key).await
190    }
191
192    /// 获取所有键
193    pub async fn keys(&self) -> Vec<String> {
194        self.inner.keys().await
195    }
196
197    /// 从父上下文查找值
198    pub async fn find<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
199        self.inner.find(key).await
200    }
201
202    /// 获取执行 ID
203    pub fn execution_id(&self) -> &str {
204        &self.inner.execution_id
205    }
206
207    /// 获取会话 ID
208    pub fn session_id(&self) -> Option<&str> {
209        self.inner.session_id.as_deref()
210    }
211
212    /// 获取父上下文
213    pub fn parent(&self) -> Option<&Arc<AgentContext>> {
214        self.inner.parent()
215    }
216
217    /// 检查是否被中断
218    pub fn is_interrupted(&self) -> bool {
219        self.inner.is_interrupted()
220    }
221
222    /// 触发中断
223    pub fn trigger_interrupt(&self) {
224        self.inner.trigger_interrupt()
225    }
226
227    /// 清除中断状态
228    pub fn clear_interrupt(&self) {
229        self.inner.clear_interrupt()
230    }
231
232    /// 获取配置
233    pub fn config(&self) -> &mofa_kernel::agent::context::ContextConfig {
234        self.inner.config()
235    }
236
237    /// 发送事件
238    pub async fn emit_event(&self, event: mofa_kernel::agent::context::AgentEvent) {
239        self.inner.emit_event(event).await
240    }
241
242    /// 订阅事件
243    pub async fn subscribe(
244        &self,
245        event_type: &str,
246    ) -> tokio::sync::mpsc::Receiver<mofa_kernel::agent::context::AgentEvent> {
247        self.inner.subscribe(event_type).await
248    }
249
250    /// 获取内部核心上下文的引用
251    pub fn inner(&self) -> &AgentContext {
252        &self.inner
253    }
254}
255
256// ===== 转换实现 =====
257
258impl From<AgentContext> for RichAgentContext {
259    fn from(inner: AgentContext) -> Self {
260        Self::new(inner)
261    }
262}
263
264impl From<RichAgentContext> for AgentContext {
265    fn from(rich: RichAgentContext) -> Self {
266        // 注意:这会克隆内部上下文,丢失 RichAgentContext 的扩展状态
267        // 在实际使用中,应该通过 AsRef trait 来获取引用
268        (*rich.inner).clone()
269    }
270}
271
272impl AsRef<AgentContext> for RichAgentContext {
273    fn as_ref(&self) -> &AgentContext {
274        &self.inner
275    }
276}