pub use mofa_kernel::agent::plugins::{Plugin, PluginMetadata, PluginRegistry, PluginStage};
use crate::agent::context::AgentContext;
use crate::agent::error::{AgentError, AgentResult};
use crate::agent::types::{AgentInput, AgentOutput};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub struct SimplePluginRegistry {
plugins: RwLock<HashMap<String, Arc<dyn Plugin>>>,
}
impl SimplePluginRegistry {
pub fn new() -> Self {
Self {
plugins: RwLock::new(HashMap::new()),
}
}
}
impl Default for SimplePluginRegistry {
fn default() -> Self {
Self::new()
}
}
impl PluginRegistry for SimplePluginRegistry {
fn register(&self, plugin: Arc<dyn Plugin>) -> AgentResult<()> {
let mut plugins = self
.plugins
.write()
.map_err(|_| AgentError::ExecutionFailed("Failed to acquire write lock".to_string()))?;
plugins.insert(plugin.name().to_string(), plugin);
Ok(())
}
fn unregister(&self, name: &str) -> AgentResult<bool> {
let mut plugins = self
.plugins
.write()
.map_err(|_| AgentError::ExecutionFailed("Failed to acquire write lock".to_string()))?;
Ok(plugins.remove(name).is_some())
}
fn get(&self, name: &str) -> Option<Arc<dyn Plugin>> {
let plugins = self.plugins.read().ok()?;
plugins.get(name).cloned()
}
fn list(&self) -> Vec<Arc<dyn Plugin>> {
self.plugins
.read()
.ok()
.map(|plugins| plugins.values().cloned().collect())
.unwrap_or_default()
}
fn list_by_stage(&self, stage: PluginStage) -> Vec<Arc<dyn Plugin>> {
self.plugins
.read()
.ok()
.map(|plugins| {
plugins
.values()
.filter(|plugin| plugin.metadata().stages.contains(&stage))
.cloned()
.collect()
})
.unwrap_or_default()
}
fn contains(&self, name: &str) -> bool {
self.plugins
.read()
.ok()
.map(|plugins| plugins.contains_key(name))
.unwrap_or(false)
}
fn count(&self) -> usize {
self.plugins
.read()
.ok()
.map(|plugins| plugins.len())
.unwrap_or(0)
}
}
pub struct PluginExecutor {
pub registry: Arc<dyn PluginRegistry>,
}
impl PluginExecutor {
pub fn new(registry: Arc<dyn PluginRegistry>) -> Self {
Self { registry }
}
pub async fn execute_stage(&self, stage: PluginStage, ctx: &AgentContext) -> AgentResult<()> {
let plugins = self.registry.list_by_stage(stage);
for plugin in plugins {
match stage {
PluginStage::PreContext => {
plugin.pre_context(ctx).await?;
}
PluginStage::PostProcess => {
plugin.post_process(ctx).await?;
}
_ => {
continue;
}
}
}
Ok(())
}
pub async fn execute_pre_request(
&self,
input: AgentInput,
ctx: &AgentContext,
) -> AgentResult<AgentInput> {
let mut result = input;
let plugins = self.registry.list_by_stage(PluginStage::PreRequest);
for plugin in plugins {
result = plugin.pre_request(result.clone(), ctx).await?;
}
Ok(result)
}
pub async fn execute_post_response(
&self,
output: AgentOutput,
ctx: &AgentContext,
) -> AgentResult<AgentOutput> {
let mut result = output;
let plugins = self.registry.list_by_stage(PluginStage::PostResponse);
for plugin in plugins {
result = plugin.post_response(result.clone(), ctx).await?;
}
Ok(result)
}
}
pub struct HttpPlugin {
name: String,
description: String,
url: String,
}
impl HttpPlugin {
pub fn new(url: impl Into<String>) -> Self {
Self {
name: "http-plugin".to_string(),
description: "HTTP请求插件".to_string(),
url: url.into(),
}
}
}
#[async_trait]
impl Plugin for HttpPlugin {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn metadata(&self) -> PluginMetadata {
let mut metadata = PluginMetadata::default();
metadata.name = self.name.clone();
metadata.description = self.description.clone();
metadata.version = "1.0.0".to_string();
metadata.stages = vec![PluginStage::PreContext];
metadata
}
async fn pre_context(&self, ctx: &AgentContext) -> AgentResult<()> {
ctx.set("http_response", "示例HTTP响应内容").await;
Ok(())
}
}
pub struct CustomFunctionPlugin {
name: String,
description: String,
func: Arc<dyn Fn(AgentInput, &AgentContext) -> AgentResult<AgentInput> + Send + Sync + 'static>,
}
impl CustomFunctionPlugin {
pub fn new<F>(name: impl Into<String>, desc: impl Into<String>, func: F) -> Self
where
F: Fn(AgentInput, &AgentContext) -> AgentResult<AgentInput> + Send + Sync + 'static,
{
Self {
name: name.into(),
description: desc.into(),
func: Arc::new(func),
}
}
}
#[async_trait]
impl Plugin for CustomFunctionPlugin {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn metadata(&self) -> PluginMetadata {
let mut metadata = PluginMetadata::default();
metadata.name = self.name.clone();
metadata.description = self.description.clone();
metadata.version = "1.0.0".to_string();
metadata.stages = vec![PluginStage::PreRequest];
metadata
}
async fn pre_request(&self, input: AgentInput, ctx: &AgentContext) -> AgentResult<AgentInput> {
(self.func)(input, ctx)
}
}