use std::collections::HashMap;
use std::fmt;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tracing::{debug, info, instrument, warn};
use crate::{
error::Error,
message::{Message, Response},
};
#[async_trait]
pub trait Plugin: Send + Sync {
fn name(&self) -> &str;
fn version(&self) -> &str;
fn description(&self) -> &str {
"No description provided"
}
fn capabilities(&self) -> Vec<Capability>;
async fn initialize(&mut self, _config: PluginConfig) -> Result<()> {
Ok(())
}
async fn process(&self, request: PluginRequest) -> Result<PluginResponse>;
async fn shutdown(&mut self) -> Result<()> {
Ok(())
}
fn can_handle(&self, _message: &Message) -> bool {
true
}
fn metadata(&self) -> PluginMetadata {
PluginMetadata {
name: self.name().to_string(),
version: self.version().to_string(),
description: self.description().to_string(),
author: None,
homepage: None,
license: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Capability {
pub name: String,
pub capability_type: CapabilityType,
pub description: String,
pub required_permissions: Vec<Permission>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CapabilityType {
MessageProcessor,
CommandHandler,
EventListener,
ToolProvider,
Middleware,
Custom(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Permission {
ReadMessages,
WriteMessages,
AccessContext,
ModifyContext,
NetworkAccess,
FileSystemAccess,
ExecuteCommands,
DatabaseAccess,
All,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PluginConfig {
pub settings: HashMap<String, serde_json::Value>,
pub enabled_features: Vec<String>,
pub permissions: Vec<Permission>,
pub resource_limits: ResourceLimits,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceLimits {
pub max_memory: Option<usize>,
pub max_cpu: Option<f32>,
pub max_execution_time: Option<std::time::Duration>,
pub max_concurrent_ops: Option<usize>,
}
impl Default for ResourceLimits {
fn default() -> Self {
Self {
max_memory: Some(100 * 1024 * 1024), max_cpu: Some(50.0), max_execution_time: Some(std::time::Duration::from_secs(30)),
max_concurrent_ops: Some(10),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginRequest {
pub id: String,
pub request_type: RequestType,
pub data: serde_json::Value,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RequestType {
ProcessMessage,
ExecuteCommand,
HandleEvent,
InvokeTool,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginResponse {
pub id: String,
pub success: bool,
pub data: serde_json::Value,
pub error: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl PluginResponse {
pub fn success(id: impl Into<String>, data: serde_json::Value) -> Self {
Self {
id: id.into(),
success: true,
data,
error: None,
metadata: HashMap::new(),
}
}
pub fn error(id: impl Into<String>, error: impl fmt::Display) -> Self {
Self {
id: id.into(),
success: false,
data: serde_json::Value::Null,
error: Some(error.to_string()),
metadata: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginMetadata {
pub name: String,
pub version: String,
pub description: String,
pub author: Option<String>,
pub homepage: Option<String>,
pub license: Option<String>,
}
pub struct PluginRegistry {
plugins: HashMap<String, Box<dyn Plugin>>,
hooks: HashMap<HookType, Vec<String>>,
permissions: HashMap<String, Vec<Permission>>,
}
impl PluginRegistry {
#[must_use]
pub fn new() -> Self {
Self {
plugins: HashMap::new(),
hooks: HashMap::new(),
permissions: HashMap::new(),
}
}
#[instrument(skip(self, plugin))]
pub fn register(&mut self, mut plugin: Box<dyn Plugin>) -> Result<()> {
let name = plugin.name().to_string();
if self.plugins.contains_key(&name) {
return Err(Error::Plugin(format!("Plugin '{name}' already registered")).into());
}
info!("Registering plugin: {} v{}", name, plugin.version());
let config = PluginConfig::default();
futures::executor::block_on(plugin.initialize(config))?;
for capability in plugin.capabilities() {
self.register_hook(&name, &capability);
}
self.plugins.insert(name.clone(), plugin);
self.permissions.insert(name, vec![Permission::All]);
Ok(())
}
#[instrument(skip(self))]
pub async fn unregister(&mut self, name: &str) -> Result<()> {
if let Some(mut plugin) = self.plugins.remove(name) {
info!("Unregistering plugin: {}", name);
plugin.shutdown().await?;
for hooks in self.hooks.values_mut() {
hooks.retain(|n| n != name);
}
self.permissions.remove(name);
Ok(())
} else {
Err(Error::NotFound(format!("Plugin '{name}' not found")).into())
}
}
pub fn get(&self, name: &str) -> Option<&dyn Plugin> {
self.plugins.get(name).map(std::convert::AsRef::as_ref)
}
pub fn list(&self) -> Vec<PluginMetadata> {
self.plugins.values().map(|p| p.metadata()).collect()
}
#[instrument(skip(self, message))]
pub async fn apply_pre_processing(&self, mut message: Message) -> Result<Message> {
for plugin in self.plugins.values() {
if plugin.can_handle(&message) {
let request = PluginRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::ProcessMessage,
data: serde_json::to_value(&message)?,
metadata: HashMap::new(),
};
match plugin.process(request).await {
Ok(response) if response.success => {
if let Ok(processed) = serde_json::from_value(response.data) {
message = processed;
}
}
Ok(response) => {
warn!(
"Plugin {} failed to process message: {:?}",
plugin.name(),
response.error
);
}
Err(e) => {
warn!("Plugin {} error: {}", plugin.name(), e);
}
}
}
}
Ok(message)
}
#[instrument(skip(self, response))]
pub async fn apply_post_processing(&self, mut response: Response) -> Result<Response> {
for plugin in self.plugins.values() {
let request = PluginRequest {
id: uuid::Uuid::new_v4().to_string(),
request_type: RequestType::Custom("post_process".to_string()),
data: serde_json::to_value(&response)?,
metadata: HashMap::new(),
};
match plugin.process(request).await {
Ok(plugin_response) if plugin_response.success => {
if let Ok(processed) = serde_json::from_value(plugin_response.data) {
response = processed;
}
}
Ok(plugin_response) => {
debug!(
"Plugin {} post-processing failed: {:?}",
plugin.name(),
plugin_response.error
);
}
Err(e) => {
debug!("Plugin {} post-processing error: {}", plugin.name(), e);
}
}
}
Ok(response)
}
pub fn has_permission(&self, plugin_name: &str, permission: &Permission) -> bool {
self.permissions
.get(plugin_name)
.is_some_and(|perms| perms.contains(permission) || perms.contains(&Permission::All))
}
fn register_hook(&mut self, plugin_name: &str, capability: &Capability) {
let hook_type = match &capability.capability_type {
CapabilityType::MessageProcessor => HookType::MessageProcessor,
CapabilityType::CommandHandler => HookType::CommandHandler,
CapabilityType::EventListener => HookType::EventListener,
CapabilityType::ToolProvider => HookType::ToolProvider,
CapabilityType::Middleware => HookType::Middleware,
CapabilityType::Custom(name) => HookType::Custom(name.clone()),
};
self.hooks
.entry(hook_type)
.or_default()
.push(plugin_name.to_string());
}
}
impl Default for PluginRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum HookType {
MessageProcessor,
CommandHandler,
EventListener,
ToolProvider,
Middleware,
Custom(String),
}
pub struct EchoPlugin {
name: String,
version: String,
}
impl EchoPlugin {
#[must_use]
pub fn new() -> Self {
Self {
name: "echo".to_string(),
version: "1.0.0".to_string(),
}
}
}
impl Default for EchoPlugin {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Plugin for EchoPlugin {
fn name(&self) -> &str {
&self.name
}
fn version(&self) -> &str {
&self.version
}
fn description(&self) -> &'static str {
"Simple echo plugin for testing"
}
fn capabilities(&self) -> Vec<Capability> {
vec![Capability {
name: "echo".to_string(),
capability_type: CapabilityType::MessageProcessor,
description: "Echoes messages back".to_string(),
required_permissions: vec![Permission::ReadMessages, Permission::WriteMessages],
}]
}
async fn process(&self, request: PluginRequest) -> Result<PluginResponse> {
match request.request_type {
RequestType::ProcessMessage => {
if let Ok(message) = serde_json::from_value::<Message>(request.data) {
let echo_message = Message::text(format!("Echo: {}", message.content));
Ok(PluginResponse::success(
request.id,
serde_json::to_value(echo_message)?,
))
} else {
Ok(PluginResponse::error(request.id, "Invalid message data"))
}
}
_ => Ok(PluginResponse::error(
request.id,
"Unsupported request type",
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugin_registry() {
let mut registry = PluginRegistry::new();
let plugin = Box::new(EchoPlugin::new());
assert!(registry.register(plugin).is_ok());
assert!(registry.get("echo").is_some());
let plugins = registry.list();
assert_eq!(plugins.len(), 1);
assert_eq!(plugins[0].name, "echo");
}
#[tokio::test]
async fn test_plugin_unregister() {
let mut registry = PluginRegistry::new();
let plugin = Box::new(EchoPlugin::new());
registry.register(plugin).unwrap();
assert!(registry.unregister("echo").await.is_ok());
assert!(registry.get("echo").is_none());
}
#[test]
fn test_plugin_permissions() {
let mut registry = PluginRegistry::new();
let plugin = Box::new(EchoPlugin::new());
assert!(!registry.has_permission("echo", &Permission::All));
registry.register(plugin).unwrap();
assert!(registry.has_permission("echo", &Permission::All));
assert!(registry.has_permission("echo", &Permission::ReadMessages));
assert!(!registry.has_permission("nonexistent", &Permission::ReadMessages));
}
#[tokio::test]
async fn test_echo_plugin() {
let plugin = EchoPlugin::new();
let message = Message::text("Hello, world!");
let request = PluginRequest {
id: "test-123".to_string(),
request_type: RequestType::ProcessMessage,
data: serde_json::to_value(message).unwrap(),
metadata: HashMap::new(),
};
let response = plugin.process(request).await.unwrap();
assert!(response.success);
let echo_message: Message = serde_json::from_value(response.data).unwrap();
assert_eq!(echo_message.content, "Echo: Hello, world!");
}
#[test]
fn test_plugin_response() {
let response = PluginResponse::success("test", serde_json::json!({"key": "value"}));
assert!(response.success);
assert!(response.error.is_none());
let error_response = PluginResponse::error("test", "Something went wrong");
assert!(!error_response.success);
assert_eq!(
error_response.error.as_deref(),
Some("Something went wrong")
);
}
}