Skip to main content

mofa_runtime/agent/plugins/
mod.rs

1// 插件系统
2//!
3//! 提供动态插件机制,允许用户在运行时扩展和控制上下文内容
4//! 该模块基于 mofa-kernel 的插件抽象,并提供运行时示例实现
5
6pub use mofa_kernel::agent::plugins::{Plugin, PluginMetadata, PluginRegistry, PluginStage};
7
8use crate::agent::context::AgentContext;
9use crate::agent::error::{AgentError, AgentResult};
10use crate::agent::types::{AgentInput, AgentOutput};
11use async_trait::async_trait;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14
15// ============================================================================
16// 运行时插件注册中心
17// ============================================================================
18
19/// 简单插件注册中心实现
20pub struct SimplePluginRegistry {
21    plugins: RwLock<HashMap<String, Arc<dyn Plugin>>>,
22}
23
24impl SimplePluginRegistry {
25    /// 创建新的插件注册中心
26    pub fn new() -> Self {
27        Self {
28            plugins: RwLock::new(HashMap::new()),
29        }
30    }
31}
32
33impl Default for SimplePluginRegistry {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl PluginRegistry for SimplePluginRegistry {
40    fn register(&self, plugin: Arc<dyn Plugin>) -> AgentResult<()> {
41        let mut plugins = self
42            .plugins
43            .write()
44            .map_err(|_| AgentError::ExecutionFailed("Failed to acquire write lock".to_string()))?;
45        plugins.insert(plugin.name().to_string(), plugin);
46        Ok(())
47    }
48
49    fn unregister(&self, name: &str) -> AgentResult<bool> {
50        let mut plugins = self
51            .plugins
52            .write()
53            .map_err(|_| AgentError::ExecutionFailed("Failed to acquire write lock".to_string()))?;
54        Ok(plugins.remove(name).is_some())
55    }
56
57    fn get(&self, name: &str) -> Option<Arc<dyn Plugin>> {
58        let plugins = self.plugins.read().ok()?;
59        plugins.get(name).cloned()
60    }
61
62    fn list(&self) -> Vec<Arc<dyn Plugin>> {
63        self.plugins
64            .read()
65            .ok()
66            .map(|plugins| plugins.values().cloned().collect())
67            .unwrap_or_default()
68    }
69
70    fn list_by_stage(&self, stage: PluginStage) -> Vec<Arc<dyn Plugin>> {
71        self.plugins
72            .read()
73            .ok()
74            .map(|plugins| {
75                plugins
76                    .values()
77                    .filter(|plugin| plugin.metadata().stages.contains(&stage))
78                    .cloned()
79                    .collect()
80            })
81            .unwrap_or_default()
82    }
83
84    fn contains(&self, name: &str) -> bool {
85        self.plugins
86            .read()
87            .ok()
88            .map(|plugins| plugins.contains_key(name))
89            .unwrap_or(false)
90    }
91
92    fn count(&self) -> usize {
93        self.plugins
94            .read()
95            .ok()
96            .map(|plugins| plugins.len())
97            .unwrap_or(0)
98    }
99}
100
101// ============================================================================
102// 插件执行器(运行时层)
103// ============================================================================
104
105/// 插件执行器
106pub struct PluginExecutor {
107    pub registry: Arc<dyn PluginRegistry>,
108}
109
110impl PluginExecutor {
111    /// 创建插件执行器
112    pub fn new(registry: Arc<dyn PluginRegistry>) -> Self {
113        Self { registry }
114    }
115
116    /// 执行指定阶段的所有插件
117    pub async fn execute_stage(&self, stage: PluginStage, ctx: &AgentContext) -> AgentResult<()> {
118        let plugins = self.registry.list_by_stage(stage);
119        for plugin in plugins {
120            match stage {
121                PluginStage::PreContext => {
122                    plugin.pre_context(ctx).await?;
123                }
124                PluginStage::PostProcess => {
125                    plugin.post_process(ctx).await?;
126                }
127                _ => {
128                    // PreRequest 和 PostResponse 需要参数,单独处理
129                    continue;
130                }
131            }
132        }
133        Ok(())
134    }
135
136    /// 执行PreRequest阶段的所有插件
137    pub async fn execute_pre_request(
138        &self,
139        input: AgentInput,
140        ctx: &AgentContext,
141    ) -> AgentResult<AgentInput> {
142        let mut result = input;
143        let plugins = self.registry.list_by_stage(PluginStage::PreRequest);
144
145        for plugin in plugins {
146            result = plugin.pre_request(result.clone(), ctx).await?;
147        }
148
149        Ok(result)
150    }
151
152    /// 执行PostResponse阶段的所有插件
153    pub async fn execute_post_response(
154        &self,
155        output: AgentOutput,
156        ctx: &AgentContext,
157    ) -> AgentResult<AgentOutput> {
158        let mut result = output;
159        let plugins = self.registry.list_by_stage(PluginStage::PostResponse);
160
161        for plugin in plugins {
162            result = plugin.post_response(result.clone(), ctx).await?;
163        }
164
165        Ok(result)
166    }
167}
168
169// ============================================================================
170// 内置插件示例 (运行时层)
171// ============================================================================
172
173/// 示例HTTP请求插件
174pub struct HttpPlugin {
175    name: String,
176    description: String,
177    url: String,
178}
179
180impl HttpPlugin {
181    /// 创建HTTP插件
182    pub fn new(url: impl Into<String>) -> Self {
183        Self {
184            name: "http-plugin".to_string(),
185            description: "HTTP请求插件".to_string(),
186            url: url.into(),
187        }
188    }
189}
190
191#[async_trait]
192impl Plugin for HttpPlugin {
193    fn name(&self) -> &str {
194        &self.name
195    }
196
197    fn description(&self) -> &str {
198        &self.description
199    }
200
201    fn metadata(&self) -> PluginMetadata {
202        let mut metadata = PluginMetadata::default();
203        metadata.name = self.name.clone();
204        metadata.description = self.description.clone();
205        metadata.version = "1.0.0".to_string();
206        metadata.stages = vec![PluginStage::PreContext];
207        metadata
208    }
209
210    async fn pre_context(&self, ctx: &AgentContext) -> AgentResult<()> {
211        // 这里可以实现HTTP请求逻辑,并将结果存入上下文
212        // 示例:将固定内容存入上下文
213        ctx.set("http_response", "示例HTTP响应内容").await;
214        Ok(())
215    }
216}
217
218/// 示例自定义函数插件
219pub struct CustomFunctionPlugin {
220    name: String,
221    description: String,
222    func: Arc<dyn Fn(AgentInput, &AgentContext) -> AgentResult<AgentInput> + Send + Sync + 'static>,
223}
224
225impl CustomFunctionPlugin {
226    /// 创建自定义函数插件
227    pub fn new<F>(name: impl Into<String>, desc: impl Into<String>, func: F) -> Self
228    where
229        F: Fn(AgentInput, &AgentContext) -> AgentResult<AgentInput> + Send + Sync + 'static,
230    {
231        Self {
232            name: name.into(),
233            description: desc.into(),
234            func: Arc::new(func),
235        }
236    }
237}
238
239#[async_trait]
240impl Plugin for CustomFunctionPlugin {
241    fn name(&self) -> &str {
242        &self.name
243    }
244
245    fn description(&self) -> &str {
246        &self.description
247    }
248
249    fn metadata(&self) -> PluginMetadata {
250        let mut metadata = PluginMetadata::default();
251        metadata.name = self.name.clone();
252        metadata.description = self.description.clone();
253        metadata.version = "1.0.0".to_string();
254        metadata.stages = vec![PluginStage::PreRequest];
255        metadata
256    }
257
258    async fn pre_request(&self, input: AgentInput, ctx: &AgentContext) -> AgentResult<AgentInput> {
259        (self.func)(input, ctx)
260    }
261}