mofa_runtime/agent/plugins/
mod.rs1pub use mofa_kernel::agent::plugins::{Plugin, PluginMetadata, PluginRegistry, PluginStage};
7
8use crate::agent::context::AgentContext;
9use crate::agent::error::{AgentError, AgentResult};
10use crate::agent::types::{AgentInput, AgentOutput};
11use async_trait::async_trait;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14
15pub struct SimplePluginRegistry {
21 plugins: RwLock<HashMap<String, Arc<dyn Plugin>>>,
22}
23
24impl SimplePluginRegistry {
25 pub fn new() -> Self {
27 Self {
28 plugins: RwLock::new(HashMap::new()),
29 }
30 }
31}
32
33impl Default for SimplePluginRegistry {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl PluginRegistry for SimplePluginRegistry {
40 fn register(&self, plugin: Arc<dyn Plugin>) -> AgentResult<()> {
41 let mut plugins = self
42 .plugins
43 .write()
44 .map_err(|_| AgentError::ExecutionFailed("Failed to acquire write lock".to_string()))?;
45 plugins.insert(plugin.name().to_string(), plugin);
46 Ok(())
47 }
48
49 fn unregister(&self, name: &str) -> AgentResult<bool> {
50 let mut plugins = self
51 .plugins
52 .write()
53 .map_err(|_| AgentError::ExecutionFailed("Failed to acquire write lock".to_string()))?;
54 Ok(plugins.remove(name).is_some())
55 }
56
57 fn get(&self, name: &str) -> Option<Arc<dyn Plugin>> {
58 let plugins = self.plugins.read().ok()?;
59 plugins.get(name).cloned()
60 }
61
62 fn list(&self) -> Vec<Arc<dyn Plugin>> {
63 self.plugins
64 .read()
65 .ok()
66 .map(|plugins| plugins.values().cloned().collect())
67 .unwrap_or_default()
68 }
69
70 fn list_by_stage(&self, stage: PluginStage) -> Vec<Arc<dyn Plugin>> {
71 self.plugins
72 .read()
73 .ok()
74 .map(|plugins| {
75 plugins
76 .values()
77 .filter(|plugin| plugin.metadata().stages.contains(&stage))
78 .cloned()
79 .collect()
80 })
81 .unwrap_or_default()
82 }
83
84 fn contains(&self, name: &str) -> bool {
85 self.plugins
86 .read()
87 .ok()
88 .map(|plugins| plugins.contains_key(name))
89 .unwrap_or(false)
90 }
91
92 fn count(&self) -> usize {
93 self.plugins
94 .read()
95 .ok()
96 .map(|plugins| plugins.len())
97 .unwrap_or(0)
98 }
99}
100
101pub struct PluginExecutor {
107 pub registry: Arc<dyn PluginRegistry>,
108}
109
110impl PluginExecutor {
111 pub fn new(registry: Arc<dyn PluginRegistry>) -> Self {
113 Self { registry }
114 }
115
116 pub async fn execute_stage(&self, stage: PluginStage, ctx: &AgentContext) -> AgentResult<()> {
118 let plugins = self.registry.list_by_stage(stage);
119 for plugin in plugins {
120 match stage {
121 PluginStage::PreContext => {
122 plugin.pre_context(ctx).await?;
123 }
124 PluginStage::PostProcess => {
125 plugin.post_process(ctx).await?;
126 }
127 _ => {
128 continue;
130 }
131 }
132 }
133 Ok(())
134 }
135
136 pub async fn execute_pre_request(
138 &self,
139 input: AgentInput,
140 ctx: &AgentContext,
141 ) -> AgentResult<AgentInput> {
142 let mut result = input;
143 let plugins = self.registry.list_by_stage(PluginStage::PreRequest);
144
145 for plugin in plugins {
146 result = plugin.pre_request(result.clone(), ctx).await?;
147 }
148
149 Ok(result)
150 }
151
152 pub async fn execute_post_response(
154 &self,
155 output: AgentOutput,
156 ctx: &AgentContext,
157 ) -> AgentResult<AgentOutput> {
158 let mut result = output;
159 let plugins = self.registry.list_by_stage(PluginStage::PostResponse);
160
161 for plugin in plugins {
162 result = plugin.post_response(result.clone(), ctx).await?;
163 }
164
165 Ok(result)
166 }
167}
168
169pub struct HttpPlugin {
175 name: String,
176 description: String,
177 url: String,
178}
179
180impl HttpPlugin {
181 pub fn new(url: impl Into<String>) -> Self {
183 Self {
184 name: "http-plugin".to_string(),
185 description: "HTTP请求插件".to_string(),
186 url: url.into(),
187 }
188 }
189}
190
191#[async_trait]
192impl Plugin for HttpPlugin {
193 fn name(&self) -> &str {
194 &self.name
195 }
196
197 fn description(&self) -> &str {
198 &self.description
199 }
200
201 fn metadata(&self) -> PluginMetadata {
202 let mut metadata = PluginMetadata::default();
203 metadata.name = self.name.clone();
204 metadata.description = self.description.clone();
205 metadata.version = "1.0.0".to_string();
206 metadata.stages = vec![PluginStage::PreContext];
207 metadata
208 }
209
210 async fn pre_context(&self, ctx: &AgentContext) -> AgentResult<()> {
211 ctx.set("http_response", "示例HTTP响应内容").await;
214 Ok(())
215 }
216}
217
218pub struct CustomFunctionPlugin {
220 name: String,
221 description: String,
222 func: Arc<dyn Fn(AgentInput, &AgentContext) -> AgentResult<AgentInput> + Send + Sync + 'static>,
223}
224
225impl CustomFunctionPlugin {
226 pub fn new<F>(name: impl Into<String>, desc: impl Into<String>, func: F) -> Self
228 where
229 F: Fn(AgentInput, &AgentContext) -> AgentResult<AgentInput> + Send + Sync + 'static,
230 {
231 Self {
232 name: name.into(),
233 description: desc.into(),
234 func: Arc::new(func),
235 }
236 }
237}
238
239#[async_trait]
240impl Plugin for CustomFunctionPlugin {
241 fn name(&self) -> &str {
242 &self.name
243 }
244
245 fn description(&self) -> &str {
246 &self.description
247 }
248
249 fn metadata(&self) -> PluginMetadata {
250 let mut metadata = PluginMetadata::default();
251 metadata.name = self.name.clone();
252 metadata.description = self.description.clone();
253 metadata.version = "1.0.0".to_string();
254 metadata.stages = vec![PluginStage::PreRequest];
255 metadata
256 }
257
258 async fn pre_request(&self, input: AgentInput, ctx: &AgentContext) -> AgentResult<AgentInput> {
259 (self.func)(input, ctx)
260 }
261}